Skip to content

Commit 0b23506

Browse files
committed
fix: squeeze image with np.squeeze / change input name for consistency
1 parent f7ee490 commit 0b23506

File tree

2 files changed

+12
-6
lines changed

2 files changed

+12
-6
lines changed

niworkflows/interfaces/nibabel.py

Lines changed: 8 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -94,7 +94,7 @@ def _run_interface(self, runtime):
9494

9595
class _FourToThreeInputSpec(BaseInterfaceInputSpec):
9696
in_file = File(exists=True, mandatory=True, desc='input 4d image')
97-
accept_3D = traits.Bool(False, usedefault=True, desc='do not fail if a 3D volume is passed in')
97+
allow_3D = traits.Bool(False, usedefault=True, desc='do not fail if a 3D volume is passed in')
9898

9999

100100
class _FourToThreeOutputSpec(TraitedSpec):
@@ -110,17 +110,22 @@ class SplitSeries(SimpleInterface):
110110

111111
def _run_interface(self, runtime):
112112
filenii = nb.squeeze_image(nb.load(self.inputs.in_file))
113+
filenii = filenii.__class__(
114+
np.squeeze(filenii.dataobj), filenii.affine, filenii.header
115+
)
113116
ndim = filenii.dataobj.ndim
114117
if ndim != 4:
115-
if self.inputs.accept_3D and ndim == 3:
118+
if self.inputs.allow_3D and ndim == 3:
116119
out_file = str(
117120
Path(fname_presuffix(self.inputs.in_file, suffix=f"_idx-000")).absolute()
118121
)
119122
self._results['out_files'] = out_file
120123
filenii.to_filename(out_file)
121124
return runtime
122125
raise RuntimeError(
123-
f"Input image image is {ndim}D ({'x'.join(['%d' % s for s in filenii.shape])}).")
126+
f"Input image <{self.inputs.in_file}> is {ndim}D "
127+
f"({'x'.join(['%d' % s for s in filenii.shape])})."
128+
)
124129

125130
files_3d = nb.four_to_three(filenii)
126131
self._results['out_files'] = []

niworkflows/interfaces/tests/test_nibabel.py

Lines changed: 4 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -91,7 +91,7 @@ def test_SplitSeries(tmp_path):
9191
with pytest.raises(RuntimeError):
9292
SplitSeries(in_file=str(in_file)).run()
9393

94-
split = SplitSeries(in_file=str(in_file), accept_3D=True).run()
94+
split = SplitSeries(in_file=str(in_file), allow_3D=True).run()
9595
assert isinstance(split.outputs.out_files, str)
9696

9797
# Test the 3D
@@ -102,7 +102,7 @@ def test_SplitSeries(tmp_path):
102102
with pytest.raises(RuntimeError):
103103
SplitSeries(in_file=str(in_file)).run()
104104

105-
split = SplitSeries(in_file=str(in_file), accept_3D=True).run()
105+
split = SplitSeries(in_file=str(in_file), allow_3D=True).run()
106106
assert isinstance(split.outputs.out_files, str)
107107

108108
# Test the 5D
@@ -114,7 +114,7 @@ def test_SplitSeries(tmp_path):
114114
SplitSeries(in_file=str(in_file)).run()
115115

116116
with pytest.raises(RuntimeError):
117-
SplitSeries(in_file=str(in_file), accept_3D=True).run()
117+
SplitSeries(in_file=str(in_file), allow_3D=True).run()
118118

119119
# Test splitting ANTs warpfields
120120
data = np.ones((20, 20, 20, 1, 3), dtype=float)
@@ -124,6 +124,7 @@ def test_SplitSeries(tmp_path):
124124
split = SplitSeries(in_file=str(in_file)).run()
125125
assert len(split.outputs.out_files) == 3
126126

127+
127128
def test_MergeSeries(tmp_path):
128129
"""Test 3-to-4 NIfTI concatenation interface."""
129130
os.chdir(str(tmp_path))

0 commit comments

Comments
 (0)