Skip to content

Commit 784539e

Browse files
committed
ENH: Add workflow splicer module
1 parent d2fda2a commit 784539e

File tree

1 file changed

+125
-0
lines changed

1 file changed

+125
-0
lines changed

niworkflows/engine/splicer.py

Lines changed: 125 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,125 @@
1+
"""Workflow splicing operations."""
2+
3+
import typing as ty
4+
5+
from nipype.pipeline import Workflow
6+
from nipype.pipeline.engine.base import EngineBase
7+
8+
9+
def tag(tag: str) -> 'EngineBase':
10+
"""
11+
Decorator to set a tag on an `init_...wf` function.
12+
13+
This is used to mark nodes or workflows for replacement in the splicing process.
14+
"""
15+
16+
def _decorator(func, *args, **kwargs) -> ty.Callable:
17+
def _tag() -> 'EngineBase':
18+
node = func(*args, **kwargs)
19+
node._tag = tag
20+
return node
21+
22+
return _tag
23+
24+
return _decorator
25+
26+
27+
def splice_workflow(
28+
root_wf: Workflow,
29+
replacements: dict[str, 'EngineBase'],
30+
*,
31+
write_graph: bool = False,
32+
debug: bool = False,
33+
):
34+
"""
35+
Splice a workflow's tagged nodes / workflows and replace connections with alternatives.
36+
37+
Requires that the workflow has been tagged with a `_tag` attribute.
38+
"""
39+
if write_graph:
40+
root_wf.write_graph('pre-slice.dot', format='png', graph2use='colored')
41+
42+
substitutions = _get_substitutions(root_wf, replacements)
43+
print(f'Substitutions: {substitutions}')
44+
_splice_components(root_wf, substitutions, debug=debug)
45+
46+
if write_graph:
47+
root_wf.write_graph('post-slice.dot', format='png', graph2use='colored')
48+
return root_wf
49+
50+
51+
def _get_substitutions(
52+
wf: Workflow,
53+
replacements: dict[str, 'EngineBase'],
54+
) -> dict['EngineBase', 'EngineBase']:
55+
""" "Query tags in workflow, and return a list of substitutions to make"""
56+
substitutions = {}
57+
tagged_wfs = _fetch_tags(wf)
58+
for tag in tagged_wfs:
59+
if tag in replacements:
60+
substitutions[tagged_wfs[tag]] = replacements[tag]
61+
return substitutions
62+
63+
64+
def _fetch_tags(wf: Workflow) -> dict[str, 'EngineBase']:
65+
"""Query all nodes in a workflow and return a dictionary of tags and nodes."""
66+
tagged = {}
67+
print(f'Querying {wf}')
68+
for node in wf._graph.nodes:
69+
if hasattr(node, '_tag'):
70+
tagged[node._tag] = node
71+
if isinstance(node, Workflow):
72+
inner_tags = _fetch_tags(node)
73+
tagged.update(inner_tags)
74+
return tagged
75+
76+
77+
def _splice_components(
78+
workflow: Workflow,
79+
substitutions: dict['EngineBase', 'EngineBase'],
80+
debug: bool = False,
81+
) -> tuple[list, list]:
82+
"""Query all connections and return a list of removals and additions to be made."""
83+
edge_removals = []
84+
edge_connects = []
85+
node_removals = set()
86+
node_adds = set()
87+
_expanded_workflows = set()
88+
89+
to_replace = [x.fullname for x in substitutions]
90+
91+
for src, dst in workflow._graph.edges: # will not expand workflows, but needs to
92+
if dst.fullname in to_replace:
93+
edge_data = workflow._graph.get_edge_data(src, dst)
94+
alt_dst = substitutions[dst]
95+
alt_dst._hierarchy = dst._hierarchy
96+
97+
edge_removals.append((src, dst))
98+
node_removals.add(dst)
99+
node_adds.add(alt_dst)
100+
edge_connects.append((src, alt_dst, edge_data))
101+
elif src.fullname in to_replace:
102+
edge_data = workflow._graph.get_edge_data(src, dst)
103+
alt_src = substitutions[src]
104+
alt_src._hierarchy = src._hierarchy
105+
106+
edge_removals.append((src, dst))
107+
node_removals.add(src)
108+
node_adds.add(alt_src)
109+
edge_connects.append((alt_src, dst, edge_data))
110+
elif isinstance(dst, Workflow) and dst not in _expanded_workflows:
111+
_expanded_workflows.add(dst)
112+
_splice_components(dst, substitutions, debug=debug)
113+
elif isinstance(src, Workflow) and src not in _expanded_workflows:
114+
_expanded_workflows.add(src)
115+
_splice_components(src, substitutions, debug=debug)
116+
117+
if debug:
118+
print(f'Workflow: {workflow}')
119+
print(f'- Removing: {edge_removals}')
120+
print(f'+ Adding: {edge_connects}')
121+
122+
workflow._graph.remove_edges_from(edge_removals)
123+
workflow.remove_nodes(node_removals)
124+
workflow.add_nodes(node_adds)
125+
workflow._graph.add_edges_from(edge_connects)

0 commit comments

Comments
 (0)