Skip to content

Commit d657546

Browse files
committed
fix: apply review comments from @effigies, add parameterized tests
1 parent 1d12dd3 commit d657546

File tree

2 files changed

+35
-78
lines changed

2 files changed

+35
-78
lines changed

niworkflows/interfaces/nibabel.py

Lines changed: 9 additions & 26 deletions
Original file line numberDiff line numberDiff line change
@@ -101,9 +101,6 @@ def _run_interface(self, runtime):
101101

102102
class _SplitSeriesInputSpec(BaseInterfaceInputSpec):
103103
in_file = File(exists=True, mandatory=True, desc="input 4d image")
104-
allow_3D = traits.Bool(
105-
False, usedefault=True, desc="do not fail if a 3D volume is passed in"
106-
)
107104

108105

109106
class _SplitSeriesOutputSpec(TraitedSpec):
@@ -117,34 +114,20 @@ class SplitSeries(SimpleInterface):
117114
output_spec = _SplitSeriesOutputSpec
118115

119116
def _run_interface(self, runtime):
120-
filenii = nb.squeeze_image(nb.load(self.inputs.in_file))
121-
filenii = filenii.__class__(
122-
np.squeeze(filenii.dataobj), filenii.affine, filenii.header
123-
)
124-
ndim = filenii.dataobj.ndim
125-
if ndim != 4:
126-
if self.inputs.allow_3D and ndim == 3:
127-
out_file = str(
128-
Path(
129-
fname_presuffix(self.inputs.in_file, suffix=f"_idx-000")
130-
).absolute()
131-
)
132-
self._results["out_files"] = out_file
133-
filenii.to_filename(out_file)
134-
return runtime
135-
raise RuntimeError(
136-
f"Input image <{self.inputs.in_file}> is {ndim}D "
137-
f"({'x'.join(['%d' % s for s in filenii.shape])})."
138-
)
117+
in_file = self.inputs.in_file
118+
img = nb.load(in_file)
119+
extra_dims = tuple(dim for dim in img.shape[3:] if dim > 1) or (1,)
120+
if len(extra_dims) != 1:
121+
raise ValueError(f"Invalid shape {'x'.join(str(s) for s in img.shape)}")
122+
img = img.__class__(img.dataobj.reshape(img.shape[:3] + extra_dims),
123+
img.affine, img.header)
139124

140-
files_3d = nb.four_to_three(filenii)
141125
self._results["out_files"] = []
142-
in_file = self.inputs.in_file
143-
for i, file_3d in enumerate(files_3d):
126+
for i, img_3d in enumerate(nb.four_to_three(img)):
144127
out_file = str(
145128
Path(fname_presuffix(in_file, suffix=f"_idx-{i:03}")).absolute()
146129
)
147-
file_3d.to_filename(out_file)
130+
img_3d.to_filename(out_file)
148131
self._results["out_files"].append(out_file)
149132

150133
return runtime

niworkflows/interfaces/tests/test_nibabel.py

Lines changed: 26 additions & 52 deletions
Original file line numberDiff line numberDiff line change
@@ -75,61 +75,35 @@ def test_ApplyMask(tmp_path):
7575
ApplyMask(in_file=str(in_file4d), in_mask=str(in_mask), threshold=0.4).run()
7676

7777

78-
def test_SplitSeries(tmp_path):
78+
@pytest.mark.parametrize("shape,exp_n", [
79+
((20, 20, 20, 15), 15),
80+
((20, 20, 20), 1),
81+
((20, 20, 20, 1), 1),
82+
((20, 20, 20, 1, 3), 3),
83+
((20, 20, 20, 3, 1), 3),
84+
((20, 20, 20, 1, 3, 3), -1),
85+
((20, 1, 20, 15), 15),
86+
((20, 1, 20), 1),
87+
((20, 1, 20, 1), 1),
88+
((20, 1, 20, 1, 3), 3),
89+
((20, 1, 20, 3, 1), 3),
90+
((20, 1, 20, 1, 3, 3), -1),
91+
])
92+
def test_SplitSeries(tmp_path, shape, exp_n):
7993
"""Test 4-to-3 NIfTI split interface."""
8094
os.chdir(tmp_path)
8195

82-
# Test the 4D
83-
data = np.ones((20, 20, 20, 15), dtype=float)
84-
in_file = tmp_path / "input4D.nii.gz"
85-
nb.Nifti1Image(data, np.eye(4), None).to_filename(str(in_file))
86-
87-
split = SplitSeries(in_file=str(in_file)).run()
88-
assert len(split.outputs.out_files) == 15
89-
90-
# Test the 3D
91-
in_file = tmp_path / "input3D.nii.gz"
92-
nb.Nifti1Image(np.ones((20, 20, 20), dtype=float), np.eye(4), None).to_filename(
93-
str(in_file)
94-
)
95-
96-
with pytest.raises(RuntimeError):
97-
SplitSeries(in_file=str(in_file)).run()
98-
99-
split = SplitSeries(in_file=str(in_file), allow_3D=True).run()
100-
assert isinstance(split.outputs.out_files, str)
101-
102-
# Test the 3D
103-
in_file = tmp_path / "input3D.nii.gz"
104-
nb.Nifti1Image(np.ones((20, 20, 20, 1), dtype=float), np.eye(4), None).to_filename(
105-
str(in_file)
106-
)
107-
108-
with pytest.raises(RuntimeError):
109-
SplitSeries(in_file=str(in_file)).run()
110-
111-
split = SplitSeries(in_file=str(in_file), allow_3D=True).run()
112-
assert isinstance(split.outputs.out_files, str)
113-
114-
# Test the 5D
115-
in_file = tmp_path / "input5D.nii.gz"
116-
nb.Nifti1Image(
117-
np.ones((20, 20, 20, 2, 2), dtype=float), np.eye(4), None
118-
).to_filename(str(in_file))
119-
120-
with pytest.raises(RuntimeError):
121-
SplitSeries(in_file=str(in_file)).run()
122-
123-
with pytest.raises(RuntimeError):
124-
SplitSeries(in_file=str(in_file), allow_3D=True).run()
125-
126-
# Test splitting ANTs warpfields
127-
data = np.ones((20, 20, 20, 1, 3), dtype=float)
128-
in_file = tmp_path / "warpfield.nii.gz"
129-
nb.Nifti1Image(data, np.eye(4), None).to_filename(str(in_file))
130-
131-
split = SplitSeries(in_file=str(in_file)).run()
132-
assert len(split.outputs.out_files) == 3
96+
in_file = str(tmp_path / "input.nii.gz")
97+
nb.Nifti1Image(np.ones(shape, dtype=float), np.eye(4), None).to_filename(in_file)
98+
99+
_interface = SplitSeries(in_file=in_file)
100+
if exp_n > 0:
101+
split = _interface.run()
102+
n = int(isinstance(split.outputs.out_files, str)) or len(split.outputs.out_files)
103+
assert n == exp_n
104+
else:
105+
with pytest.raises(ValueError):
106+
_interface.run()
133107

134108

135109
def test_MergeSeries(tmp_path):

0 commit comments

Comments
 (0)