Skip to content

Commit 6570949

Browse files
committed
fix: ensure io classes derived from IOBase but not in nipype allow appropriate connections
1 parent f290b7e commit 6570949

File tree

2 files changed

+41
-2
lines changed

2 files changed

+41
-2
lines changed

nipype/pipeline/engine/tests/test_engine.py

Lines changed: 34 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -734,3 +734,37 @@ def test_deep_nested_write_graph_runs(tmpdir):
734734
os.remove('graph_detailed.dot')
735735
except OSError:
736736
pass
737+
738+
739+
def test_io_subclass():
740+
"""Ensure any io subclass allows dynamic traits"""
741+
from nipype.interfaces.io import IOBase
742+
from nipype.interfaces.base import DynamicTraitedSpec
743+
744+
class TestKV(IOBase):
745+
_always_run = True
746+
output_spec = DynamicTraitedSpec
747+
748+
def _list_outputs(self):
749+
outputs = {}
750+
outputs['test'] = 1
751+
outputs['foo'] = 'bar'
752+
return outputs
753+
754+
wf = pe.Workflow('testkv')
755+
756+
def testx2(test):
757+
return test * 2
758+
759+
kvnode = pe.Node(TestKV(), name='testkv')
760+
from nipype.interfaces.utility import Function
761+
func = pe.Node(
762+
Function(input_names=['test'], output_names=['test2'], function=testx2),
763+
name='func')
764+
exception_not_raised = True
765+
try:
766+
wf.connect(kvnode, 'test', func, 'test')
767+
except Exception as e:
768+
if 'Module testkv has no output called test' in e:
769+
exception_not_raised = False
770+
assert exception_not_raised

nipype/pipeline/engine/workflows.py

Lines changed: 7 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -200,11 +200,16 @@ def connect(self, *args, **kwargs):
200200
connected.
201201
""" % (srcnode, source, destnode, dest, dest, destnode))
202202
if not (hasattr(destnode, '_interface') and
203-
'.io' in str(destnode._interface.__class__)):
203+
('.io' in str(destnode._interface.__class__) or
204+
any(['.io' in str(val) for val in
205+
destnode._interface.__class__.__bases__]))
206+
):
204207
if not destnode._check_inputs(dest):
205208
not_found.append(['in', destnode.name, dest])
206209
if not (hasattr(srcnode, '_interface') and
207-
'.io' in str(srcnode._interface.__class__)):
210+
('.io' in str(srcnode._interface.__class__)
211+
or any(['.io' in str(val) for val in
212+
srcnode._interface.__class__.__bases__]))):
208213
if isinstance(source, tuple):
209214
# handles the case that source is specified
210215
# with a function

0 commit comments

Comments
 (0)