@@ -99,13 +99,27 @@ def _list_outputs(self):
99
99
class MergeInputSpec (DynamicTraitedSpec , BaseInterfaceInputSpec ):
100
100
axis = traits .Enum ('vstack' , 'hstack' , usedefault = True ,
101
101
desc = 'direction in which to merge, hstack requires same number of elements in each input' )
102
- no_flatten = traits .Bool (False , usedefault = True , desc = 'append to outlist instead of extending in vstack mode' )
102
+ no_flatten = traits .Bool (False , usedefault = True ,
103
+ desc = 'append to outlist instead of extending in vstack mode' )
103
104
104
105
105
106
class MergeOutputSpec (TraitedSpec ):
106
107
out = traits .List (desc = 'Merged output' )
107
108
108
109
110
+ def _ravel (in_val ):
111
+ if not isinstance (in_val , list ):
112
+ return in_val
113
+ flat_list = []
114
+ for val in in_val :
115
+ raveled_val = _ravel (val )
116
+ if isinstance (raveled_val , list ):
117
+ flat_list .extend (raveled_val )
118
+ else :
119
+ flat_list .append (raveled_val )
120
+ return flat_list
121
+
122
+
109
123
class Merge (IOBase ):
110
124
"""Basic interface class to merge inputs into a single list
111
125
@@ -124,22 +138,26 @@ class Merge(IOBase):
124
138
[1, 2, 5, 3]
125
139
126
140
>>> merge = Merge() # Or Merge(1)
127
- >>> merge.inputs.in_lists = [1, [2, 5], 3]
141
+ >>> merge.inputs.in1 = [1, [2, 5], 3]
128
142
>>> out = merge.run()
129
143
>>> out.outputs.out
130
144
[1, 2, 5, 3]
131
145
146
+ >>> merge = Merge() # Or Merge(1)
147
+ >>> merge.inputs.in1 = [1, [2, 5], 3]
148
+ >>> merge.inputs.no_flatten = True
149
+ >>> out = merge.run()
150
+ >>> out.outputs.out
151
+ [[1, [2, 5], 3]]
132
152
"""
133
153
input_spec = MergeInputSpec
134
154
output_spec = MergeOutputSpec
135
155
136
156
def __init__ (self , numinputs = 1 , ** inputs ):
137
157
super (Merge , self ).__init__ (** inputs )
138
158
self ._numinputs = numinputs
139
- if numinputs > 1 :
159
+ if numinputs >= 1 :
140
160
input_names = ['in%d' % (i + 1 ) for i in range (numinputs )]
141
- elif numinputs == 1 :
142
- input_names = ['in_lists' ]
143
161
else :
144
162
input_names = []
145
163
add_traits (self .inputs , input_names )
@@ -150,8 +168,6 @@ def _list_outputs(self):
150
168
151
169
if self ._numinputs < 1 :
152
170
return outputs
153
- elif self ._numinputs == 1 :
154
- values = self .inputs .in_lists
155
171
else :
156
172
getval = lambda idx : getattr (self .inputs , 'in%d' % (idx + 1 ))
157
173
values = [getval (idx ) for idx in range (self ._numinputs )
@@ -160,7 +176,7 @@ def _list_outputs(self):
160
176
if self .inputs .axis == 'vstack' :
161
177
for value in values :
162
178
if isinstance (value , list ) and not self .inputs .no_flatten :
163
- out .extend (value )
179
+ out .extend (_ravel ( value ) )
164
180
else :
165
181
out .append (value )
166
182
else :
0 commit comments