36
36
WorkflowInitStatement ,
37
37
AssignmentStatement ,
38
38
OtherStatement ,
39
+ DynamicField ,
39
40
)
40
41
import nipype2pydra .package
41
42
@@ -94,8 +95,7 @@ class WorkflowInterfaceField:
94
95
factory = list ,
95
96
metadata = {
96
97
"help" : (
97
- "node-name/field-name pairs of other fields that are to be routed to "
98
- "from other node fields to this input/output" ,
98
+ "node-name/field-name pairs of additional fields that this input/output replaces" ,
99
99
)
100
100
},
101
101
)
@@ -159,6 +159,11 @@ def __hash__(self):
159
159
@attrs .define
160
160
class WorkflowInput (WorkflowInterfaceField ):
161
161
162
+ connections : ty .Tuple [ty .Tuple [str , str ]] = attrs .field (
163
+ converter = lambda lst : tuple (sorted (tuple (t ) for t in lst )),
164
+ factory = list ,
165
+ metadata = {"help" : ("Explicit connections to be made from this input field" ,)},
166
+ )
162
167
out_conns : ty .List [ConnectionStatement ] = attrs .field (
163
168
factory = list ,
164
169
eq = False ,
@@ -170,9 +175,7 @@ class WorkflowInput(WorkflowInterfaceField):
170
175
)
171
176
},
172
177
)
173
-
174
178
include : bool = attrs .field (
175
- default = False ,
176
179
eq = False ,
177
180
hash = False ,
178
181
metadata = {
@@ -183,13 +186,22 @@ class WorkflowInput(WorkflowInterfaceField):
183
186
},
184
187
)
185
188
189
+ @include .default
190
+ def _include_default (self ) -> bool :
191
+ return bool (self .connections )
192
+
186
193
def __hash__ (self ):
187
194
return super ().__hash__ ()
188
195
189
196
190
197
@attrs .define
191
198
class WorkflowOutput (WorkflowInterfaceField ):
192
199
200
+ connection : ty .Tuple [str , str ] = attrs .field (
201
+ converter = tuple ,
202
+ factory = list ,
203
+ metadata = {"help" : ("Explicit connection to be made to this output field" ,)},
204
+ )
193
205
in_conns : ty .List [ConnectionStatement ] = attrs .field (
194
206
factory = list ,
195
207
eq = False ,
@@ -413,6 +425,12 @@ def get_input_from_conn(self, conn: ConnectionStatement) -> WorkflowInput:
413
425
"""
414
426
Returns the name of the input field in the workflow for the given node and field
415
427
escaped by the prefix of the node if present"""
428
+ if isinstance (conn .source_out , DynamicField ):
429
+ logger .warning (
430
+ f"Not able to connect inputs from { conn .source_name } :{ conn .source_out } ->"
431
+ f"{ conn .target_name } :{ conn .target_in } properly due to adynamic-field "
432
+ "just connecting to source input for now"
433
+ )
416
434
try :
417
435
return self .make_input (
418
436
field_name = conn .source_out ,
@@ -1036,7 +1054,7 @@ def prepare_connections(self):
1036
1054
# append to parsed statements so set_output can be set
1037
1055
self .parsed_statements .append (conn_stmt )
1038
1056
while self ._unprocessed_connections :
1039
- conn = self ._unprocessed_connections .pop ()
1057
+ conn = self ._unprocessed_connections .pop (0 )
1040
1058
try :
1041
1059
inpt = self .get_input_from_conn (conn )
1042
1060
except KeyError :
@@ -1056,6 +1074,47 @@ def prepare_connections(self):
1056
1074
conn .target_in = outpt .name
1057
1075
outpt .in_conns .append (conn )
1058
1076
1077
+ # Overwrite connections with explict connections
1078
+ for inpt in list (self .inputs .values ()):
1079
+ for target_name , target_in in inpt .connections :
1080
+ conn = ConnectionStatement (
1081
+ indent = " " ,
1082
+ source_name = None ,
1083
+ source_out = inpt .name ,
1084
+ target_name = target_name ,
1085
+ target_in = target_in ,
1086
+ workflow_converter = self ,
1087
+ )
1088
+ for tgt_node in self .nodes [conn .target_name ]:
1089
+ try :
1090
+ existing_conn = next (
1091
+ c for c in tgt_node .in_conns if c .target_in == target_in
1092
+ )
1093
+ except StopIteration :
1094
+ pass
1095
+ else :
1096
+ tgt_node .in_conns .remove (existing_conn )
1097
+ self .inputs [existing_conn .source_out ].out_conns .remove (
1098
+ existing_conn
1099
+ )
1100
+ inpt .out_conns .append (conn )
1101
+ tgt_node .add_input_connection (conn )
1102
+
1103
+ for outpt in list (self .outputs .values ()):
1104
+ if outpt .connection :
1105
+ source_name , source_out = outpt .connection
1106
+ conn = ConnectionStatement (
1107
+ indent = " " ,
1108
+ source_name = source_name ,
1109
+ source_out = source_out ,
1110
+ target_name = None ,
1111
+ target_in = outpt .name ,
1112
+ workflow_converter = self ,
1113
+ )
1114
+ for src_node in self .nodes [conn .source_name ]:
1115
+ src_node .add_output_connection (conn )
1116
+ outpt .in_conns .append (conn )
1117
+
1059
1118
def _parse_statements (self , func_body : str ) -> ty .Tuple [
1060
1119
ty .List [
1061
1120
ty .Union [
0 commit comments