@@ -177,24 +177,21 @@ def targets(self):
177177
178178 @property
179179 def wf_in (self ):
180- if self .source_name is None :
180+ try :
181+ self .workflow_converter .get_input_from_conn (self )
182+ except KeyError :
183+ return False
184+ else :
181185 return True
182- for inpt in self .workflow_converter .inputs .values ():
183- if self .target_name == inpt .node_name and str (self .target_in ) == inpt .field :
184- return True
185- return False
186186
187187 @property
188188 def wf_out (self ):
189- if self .target_name is None :
189+ try :
190+ self .workflow_converter .get_output_from_conn (self )
191+ except KeyError :
192+ return False
193+ else :
190194 return True
191- for output in self .workflow_converter .outputs .values ():
192- if (
193- self .source_name == output .node_name
194- and str (self .source_out ) == output .field
195- ):
196- return True
197- return False
198195
199196 @cached_property
200197 def conditional (self ):
@@ -215,29 +212,11 @@ def workflow_variable(self):
215212
216213 @property
217214 def wf_in_name (self ):
218- if not self .wf_in :
219- raise ValueError (
220- f"Cannot get wf_in_name for { self } as it is not a workflow input"
221- )
222- if self .source_name is None :
223- return (
224- self .source_out
225- if not isinstance (self .source_out , DynamicField )
226- else self .source_out .varname
227- )
228- return self .workflow_converter .get_input (self .target_in , self .target_name ).name
215+ return self .workflow_converter .get_input_from_conn (self ).name
229216
230217 @property
231218 def wf_out_name (self ):
232- if not self .wf_out :
233- raise ValueError (
234- f"Cannot get wf_out_name for { self } as it is not a workflow output"
235- )
236- if self .target_name is None :
237- return self .target_in
238- return self .workflow_converter .get_output (
239- self .source_out , self .source_name
240- ).name
219+ return self .workflow_converter .get_output_from_conn (self ).name
241220
242221 def __str__ (self ):
243222 if not self .include :
@@ -274,7 +253,7 @@ def __str__(self):
274253 # to add an "identity" node to pass it through
275254 intf_name = f"{ base_task_name } _identity"
276255 code_str += (
277- f"{ self .indent } @pydra.mark.task\n "
256+ f"\n { self .indent } @pydra.mark.task\n "
278257 f"{ self .indent } def { intf_name } ({ self .wf_in_name } : ty.Any) -> ty.Any:\n "
279258 f"{ self .indent } return { self .wf_in_name } \n \n "
280259 f"{ self .indent } { self .workflow_variable } .add("
@@ -669,11 +648,29 @@ def add_input_connection(self, conn: ConnectionStatement):
669648 else :
670649 target_in = conn .target_in
671650 target_name = None
672- if target_name == self .nested_workflow .input_node :
651+ # Check for replacements for the given target field
652+ replacements = [
653+ i
654+ for i in self .nested_workflow .inputs .values ()
655+ if any (n == target_name and f == target_in for n , f in i .replaces )
656+ ]
657+ if len (replacements ) > 1 :
658+ raise ValueError (
659+ f"Multiple inputs found for replacements of '{ target_in } ' "
660+ f"field in '{ target_name } ' node in '{ self .name } ' workflow: "
661+ + ", " .join (str (m ) for m in replacements )
662+ )
663+ elif len (replacements ) == 1 :
664+ nested_input = replacements [0 ]
673665 target_name = None
674- nested_input = self .nested_workflow .get_input (
675- target_in , node_name = target_name , create = True
676- )
666+ else :
667+ # If no replacements, create an input for the nested workflow
668+ if target_name == self .nested_workflow .input_node :
669+ target_name = None
670+ nested_input = self .nested_workflow .make_input (
671+ target_in ,
672+ node_name = target_name ,
673+ )
677674 conn .target_in = nested_input .name
678675 super ().add_input_connection (conn )
679676 if target_name :
@@ -716,11 +713,26 @@ def add_output_connection(self, conn: ConnectionStatement):
716713 else :
717714 source_out = conn .source_out
718715 source_name = None
719- if source_name == self .nested_workflow .output_node :
716+ replacements = [
717+ o
718+ for o in self .nested_workflow .outputs .values ()
719+ if any (n == source_name and f == source_out for n , f in o .replaces )
720+ ]
721+ if len (replacements ) > 1 :
722+ raise KeyError (
723+ f"Multiple outputs found for replacements of '{ source_out } ' "
724+ f"field in '{ source_name } ' node in '{ self .name } ' workflow: "
725+ + ", " .join (str (m ) for m in replacements )
726+ )
727+ elif len (replacements ) == 1 :
728+ nested_output = replacements [0 ]
720729 source_name = None
721- nested_output = self .nested_workflow .get_output (
722- source_out , node_name = source_name , create = True
723- )
730+ else :
731+ if source_name == self .nested_workflow .output_node :
732+ source_name = None
733+ nested_output = self .nested_workflow .make_output (
734+ source_out , node_name = source_name
735+ )
724736 conn .source_out = nested_output .name
725737 super ().add_output_connection (conn )
726738 if source_name :
@@ -759,7 +771,7 @@ def __str__(self):
759771 parts = self .attribute .split ("." )
760772 nested_node_name = parts [2 ]
761773 attribute_name = parts [3 ]
762- target_in = nested_wf .get_input (attribute_name , nested_node_name ).name
774+ target_in = nested_wf .make_input (attribute_name , nested_node_name ).name
763775 attribute = "." .join (parts [:2 ] + [target_in ] + parts [4 :])
764776 workflow_variable = self .nodes [0 ].workflow_variable
765777 assert (n .workflow_variable == workflow_variable for n in self .nodes )
@@ -782,6 +794,10 @@ def matches(cls, stmt, node_names: ty.List[str]) -> bool:
782794 return False
783795 return bool (cls .match_re (node_names ).match (stmt ))
784796
797+ @property
798+ def conditional (self ):
799+ return len (self .indent ) != 4
800+
785801 @classmethod
786802 def parse (
787803 cls , statement : str , workflow_converter : "WorkflowConverter"
0 commit comments