1313import attrs
1414import yaml
1515from fileformats .core import from_mime , FileSet , Field
16+ from fileformats .core .exceptions import FormatRecognitionError
1617from .utils import (
1718 UsedSymbols ,
1819 split_source_into_statements ,
@@ -51,6 +52,15 @@ def convert_node_prefixes(
5152 return {n : v if v is not None else "" for n , v in nodes_it }
5253
5354
55+ def convert_type (tp : ty .Union [str , type ]) -> type :
56+ if not isinstance (tp , str ):
57+ return tp
58+ try :
59+ return from_mime (tp )
60+ except FormatRecognitionError :
61+ return eval (tp )
62+
63+
5464@attrs .define
5565class WorkflowInterfaceField :
5666
@@ -73,7 +83,7 @@ class WorkflowInterfaceField:
7383 )
7484 type : type = attrs .field (
7585 default = ty .Any ,
76- converter = lambda t : from_mime ( t ) if isinstance ( t , str ) else t ,
86+ converter = convert_type ,
7787 metadata = {
7888 "help" : "The type of the input/output of the converted workflow" ,
7989 },
@@ -117,6 +127,8 @@ def type_repr_(t):
117127 return t .primitive .__name__
118128 elif issubclass (t , FileSet ):
119129 return t .__name__
130+ elif t .__module__ == "builtins" :
131+ return t .__name__
120132 else :
121133 return f"{ t .__module__ } .{ t .__name__ } "
122134
@@ -154,6 +166,18 @@ class WorkflowInput(WorkflowInterfaceField):
154166 },
155167 )
156168
169+ include : bool = attrs .field (
170+ default = False ,
171+ eq = False ,
172+ hash = False ,
173+ metadata = {
174+ "help" : (
175+ "Whether the input is required for the workflow once the unused nodes "
176+ "have been filtered out"
177+ )
178+ },
179+ )
180+
157181 def __hash__ (self ):
158182 return super ().__hash__ ()
159183
@@ -321,6 +345,9 @@ class WorkflowConverter:
321345 _unprocessed_connections : ty .List [ConnectionStatement ] = attrs .field (
322346 factory = list , repr = False
323347 )
348+ used_inputs : ty .Optional [ty .Set [WorkflowInput ]] = attrs .field (
349+ default = None , repr = False
350+ )
324351
325352 def __attrs_post_init__ (self ):
326353 if self .workflow_variable is None :
@@ -383,7 +410,9 @@ def get_input_from_conn(self, conn: ConnectionStatement) -> WorkflowInput:
383410 escaped by the prefix of the node if present"""
384411 try :
385412 return self .make_input (
386- field_name = conn .target_in , node_name = conn .target_name , input_node_only = None
413+ field_name = conn .source_out ,
414+ node_name = conn .source_name ,
415+ input_node_only = None ,
387416 )
388417 except KeyError :
389418 pass
@@ -394,7 +423,7 @@ def get_input_from_conn(self, conn: ConnectionStatement) -> WorkflowInput:
394423 f"Could not find output corresponding to '{ conn .source_out } ' input"
395424 )
396425 return self .make_input (
397- field_name = conn .target_in , node_name = conn .target_name , input_node_only = True
426+ field_name = conn .source_out , node_name = conn .source_name , input_node_only = True
398427 )
399428
400429 def get_output_from_conn (self , conn : ConnectionStatement ) -> WorkflowOutput :
@@ -437,7 +466,7 @@ def make_input(
437466 if i .node_name == node_name and i .field == field_name
438467 ]
439468 if len (matching ) > 1 :
440- raise KeyError (
469+ raise RuntimeError (
441470 f"Multiple inputs found for '{ field_name } ' field in "
442471 f"'{ node_name } ' node in '{ self .name } ' workflow"
443472 )
@@ -481,7 +510,7 @@ def make_output(
481510 if o .node_name == node_name and o .field == field_name
482511 ]
483512 if len (matching ) > 1 :
484- raise KeyError (
513+ raise RuntimeError (
485514 f"Multiple outputs found for '{ field_name } ' field in "
486515 f"'{ node_name } ' node in '{ self .name } ' workflow: "
487516 + ", " .join (str (m ) for m in matching )
@@ -569,74 +598,6 @@ def add_connection_from_output(self, out_conn: ConnectionStatement):
569598 """Add a connection to an input of the workflow, adding the input if not present"""
570599 self ._add_output_conn (out_conn , "from" )
571600
572- # def _add_input_conn(self, conn: ConnectionStatement, direction: str = "in"):
573- # """Add an incoming connection to an input of the workflow, adding the input
574- # if not present"""
575- # if direction == "in":
576- # node_name = conn.target_name
577- # field_name = str(conn.target_in)
578- # else:
579- # node_name = conn.source_name
580- # field_name = str(conn.source_out)
581- # try:
582- # inpt = self._input_mapping[(node_name, field_name)]
583- # except KeyError:
584- # if node_name == self.input_node:
585- # inpt = WorkflowInput(
586- # name=field_name,
587- # node_name=self.input_node,
588- # field=field_name,
589- # )
590- # elif direction == "in":
591- # name = conn.source_out
592- # if conn.source_name != conn.workflow_converter.input_node:
593- # name = f"{conn.source_name}_{name}"
594- # inpt = WorkflowInput(
595- # name=name,
596- # node_name=self.input_node,
597- # field=field_name,
598- # )
599- # else:
600- # raise KeyError(
601- # f"Could not find input corresponding to '{field_name}' field in "
602- # f"'{conn.target_name}' node in '{self.name}' workflow"
603- # )
604- # self._input_mapping[(node_name, field_name)] = inpt
605- # self.inputs[field_name] = inpt
606-
607- # inpt.in_conns.append(conn)
608-
609- # def _add_output_conn(self, conn: ConnectionStatement, direction="in"):
610- # if direction == "from":
611- # node_name = conn.source_name
612- # field_name = str(conn.source_out)
613- # else:
614- # node_name = conn.target_name
615- # field_name = str(conn.target_in)
616- # try:
617- # outpt = self._output_mapping[(node_name, field_name)]
618- # except KeyError:
619- # if node_name == self.output_node:
620- # outpt = WorkflowOutput(
621- # name=field_name,
622- # node_name=self.output_node,
623- # field=field_name,
624- # )
625- # elif direction == "out":
626- # outpt = WorkflowOutput(
627- # name=field_name,
628- # node_name=self.output_node,
629- # field=field_name,
630- # )
631- # else:
632- # raise KeyError(
633- # f"Could not foutd output correspondoutg to '{field_name}' field out "
634- # f"'{conn.target_name}' node out '{self.name}' workflow"
635- # )
636- # self._output_mapping[(node_name, field_name)] = outpt
637- # self.outputs[field_name] = outpt
638- # outpt.out_conns.append(conn)
639-
640601 @cached_property
641602 def used_symbols (self ) -> UsedSymbols :
642603 return UsedSymbols .find (
@@ -651,13 +612,16 @@ def used_symbols(self) -> UsedSymbols:
651612 translations = self .package .all_import_translations ,
652613 )
653614
654- @cached_property
615+ @property
655616 def used_configs (self ) -> ty .List [str ]:
656617 return self ._converted_code [1 ]
657618
658- @cached_property
619+ @property
659620 def converted_code (self ) -> ty .List [str ]:
660- return self ._converted_code [0 ]
621+ try :
622+ return self ._converted_code [0 ]
623+ except AttributeError as e :
624+ raise RuntimeError ("caught AttributeError" ) from e
661625
662626 @cached_property
663627 def input_output_imports (self ) -> ty .List [ImportStatement ]:
@@ -667,10 +631,6 @@ def input_output_imports(self) -> ty.List[ImportStatement]:
667631 stmts .append (ImportStatement .from_object (tp ))
668632 return ImportStatement .collate (stmts )
669633
670- @cached_property
671- def inline_imports (self ) -> ty .List [str ]:
672- return [s for s in self .converted_code if isinstance (s , ImportStatement )]
673-
674634 @cached_property
675635 def func_src (self ):
676636 return inspect .getsource (self .nipype_function )
@@ -824,6 +784,10 @@ def _converted_code(self) -> ty.Tuple[str, ty.List[str]]:
824784 the names of the used configs
825785 """
826786
787+ for nested_workflow in self .nested_workflows .values ():
788+ # processing nested workflows first so we know which inputs are required
789+ nested_workflow ._converted_code
790+
827791 declaration , func_args , post = extract_args (self .func_src )
828792 return_types = post [1 :].split (":" , 1 )[0 ] # Get the return type
829793
@@ -846,14 +810,26 @@ def add_nonstd_types(tp):
846810
847811 while conn_stack :
848812 conn = conn_stack .pop ()
849- # Will only be included if connected from inputs to outputs, still coerces to
850- # false but
851- conn .include = 0
813+ # Will only be included if connected from inputs to outputs. If included
814+ # from input->output traversal nodes and conns are flagged as include=None,
815+ # because this coerces to False but is differentiable from False when we
816+ # come to do the traversal in the other direction
817+ conn .include = None
852818 if conn .target_name :
853819 sibling_target_nodes = self .nodes [conn .target_name ]
820+ exclude = True
854821 for target_node in sibling_target_nodes :
855- target_node .include = 0
856- conn_stack .extend (target_node .out_conns )
822+ # Check to see if the input is required, so we can change its include
823+ # flag back to false if not
824+ if (
825+ not isinstance (target_node , AddNestedWorkflowStatement )
826+ or target_node .nested_workflow .inputs [conn .target_in ].include
827+ ):
828+ target_node .include = None
829+ conn_stack .extend (target_node .out_conns )
830+ exclude = False
831+ if exclude :
832+ conn .include = False
857833
858834 # Walk through the graph backwards from the outputs and trim any unnecessary
859835 # connections
@@ -864,20 +840,26 @@ def add_nonstd_types(tp):
864840
865841 nonstd_types .discard (ty .Any )
866842
843+ self .used_inputs = set ()
844+
867845 while conn_stack :
868846 conn = conn_stack .pop ()
869- if (
870- conn .include == 0
871- ): # if included forward from inputs and backwards from outputs
847+ # if included forward from inputs and backwards from outputs
848+ if conn .include is None :
872849 conn .include = True
850+ else :
851+ continue
873852 if conn .source_name :
874853 sibling_source_nodes = self .nodes [conn .source_name ]
875854 for source_node in sibling_source_nodes :
876- if (
877- source_node .include == 0
878- ): # if included forward from inputs and backwards from outputs
855+ # if included forward from inputs and backwards from outputs
856+ if source_node .include is None :
879857 source_node .include = True
880858 conn_stack .extend (source_node .in_conns )
859+ else :
860+ inpt = self .inputs [conn .source_out ]
861+ inpt .include = True
862+ self .used_inputs .add (inpt )
881863
882864 preamble = ""
883865 statements = copy (self .parsed_statements )
@@ -901,7 +883,7 @@ def add_nonstd_types(tp):
901883 self .package .find_and_replace_config_params (code_str , nested_configs )
902884 )
903885
904- inputs_sig = [f"{ i } =attrs.NOTHING" for i in self .inputs ]
886+ inputs_sig = [f"{ i . name } =attrs.NOTHING" for i in self .used_inputs ]
905887
906888 # construct code string with modified signature
907889 signature = (
0 commit comments