Skip to content

Commit 03ebb6d

Browse files
committed
fix: added few bugfixes and regression tests
1 parent 7263d03 commit 03ebb6d

File tree

2 files changed

+92
-11
lines changed

2 files changed

+92
-11
lines changed

niworkflows/interfaces/nibabel.py

Lines changed: 15 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,7 @@
11
# emacs: -*- mode: python; py-indent-offset: 4; indent-tabs-mode: nil -*-
22
# vi: set ft=python sts=4 ts=4 sw=4 et:
33
"""Nibabel-based interfaces."""
4+
from pathlib import Path
45
import numpy as np
56
import nibabel as nb
67
from nipype import logging
@@ -98,12 +99,11 @@ class _FourToThreeInputSpec(BaseInterfaceInputSpec):
9899

99100
class _FourToThreeOutputSpec(TraitedSpec):
100101
out_files = OutputMultiObject(File(exists=True),
101-
desc='output list of 3d images')
102+
desc='output list of 3d images')
102103

103104

104105
class SplitSeries(SimpleInterface):
105-
"""Split a 4D dataset along the last dimension
106-
into a series of 3D volumes."""
106+
"""Split a 4D dataset along the last dimension into a series of 3D volumes."""
107107

108108
input_spec = _FourToThreeInputSpec
109109
output_spec = _FourToThreeOutputSpec
@@ -119,7 +119,8 @@ def _run_interface(self, runtime):
119119
self._results['out_files'] = out_file
120120
filenii.to_filename(out_file)
121121
return runtime
122-
raise RuntimeError(f"Input image image is {ndim}D.")
122+
raise RuntimeError(
123+
f"Input image image is {ndim}D ({'x'.join(['%d' % s for s in filenii.shape])}).")
123124

124125
files_3d = nb.four_to_three(filenii)
125126
self._results['out_files'] = []
@@ -137,7 +138,8 @@ def _run_interface(self, runtime):
137138
class _MergeSeriesInputSpec(BaseInterfaceInputSpec):
138139
in_files = InputMultiObject(File(exists=True, mandatory=True,
139140
desc='input list of 3d images'))
140-
allow_4D = traits.Bool(True, usedefault=True, desc='whether 4D images are allowed to be concatenated')
141+
allow_4D = traits.Bool(True, usedefault=True,
142+
desc='whether 4D images are allowed to be concatenated')
141143

142144

143145
class _MergeSeriesOutputSpec(TraitedSpec):
@@ -153,15 +155,18 @@ class MergeSeries(SimpleInterface):
153155
def _run_interface(self, runtime):
154156
nii_list = []
155157
for f in self.inputs.in_files:
156-
filenii = nb.load(f)
157-
filenii = nb.squeeze_image(filenii)
158-
if filenii.dataobj.ndim == 3:
158+
filenii = nb.squeeze_image(nb.load(f))
159+
ndim = filenii.dataobj.ndim
160+
if ndim == 3:
159161
nii_list.append(filenii)
160-
elif self.inputs.allow_4D and filenii.dataobj.ndim == 4:
162+
continue
163+
elif self.inputs.allow_4D and ndim == 4:
161164
nii_list += nb.four_to_three(filenii)
165+
continue
162166
else:
163167
raise ValueError("Input image has an incorrect number of dimensions"
164-
f" ({filenii.dataobj.ndim}).")
168+
f" ({ndim}).")
169+
165170
img_4d = nb.concat_images(nii_list)
166171
out_file = fname_presuffix(self.inputs.in_files[0], suffix="_merged")
167172
img_4d.to_filename(out_file)

niworkflows/interfaces/tests/test_nibabel.py

Lines changed: 77 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -4,7 +4,7 @@
44
import nibabel as nb
55
import pytest
66

7-
from ..nibabel import Binarize, ApplyMask
7+
from ..nibabel import Binarize, ApplyMask, SplitSeries, MergeSeries
88

99

1010
def test_Binarize(tmp_path):
@@ -69,3 +69,79 @@ def test_ApplyMask(tmp_path):
6969
ApplyMask(in_file=str(in_file), in_mask=str(in_mask), threshold=0.4).run()
7070
with pytest.raises(ValueError):
7171
ApplyMask(in_file=str(in_file4d), in_mask=str(in_mask), threshold=0.4).run()
72+
73+
74+
def test_SplitSeries(tmp_path):
75+
"""Test 4-to-3 NIfTI split interface."""
76+
os.chdir(str(tmp_path))
77+
78+
# Test the 4D
79+
data = np.ones((20, 20, 20, 15), dtype=float)
80+
in_file = tmp_path / 'input4D.nii.gz'
81+
nb.Nifti1Image(data, np.eye(4), None).to_filename(str(in_file))
82+
83+
split = SplitSeries(in_file=str(in_file)).run()
84+
assert len(split.outputs.out_files) == 15
85+
86+
# Test the 3D
87+
data = np.ones((20, 20, 20), dtype=float)
88+
in_file = tmp_path / 'input3D.nii.gz'
89+
nb.Nifti1Image(data, np.eye(4), None).to_filename(str(in_file))
90+
91+
with pytest.raises(RuntimeError):
92+
SplitSeries(in_file=str(in_file)).run()
93+
94+
split = SplitSeries(in_file=str(in_file), accept_3D=True).run()
95+
assert isinstance(split.outputs.out_files, str)
96+
97+
# Test the 3D
98+
data = np.ones((20, 20, 20, 1), dtype=float)
99+
in_file = tmp_path / 'input3D.nii.gz'
100+
nb.Nifti1Image(data, np.eye(4), None).to_filename(str(in_file))
101+
102+
with pytest.raises(RuntimeError):
103+
SplitSeries(in_file=str(in_file)).run()
104+
105+
split = SplitSeries(in_file=str(in_file), accept_3D=True).run()
106+
assert isinstance(split.outputs.out_files, str)
107+
108+
# Test the 5D
109+
data = np.ones((20, 20, 20, 2, 2), dtype=float)
110+
in_file = tmp_path / 'input5D.nii.gz'
111+
nb.Nifti1Image(data, np.eye(4), None).to_filename(str(in_file))
112+
113+
with pytest.raises(RuntimeError):
114+
SplitSeries(in_file=str(in_file)).run()
115+
116+
with pytest.raises(RuntimeError):
117+
SplitSeries(in_file=str(in_file), accept_3D=True).run()
118+
119+
# Test splitting ANTs warpfields
120+
data = np.ones((20, 20, 20, 1, 3), dtype=float)
121+
in_file = tmp_path / 'warpfield.nii.gz'
122+
nb.Nifti1Image(data, np.eye(4), None).to_filename(str(in_file))
123+
124+
split = SplitSeries(in_file=str(in_file)).run()
125+
assert len(split.outputs.out_files) == 3
126+
127+
def test_MergeSeries(tmp_path):
128+
"""Test 3-to-4 NIfTI concatenation interface."""
129+
os.chdir(str(tmp_path))
130+
131+
data = np.ones((20, 20, 20), dtype=float)
132+
in_file = tmp_path / 'input3D.nii.gz'
133+
nb.Nifti1Image(data, np.eye(4), None).to_filename(str(in_file))
134+
135+
merge = MergeSeries(in_files=[str(in_file)] * 5).run()
136+
assert nb.load(merge.outputs.out_file).dataobj.shape == (20, 20, 20, 5)
137+
138+
in_4D = tmp_path / 'input4D.nii.gz'
139+
nb.Nifti1Image(
140+
np.ones((20, 20, 20, 4), dtype=float), np.eye(4), None
141+
).to_filename(str(in_4D))
142+
143+
merge = MergeSeries(in_files=[str(in_file)] + [str(in_4D)]).run()
144+
assert nb.load(merge.outputs.out_file).dataobj.shape == (20, 20, 20, 5)
145+
146+
with pytest.raises(ValueError):
147+
MergeSeries(in_files=[str(in_file)] + [str(in_4D)], allow_4D=False).run()

0 commit comments

Comments
 (0)