1212import black .report
1313import attrs
1414import yaml
15- from fileformats .core import from_mime , FileSet
15+ from fileformats .core import from_mime , FileSet , Field
1616from .utils import (
1717 UsedSymbols ,
1818 split_source_into_statements ,
@@ -114,6 +114,8 @@ def type_repr_(t):
114114 )
115115 if t in (ty .Any , ty .Union , ty .List , ty .Tuple ):
116116 return f"ty.{ t .__name__ } "
117+ elif issubclass (t , Field ):
118+ return t .primative .__name__
117119 elif issubclass (t , FileSet ):
118120 return t .__name__
119121 else :
@@ -407,7 +409,7 @@ def exported_outputs(self):
407409 return (o for o in self .outputs .values () if o .export )
408410
409411 def get_input (
410- self , field_name : str , node_name : ty .Optional [str ] = None
412+ self , field_name : str , node_name : ty .Optional [str ] = None , create : bool = False
411413 ) -> WorkflowInput :
412414 """
413415 Returns the name of the input field in the workflow for the given node and field
@@ -416,17 +418,21 @@ def get_input(
416418 try :
417419 return self ._input_mapping [(node_name , field_name )]
418420 except KeyError :
419- inpt_name = (
420- field_name
421- if node_name is None or node_name == self .input_node
422- else f"{ node_name } _{ field_name } "
423- )
421+ if node_name is None or node_name == self .input_node :
422+ inpt_name = field_name
423+ elif create :
424+ inpt_name = f"{ node_name } _{ field_name } "
425+ else :
426+ raise KeyError (
427+ f"Unrecognised output corresponding to { node_name } :{ field_name } field, "
428+ "set create=True to auto-create"
429+ )
424430 inpt = WorkflowInput (name = inpt_name , field = field_name , node_name = node_name )
425431 self .inputs [inpt_name ] = self ._input_mapping [(node_name , field_name )] = inpt
426432 return inpt
427433
428434 def get_output (
429- self , field_name : str , node_name : ty .Optional [str ] = None
435+ self , field_name : str , node_name : ty .Optional [str ] = None , create : bool = False
430436 ) -> WorkflowOutput :
431437 """
432438 Returns the name of the input field in the workflow for the given node and field
@@ -435,11 +441,15 @@ def get_output(
435441 try :
436442 return self ._output_mapping [(node_name , field_name )]
437443 except KeyError :
438- outpt_name = (
439- field_name
440- if node_name is None or node_name == self .input_node
441- else f"{ node_name } _{ field_name } "
442- )
444+ if node_name is None or node_name == self .output_node :
445+ outpt_name = field_name
446+ elif create :
447+ outpt_name = f"{ node_name } _{ field_name } "
448+ else :
449+ raise KeyError (
450+ f"Unrecognised output corresponding to { node_name } :{ field_name } field, "
451+ "set create=True to auto-create"
452+ )
443453 outpt = WorkflowOutput (
444454 name = outpt_name , field = field_name , node_name = node_name
445455 )
@@ -923,11 +933,11 @@ def prepare_connections(self):
923933 for node in nodes :
924934 if isinstance (node , AddNestedWorkflowStatement ):
925935 exported_inputs .update (
926- (i .name , self .get_input (i .name , node_name ))
936+ (i .name , self .get_input (i .name , node_name , create = True ))
927937 for i in node .nested_workflow .exported_inputs
928938 )
929939 exported_outputs .update (
930- (o .name , self .get_output (o .name , node_name ))
940+ (o .name , self .get_output (o .name , node_name , create = True ))
931941 for o in node .nested_workflow .exported_outputs
932942 )
933943 for inpt_name , exp_inpt in exported_inputs :
@@ -957,16 +967,20 @@ def prepare_connections(self):
957967 self .parsed_statements .append (conn_stmt )
958968 while self ._unprocessed_connections :
959969 conn = self ._unprocessed_connections .pop ()
960- if conn . wf_in :
961- self .get_input (conn .source_out ). out_conns . append ( conn )
962- else :
970+ try :
971+ inpt = self .get_input (conn .source_out , node_name = conn . source_name )
972+ except KeyError :
963973 for src_node in self .nodes [conn .source_name ]:
964974 src_node .add_output_connection (conn )
965- if conn .wf_out :
966- self .get_output (conn .target_in ).in_conns .append (conn )
967975 else :
976+ inpt .out_conns .append (conn )
977+ try :
978+ outpt = self .get_output (conn .target_in , node_name = conn .target_name )
979+ except KeyError :
968980 for tgt_node in self .nodes [conn .target_name ]:
969981 tgt_node .add_input_connection (conn )
982+ else :
983+ outpt .in_conns .append (conn )
970984
971985 def _parse_statements (self , func_body : str ) -> ty .Tuple [
972986 ty .List [
0 commit comments