Skip to content
Merged
Show file tree
Hide file tree
Changes from 4 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions niworkflows/engine/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,4 +4,5 @@
The fmriprep reporting engine for visual assessment
"""

from .splicer import splice_workflow, tag
from .workflows import LiterateWorkflow as Workflow
123 changes: 123 additions & 0 deletions niworkflows/engine/splicer.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,123 @@
"""Workflow splicing operations."""

import typing as ty

from nipype.pipeline import Workflow
from nipype.pipeline.engine.base import EngineBase


def tag(tag: str) -> 'EngineBase':
"""
Decorator to set a tag on an `init_...wf` function.

This is used to mark nodes or workflows for replacement in the splicing process.
"""

def _decorator(func, *args, **kwargs) -> ty.Callable:
def _tag() -> 'EngineBase':
node = func(*args, **kwargs)
node._tag = tag
return node

Check warning on line 20 in niworkflows/engine/splicer.py

View check run for this annotation

Codecov / codecov/patch

niworkflows/engine/splicer.py#L16-L20

Added lines #L16 - L20 were not covered by tests

return _tag

Check warning on line 22 in niworkflows/engine/splicer.py

View check run for this annotation

Codecov / codecov/patch

niworkflows/engine/splicer.py#L22

Added line #L22 was not covered by tests

return _decorator

Check warning on line 24 in niworkflows/engine/splicer.py

View check run for this annotation

Codecov / codecov/patch

niworkflows/engine/splicer.py#L24

Added line #L24 was not covered by tests


def splice_workflow(
root_wf: Workflow,
replacements: dict[str, 'EngineBase'],
*,
write_graph: bool = False,
debug: bool = False,
):
"""
Splice a workflow's tagged nodes / workflows and replace connections with alternatives.

Requires that the workflow has been tagged with a `_tag` attribute.
"""
if write_graph:
root_wf.write_graph('pre-slice.dot', format='png', graph2use='colored')

Check warning on line 40 in niworkflows/engine/splicer.py

View check run for this annotation

Codecov / codecov/patch

niworkflows/engine/splicer.py#L40

Added line #L40 was not covered by tests

substitutions = _get_substitutions(root_wf, replacements)
_splice_components(root_wf, substitutions, debug=debug)

if write_graph:
root_wf.write_graph('post-slice.dot', format='png', graph2use='colored')

Check warning on line 46 in niworkflows/engine/splicer.py

View check run for this annotation

Codecov / codecov/patch

niworkflows/engine/splicer.py#L46

Added line #L46 was not covered by tests
return root_wf


def _get_substitutions(
wf: Workflow,
replacements: dict[str, 'EngineBase'],
) -> dict['EngineBase', 'EngineBase']:
""" "Query tags in workflow, and return a list of substitutions to make"""
substitutions = {}
tagged_wfs = _fetch_tags(wf)
for tag in tagged_wfs:
if tag in replacements:
substitutions[tagged_wfs[tag]] = replacements[tag]
return substitutions


def _fetch_tags(wf: Workflow) -> dict[str, 'EngineBase']:
"""Query all nodes in a workflow and return a dictionary of tags and nodes."""
tagged = {}
for node in wf._graph.nodes:
if hasattr(node, '_tag'):
tagged[node._tag] = node
if isinstance(node, Workflow):
inner_tags = _fetch_tags(node)
tagged.update(inner_tags)
return tagged


def _splice_components(
workflow: Workflow,
substitutions: dict['EngineBase', 'EngineBase'],
debug: bool = False,
) -> tuple[list, list]:
"""Query all connections and return a list of removals and additions to be made."""
edge_removals = []
edge_connects = []
node_removals = set()
node_adds = set()
_expanded_workflows = set()

to_replace = [x.fullname for x in substitutions]

for src, dst in workflow._graph.edges: # will not expand workflows, but needs to
if dst.fullname in to_replace:
edge_data = workflow._graph.get_edge_data(src, dst)
alt_dst = substitutions[dst]
alt_dst._hierarchy = dst._hierarchy

edge_removals.append((src, dst))
node_removals.add(dst)
node_adds.add(alt_dst)
edge_connects.append((src, alt_dst, edge_data))
elif src.fullname in to_replace:
edge_data = workflow._graph.get_edge_data(src, dst)
alt_src = substitutions[src]
alt_src._hierarchy = src._hierarchy

edge_removals.append((src, dst))
node_removals.add(src)
node_adds.add(alt_src)
edge_connects.append((alt_src, dst, edge_data))
elif isinstance(dst, Workflow) and dst not in _expanded_workflows:
_expanded_workflows.add(dst)
_splice_components(dst, substitutions, debug=debug)
elif isinstance(src, Workflow) and src not in _expanded_workflows:
_expanded_workflows.add(src)
_splice_components(src, substitutions, debug=debug)

Check warning on line 113 in niworkflows/engine/splicer.py

View check run for this annotation

Codecov / codecov/patch

niworkflows/engine/splicer.py#L112-L113

Added lines #L112 - L113 were not covered by tests

if debug:
print(f'Workflow: {workflow}')
print(f'- Removing: {edge_removals}')
print(f'+ Adding: {edge_connects}')

Check warning on line 118 in niworkflows/engine/splicer.py

View check run for this annotation

Codecov / codecov/patch

niworkflows/engine/splicer.py#L116-L118

Added lines #L116 - L118 were not covered by tests

workflow._graph.remove_edges_from(edge_removals)
workflow.remove_nodes(node_removals)
workflow.add_nodes(node_adds)
workflow._graph.add_edges_from(edge_connects)
155 changes: 155 additions & 0 deletions niworkflows/engine/tests/test_splicer.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,155 @@
from __future__ import annotations

import pytest
from nipype import Node, Workflow
from nipype.interfaces.base import BaseInterfaceInputSpec, SimpleInterface, TraitedSpec, traits
from nipype.interfaces.utility import IdentityInterface

from ..splicer import splice_workflow


class _NullInterfaceInputSpec(BaseInterfaceInputSpec):
in1 = traits.Int(default=0, usedefault=True, desc='Input 1')
in2 = traits.Int(default=0, usedefault=True, desc='Input 2')


class _NullInterfaceOutputSpec(TraitedSpec):
out1 = traits.Int(desc='Output 1')
out2 = traits.Int(desc='Output 2')


class NullInterface(SimpleInterface):
"""
A simple interface that does nothing.
"""

input_spec = _NullInterfaceInputSpec
output_spec = _NullInterfaceOutputSpec

def _run_interface(self, runtime):
self._results['out1'] = self.inputs.in1
self._results['out2'] = self.inputs.in2
return runtime

Check warning on line 32 in niworkflows/engine/tests/test_splicer.py

View check run for this annotation

Codecov / codecov/patch

niworkflows/engine/tests/test_splicer.py#L30-L32

Added lines #L30 - L32 were not covered by tests


def _create_nested_null_wf(name: str, tag: str | None = None):
wf = Workflow(name=name)
if tag:
wf._tag = tag

inputnode = Node(IdentityInterface(fields=['in1', 'in2']), name='inputnode')
outputnode = Node(IdentityInterface(fields=['out1', 'out2']), name='outputnode')

n1 = Node(NullInterface(), name='null1')
n2_wf = _create_null_wf('nested_wf', tag='nested')
n3 = Node(NullInterface(), name='null3')

wf.connect([
(inputnode, n1, [
('in1', 'in1'),
('in2', 'in2'),
]),
(n1, n2_wf, [('out1', 'inputnode.in1')]),
(n2_wf, n3, [('outputnode.out1', 'in1')]),
(n3, outputnode, [
('out1', 'out1'),
('out2', 'out2'),
]),
]) # fmt:skip
return wf


def _create_null_wf(name: str, tag: str | None = None):
wf = Workflow(name=name)
if tag:
wf._tag = tag

inputnode = Node(IdentityInterface(fields=['in1', 'in2']), name='inputnode')
outputnode = Node(IdentityInterface(fields=['out1', 'out2']), name='outputnode')

n1 = Node(NullInterface(), name='null1')
n2 = Node(NullInterface(), name='null2')
n3 = Node(NullInterface(), name='null3')

wf.connect([
(inputnode, n1, [
('in1', 'in1'),
('in2', 'in2'),
]),
(n1, n2, [('out1', 'in1')]),
(n2, n3, [('out1', 'in1')]),
(n3, outputnode, [
('out1', 'out1'),
('out2', 'out2'),
]),
]) # fmt:skip
return wf


@pytest.fixture
def wf0(tmp_path) -> Workflow:
"""
Create a tagged workflow.
"""
wf = Workflow(name='root', base_dir=tmp_path)
wf._tag = 'root'

inputnode = Node(IdentityInterface(fields=['in1', 'in2']), name='inputnode')
inputnode.inputs.in1 = 1
inputnode.inputs.in2 = 2
outputnode = Node(IdentityInterface(fields=['out1', 'out2']), name='outputnode')

a_in = Node(IdentityInterface(fields=['in1', 'in2']), name='a_in')
a_wf = _create_null_wf('a_wf', tag='a')
a_out = Node(IdentityInterface(fields=['out1', 'out2']), name='a_out')

b_in = Node(IdentityInterface(fields=['in1', 'in2']), name='b_in')
b_wf = _create_nested_null_wf('b_wf', tag='b')
b_out = Node(IdentityInterface(fields=['in1', 'out2']), name='b_out')

wf.connect([
(inputnode, a_in, [
('in1', 'in1'),
('in2', 'in2'),
]),
(a_in, a_wf, [
('in1', 'inputnode.in1'),
('in2', 'inputnode.in2'),
]),
(a_wf, a_out, [
('outputnode.out1', 'out1'),
('outputnode.out2', 'out2'),
]),
(a_out, b_in, [
('out1', 'in1'),
('out2', 'in2'),
]),
(b_in, b_wf, [
('in1', 'inputnode.in1'),
('in2', 'inputnode.in2'),
]),
(b_wf, b_out, [
('outputnode.out1', 'out1'),
('outputnode.out2', 'out2'),
]),
(a_out, outputnode, [
('out1', 'out1'),
]),
(b_out, outputnode, [
('out2', 'out2'),
]),
]) # fmt:skip
return wf


def test_splice(wf0):
replacements = {
'a': _create_null_wf('a2_wf', tag='a'),
'nested': _create_null_wf('nested2_wf', tag='nested'),
'c': _create_null_wf('c_wf', tag='c'),
}
wf = splice_workflow(wf0, replacements)

assert wf.get_node('a2_wf')
assert wf.get_node('b_wf').get_node('nested2_wf')
assert wf.get_node('c_wf') is None
Loading