Skip to content

Commit 0cac4cd

Browse files
committed
added option to explicitly route connections from/to explicit inputs and outputs
1 parent dc3214f commit 0cac4cd

File tree

5 files changed

+71
-13
lines changed

5 files changed

+71
-13
lines changed

example-specs/pkg-gen/niworkflows.yaml

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -5,6 +5,7 @@ niworkflows:
55
- niworkflows.interfaces.bids.ReadSidecarJSON
66
- niworkflows.interfaces.fixes.FixHeaderApplyTransforms
77
- niworkflows.interfaces.fixes.FixN4BiasFieldCorrection
8+
- niworkflows.interfaces.fixes.FixHeaderRegistration
89
- niworkflows.interfaces.header.SanitizeImage
910
- niworkflows.interfaces.images.RobustAverage
1011
- niworkflows.interfaces.morphology.BinaryDilation

nipype2pydra/package.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -422,10 +422,10 @@ def collect_intra_pkg_objects(used: UsedSymbols, port_nipype: bool = True):
422422
intra_pkg_modules[conv.nipype_module_name].add(conv.nipype_object)
423423
collect_intra_pkg_objects(conv.used_symbols)
424424

425-
for converter in tqdm(
425+
for workflow in tqdm(
426426
workflows_to_include, "converting workflows from Nipype to Pydra syntax"
427427
):
428-
all_used = converter.write(
428+
all_used = workflow.write(
429429
package_root,
430430
already_converted=already_converted,
431431
)

nipype2pydra/pkg_gen/__init__.py

Lines changed: 3 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -403,7 +403,7 @@ def generate_callables(self, nipype_interface) -> str:
403403
if output_name not in INBUILT_NIPYPE_TRAIT_NAMES:
404404
callables_str += (
405405
f"def {output_name}_callable(output_dir, inputs, stdout, stderr):\n"
406-
" parsed_inputs = {}"
406+
" parsed_inputs = {}\n"
407407
" outputs = _list_outputs(output_dir=output_dir, inputs=inputs, stdout=stdout, stderr=stderr, parsed_inputs=parsed_inputs)\n"
408408
' return outputs["' + output_name + '"]\n\n'
409409
)
@@ -422,7 +422,7 @@ def generate_callables(self, nipype_interface) -> str:
422422
callables_str, fast=False, mode=black.FileMode()
423423
)
424424
except black.parsing.InvalidInput as e:
425-
with open(Path("~/Desktop/gen-code.py").expanduser(), "w") as f:
425+
with open(Path("~/unparsable-gen-code.py").expanduser(), "w") as f:
426426
f.write(callables_str)
427427
raise RuntimeError(
428428
f"Black could not parse generated code: {e}\n\n{callables_str}"
@@ -1026,9 +1026,7 @@ def process_method(
10261026
)
10271027
if hasattr(nipype_interface, "_cmd"):
10281028
body = body.replace("self.cmd", f'"{nipype_interface._cmd}"')
1029-
body = re.sub(
1030-
r"getattr\(self\.inputs, (\w+), None\)", r"inputs.get(\1)", body
1031-
)
1029+
body = re.sub(r"getattr\(self\.inputs, (\w+), None\)", r"inputs.get(\1)", body)
10321030
body = re.sub(r"getattr\(self\.inputs, (\w+)\)", r"inputs[\1]", body)
10331031
if attrs_as_parsed_inputs:
10341032
body = re.sub(

nipype2pydra/statements/workflow_build.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -106,7 +106,7 @@ class DynamicField(VarField):
106106
callable: ty.Callable = attrs.field()
107107

108108
def __repr__(self):
109-
return f"DelayedVarField({self.varname}, callable={self.callable})"
109+
return f"DynamicField({self.varname}, callable={self.callable})"
110110

111111

112112
@attrs.define

nipype2pydra/workflow.py

Lines changed: 64 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -36,6 +36,7 @@
3636
WorkflowInitStatement,
3737
AssignmentStatement,
3838
OtherStatement,
39+
DynamicField,
3940
)
4041
import nipype2pydra.package
4142

@@ -94,8 +95,7 @@ class WorkflowInterfaceField:
9495
factory=list,
9596
metadata={
9697
"help": (
97-
"node-name/field-name pairs of other fields that are to be routed to "
98-
"from other node fields to this input/output",
98+
"node-name/field-name pairs of additional fields that this input/output replaces",
9999
)
100100
},
101101
)
@@ -159,6 +159,11 @@ def __hash__(self):
159159
@attrs.define
160160
class WorkflowInput(WorkflowInterfaceField):
161161

162+
connections: ty.Tuple[ty.Tuple[str, str]] = attrs.field(
163+
converter=lambda lst: tuple(sorted(tuple(t) for t in lst)),
164+
factory=list,
165+
metadata={"help": ("Explicit connections to be made from this input field",)},
166+
)
162167
out_conns: ty.List[ConnectionStatement] = attrs.field(
163168
factory=list,
164169
eq=False,
@@ -170,9 +175,7 @@ class WorkflowInput(WorkflowInterfaceField):
170175
)
171176
},
172177
)
173-
174178
include: bool = attrs.field(
175-
default=False,
176179
eq=False,
177180
hash=False,
178181
metadata={
@@ -183,13 +186,22 @@ class WorkflowInput(WorkflowInterfaceField):
183186
},
184187
)
185188

189+
@include.default
190+
def _include_default(self) -> bool:
191+
return bool(self.connections)
192+
186193
def __hash__(self):
187194
return super().__hash__()
188195

189196

190197
@attrs.define
191198
class WorkflowOutput(WorkflowInterfaceField):
192199

200+
connection: ty.Tuple[str, str] = attrs.field(
201+
converter=tuple,
202+
factory=list,
203+
metadata={"help": ("Explicit connection to be made to this output field",)},
204+
)
193205
in_conns: ty.List[ConnectionStatement] = attrs.field(
194206
factory=list,
195207
eq=False,
@@ -413,6 +425,12 @@ def get_input_from_conn(self, conn: ConnectionStatement) -> WorkflowInput:
413425
"""
414426
Returns the name of the input field in the workflow for the given node and field
415427
escaped by the prefix of the node if present"""
428+
if isinstance(conn.source_out, DynamicField):
429+
logger.warning(
430+
f"Not able to connect inputs from {conn.source_name}:{conn.source_out}->"
431+
f"{conn.target_name}:{conn.target_in} properly due to adynamic-field "
432+
"just connecting to source input for now"
433+
)
416434
try:
417435
return self.make_input(
418436
field_name=conn.source_out,
@@ -1036,7 +1054,7 @@ def prepare_connections(self):
10361054
# append to parsed statements so set_output can be set
10371055
self.parsed_statements.append(conn_stmt)
10381056
while self._unprocessed_connections:
1039-
conn = self._unprocessed_connections.pop()
1057+
conn = self._unprocessed_connections.pop(0)
10401058
try:
10411059
inpt = self.get_input_from_conn(conn)
10421060
except KeyError:
@@ -1056,6 +1074,47 @@ def prepare_connections(self):
10561074
conn.target_in = outpt.name
10571075
outpt.in_conns.append(conn)
10581076

1077+
# Overwrite connections with explict connections
1078+
for inpt in list(self.inputs.values()):
1079+
for target_name, target_in in inpt.connections:
1080+
conn = ConnectionStatement(
1081+
indent=" ",
1082+
source_name=None,
1083+
source_out=inpt.name,
1084+
target_name=target_name,
1085+
target_in=target_in,
1086+
workflow_converter=self,
1087+
)
1088+
for tgt_node in self.nodes[conn.target_name]:
1089+
try:
1090+
existing_conn = next(
1091+
c for c in tgt_node.in_conns if c.target_in == target_in
1092+
)
1093+
except StopIteration:
1094+
pass
1095+
else:
1096+
tgt_node.in_conns.remove(existing_conn)
1097+
self.inputs[existing_conn.source_out].out_conns.remove(
1098+
existing_conn
1099+
)
1100+
inpt.out_conns.append(conn)
1101+
tgt_node.add_input_connection(conn)
1102+
1103+
for outpt in list(self.outputs.values()):
1104+
if outpt.connection:
1105+
source_name, source_out = outpt.connection
1106+
conn = ConnectionStatement(
1107+
indent=" ",
1108+
source_name=source_name,
1109+
source_out=source_out,
1110+
target_name=None,
1111+
target_in=outpt.name,
1112+
workflow_converter=self,
1113+
)
1114+
for src_node in self.nodes[conn.source_name]:
1115+
src_node.add_output_connection(conn)
1116+
outpt.in_conns.append(conn)
1117+
10591118
def _parse_statements(self, func_body: str) -> ty.Tuple[
10601119
ty.List[
10611120
ty.Union[

0 commit comments

Comments
 (0)