Skip to content

Commit cd97c11

Browse files
committed
Debugged splitting and combining of lazy fields
1 parent 476f7db commit cd97c11

File tree

8 files changed

+392
-137
lines changed

8 files changed

+392
-137
lines changed

pydra/design/base.py

Lines changed: 260 additions & 25 deletions
Original file line numberDiff line numberDiff line change
@@ -5,6 +5,7 @@
55
import enum
66
from pathlib import Path
77
from copy import copy
8+
from typing_extensions import Self
89
import attrs.validators
910
from attrs.converters import default_if_none
1011
from fileformats.generic import File
@@ -21,16 +22,20 @@
2122
)
2223
from pydra.engine.core import Task, AuditFlag
2324

25+
2426
__all__ = [
2527
"Field",
2628
"Arg",
2729
"Out",
2830
"TaskSpec",
29-
"collate_with_helps",
31+
"OutputsSpec",
32+
"ensure_field_objects",
3033
"make_task_spec",
3134
"list_fields",
3235
]
3336

37+
RESERVED_OUTPUT_NAMES = ("split", "combine")
38+
3439

3540
class _Empty(enum.Enum):
3641

@@ -149,7 +154,73 @@ class Out(Field):
149154
pass
150155

151156

152-
OutputType = ty.TypeVar("OutputType")
157+
class OutputsSpec:
158+
"""Base class for all output specifications"""
159+
160+
def split(
161+
self,
162+
splitter: ty.Union[str, ty.List[str], ty.Tuple[str, ...], None] = None,
163+
/,
164+
overwrite: bool = False,
165+
cont_dim: ty.Optional[dict] = None,
166+
**inputs,
167+
) -> Self:
168+
"""
169+
Run this task parametrically over lists of split inputs.
170+
171+
Parameters
172+
----------
173+
splitter : str or list[str] or tuple[str] or None
174+
the fields which to split over. If splitting over multiple fields, lists of
175+
fields are interpreted as outer-products and tuples inner-products. If None,
176+
then the fields to split are taken from the keyword-arg names.
177+
overwrite : bool, optional
178+
whether to overwrite an existing split on the node, by default False
179+
cont_dim : dict, optional
180+
Container dimensions for specific inputs, used in the splitter.
181+
If input name is not in cont_dim, it is assumed that the input values has
182+
a container dimension of 1, so only the most outer dim will be used for splitting.
183+
**inputs
184+
fields to split over, will automatically be wrapped in a StateArray object
185+
and passed to the node inputs
186+
187+
Returns
188+
-------
189+
self : TaskBase
190+
a reference to the task
191+
"""
192+
self._node.split(splitter, overwrite=overwrite, cont_dim=cont_dim, **inputs)
193+
return self
194+
195+
def combine(
196+
self,
197+
combiner: ty.Union[ty.List[str], str],
198+
overwrite: bool = False, # **kwargs
199+
) -> Self:
200+
"""
201+
Combine inputs parameterized by one or more previous tasks.
202+
203+
Parameters
204+
----------
205+
combiner : list[str] or str
206+
the field or list of inputs to be combined (i.e. not left split) after the
207+
task has been run
208+
overwrite : bool
209+
whether to overwrite an existing combiner on the node
210+
**kwargs : dict[str, Any]
211+
values for the task that will be "combined" before they are provided to the
212+
node
213+
214+
Returns
215+
-------
216+
self : Self
217+
a reference to the outputs object
218+
"""
219+
self._node.combine(combiner, overwrite=overwrite)
220+
return self
221+
222+
223+
OutputType = ty.TypeVar("OutputType", bound=OutputsSpec)
153224

154225

155226
class TaskSpec(ty.Generic[OutputType]):
@@ -197,13 +268,33 @@ def _check_for_unset_values(self):
197268
)
198269

199270

200-
def get_fields_from_class(
271+
def extract_fields_from_class(
201272
klass: type,
202273
arg_type: type[Arg],
203274
out_type: type[Out],
204275
auto_attribs: bool,
205276
) -> tuple[dict[str, Arg], dict[str, Out]]:
206-
"""Parse the input and output fields from a class"""
277+
"""Extract the input and output fields from an existing class
278+
279+
Parameters
280+
----------
281+
klass : type
282+
The class to extract the fields from
283+
arg_type : type
284+
The type of the input fields
285+
out_type : type
286+
The type of the output fields
287+
auto_attribs : bool
288+
Whether to assume that all attribute annotations should be interpreted as
289+
fields or not
290+
291+
Returns
292+
-------
293+
inputs : dict[str, Arg]
294+
The input fields extracted from the class
295+
outputs : dict[str, Out]
296+
The output fields extracted from the class
297+
"""
207298

208299
input_helps, _ = parse_doc_string(klass.__doc__)
209300

@@ -269,31 +360,50 @@ def make_task_spec(
269360
bases: ty.Sequence[type] = (),
270361
outputs_bases: ty.Sequence[type] = (),
271362
):
363+
"""Create a task specification class and its outputs specification class from the
364+
input and output fields provided to the decorator/function.
365+
366+
Modifies the class so that its attributes are converted from pydra fields to attrs fields
367+
and then calls `attrs.define` to create an attrs class (dataclass-like).
368+
on
369+
370+
Parameters
371+
----------
372+
task_type : type
373+
The type of the task to be created
374+
inputs : dict[str, Arg]
375+
The input fields of the task
376+
outputs : dict[str, Out]
377+
The output fields of the task
378+
klass : type, optional
379+
The class to be decorated, by default None
380+
name : str, optional
381+
The name of the class, by default
382+
bases : ty.Sequence[type], optional
383+
The base classes for the task specification class, by default ()
384+
outputs_bases : ty.Sequence[type], optional
385+
The base classes for the outputs specification class, by default ()
386+
387+
Returns
388+
-------
389+
klass : type
390+
The class created using the attrs package
391+
"""
272392
if name is None and klass is not None:
273393
name = klass.__name__
274-
outputs_klass = type(
275-
"Outputs",
276-
tuple(outputs_bases),
277-
{
278-
o.name: attrs.field(
279-
converter=make_converter(o, f"{name}.Outputs"),
280-
metadata={PYDRA_ATTR_METADATA: o},
281-
**_get_default(o),
282-
)
283-
for o in outputs.values()
284-
},
285-
)
286-
outputs_klass.__annotations__.update((o.name, o.type) for o in outputs.values())
287-
outputs_klass = attrs.define(auto_attribs=False, kw_only=True)(outputs_klass)
288-
394+
outputs_klass = make_outputs_spec(outputs, outputs_bases, name)
289395
if klass is None or not issubclass(klass, TaskSpec):
290396
if name is None:
291397
raise ValueError("name must be provided if klass is not")
292398
bases = tuple(bases)
399+
# Ensure that TaskSpec is a base class
293400
if not any(issubclass(b, TaskSpec) for b in bases):
294401
bases = bases + (TaskSpec,)
402+
# If building from a decorated class (as opposed to dynamically from a function
403+
# or shell-template), add any base classes not already in the bases tuple
295404
if klass is not None:
296405
bases += tuple(c for c in klass.__mro__ if c not in bases + (object,))
406+
# Create a new class with the TaskSpec as a base class
297407
klass = types.new_class(
298408
name=name,
299409
bases=bases,
@@ -303,7 +413,7 @@ def make_task_spec(
303413
),
304414
)
305415
else:
306-
# Ensure that the class has it's own annotaitons dict so we can modify it without
416+
# Ensure that the class has it's own annotations dict so we can modify it without
307417
# messing up other classes
308418
klass.__annotations__ = copy(klass.__annotations__)
309419
klass.Task = task_type
@@ -345,7 +455,53 @@ def make_task_spec(
345455
return attrs_klass
346456

347457

348-
def collate_with_helps(
458+
def make_outputs_spec(
459+
outputs: dict[str, Out], bases: ty.Sequence[type], spec_name: str
460+
) -> type[OutputsSpec]:
461+
"""Create an outputs specification class and its outputs specification class from the
462+
output fields provided to the decorator/function.
463+
464+
Creates a new class with attrs fields and then calls `attrs.define` to create an
465+
attrs class (dataclass-like).
466+
467+
Parameters
468+
----------
469+
outputs : dict[str, Out]
470+
The output fields of the task
471+
bases : ty.Sequence[type], optional
472+
The base classes for the outputs specification class, by default ()
473+
spec_name : str
474+
The name of the task specification class the outputs are for
475+
476+
Returns
477+
-------
478+
klass : type
479+
The class created using the attrs package
480+
"""
481+
if not any(issubclass(b, OutputsSpec) for b in bases):
482+
outputs_bases = bases + (OutputsSpec,)
483+
if reserved_names := [n for n in outputs if n in RESERVED_OUTPUT_NAMES]:
484+
raise ValueError(
485+
f"{reserved_names} are reserved and cannot be used for output field names"
486+
)
487+
outputs_klass = type(
488+
spec_name + "Outputs",
489+
tuple(outputs_bases),
490+
{
491+
o.name: attrs.field(
492+
converter=make_converter(o, f"{spec_name}.Outputs"),
493+
metadata={PYDRA_ATTR_METADATA: o},
494+
**_get_default(o),
495+
)
496+
for o in outputs.values()
497+
},
498+
)
499+
outputs_klass.__annotations__.update((o.name, o.type) for o in outputs.values())
500+
outputs_klass = attrs.define(auto_attribs=False, kw_only=True)(outputs_klass)
501+
return outputs_klass
502+
503+
504+
def ensure_field_objects(
349505
arg_type: type[Arg],
350506
out_type: type[Out],
351507
doc_string: str | None = None,
@@ -354,7 +510,33 @@ def collate_with_helps(
354510
input_helps: dict[str, str] | None = None,
355511
output_helps: dict[str, str] | None = None,
356512
) -> tuple[dict[str, Arg], dict[str, Out]]:
357-
"""Assign help strings to the appropriate inputs and outputs"""
513+
"""Converts dicts containing input/output types into input/output, including any
514+
help strings to the appropriate inputs and outputs
515+
516+
Parameters
517+
----------
518+
arg_type : type
519+
The type of the input fields
520+
out_type : type
521+
The type of the output fields
522+
doc_string : str, optional
523+
The docstring of the function or class
524+
inputs : dict[str, Arg | type], optional
525+
The inputs to the function or class
526+
outputs : dict[str, Out | type], optional
527+
The outputs of the function or class
528+
input_helps : dict[str, str], optional
529+
The help strings for the inputs
530+
output_helps : dict[str, str], optional
531+
The help strings for the outputs
532+
533+
Returns
534+
-------
535+
inputs : dict[str, Arg]
536+
The input fields with help strings added
537+
outputs : dict[str, Out]
538+
The output fields with help strings added
539+
"""
358540

359541
for input_name, arg in list(inputs.items()):
360542
if isinstance(arg, Arg):
@@ -403,7 +585,24 @@ def collate_with_helps(
403585

404586
def make_converter(
405587
field: Field, interface_name: str, field_type: ty.Type | None = None
406-
):
588+
) -> ty.Callable[..., ty.Any]:
589+
"""Makes an attrs converter for the field, combining type checking with any explicit
590+
converters
591+
592+
Parameters
593+
----------
594+
field : Field
595+
The field to make the converter for
596+
interface_name : str
597+
The name of the interface the field is part of
598+
field_type : type, optional
599+
The type of the field, by default None
600+
601+
Returns
602+
-------
603+
converter : callable
604+
The converter for the field
605+
"""
407606
if field_type is None:
408607
field_type = field.type
409608
checker_label = f"'{field.name}' field of {interface_name} interface"
@@ -425,7 +624,22 @@ def make_converter(
425624
return converter
426625

427626

428-
def make_validator(field: Field, interface_name: str):
627+
def make_validator(field: Field, interface_name: str) -> ty.Callable[..., None] | None:
628+
"""Makes an attrs validator for the field, combining allowed values and any explicit
629+
validators
630+
631+
Parameters
632+
----------
633+
field : Field
634+
The field to make the validator for
635+
interface_name : str
636+
The name of the interface the field is part of
637+
638+
Returns
639+
-------
640+
validator : callable
641+
The validator for the field
642+
"""
429643
validators = []
430644
if field.allowed_values:
431645
validators.append(allowed_values_validator)
@@ -458,7 +672,28 @@ def extract_function_inputs_and_outputs(
458672
outputs: list[str | Out] | dict[str, Out | type] | type | None = None,
459673
) -> tuple[dict[str, type | Arg], dict[str, type | Out]]:
460674
"""Extract input output types and output names from the function source if they
461-
aren't explicitly"""
675+
aren't explicitly
676+
677+
Parameters
678+
----------
679+
function : callable
680+
The function to extract the inputs and outputs from
681+
arg_type : type
682+
The type of the input fields
683+
out_type : type
684+
The type of the output fields
685+
inputs : list[str | Arg] | dict[str, Arg | type] | None
686+
The inputs to the function
687+
outputs : list[str | Out] | dict[str, Out | type] | type | None
688+
The outputs of the function
689+
690+
Returns
691+
-------
692+
inputs : dict[str, Arg]
693+
The input fields extracted from the function
694+
outputs : dict[str, Out]
695+
The output fields extracted from the function
696+
"""
462697
# if undefined_symbols := get_undefined_symbols(
463698
# function, exclude_signature_type_hints=True, ignore_decorator=True
464699
# ):

0 commit comments

Comments
 (0)