55import enum
66from pathlib import Path
77from copy import copy
8+ from typing_extensions import Self
89import attrs .validators
910from attrs .converters import default_if_none
1011from fileformats .generic import File
2122)
2223from pydra .engine .core import Task , AuditFlag
2324
25+
2426__all__ = [
2527 "Field" ,
2628 "Arg" ,
2729 "Out" ,
2830 "TaskSpec" ,
29- "collate_with_helps" ,
31+ "OutputsSpec" ,
32+ "ensure_field_objects" ,
3033 "make_task_spec" ,
3134 "list_fields" ,
3235]
3336
37+ RESERVED_OUTPUT_NAMES = ("split" , "combine" )
38+
3439
3540class _Empty (enum .Enum ):
3641
@@ -149,7 +154,73 @@ class Out(Field):
149154 pass
150155
151156
152- OutputType = ty .TypeVar ("OutputType" )
157+ class OutputsSpec :
158+ """Base class for all output specifications"""
159+
160+ def split (
161+ self ,
162+ splitter : ty .Union [str , ty .List [str ], ty .Tuple [str , ...], None ] = None ,
163+ / ,
164+ overwrite : bool = False ,
165+ cont_dim : ty .Optional [dict ] = None ,
166+ ** inputs ,
167+ ) -> Self :
168+ """
169+ Run this task parametrically over lists of split inputs.
170+
171+ Parameters
172+ ----------
173+ splitter : str or list[str] or tuple[str] or None
174+ the fields which to split over. If splitting over multiple fields, lists of
175+ fields are interpreted as outer-products and tuples inner-products. If None,
176+ then the fields to split are taken from the keyword-arg names.
177+ overwrite : bool, optional
178+ whether to overwrite an existing split on the node, by default False
179+ cont_dim : dict, optional
180+ Container dimensions for specific inputs, used in the splitter.
181+ If input name is not in cont_dim, it is assumed that the input values has
182+ a container dimension of 1, so only the most outer dim will be used for splitting.
183+ **inputs
184+ fields to split over, will automatically be wrapped in a StateArray object
185+ and passed to the node inputs
186+
187+ Returns
188+ -------
189+ self : TaskBase
190+ a reference to the task
191+ """
192+ self ._node .split (splitter , overwrite = overwrite , cont_dim = cont_dim , ** inputs )
193+ return self
194+
195+ def combine (
196+ self ,
197+ combiner : ty .Union [ty .List [str ], str ],
198+ overwrite : bool = False , # **kwargs
199+ ) -> Self :
200+ """
201+ Combine inputs parameterized by one or more previous tasks.
202+
203+ Parameters
204+ ----------
205+ combiner : list[str] or str
206+ the field or list of inputs to be combined (i.e. not left split) after the
207+ task has been run
208+ overwrite : bool
209+ whether to overwrite an existing combiner on the node
210+ **kwargs : dict[str, Any]
211+ values for the task that will be "combined" before they are provided to the
212+ node
213+
214+ Returns
215+ -------
216+ self : Self
217+ a reference to the outputs object
218+ """
219+ self ._node .combine (combiner , overwrite = overwrite )
220+ return self
221+
222+
223+ OutputType = ty .TypeVar ("OutputType" , bound = OutputsSpec )
153224
154225
155226class TaskSpec (ty .Generic [OutputType ]):
@@ -197,13 +268,33 @@ def _check_for_unset_values(self):
197268 )
198269
199270
200- def get_fields_from_class (
271+ def extract_fields_from_class (
201272 klass : type ,
202273 arg_type : type [Arg ],
203274 out_type : type [Out ],
204275 auto_attribs : bool ,
205276) -> tuple [dict [str , Arg ], dict [str , Out ]]:
206- """Parse the input and output fields from a class"""
277+ """Extract the input and output fields from an existing class
278+
279+ Parameters
280+ ----------
281+ klass : type
282+ The class to extract the fields from
283+ arg_type : type
284+ The type of the input fields
285+ out_type : type
286+ The type of the output fields
287+ auto_attribs : bool
288+ Whether to assume that all attribute annotations should be interpreted as
289+ fields or not
290+
291+ Returns
292+ -------
293+ inputs : dict[str, Arg]
294+ The input fields extracted from the class
295+ outputs : dict[str, Out]
296+ The output fields extracted from the class
297+ """
207298
208299 input_helps , _ = parse_doc_string (klass .__doc__ )
209300
@@ -269,31 +360,50 @@ def make_task_spec(
269360 bases : ty .Sequence [type ] = (),
270361 outputs_bases : ty .Sequence [type ] = (),
271362):
363+ """Create a task specification class and its outputs specification class from the
364+ input and output fields provided to the decorator/function.
365+
366+ Modifies the class so that its attributes are converted from pydra fields to attrs fields
367+ and then calls `attrs.define` to create an attrs class (dataclass-like).
368+ on
369+
370+ Parameters
371+ ----------
372+ task_type : type
373+ The type of the task to be created
374+ inputs : dict[str, Arg]
375+ The input fields of the task
376+ outputs : dict[str, Out]
377+ The output fields of the task
378+ klass : type, optional
379+ The class to be decorated, by default None
380+ name : str, optional
381+ The name of the class, by default
382+ bases : ty.Sequence[type], optional
383+ The base classes for the task specification class, by default ()
384+ outputs_bases : ty.Sequence[type], optional
385+ The base classes for the outputs specification class, by default ()
386+
387+ Returns
388+ -------
389+ klass : type
390+ The class created using the attrs package
391+ """
272392 if name is None and klass is not None :
273393 name = klass .__name__
274- outputs_klass = type (
275- "Outputs" ,
276- tuple (outputs_bases ),
277- {
278- o .name : attrs .field (
279- converter = make_converter (o , f"{ name } .Outputs" ),
280- metadata = {PYDRA_ATTR_METADATA : o },
281- ** _get_default (o ),
282- )
283- for o in outputs .values ()
284- },
285- )
286- outputs_klass .__annotations__ .update ((o .name , o .type ) for o in outputs .values ())
287- outputs_klass = attrs .define (auto_attribs = False , kw_only = True )(outputs_klass )
288-
394+ outputs_klass = make_outputs_spec (outputs , outputs_bases , name )
289395 if klass is None or not issubclass (klass , TaskSpec ):
290396 if name is None :
291397 raise ValueError ("name must be provided if klass is not" )
292398 bases = tuple (bases )
399+ # Ensure that TaskSpec is a base class
293400 if not any (issubclass (b , TaskSpec ) for b in bases ):
294401 bases = bases + (TaskSpec ,)
402+ # If building from a decorated class (as opposed to dynamically from a function
403+ # or shell-template), add any base classes not already in the bases tuple
295404 if klass is not None :
296405 bases += tuple (c for c in klass .__mro__ if c not in bases + (object ,))
406+ # Create a new class with the TaskSpec as a base class
297407 klass = types .new_class (
298408 name = name ,
299409 bases = bases ,
@@ -303,7 +413,7 @@ def make_task_spec(
303413 ),
304414 )
305415 else :
306- # Ensure that the class has it's own annotaitons dict so we can modify it without
416+ # Ensure that the class has it's own annotations dict so we can modify it without
307417 # messing up other classes
308418 klass .__annotations__ = copy (klass .__annotations__ )
309419 klass .Task = task_type
@@ -345,7 +455,53 @@ def make_task_spec(
345455 return attrs_klass
346456
347457
348- def collate_with_helps (
458+ def make_outputs_spec (
459+ outputs : dict [str , Out ], bases : ty .Sequence [type ], spec_name : str
460+ ) -> type [OutputsSpec ]:
461+ """Create an outputs specification class and its outputs specification class from the
462+ output fields provided to the decorator/function.
463+
464+ Creates a new class with attrs fields and then calls `attrs.define` to create an
465+ attrs class (dataclass-like).
466+
467+ Parameters
468+ ----------
469+ outputs : dict[str, Out]
470+ The output fields of the task
471+ bases : ty.Sequence[type], optional
472+ The base classes for the outputs specification class, by default ()
473+ spec_name : str
474+ The name of the task specification class the outputs are for
475+
476+ Returns
477+ -------
478+ klass : type
479+ The class created using the attrs package
480+ """
481+ if not any (issubclass (b , OutputsSpec ) for b in bases ):
482+ outputs_bases = bases + (OutputsSpec ,)
483+ if reserved_names := [n for n in outputs if n in RESERVED_OUTPUT_NAMES ]:
484+ raise ValueError (
485+ f"{ reserved_names } are reserved and cannot be used for output field names"
486+ )
487+ outputs_klass = type (
488+ spec_name + "Outputs" ,
489+ tuple (outputs_bases ),
490+ {
491+ o .name : attrs .field (
492+ converter = make_converter (o , f"{ spec_name } .Outputs" ),
493+ metadata = {PYDRA_ATTR_METADATA : o },
494+ ** _get_default (o ),
495+ )
496+ for o in outputs .values ()
497+ },
498+ )
499+ outputs_klass .__annotations__ .update ((o .name , o .type ) for o in outputs .values ())
500+ outputs_klass = attrs .define (auto_attribs = False , kw_only = True )(outputs_klass )
501+ return outputs_klass
502+
503+
504+ def ensure_field_objects (
349505 arg_type : type [Arg ],
350506 out_type : type [Out ],
351507 doc_string : str | None = None ,
@@ -354,7 +510,33 @@ def collate_with_helps(
354510 input_helps : dict [str , str ] | None = None ,
355511 output_helps : dict [str , str ] | None = None ,
356512) -> tuple [dict [str , Arg ], dict [str , Out ]]:
357- """Assign help strings to the appropriate inputs and outputs"""
513+ """Converts dicts containing input/output types into input/output, including any
514+ help strings to the appropriate inputs and outputs
515+
516+ Parameters
517+ ----------
518+ arg_type : type
519+ The type of the input fields
520+ out_type : type
521+ The type of the output fields
522+ doc_string : str, optional
523+ The docstring of the function or class
524+ inputs : dict[str, Arg | type], optional
525+ The inputs to the function or class
526+ outputs : dict[str, Out | type], optional
527+ The outputs of the function or class
528+ input_helps : dict[str, str], optional
529+ The help strings for the inputs
530+ output_helps : dict[str, str], optional
531+ The help strings for the outputs
532+
533+ Returns
534+ -------
535+ inputs : dict[str, Arg]
536+ The input fields with help strings added
537+ outputs : dict[str, Out]
538+ The output fields with help strings added
539+ """
358540
359541 for input_name , arg in list (inputs .items ()):
360542 if isinstance (arg , Arg ):
@@ -403,7 +585,24 @@ def collate_with_helps(
403585
404586def make_converter (
405587 field : Field , interface_name : str , field_type : ty .Type | None = None
406- ):
588+ ) -> ty .Callable [..., ty .Any ]:
589+ """Makes an attrs converter for the field, combining type checking with any explicit
590+ converters
591+
592+ Parameters
593+ ----------
594+ field : Field
595+ The field to make the converter for
596+ interface_name : str
597+ The name of the interface the field is part of
598+ field_type : type, optional
599+ The type of the field, by default None
600+
601+ Returns
602+ -------
603+ converter : callable
604+ The converter for the field
605+ """
407606 if field_type is None :
408607 field_type = field .type
409608 checker_label = f"'{ field .name } ' field of { interface_name } interface"
@@ -425,7 +624,22 @@ def make_converter(
425624 return converter
426625
427626
428- def make_validator (field : Field , interface_name : str ):
627+ def make_validator (field : Field , interface_name : str ) -> ty .Callable [..., None ] | None :
628+ """Makes an attrs validator for the field, combining allowed values and any explicit
629+ validators
630+
631+ Parameters
632+ ----------
633+ field : Field
634+ The field to make the validator for
635+ interface_name : str
636+ The name of the interface the field is part of
637+
638+ Returns
639+ -------
640+ validator : callable
641+ The validator for the field
642+ """
429643 validators = []
430644 if field .allowed_values :
431645 validators .append (allowed_values_validator )
@@ -458,7 +672,28 @@ def extract_function_inputs_and_outputs(
458672 outputs : list [str | Out ] | dict [str , Out | type ] | type | None = None ,
459673) -> tuple [dict [str , type | Arg ], dict [str , type | Out ]]:
460674 """Extract input output types and output names from the function source if they
461- aren't explicitly"""
675+ aren't explicitly
676+
677+ Parameters
678+ ----------
679+ function : callable
680+ The function to extract the inputs and outputs from
681+ arg_type : type
682+ The type of the input fields
683+ out_type : type
684+ The type of the output fields
685+ inputs : list[str | Arg] | dict[str, Arg | type] | None
686+ The inputs to the function
687+ outputs : list[str | Out] | dict[str, Out | type] | type | None
688+ The outputs of the function
689+
690+ Returns
691+ -------
692+ inputs : dict[str, Arg]
693+ The input fields extracted from the function
694+ outputs : dict[str, Out]
695+ The output fields extracted from the function
696+ """
462697 # if undefined_symbols := get_undefined_symbols(
463698 # function, exclude_signature_type_hints=True, ignore_decorator=True
464699 # ):
0 commit comments