Skip to content

Commit aa6e77e

Browse files
committed
Merge pull request #955 from oesteban/enh/niuSplitSqueeze
Squeeze option in utility.Split
2 parents 35c9f46 + 95f00a7 commit aa6e77e

File tree

3 files changed

+33
-1
lines changed

3 files changed

+33
-1
lines changed

nipype/interfaces/tests/test_auto_Split.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -10,6 +10,8 @@ def test_Split_inputs():
1010
),
1111
splits=dict(mandatory=True,
1212
),
13+
squeeze=dict(usedefault=True,
14+
),
1315
)
1416
inputs = Split.input_spec()
1517

nipype/interfaces/tests/test_utility.py

Lines changed: 25 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -113,3 +113,28 @@ def test_function_with_imports():
113113
finally:
114114
os.chdir(origdir)
115115
shutil.rmtree(tempdir)
116+
117+
118+
def test_split():
119+
tempdir = os.path.realpath(mkdtemp())
120+
origdir = os.getcwd()
121+
os.chdir(tempdir)
122+
123+
try:
124+
node = pe.Node(utility.Split(inlist=range(4),
125+
splits=[1, 3]),
126+
name='split_squeeze')
127+
res = node.run()
128+
yield assert_equal, res.outputs.out1, [0]
129+
yield assert_equal, res.outputs.out2, [1, 2, 3]
130+
131+
node = pe.Node(utility.Split(inlist=range(4),
132+
splits=[1, 3],
133+
squeeze=True),
134+
name='split_squeeze')
135+
res = node.run()
136+
yield assert_equal, res.outputs.out1, 0
137+
yield assert_equal, res.outputs.out2, [1, 2, 3]
138+
finally:
139+
os.chdir(origdir)
140+
shutil.rmtree(tempdir)

nipype/interfaces/utility.py

Lines changed: 6 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -252,6 +252,8 @@ class SplitInputSpec(BaseInterfaceInputSpec):
252252
desc='list of values to split')
253253
splits = traits.List(traits.Int, mandatory=True,
254254
desc='Number of outputs in each split - should add to number of inputs')
255+
squeeze = traits.Bool(False, usedefault=True,
256+
desc='unfold one-element splits removing the list')
255257

256258

257259
class Split(IOBase):
@@ -290,7 +292,10 @@ def _list_outputs(self):
290292
splits.extend(self.inputs.splits)
291293
splits = np.cumsum(splits)
292294
for i in range(len(splits) - 1):
293-
outputs['out%d' % (i + 1)] = np.array(self.inputs.inlist)[splits[i]:splits[i + 1]].tolist()
295+
val = np.array(self.inputs.inlist)[splits[i]:splits[i + 1]].tolist()
296+
if self.inputs.squeeze and len(val) == 1:
297+
val = val[0]
298+
outputs['out%d' % (i + 1)] = val
294299
return outputs
295300

296301

0 commit comments

Comments
 (0)