Skip to content

Commit ad6a89d

Browse files
committed
TST: Add splice test
1 parent 784539e commit ad6a89d

File tree

1 file changed

+155
-0
lines changed

1 file changed

+155
-0
lines changed
Lines changed: 155 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,155 @@
1+
from __future__ import annotations
2+
3+
import pytest
4+
from nipype import Node, Workflow
5+
from nipype.interfaces.base import BaseInterfaceInputSpec, SimpleInterface, TraitedSpec, traits
6+
from nipype.interfaces.utility import IdentityInterface
7+
8+
from ..splicer import splice_workflow
9+
10+
11+
class _NullInterfaceInputSpec(BaseInterfaceInputSpec):
12+
in1 = traits.Int(default=0, usedefault=True, desc='Input 1')
13+
in2 = traits.Int(default=0, usedefault=True, desc='Input 2')
14+
15+
16+
class _NullInterfaceOutputSpec(TraitedSpec):
17+
out1 = traits.Int(desc='Output 1')
18+
out2 = traits.Int(desc='Output 2')
19+
20+
21+
class NullInterface(SimpleInterface):
22+
"""
23+
A simple interface that does nothing.
24+
"""
25+
26+
input_spec = _NullInterfaceInputSpec
27+
output_spec = _NullInterfaceOutputSpec
28+
29+
def _run_interface(self, runtime):
30+
self._results['out1'] = self.inputs.in1
31+
self._results['out2'] = self.inputs.in2
32+
return runtime
33+
34+
35+
def _create_nested_null_wf(name: str, tag: str | None = None):
36+
wf = Workflow(name=name)
37+
if tag:
38+
wf._tag = tag
39+
40+
inputnode = Node(IdentityInterface(fields=['in1', 'in2']), name='inputnode')
41+
outputnode = Node(IdentityInterface(fields=['out1', 'out2']), name='outputnode')
42+
43+
n1 = Node(NullInterface(), name='null1')
44+
n2_wf = _create_null_wf('nested_wf', tag='nested')
45+
n3 = Node(NullInterface(), name='null3')
46+
47+
wf.connect([
48+
(inputnode, n1, [
49+
('in1', 'in1'),
50+
('in2', 'in2'),
51+
]),
52+
(n1, n2_wf, [('out1', 'inputnode.in1')]),
53+
(n2_wf, n3, [('outputnode.out1', 'in1')]),
54+
(n3, outputnode, [
55+
('out1', 'out1'),
56+
('out2', 'out2'),
57+
]),
58+
]) # fmt:skip
59+
return wf
60+
61+
62+
def _create_null_wf(name: str, tag: str | None = None):
63+
wf = Workflow(name=name)
64+
if tag:
65+
wf._tag = tag
66+
67+
inputnode = Node(IdentityInterface(fields=['in1', 'in2']), name='inputnode')
68+
outputnode = Node(IdentityInterface(fields=['out1', 'out2']), name='outputnode')
69+
70+
n1 = Node(NullInterface(), name='null1')
71+
n2 = Node(NullInterface(), name='null2')
72+
n3 = Node(NullInterface(), name='null3')
73+
74+
wf.connect([
75+
(inputnode, n1, [
76+
('in1', 'in1'),
77+
('in2', 'in2'),
78+
]),
79+
(n1, n2, [('out1', 'in1')]),
80+
(n2, n3, [('out1', 'in1')]),
81+
(n3, outputnode, [
82+
('out1', 'out1'),
83+
('out2', 'out2'),
84+
]),
85+
]) # fmt:skip
86+
return wf
87+
88+
89+
@pytest.fixture
90+
def wf0(tmp_path) -> Workflow:
91+
"""
92+
Create a tagged workflow.
93+
"""
94+
wf = Workflow(name='root', base_dir=tmp_path)
95+
wf._tag = 'root'
96+
97+
inputnode = Node(IdentityInterface(fields=['in1', 'in2']), name='inputnode')
98+
inputnode.inputs.in1 = 1
99+
inputnode.inputs.in2 = 2
100+
outputnode = Node(IdentityInterface(fields=['out1', 'out2']), name='outputnode')
101+
102+
a_in = Node(IdentityInterface(fields=['in1', 'in2']), name='a_in')
103+
a_wf = _create_null_wf('a_wf', tag='a')
104+
a_out = Node(IdentityInterface(fields=['out1', 'out2']), name='a_out')
105+
106+
b_in = Node(IdentityInterface(fields=['in1', 'in2']), name='b_in')
107+
b_wf = _create_nested_null_wf('b_wf', tag='b')
108+
b_out = Node(IdentityInterface(fields=['in1', 'out2']), name='b_out')
109+
110+
wf.connect([
111+
(inputnode, a_in, [
112+
('in1', 'in1'),
113+
('in2', 'in2'),
114+
]),
115+
(a_in, a_wf, [
116+
('in1', 'inputnode.in1'),
117+
('in2', 'inputnode.in2'),
118+
]),
119+
(a_wf, a_out, [
120+
('outputnode.out1', 'out1'),
121+
('outputnode.out2', 'out2'),
122+
]),
123+
(a_out, b_in, [
124+
('out1', 'in1'),
125+
('out2', 'in2'),
126+
]),
127+
(b_in, b_wf, [
128+
('in1', 'inputnode.in1'),
129+
('in2', 'inputnode.in2'),
130+
]),
131+
(b_wf, b_out, [
132+
('outputnode.out1', 'out1'),
133+
('outputnode.out2', 'out2'),
134+
]),
135+
(a_out, outputnode, [
136+
('out1', 'out1'),
137+
]),
138+
(b_out, outputnode, [
139+
('out2', 'out2'),
140+
]),
141+
]) # fmt:skip
142+
return wf
143+
144+
145+
def test_splice(wf0):
146+
replacements = {
147+
'a': _create_null_wf('a2_wf', tag='a'),
148+
'nested': _create_null_wf('nested2_wf', tag='nested'),
149+
'c': _create_null_wf('c_wf', tag='c'),
150+
}
151+
wf = splice_workflow(wf0, replacements, write_graph=True, debug=True)
152+
153+
assert wf.get_node('a2_wf')
154+
assert wf.get_node('b_wf').get_node('nested2_wf')
155+
assert wf.get_node('c_wf') is None

0 commit comments

Comments
 (0)