55import  enum 
66from  pathlib  import  Path 
77from  copy  import  copy 
8- from  typing_extensions  import  Self 
98import  attrs .validators 
109from  attrs .converters  import  default_if_none 
1110from  fileformats .generic  import  File 
1211from  pydra .utils .typing  import  TypeParser , is_optional , is_fileset_or_union 
13- 
14- # from pydra.utils.misc import get_undefined_symbols 
15- from  pydra .engine .helpers  import  from_list_if_single , ensure_list 
16- from  pydra .engine .specs  import  (
17-     LazyField ,
12+ from  pydra .engine .helpers  import  (
13+     from_list_if_single ,
14+     ensure_list ,
15+     PYDRA_ATTR_METADATA ,
16+     list_fields ,
17+ )
18+ from  pydra .utils .typing  import  (
1819    MultiInputObj ,
1920    MultiInputFile ,
2021    MultiOutputObj ,
2122    MultiOutputFile ,
2223)
23- from  pydra .engine .core  import  Task ,  AuditFlag 
24+ from  pydra .engine .workflow . lazy  import  LazyField 
2425
2526
27+ if  ty .TYPE_CHECKING :
28+     from  pydra .engine .specs  import  OutputsSpec 
29+     from  pydra .engine .core  import  Task 
30+ 
2631__all__  =  [
2732    "Field" ,
2833    "Arg" ,
2934    "Out" ,
30-     "TaskSpec" ,
31-     "OutputsSpec" ,
3235    "ensure_field_objects" ,
3336    "make_task_spec" ,
34-     "list_fields" ,
3537]
3638
3739RESERVED_OUTPUT_NAMES  =  ("split" , "combine" )
@@ -154,120 +156,6 @@ class Out(Field):
154156    pass 
155157
156158
157- class  OutputsSpec :
158-     """Base class for all output specifications""" 
159- 
160-     def  split (
161-         self ,
162-         splitter : ty .Union [str , ty .List [str ], ty .Tuple [str , ...], None ] =  None ,
163-         / ,
164-         overwrite : bool  =  False ,
165-         cont_dim : ty .Optional [dict ] =  None ,
166-         ** inputs ,
167-     ) ->  Self :
168-         """ 
169-         Run this task parametrically over lists of split inputs. 
170- 
171-         Parameters 
172-         ---------- 
173-         splitter : str or list[str] or tuple[str] or None 
174-             the fields which to split over. If splitting over multiple fields, lists of 
175-             fields are interpreted as outer-products and tuples inner-products. If None, 
176-             then the fields to split are taken from the keyword-arg names. 
177-         overwrite : bool, optional 
178-             whether to overwrite an existing split on the node, by default False 
179-         cont_dim : dict, optional 
180-             Container dimensions for specific inputs, used in the splitter. 
181-             If input name is not in cont_dim, it is assumed that the input values has 
182-             a container dimension of 1, so only the most outer dim will be used for splitting. 
183-         **inputs 
184-             fields to split over, will automatically be wrapped in a StateArray object 
185-             and passed to the node inputs 
186- 
187-         Returns 
188-         ------- 
189-         self : TaskBase 
190-             a reference to the task 
191-         """ 
192-         self ._node .split (splitter , overwrite = overwrite , cont_dim = cont_dim , ** inputs )
193-         return  self 
194- 
195-     def  combine (
196-         self ,
197-         combiner : ty .Union [ty .List [str ], str ],
198-         overwrite : bool  =  False ,  # **kwargs 
199-     ) ->  Self :
200-         """ 
201-         Combine inputs parameterized by one or more previous tasks. 
202- 
203-         Parameters 
204-         ---------- 
205-         combiner : list[str] or str 
206-             the field or list of inputs to be combined (i.e. not left split) after the 
207-             task has been run 
208-         overwrite : bool 
209-             whether to overwrite an existing combiner on the node 
210-         **kwargs : dict[str, Any] 
211-             values for the task that will be "combined" before they are provided to the 
212-             node 
213- 
214-         Returns 
215-         ------- 
216-         self : Self 
217-             a reference to the outputs object 
218-         """ 
219-         self ._node .combine (combiner , overwrite = overwrite )
220-         return  self 
221- 
222- 
223- OutputType  =  ty .TypeVar ("OutputType" , bound = OutputsSpec )
224- 
225- 
226- class  TaskSpec (ty .Generic [OutputType ]):
227-     """Base class for all task specifications""" 
228- 
229-     Task : ty .Type [Task ]
230- 
231-     def  __call__ (
232-         self ,
233-         name : str  |  None  =  None ,
234-         audit_flags : AuditFlag  =  AuditFlag .NONE ,
235-         cache_dir = None ,
236-         cache_locations = None ,
237-         inputs : ty .Text  |  File  |  dict [str , ty .Any ] |  None  =  None ,
238-         cont_dim = None ,
239-         messenger_args = None ,
240-         messengers = None ,
241-         rerun = False ,
242-         ** kwargs ,
243-     ):
244-         self ._check_for_unset_values ()
245-         task  =  self .Task (
246-             self ,
247-             name = name ,
248-             audit_flags = audit_flags ,
249-             cache_dir = cache_dir ,
250-             cache_locations = cache_locations ,
251-             inputs = inputs ,
252-             cont_dim = cont_dim ,
253-             messenger_args = messenger_args ,
254-             messengers = messengers ,
255-             rerun = rerun ,
256-         )
257-         return  task (** kwargs )
258- 
259-     def  _check_for_unset_values (self ):
260-         if  unset  :=  [
261-             k 
262-             for  k , v  in  attrs .asdict (self , recurse = False ).items ()
263-             if  v  is  attrs .NOTHING 
264-         ]:
265-             raise  ValueError (
266-                 f"The following values { unset }   in the { self !r}   interface need to be set " 
267-                 "before the workflow can be constructed" 
268-             )
269- 
270- 
271159def  extract_fields_from_class (
272160    klass : type ,
273161    arg_type : type [Arg ],
@@ -352,7 +240,7 @@ def get_fields(klass, field_type, auto_attribs, helps) -> dict[str, Field]:
352240
353241
354242def  make_task_spec (
355-     task_type : type [Task ],
243+     task_type : type [" Task"  ],
356244    inputs : dict [str , Arg ],
357245    outputs : dict [str , Out ],
358246    klass : type  |  None  =  None ,
@@ -389,6 +277,8 @@ def make_task_spec(
389277    klass : type 
390278        The class created using the attrs package 
391279    """ 
280+     from  pydra .engine .specs  import  TaskSpec 
281+ 
392282    if  name  is  None  and  klass  is  not   None :
393283        name  =  klass .__name__ 
394284    outputs_klass  =  make_outputs_spec (outputs , outputs_bases , name )
@@ -457,7 +347,7 @@ def make_task_spec(
457347
458348def  make_outputs_spec (
459349    outputs : dict [str , Out ], bases : ty .Sequence [type ], spec_name : str 
460- ) ->  type [OutputsSpec ]:
350+ ) ->  type [" OutputsSpec"  ]:
461351    """Create an outputs specification class and its outputs specification class from the 
462352    output fields provided to the decorator/function. 
463353
@@ -478,6 +368,8 @@ def make_outputs_spec(
478368    klass : type 
479369        The class created using the attrs package 
480370    """ 
371+     from  pydra .engine .specs  import  OutputsSpec 
372+ 
481373    if  not  any (issubclass (b , OutputsSpec ) for  b  in  bases ):
482374        outputs_bases  =  bases  +  (OutputsSpec ,)
483375    if  reserved_names  :=  [n  for  n  in  outputs  if  n  in  RESERVED_OUTPUT_NAMES ]:
@@ -880,16 +772,6 @@ def split_block(string: str) -> ty.Generator[str, None, None]:
880772        yield  block .strip ()
881773
882774
883- def  list_fields (interface : TaskSpec ) ->  list [Field ]:
884-     if  not  attrs .has (interface ):
885-         return  []
886-     return  [
887-         f .metadata [PYDRA_ATTR_METADATA ]
888-         for  f  in  attrs .fields (interface )
889-         if  PYDRA_ATTR_METADATA  in  f .metadata 
890-     ]
891- 
892- 
893775def  check_explicit_fields_are_none (klass , inputs , outputs ):
894776    if  inputs  is  not   None :
895777        raise  ValueError (
@@ -918,5 +800,3 @@ def nothing_factory():
918800
919801
920802white_space_re  =  re .compile (r"\s+" )
921- 
922- PYDRA_ATTR_METADATA  =  "__PYDRA_METADATA__" 
0 commit comments