Skip to content

Commit c55c945

Browse files
authored
Merge pull request #938 from mgxd/enh/workflow-splicer
ENH: Workflow splicer module
2 parents d2fda2a + 336f998 commit c55c945

File tree

3 files changed

+291
-0
lines changed

3 files changed

+291
-0
lines changed

niworkflows/engine/__init__.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -4,4 +4,5 @@
44
The fmriprep reporting engine for visual assessment
55
"""
66

7+
from .splicer import splice_workflow, tag
78
from .workflows import LiterateWorkflow as Workflow

niworkflows/engine/splicer.py

Lines changed: 127 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,127 @@
1+
"""Workflow splicing operations."""
2+
3+
import logging
4+
import typing as ty
5+
6+
import nipype.pipeline.engine as pe
7+
from nipype.pipeline.engine.base import EngineBase
8+
9+
10+
def tag(tag: str) -> ty.Callable:
11+
"""
12+
Decorator to set a tag on an `init_...wf` function.
13+
14+
This is used to mark nodes or workflows for replacement in the splicing process.
15+
"""
16+
17+
def _decorator(func, *args, **kwargs) -> ty.Callable:
18+
def _tag() -> EngineBase:
19+
node = func(*args, **kwargs)
20+
node._tag = tag
21+
return node
22+
23+
return _tag
24+
25+
return _decorator
26+
27+
28+
def splice_workflow(
29+
root_wf: pe.Workflow,
30+
replacements: dict[str, EngineBase],
31+
*,
32+
write_graph: bool = False,
33+
debug: bool = False,
34+
):
35+
"""
36+
Splice a workflow's tagged nodes / workflows and replace connections with alternatives.
37+
38+
Requires that the workflow has been tagged with a `_tag` attribute.
39+
"""
40+
if write_graph:
41+
root_wf.write_graph('pre-slice.dot', format='png', graph2use='colored')
42+
43+
substitutions = _get_substitutions(root_wf, replacements)
44+
_splice_components(root_wf, substitutions, debug=debug)
45+
46+
if write_graph:
47+
root_wf.write_graph('post-slice.dot', format='png', graph2use='colored')
48+
return root_wf
49+
50+
51+
def _get_substitutions(
52+
workflow: pe.Workflow,
53+
replacements: dict[str, EngineBase],
54+
) -> dict[EngineBase, EngineBase]:
55+
""" "Query tags in workflow, and return a list of substitutions to make"""
56+
substitutions = {}
57+
tagged_wfs = _fetch_tags(workflow)
58+
for tag in tagged_wfs:
59+
if tag in replacements:
60+
substitutions[tagged_wfs[tag]] = replacements[tag]
61+
return substitutions
62+
63+
64+
def _fetch_tags(wf: pe.Workflow) -> dict[str, EngineBase]:
65+
"""Query all nodes in a workflow and return a dictionary of tags and nodes."""
66+
tagged = {}
67+
for node in wf._graph.nodes:
68+
if hasattr(node, '_tag'):
69+
tagged[node._tag] = node
70+
if isinstance(node, pe.Workflow):
71+
inner_tags = _fetch_tags(node)
72+
tagged.update(inner_tags)
73+
return tagged
74+
75+
76+
def _splice_components(
77+
workflow: pe.Workflow,
78+
substitutions: dict[EngineBase, EngineBase],
79+
debug: bool = False,
80+
) -> tuple[list, list]:
81+
"""Query all connections and return a list of removals and additions to be made."""
82+
edge_removals = []
83+
edge_connects = []
84+
node_removals = set()
85+
node_adds = set()
86+
_expanded_workflows = set()
87+
88+
to_replace = [x.fullname for x in substitutions]
89+
90+
for src, dst in workflow._graph.edges:
91+
if dst.fullname in to_replace:
92+
edge_data = workflow._graph.get_edge_data(src, dst)
93+
alt_dst = substitutions[dst]
94+
alt_dst._hierarchy = dst._hierarchy
95+
96+
edge_removals.append((src, dst))
97+
node_removals.add(dst)
98+
node_adds.add(alt_dst)
99+
edge_connects.append((src, alt_dst, edge_data))
100+
elif src.fullname in to_replace:
101+
edge_data = workflow._graph.get_edge_data(src, dst)
102+
alt_src = substitutions[src]
103+
alt_src._hierarchy = src._hierarchy
104+
105+
edge_removals.append((src, dst))
106+
node_removals.add(src)
107+
node_adds.add(alt_src)
108+
edge_connects.append((alt_src, dst, edge_data))
109+
elif isinstance(dst, pe.Workflow) and dst not in _expanded_workflows:
110+
_expanded_workflows.add(dst)
111+
_splice_components(dst, substitutions, debug=debug)
112+
elif isinstance(src, pe.Workflow) and src not in _expanded_workflows:
113+
_expanded_workflows.add(src)
114+
_splice_components(src, substitutions, debug=debug)
115+
116+
logger = logging.getLogger('nipype.workflow')
117+
logger.debug(
118+
'Workflow: %s, \n- edge_removals: %s, \n+ edge_connects: %s',
119+
workflow,
120+
edge_removals,
121+
edge_connects,
122+
)
123+
124+
workflow._graph.remove_edges_from(edge_removals)
125+
workflow.remove_nodes(node_removals)
126+
workflow.add_nodes(node_adds)
127+
workflow._graph.add_edges_from(edge_connects)
Lines changed: 163 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,163 @@
1+
from __future__ import annotations
2+
3+
import pytest
4+
from nipype import Node, Workflow
5+
from nipype.interfaces.base import BaseInterfaceInputSpec, SimpleInterface, TraitedSpec, traits
6+
from nipype.interfaces.utility import IdentityInterface
7+
8+
from ..splicer import splice_workflow, tag
9+
10+
11+
class _NullInterfaceInputSpec(BaseInterfaceInputSpec):
12+
in1 = traits.Int(default=0, usedefault=True, desc='Input 1')
13+
in2 = traits.Int(default=0, usedefault=True, desc='Input 2')
14+
15+
16+
class _NullInterfaceOutputSpec(TraitedSpec):
17+
out1 = traits.Int(desc='Output 1')
18+
out2 = traits.Int(desc='Output 2')
19+
20+
21+
class NullInterface(SimpleInterface):
22+
"""
23+
A simple interface that does nothing.
24+
"""
25+
26+
input_spec = _NullInterfaceInputSpec
27+
output_spec = _NullInterfaceOutputSpec
28+
29+
def _run_interface(self, runtime):
30+
self._results['out1'] = self.inputs.in1
31+
self._results['out2'] = self.inputs.in2
32+
return runtime
33+
34+
35+
def _create_nested_null_wf(name: str, tag: str | None = None):
36+
wf = Workflow(name=name)
37+
if tag:
38+
wf._tag = tag
39+
40+
inputnode = Node(IdentityInterface(fields=['in1', 'in2']), name='inputnode')
41+
outputnode = Node(IdentityInterface(fields=['out1', 'out2']), name='outputnode')
42+
43+
n1 = Node(NullInterface(), name='null1')
44+
n2_wf = _create_null_wf('nested_wf', tag='nested')
45+
n3 = Node(NullInterface(), name='null3')
46+
47+
wf.connect([
48+
(inputnode, n1, [
49+
('in1', 'in1'),
50+
('in2', 'in2'),
51+
]),
52+
(n1, n2_wf, [('out1', 'inputnode.in1')]),
53+
(n2_wf, n3, [('outputnode.out1', 'in1')]),
54+
(n3, outputnode, [
55+
('out1', 'out1'),
56+
('out2', 'out2'),
57+
]),
58+
]) # fmt:skip
59+
return wf
60+
61+
62+
def _create_null_wf(name: str, tag: str | None = None):
63+
wf = Workflow(name=name)
64+
if tag:
65+
wf._tag = tag
66+
67+
inputnode = Node(IdentityInterface(fields=['in1', 'in2']), name='inputnode')
68+
outputnode = Node(IdentityInterface(fields=['out1', 'out2']), name='outputnode')
69+
70+
n1 = Node(NullInterface(), name='null1')
71+
n2 = Node(NullInterface(), name='null2')
72+
n3 = Node(NullInterface(), name='null3')
73+
74+
wf.connect([
75+
(inputnode, n1, [
76+
('in1', 'in1'),
77+
('in2', 'in2'),
78+
]),
79+
(n1, n2, [('out1', 'in1')]),
80+
(n2, n3, [('out1', 'in1')]),
81+
(n3, outputnode, [
82+
('out1', 'out1'),
83+
('out2', 'out2'),
84+
]),
85+
]) # fmt:skip
86+
return wf
87+
88+
89+
@pytest.fixture
90+
def wf0(tmp_path) -> Workflow:
91+
"""
92+
Create a tagged workflow.
93+
"""
94+
wf = Workflow(name='root', base_dir=tmp_path)
95+
wf._tag = 'root'
96+
97+
inputnode = Node(IdentityInterface(fields=['in1', 'in2']), name='inputnode')
98+
inputnode.inputs.in1 = 1
99+
inputnode.inputs.in2 = 2
100+
outputnode = Node(IdentityInterface(fields=['out1', 'out2']), name='outputnode')
101+
102+
a_in = Node(IdentityInterface(fields=['in1', 'in2']), name='a_in')
103+
a_wf = _create_null_wf('a_wf', tag='a')
104+
a_out = Node(IdentityInterface(fields=['out1', 'out2']), name='a_out')
105+
106+
b_in = Node(IdentityInterface(fields=['in1', 'in2']), name='b_in')
107+
b_wf = _create_nested_null_wf('b_wf', tag='b')
108+
b_out = Node(IdentityInterface(fields=['in1', 'out2']), name='b_out')
109+
110+
wf.connect([
111+
(inputnode, a_in, [
112+
('in1', 'in1'),
113+
('in2', 'in2'),
114+
]),
115+
(a_in, a_wf, [
116+
('in1', 'inputnode.in1'),
117+
('in2', 'inputnode.in2'),
118+
]),
119+
(a_wf, a_out, [
120+
('outputnode.out1', 'out1'),
121+
('outputnode.out2', 'out2'),
122+
]),
123+
(a_out, b_in, [
124+
('out1', 'in1'),
125+
('out2', 'in2'),
126+
]),
127+
(b_in, b_wf, [
128+
('in1', 'inputnode.in1'),
129+
('in2', 'inputnode.in2'),
130+
]),
131+
(b_wf, b_out, [
132+
('outputnode.out1', 'out1'),
133+
('outputnode.out2', 'out2'),
134+
]),
135+
(a_out, outputnode, [
136+
('out1', 'out1'),
137+
]),
138+
(b_out, outputnode, [
139+
('out2', 'out2'),
140+
]),
141+
]) # fmt:skip
142+
return wf
143+
144+
145+
def test_splice(wf0):
146+
replacements = {
147+
'a': _create_null_wf('a2_wf', tag='a'),
148+
'nested': _create_null_wf('nested2_wf', tag='nested'),
149+
'c': _create_null_wf('c_wf', tag='c'),
150+
}
151+
wf = splice_workflow(wf0, replacements)
152+
153+
assert wf.get_node('a2_wf')
154+
assert wf.get_node('b_wf').get_node('nested2_wf')
155+
assert wf.get_node('c_wf') is None
156+
157+
158+
def test_tag():
159+
@tag('foo')
160+
def init_workflow():
161+
return Workflow(name='foo')
162+
163+
assert init_workflow()._tag == 'foo'

0 commit comments

Comments
 (0)