13
13
"""
14
14
15
15
from datetime import datetime
16
+ from nipype .utils .misc import flatten , unflatten
16
17
try :
17
18
from collections import OrderedDict
18
19
except ImportError :
@@ -2027,7 +2028,7 @@ class MapNode(Node):
2027
2028
2028
2029
"""
2029
2030
2030
- def __init__ (self , interface , iterfield , name , serial = False , ** kwargs ):
2031
+ def __init__ (self , interface , iterfield , name , serial = False , nested = False , ** kwargs ):
2031
2032
"""
2032
2033
2033
2034
Parameters
@@ -2043,6 +2044,9 @@ def __init__(self, interface, iterfield, name, serial=False, **kwargs):
2043
2044
node specific name
2044
2045
serial : boolean
2045
2046
flag to enforce executing the jobs of the mapnode in a serial manner rather than parallel
2047
+ nested : boolea
2048
+ support for nested lists, if set the input list will be flattened before running, and the
2049
+ nested list structure of the outputs will be resored
2046
2050
See Node docstring for additional keyword arguments.
2047
2051
"""
2048
2052
@@ -2051,6 +2055,7 @@ def __init__(self, interface, iterfield, name, serial=False, **kwargs):
2051
2055
if isinstance (iterfield , six .string_types ):
2052
2056
iterfield = [iterfield ]
2053
2057
self .iterfield = iterfield
2058
+ self .nested = nested
2054
2059
self ._inputs = self ._create_dynamic_traits (self ._interface .inputs ,
2055
2060
fields = self .iterfield )
2056
2061
self ._inputs .on_trait_change (self ._set_mapnode_input )
@@ -2066,7 +2071,10 @@ def _create_dynamic_traits(self, basetraits, fields=None, nitems=None):
2066
2071
for name , spec in basetraits .items ():
2067
2072
if name in fields and ((nitems is None ) or (nitems > 1 )):
2068
2073
logger .debug ('adding multipath trait: %s' % name )
2069
- output .add_trait (name , InputMultiPath (spec .trait_type ))
2074
+ if self .nested :
2075
+ output .add_trait (name , InputMultiPath (traits .Any ()))
2076
+ else :
2077
+ output .add_trait (name , InputMultiPath (spec .trait_type ))
2070
2078
else :
2071
2079
output .add_trait (name , traits .Trait (spec ))
2072
2080
setattr (output , name , Undefined )
@@ -2110,7 +2118,10 @@ def _get_hashval(self):
2110
2118
self ._interface .inputs .traits ()[name ].trait_type ))
2111
2119
logger .debug ('setting hashinput %s-> %s' %
2112
2120
(name , getattr (self ._inputs , name )))
2113
- setattr (hashinputs , name , getattr (self ._inputs , name ))
2121
+ if self .nested :
2122
+ setattr (hashinputs , name , flatten (getattr (self ._inputs , name )))
2123
+ else :
2124
+ setattr (hashinputs , name , getattr (self ._inputs , name ))
2114
2125
hashed_inputs , hashvalue = hashinputs .get_hashval (
2115
2126
hash_method = self .config ['execution' ]['hash_method' ])
2116
2127
rm_extra = self .config ['execution' ]['remove_unnecessary_outputs' ]
@@ -2137,7 +2148,10 @@ def outputs(self):
2137
2148
def _make_nodes (self , cwd = None ):
2138
2149
if cwd is None :
2139
2150
cwd = self .output_dir ()
2140
- nitems = len (filename_to_list (getattr (self .inputs , self .iterfield [0 ])))
2151
+ if self .nested :
2152
+ nitems = len (flatten (filename_to_list (getattr (self .inputs , self .iterfield [0 ]))))
2153
+ else :
2154
+ nitems = len (filename_to_list (getattr (self .inputs , self .iterfield [0 ])))
2141
2155
for i in range (nitems ):
2142
2156
nodename = '_' + self .name + str (i )
2143
2157
node = Node (deepcopy (self ._interface ), name = nodename )
@@ -2147,7 +2161,10 @@ def _make_nodes(self, cwd=None):
2147
2161
node ._interface .inputs .set (
2148
2162
** deepcopy (self ._interface .inputs .get ()))
2149
2163
for field in self .iterfield :
2150
- fieldvals = filename_to_list (getattr (self .inputs , field ))
2164
+ if self .nested :
2165
+ fieldvals = flatten (filename_to_list (getattr (self .inputs , field )))
2166
+ else :
2167
+ fieldvals = filename_to_list (getattr (self .inputs , field ))
2151
2168
logger .debug ('setting input %d %s %s' % (i , field ,
2152
2169
fieldvals [i ]))
2153
2170
setattr (node .inputs , field ,
@@ -2199,6 +2216,14 @@ def _collate_results(self, nodes):
2199
2216
defined_vals = [isdefined (val ) for val in values ]
2200
2217
if any (defined_vals ) and self ._result .outputs :
2201
2218
setattr (self ._result .outputs , key , values )
2219
+
2220
+ if self .nested :
2221
+ for key , _ in self .outputs .items ():
2222
+ values = getattr (self ._result .outputs , key )
2223
+ if isdefined (values ):
2224
+ values = unflatten (values , filename_to_list (getattr (self .inputs , self .iterfield [0 ])))
2225
+ setattr (self ._result .outputs , key , values )
2226
+
2202
2227
if returncode and any ([code is not None for code in returncode ]):
2203
2228
msg = []
2204
2229
for i , code in enumerate (returncode ):
@@ -2249,7 +2274,10 @@ def num_subnodes(self):
2249
2274
if self ._serial :
2250
2275
return 1
2251
2276
else :
2252
- return len (filename_to_list (getattr (self .inputs , self .iterfield [0 ])))
2277
+ if self .nested :
2278
+ return len (filename_to_list (flatten (getattr (self .inputs , self .iterfield [0 ]))))
2279
+ else :
2280
+ return len (filename_to_list (getattr (self .inputs , self .iterfield [0 ])))
2253
2281
2254
2282
def _get_inputs (self ):
2255
2283
old_inputs = self ._inputs .get ()
@@ -2289,8 +2317,12 @@ def _run_interface(self, execute=True, updatehash=False):
2289
2317
os .chdir (cwd )
2290
2318
self ._check_iterfield ()
2291
2319
if execute :
2292
- nitems = len (filename_to_list (getattr (self .inputs ,
2293
- self .iterfield [0 ])))
2320
+ if self .nested :
2321
+ nitems = len (filename_to_list (flatten (getattr (self .inputs ,
2322
+ self .iterfield [0 ]))))
2323
+ else :
2324
+ nitems = len (filename_to_list (getattr (self .inputs ,
2325
+ self .iterfield [0 ])))
2294
2326
nodenames = ['_' + self .name + str (i ) for i in range (nitems )]
2295
2327
# map-reduce formulation
2296
2328
self ._collate_results (self ._node_runner (self ._make_nodes (cwd ),
0 commit comments