Skip to content

Commit f95bc73

Browse files
committed
Add supported for nested lists in iterfields.
1 parent 7aaf5d4 commit f95bc73

File tree

4 files changed

+80
-10
lines changed

4 files changed

+80
-10
lines changed

doc/users/mapnode_and_iterables.rst

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -92,6 +92,11 @@ is almost the same as running
9292
9393
It is a rarely used feature, but you can sometimes find it useful.
9494

95+
In more advanced applications it is useful to be able to iterate over items
96+
of nested lists (for example [[1,2],[3,4]]). MapNode allows you to do this
97+
with the "nested=True" parameter. Outputs will preserve the same nested
98+
structure as the inputs.
99+
95100
Iterables
96101
=========
97102

nipype/algorithms/modelgen.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -165,11 +165,11 @@ def gen_info(run_event_files):
165165

166166

167167
class SpecifyModelInputSpec(BaseInterfaceInputSpec):
168-
subject_info = InputMultiPath(Bunch, mandatory=True, xor=['event_files'],
168+
subject_info = InputMultiPath(Bunch, mandatory=True, xor=['subject_info', 'event_files'],
169169
desc=("Bunch or List(Bunch) subject specific condition information. "
170170
"see :ref:`SpecifyModel` or SpecifyModel.__doc__ for details"))
171171
event_files = InputMultiPath(traits.List(File(exists=True)), mandatory=True,
172-
xor=['subject_info'],
172+
xor=['subject_info', 'event_files'],
173173
desc=('list of event description files 1, 2 or 3 column format '
174174
'corresponding to onsets, durations and amplitudes'))
175175
realignment_parameters = InputMultiPath(File(exists=True),

nipype/pipeline/engine.py

Lines changed: 40 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -13,6 +13,7 @@
1313
"""
1414

1515
from datetime import datetime
16+
from nipype.utils.misc import flatten, unflatten
1617
try:
1718
from collections import OrderedDict
1819
except ImportError:
@@ -2027,7 +2028,7 @@ class MapNode(Node):
20272028
20282029
"""
20292030

2030-
def __init__(self, interface, iterfield, name, serial=False, **kwargs):
2031+
def __init__(self, interface, iterfield, name, serial=False, nested=False, **kwargs):
20312032
"""
20322033
20332034
Parameters
@@ -2043,6 +2044,9 @@ def __init__(self, interface, iterfield, name, serial=False, **kwargs):
20432044
node specific name
20442045
serial : boolean
20452046
flag to enforce executing the jobs of the mapnode in a serial manner rather than parallel
2047+
nested : boolea
2048+
support for nested lists, if set the input list will be flattened before running, and the
2049+
nested list structure of the outputs will be resored
20462050
See Node docstring for additional keyword arguments.
20472051
"""
20482052

@@ -2051,6 +2055,7 @@ def __init__(self, interface, iterfield, name, serial=False, **kwargs):
20512055
if isinstance(iterfield, six.string_types):
20522056
iterfield = [iterfield]
20532057
self.iterfield = iterfield
2058+
self.nested = nested
20542059
self._inputs = self._create_dynamic_traits(self._interface.inputs,
20552060
fields=self.iterfield)
20562061
self._inputs.on_trait_change(self._set_mapnode_input)
@@ -2066,7 +2071,10 @@ def _create_dynamic_traits(self, basetraits, fields=None, nitems=None):
20662071
for name, spec in basetraits.items():
20672072
if name in fields and ((nitems is None) or (nitems > 1)):
20682073
logger.debug('adding multipath trait: %s' % name)
2069-
output.add_trait(name, InputMultiPath(spec.trait_type))
2074+
if self.nested:
2075+
output.add_trait(name, InputMultiPath(traits.Any()))
2076+
else:
2077+
output.add_trait(name, InputMultiPath(spec.trait_type))
20702078
else:
20712079
output.add_trait(name, traits.Trait(spec))
20722080
setattr(output, name, Undefined)
@@ -2110,7 +2118,10 @@ def _get_hashval(self):
21102118
self._interface.inputs.traits()[name].trait_type))
21112119
logger.debug('setting hashinput %s-> %s' %
21122120
(name, getattr(self._inputs, name)))
2113-
setattr(hashinputs, name, getattr(self._inputs, name))
2121+
if self.nested:
2122+
setattr(hashinputs, name, flatten(getattr(self._inputs, name)))
2123+
else:
2124+
setattr(hashinputs, name, getattr(self._inputs, name))
21142125
hashed_inputs, hashvalue = hashinputs.get_hashval(
21152126
hash_method=self.config['execution']['hash_method'])
21162127
rm_extra = self.config['execution']['remove_unnecessary_outputs']
@@ -2137,7 +2148,10 @@ def outputs(self):
21372148
def _make_nodes(self, cwd=None):
21382149
if cwd is None:
21392150
cwd = self.output_dir()
2140-
nitems = len(filename_to_list(getattr(self.inputs, self.iterfield[0])))
2151+
if self.nested:
2152+
nitems = len(flatten(filename_to_list(getattr(self.inputs, self.iterfield[0]))))
2153+
else:
2154+
nitems = len(filename_to_list(getattr(self.inputs, self.iterfield[0])))
21412155
for i in range(nitems):
21422156
nodename = '_' + self.name + str(i)
21432157
node = Node(deepcopy(self._interface), name=nodename)
@@ -2147,7 +2161,10 @@ def _make_nodes(self, cwd=None):
21472161
node._interface.inputs.set(
21482162
**deepcopy(self._interface.inputs.get()))
21492163
for field in self.iterfield:
2150-
fieldvals = filename_to_list(getattr(self.inputs, field))
2164+
if self.nested:
2165+
fieldvals = flatten(filename_to_list(getattr(self.inputs, field)))
2166+
else:
2167+
fieldvals = filename_to_list(getattr(self.inputs, field))
21512168
logger.debug('setting input %d %s %s' % (i, field,
21522169
fieldvals[i]))
21532170
setattr(node.inputs, field,
@@ -2199,6 +2216,14 @@ def _collate_results(self, nodes):
21992216
defined_vals = [isdefined(val) for val in values]
22002217
if any(defined_vals) and self._result.outputs:
22012218
setattr(self._result.outputs, key, values)
2219+
2220+
if self.nested:
2221+
for key, _ in self.outputs.items():
2222+
values = getattr(self._result.outputs, key)
2223+
if isdefined(values):
2224+
values = unflatten(values, filename_to_list(getattr(self.inputs, self.iterfield[0])))
2225+
setattr(self._result.outputs, key, values)
2226+
22022227
if returncode and any([code is not None for code in returncode]):
22032228
msg = []
22042229
for i, code in enumerate(returncode):
@@ -2249,7 +2274,10 @@ def num_subnodes(self):
22492274
if self._serial :
22502275
return 1
22512276
else:
2252-
return len(filename_to_list(getattr(self.inputs, self.iterfield[0])))
2277+
if self.nested:
2278+
return len(filename_to_list(flatten(getattr(self.inputs, self.iterfield[0]))))
2279+
else:
2280+
return len(filename_to_list(getattr(self.inputs, self.iterfield[0])))
22532281

22542282
def _get_inputs(self):
22552283
old_inputs = self._inputs.get()
@@ -2289,8 +2317,12 @@ def _run_interface(self, execute=True, updatehash=False):
22892317
os.chdir(cwd)
22902318
self._check_iterfield()
22912319
if execute:
2292-
nitems = len(filename_to_list(getattr(self.inputs,
2293-
self.iterfield[0])))
2320+
if self.nested:
2321+
nitems = len(filename_to_list(flatten(getattr(self.inputs,
2322+
self.iterfield[0]))))
2323+
else:
2324+
nitems = len(filename_to_list(getattr(self.inputs,
2325+
self.iterfield[0])))
22942326
nodenames = ['_' + self.name + str(i) for i in range(nitems)]
22952327
# map-reduce formulation
22962328
self._collate_results(self._node_runner(self._make_nodes(cwd),

nipype/pipeline/tests/test_engine.py

Lines changed: 33 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -470,6 +470,39 @@ def test_mapnode_iterfield_check():
470470
yield assert_raises, ValueError, mod1._check_iterfield
471471

472472

473+
def test_mapnode_nested():
474+
cwd = os.getcwd()
475+
wd = mkdtemp()
476+
os.chdir(wd)
477+
from nipype import MapNode, Function
478+
def func1(in1):
479+
return in1 + 1
480+
n1 = MapNode(Function(input_names=['in1'],
481+
output_names=['out'],
482+
function=func1),
483+
iterfield=['in1'],
484+
nested=True,
485+
name='n1')
486+
n1.inputs.in1 = [[1,[2]],3,[4,5]]
487+
n1.run()
488+
print n1.get_output('out')
489+
yield assert_equal, n1.get_output('out'), [[2,[3]],4,[5,6]]
490+
491+
n2 = MapNode(Function(input_names=['in1'],
492+
output_names=['out'],
493+
function=func1),
494+
iterfield=['in1'],
495+
nested=False,
496+
name='n1')
497+
n2.inputs.in1 = [[1,[2]],3,[4,5]]
498+
error_raised = False
499+
try:
500+
n2.run()
501+
except Exception, e:
502+
pe.logger.info('Exception: %s' % str(e))
503+
error_raised = True
504+
yield assert_true, error_raised
505+
473506
def test_node_hash():
474507
cwd = os.getcwd()
475508
wd = mkdtemp()

0 commit comments

Comments
 (0)