Skip to content

Commit a3b1bc5

Browse files
committed
fix: closes #1920 to revert to original behavior for input names
1 parent d68b929 commit a3b1bc5

File tree

1 file changed

+24
-8
lines changed

1 file changed

+24
-8
lines changed

nipype/interfaces/utility/base.py

Lines changed: 24 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -99,13 +99,27 @@ def _list_outputs(self):
9999
class MergeInputSpec(DynamicTraitedSpec, BaseInterfaceInputSpec):
100100
axis = traits.Enum('vstack', 'hstack', usedefault=True,
101101
desc='direction in which to merge, hstack requires same number of elements in each input')
102-
no_flatten = traits.Bool(False, usedefault=True, desc='append to outlist instead of extending in vstack mode')
102+
no_flatten = traits.Bool(False, usedefault=True,
103+
desc='append to outlist instead of extending in vstack mode')
103104

104105

105106
class MergeOutputSpec(TraitedSpec):
106107
out = traits.List(desc='Merged output')
107108

108109

110+
def _ravel(in_val):
111+
if not isinstance(in_val, list):
112+
return in_val
113+
flat_list = []
114+
for val in in_val:
115+
raveled_val = _ravel(val)
116+
if isinstance(raveled_val, list):
117+
flat_list.extend(raveled_val)
118+
else:
119+
flat_list.append(raveled_val)
120+
return flat_list
121+
122+
109123
class Merge(IOBase):
110124
"""Basic interface class to merge inputs into a single list
111125
@@ -124,22 +138,26 @@ class Merge(IOBase):
124138
[1, 2, 5, 3]
125139
126140
>>> merge = Merge() # Or Merge(1)
127-
>>> merge.inputs.in_lists = [1, [2, 5], 3]
141+
>>> merge.inputs.in1 = [1, [2, 5], 3]
128142
>>> out = merge.run()
129143
>>> out.outputs.out
130144
[1, 2, 5, 3]
131145
146+
>>> merge = Merge() # Or Merge(1)
147+
>>> merge.inputs.in1 = [1, [2, 5], 3]
148+
>>> merge.inputs.no_flatten = True
149+
>>> out = merge.run()
150+
>>> out.outputs.out
151+
[[1, [2, 5], 3]]
132152
"""
133153
input_spec = MergeInputSpec
134154
output_spec = MergeOutputSpec
135155

136156
def __init__(self, numinputs=1, **inputs):
137157
super(Merge, self).__init__(**inputs)
138158
self._numinputs = numinputs
139-
if numinputs > 1:
159+
if numinputs >= 1:
140160
input_names = ['in%d' % (i + 1) for i in range(numinputs)]
141-
elif numinputs == 1:
142-
input_names = ['in_lists']
143161
else:
144162
input_names = []
145163
add_traits(self.inputs, input_names)
@@ -150,8 +168,6 @@ def _list_outputs(self):
150168

151169
if self._numinputs < 1:
152170
return outputs
153-
elif self._numinputs == 1:
154-
values = self.inputs.in_lists
155171
else:
156172
getval = lambda idx: getattr(self.inputs, 'in%d' % (idx + 1))
157173
values = [getval(idx) for idx in range(self._numinputs)
@@ -160,7 +176,7 @@ def _list_outputs(self):
160176
if self.inputs.axis == 'vstack':
161177
for value in values:
162178
if isinstance(value, list) and not self.inputs.no_flatten:
163-
out.extend(value)
179+
out.extend(_ravel(value))
164180
else:
165181
out.append(value)
166182
else:

0 commit comments

Comments
 (0)