Skip to content

Commit 7b28b4f

Browse files
authored
Merge pull request #489 from dPys/enh/nibabel_splitmerge_interfaces
ENH: Add nibabel-based split and merge interfaces
2 parents 602a7ca + d657546 commit 7b28b4f

File tree

2 files changed

+181
-37
lines changed

2 files changed

+181
-37
lines changed

niworkflows/interfaces/nibabel.py

Lines changed: 110 additions & 25 deletions
Original file line numberDiff line numberDiff line change
@@ -1,27 +1,34 @@
11
# emacs: -*- mode: python; py-indent-offset: 4; indent-tabs-mode: nil -*-
22
# vi: set ft=python sts=4 ts=4 sw=4 et:
33
"""Nibabel-based interfaces."""
4+
from pathlib import Path
45
import numpy as np
56
import nibabel as nb
67
from nipype import logging
78
from nipype.utils.filemanip import fname_presuffix
89
from nipype.interfaces.base import (
9-
traits, TraitedSpec, BaseInterfaceInputSpec, File,
10-
SimpleInterface
10+
traits,
11+
TraitedSpec,
12+
BaseInterfaceInputSpec,
13+
File,
14+
SimpleInterface,
15+
OutputMultiObject,
16+
InputMultiObject,
1117
)
1218

13-
IFLOGGER = logging.getLogger('nipype.interface')
19+
IFLOGGER = logging.getLogger("nipype.interface")
1420

1521

1622
class _ApplyMaskInputSpec(BaseInterfaceInputSpec):
17-
in_file = File(exists=True, mandatory=True, desc='an image')
18-
in_mask = File(exists=True, mandatory=True, desc='a mask')
19-
threshold = traits.Float(0.5, usedefault=True,
20-
desc='a threshold to the mask, if it is nonbinary')
23+
in_file = File(exists=True, mandatory=True, desc="an image")
24+
in_mask = File(exists=True, mandatory=True, desc="a mask")
25+
threshold = traits.Float(
26+
0.5, usedefault=True, desc="a threshold to the mask, if it is nonbinary"
27+
)
2128

2229

2330
class _ApplyMaskOutputSpec(TraitedSpec):
24-
out_file = File(exists=True, desc='masked file')
31+
out_file = File(exists=True, desc="masked file")
2532

2633

2734
class ApplyMask(SimpleInterface):
@@ -35,8 +42,9 @@ def _run_interface(self, runtime):
3542
msknii = nb.load(self.inputs.in_mask)
3643
msk = msknii.get_fdata() > self.inputs.threshold
3744

38-
self._results['out_file'] = fname_presuffix(
39-
self.inputs.in_file, suffix='_masked', newpath=runtime.cwd)
45+
self._results["out_file"] = fname_presuffix(
46+
self.inputs.in_file, suffix="_masked", newpath=runtime.cwd
47+
)
4048

4149
if img.dataobj.shape[:3] != msk.shape:
4250
raise ValueError("Image and mask sizes do not match.")
@@ -48,19 +56,18 @@ def _run_interface(self, runtime):
4856
msk = msk[..., np.newaxis]
4957

5058
masked = img.__class__(img.dataobj * msk, None, img.header)
51-
masked.to_filename(self._results['out_file'])
59+
masked.to_filename(self._results["out_file"])
5260
return runtime
5361

5462

5563
class _BinarizeInputSpec(BaseInterfaceInputSpec):
56-
in_file = File(exists=True, mandatory=True, desc='input image')
57-
thresh_low = traits.Float(mandatory=True,
58-
desc='non-inclusive lower threshold')
64+
in_file = File(exists=True, mandatory=True, desc="input image")
65+
thresh_low = traits.Float(mandatory=True, desc="non-inclusive lower threshold")
5966

6067

6168
class _BinarizeOutputSpec(TraitedSpec):
62-
out_file = File(exists=True, desc='masked file')
63-
out_mask = File(exists=True, desc='output mask')
69+
out_file = File(exists=True, desc="masked file")
70+
out_mask = File(exists=True, desc="output mask")
6471

6572

6673
class Binarize(SimpleInterface):
@@ -72,20 +79,98 @@ class Binarize(SimpleInterface):
7279
def _run_interface(self, runtime):
7380
img = nb.load(self.inputs.in_file)
7481

75-
self._results['out_file'] = fname_presuffix(
76-
self.inputs.in_file, suffix='_masked', newpath=runtime.cwd)
77-
self._results['out_mask'] = fname_presuffix(
78-
self.inputs.in_file, suffix='_mask', newpath=runtime.cwd)
82+
self._results["out_file"] = fname_presuffix(
83+
self.inputs.in_file, suffix="_masked", newpath=runtime.cwd
84+
)
85+
self._results["out_mask"] = fname_presuffix(
86+
self.inputs.in_file, suffix="_mask", newpath=runtime.cwd
87+
)
7988

8089
data = img.get_fdata()
8190
mask = data > self.inputs.thresh_low
8291
data[~mask] = 0.0
8392
masked = img.__class__(data, img.affine, img.header)
84-
masked.to_filename(self._results['out_file'])
93+
masked.to_filename(self._results["out_file"])
8594

86-
img.header.set_data_dtype('uint8')
87-
maskimg = img.__class__(mask.astype('uint8'), img.affine,
88-
img.header)
89-
maskimg.to_filename(self._results['out_mask'])
95+
img.header.set_data_dtype("uint8")
96+
maskimg = img.__class__(mask.astype("uint8"), img.affine, img.header)
97+
maskimg.to_filename(self._results["out_mask"])
9098

9199
return runtime
100+
101+
102+
class _SplitSeriesInputSpec(BaseInterfaceInputSpec):
103+
in_file = File(exists=True, mandatory=True, desc="input 4d image")
104+
105+
106+
class _SplitSeriesOutputSpec(TraitedSpec):
107+
out_files = OutputMultiObject(File(exists=True), desc="output list of 3d images")
108+
109+
110+
class SplitSeries(SimpleInterface):
111+
"""Split a 4D dataset along the last dimension into a series of 3D volumes."""
112+
113+
input_spec = _SplitSeriesInputSpec
114+
output_spec = _SplitSeriesOutputSpec
115+
116+
def _run_interface(self, runtime):
117+
in_file = self.inputs.in_file
118+
img = nb.load(in_file)
119+
extra_dims = tuple(dim for dim in img.shape[3:] if dim > 1) or (1,)
120+
if len(extra_dims) != 1:
121+
raise ValueError(f"Invalid shape {'x'.join(str(s) for s in img.shape)}")
122+
img = img.__class__(img.dataobj.reshape(img.shape[:3] + extra_dims),
123+
img.affine, img.header)
124+
125+
self._results["out_files"] = []
126+
for i, img_3d in enumerate(nb.four_to_three(img)):
127+
out_file = str(
128+
Path(fname_presuffix(in_file, suffix=f"_idx-{i:03}")).absolute()
129+
)
130+
img_3d.to_filename(out_file)
131+
self._results["out_files"].append(out_file)
132+
133+
return runtime
134+
135+
136+
class _MergeSeriesInputSpec(BaseInterfaceInputSpec):
137+
in_files = InputMultiObject(
138+
File(exists=True, mandatory=True, desc="input list of 3d images")
139+
)
140+
allow_4D = traits.Bool(
141+
True, usedefault=True, desc="whether 4D images are allowed to be concatenated"
142+
)
143+
144+
145+
class _MergeSeriesOutputSpec(TraitedSpec):
146+
out_file = File(exists=True, desc="output 4d image")
147+
148+
149+
class MergeSeries(SimpleInterface):
150+
"""Merge a series of 3D volumes along the last dimension into a single 4D image."""
151+
152+
input_spec = _MergeSeriesInputSpec
153+
output_spec = _MergeSeriesOutputSpec
154+
155+
def _run_interface(self, runtime):
156+
nii_list = []
157+
for f in self.inputs.in_files:
158+
filenii = nb.squeeze_image(nb.load(f))
159+
ndim = filenii.dataobj.ndim
160+
if ndim == 3:
161+
nii_list.append(filenii)
162+
continue
163+
elif self.inputs.allow_4D and ndim == 4:
164+
nii_list += nb.four_to_three(filenii)
165+
continue
166+
else:
167+
raise ValueError(
168+
"Input image has an incorrect number of dimensions" f" ({ndim})."
169+
)
170+
171+
img_4d = nb.concat_images(nii_list)
172+
out_file = fname_presuffix(self.inputs.in_files[0], suffix="_merged")
173+
img_4d.to_filename(out_file)
174+
175+
self._results["out_file"] = out_file
176+
return runtime

niworkflows/interfaces/tests/test_nibabel.py

Lines changed: 71 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -4,7 +4,7 @@
44
import nibabel as nb
55
import pytest
66

7-
from ..nibabel import Binarize, ApplyMask
7+
from ..nibabel import Binarize, ApplyMask, SplitSeries, MergeSeries
88

99

1010
def test_Binarize(tmp_path):
@@ -14,10 +14,10 @@ def test_Binarize(tmp_path):
1414
mask = np.zeros((20, 20, 20), dtype=bool)
1515
mask[5:15, 5:15, 5:15] = bool
1616

17-
data = np.zeros_like(mask, dtype='float32')
17+
data = np.zeros_like(mask, dtype="float32")
1818
data[mask] = np.random.gamma(2, size=mask.sum())
1919

20-
in_file = tmp_path / 'input.nii.gz'
20+
in_file = tmp_path / "input.nii.gz"
2121
nb.Nifti1Image(data, np.eye(4), None).to_filename(str(in_file))
2222

2323
binif = Binarize(thresh_low=0.0, in_file=str(in_file)).run()
@@ -36,28 +36,32 @@ def test_ApplyMask(tmp_path):
3636
mask[8:11, 8:11, 8:11] = 1.0
3737

3838
# Test the 3D
39-
in_file = tmp_path / 'input3D.nii.gz'
39+
in_file = tmp_path / "input3D.nii.gz"
4040
nb.Nifti1Image(data, np.eye(4), None).to_filename(str(in_file))
4141

42-
in_mask = tmp_path / 'mask.nii.gz'
42+
in_mask = tmp_path / "mask.nii.gz"
4343
nb.Nifti1Image(mask, np.eye(4), None).to_filename(str(in_mask))
4444

4545
masked1 = ApplyMask(in_file=str(in_file), in_mask=str(in_mask), threshold=0.4).run()
46-
assert nb.load(masked1.outputs.out_file).get_fdata().sum() == 5**3
46+
assert nb.load(masked1.outputs.out_file).get_fdata().sum() == 5 ** 3
4747

4848
masked1 = ApplyMask(in_file=str(in_file), in_mask=str(in_mask), threshold=0.6).run()
49-
assert nb.load(masked1.outputs.out_file).get_fdata().sum() == 3**3
49+
assert nb.load(masked1.outputs.out_file).get_fdata().sum() == 3 ** 3
5050

5151
data4d = np.stack((data, 2 * data, 3 * data), axis=-1)
5252
# Test the 4D case
53-
in_file4d = tmp_path / 'input4D.nii.gz'
53+
in_file4d = tmp_path / "input4D.nii.gz"
5454
nb.Nifti1Image(data4d, np.eye(4), None).to_filename(str(in_file4d))
5555

56-
masked1 = ApplyMask(in_file=str(in_file4d), in_mask=str(in_mask), threshold=0.4).run()
57-
assert nb.load(masked1.outputs.out_file).get_fdata().sum() == 5**3 * 6
56+
masked1 = ApplyMask(
57+
in_file=str(in_file4d), in_mask=str(in_mask), threshold=0.4
58+
).run()
59+
assert nb.load(masked1.outputs.out_file).get_fdata().sum() == 5 ** 3 * 6
5860

59-
masked1 = ApplyMask(in_file=str(in_file4d), in_mask=str(in_mask), threshold=0.6).run()
60-
assert nb.load(masked1.outputs.out_file).get_fdata().sum() == 3**3 * 6
61+
masked1 = ApplyMask(
62+
in_file=str(in_file4d), in_mask=str(in_mask), threshold=0.6
63+
).run()
64+
assert nb.load(masked1.outputs.out_file).get_fdata().sum() == 3 ** 3 * 6
6165

6266
# Test errors
6367
nb.Nifti1Image(mask, 2 * np.eye(4), None).to_filename(str(in_mask))
@@ -69,3 +73,58 @@ def test_ApplyMask(tmp_path):
6973
ApplyMask(in_file=str(in_file), in_mask=str(in_mask), threshold=0.4).run()
7074
with pytest.raises(ValueError):
7175
ApplyMask(in_file=str(in_file4d), in_mask=str(in_mask), threshold=0.4).run()
76+
77+
78+
@pytest.mark.parametrize("shape,exp_n", [
79+
((20, 20, 20, 15), 15),
80+
((20, 20, 20), 1),
81+
((20, 20, 20, 1), 1),
82+
((20, 20, 20, 1, 3), 3),
83+
((20, 20, 20, 3, 1), 3),
84+
((20, 20, 20, 1, 3, 3), -1),
85+
((20, 1, 20, 15), 15),
86+
((20, 1, 20), 1),
87+
((20, 1, 20, 1), 1),
88+
((20, 1, 20, 1, 3), 3),
89+
((20, 1, 20, 3, 1), 3),
90+
((20, 1, 20, 1, 3, 3), -1),
91+
])
92+
def test_SplitSeries(tmp_path, shape, exp_n):
93+
"""Test 4-to-3 NIfTI split interface."""
94+
os.chdir(tmp_path)
95+
96+
in_file = str(tmp_path / "input.nii.gz")
97+
nb.Nifti1Image(np.ones(shape, dtype=float), np.eye(4), None).to_filename(in_file)
98+
99+
_interface = SplitSeries(in_file=in_file)
100+
if exp_n > 0:
101+
split = _interface.run()
102+
n = int(isinstance(split.outputs.out_files, str)) or len(split.outputs.out_files)
103+
assert n == exp_n
104+
else:
105+
with pytest.raises(ValueError):
106+
_interface.run()
107+
108+
109+
def test_MergeSeries(tmp_path):
110+
"""Test 3-to-4 NIfTI concatenation interface."""
111+
os.chdir(str(tmp_path))
112+
113+
in_file = tmp_path / "input3D.nii.gz"
114+
nb.Nifti1Image(np.ones((20, 20, 20), dtype=float), np.eye(4), None).to_filename(
115+
str(in_file)
116+
)
117+
118+
merge = MergeSeries(in_files=[str(in_file)] * 5).run()
119+
assert nb.load(merge.outputs.out_file).dataobj.shape == (20, 20, 20, 5)
120+
121+
in_4D = tmp_path / "input4D.nii.gz"
122+
nb.Nifti1Image(np.ones((20, 20, 20, 4), dtype=float), np.eye(4), None).to_filename(
123+
str(in_4D)
124+
)
125+
126+
merge = MergeSeries(in_files=[str(in_file)] + [str(in_4D)]).run()
127+
assert nb.load(merge.outputs.out_file).dataobj.shape == (20, 20, 20, 5)
128+
129+
with pytest.raises(ValueError):
130+
MergeSeries(in_files=[str(in_file)] + [str(in_4D)], allow_4D=False).run()

0 commit comments

Comments
 (0)