Skip to content

Commit 23d8db4

Browse files
authored
Merge pull request #81 from josephmje/enh/skullstrip
ENH: Update image utility output path behaviour
2 parents 199a496 + a0e205c commit 23d8db4

File tree

2 files changed

+109
-73
lines changed

2 files changed

+109
-73
lines changed

dmriprep/interfaces/images.py

Lines changed: 39 additions & 73 deletions
Original file line numberDiff line numberDiff line change
@@ -1,23 +1,27 @@
11
"""Image tools interfaces."""
2-
import numpy as np
3-
import nibabel as nb
4-
from nipype.utils.filemanip import fname_presuffix
2+
from pathlib import Path
3+
54
from nipype import logging
65
from nipype.interfaces.base import (
7-
traits, TraitedSpec, BaseInterfaceInputSpec, SimpleInterface, File
6+
BaseInterfaceInputSpec,
7+
File,
8+
SimpleInterface,
9+
TraitedSpec,
10+
traits
811
)
912

13+
from dmriprep.utils.images import extract_b0, median, rescale_b0
14+
1015
LOGGER = logging.getLogger('nipype.interface')
1116

1217

1318
class _ExtractB0InputSpec(BaseInterfaceInputSpec):
1419
in_file = File(exists=True, mandatory=True, desc='dwi file')
15-
b0_ixs = traits.List(traits.Int, mandatory=True,
16-
desc='Index of b0s')
20+
b0_ixs = traits.List(traits.Int, mandatory=True, desc='Index of b0s')
1721

1822

1923
class _ExtractB0OutputSpec(TraitedSpec):
20-
out_file = File(exists=True, desc='b0 file')
24+
out_file = File(exists=True, desc='output b0 file')
2125

2226

2327
class ExtractB0(SimpleInterface):
@@ -38,29 +42,19 @@ class ExtractB0(SimpleInterface):
3842
output_spec = _ExtractB0OutputSpec
3943

4044
def _run_interface(self, runtime):
41-
self._results['out_file'] = extract_b0(
42-
self.inputs.in_file,
43-
self.inputs.b0_ixs,
44-
newpath=runtime.cwd)
45-
return runtime
46-
47-
48-
def extract_b0(in_file, b0_ixs, newpath=None):
49-
"""Extract the *b0* volumes from a DWI dataset."""
50-
out_file = fname_presuffix(
51-
in_file, suffix='_b0', newpath=newpath)
52-
53-
img = nb.load(in_file)
54-
data = img.get_fdata(dtype='float32')
45+
from nipype.utils.filemanip import fname_presuffix
5546

56-
b0 = data[..., b0_ixs]
47+
out_file = fname_presuffix(
48+
self.inputs.in_file,
49+
suffix='_b0',
50+
use_ext=True,
51+
newpath=str(Path(runtime.cwd).absolute()),
52+
)
5753

58-
hdr = img.header.copy()
59-
hdr.set_data_shape(b0.shape)
60-
hdr.set_xyzt_units('mm')
61-
hdr.set_data_dtype(np.float32)
62-
nb.Nifti1Image(b0, img.affine, hdr).to_filename(out_file)
63-
return out_file
54+
self._results['out_file'] = extract_b0(
55+
self.inputs.in_file, self.inputs.b0_ixs, out_file
56+
)
57+
return runtime
6458

6559

6660
class _RescaleB0InputSpec(BaseInterfaceInputSpec):
@@ -91,55 +85,27 @@ class RescaleB0(SimpleInterface):
9185
output_spec = _RescaleB0OutputSpec
9286

9387
def _run_interface(self, runtime):
88+
from nipype.utils.filemanip import fname_presuffix
89+
90+
out_b0s = fname_presuffix(
91+
self.inputs.in_file,
92+
suffix='_rescaled',
93+
use_ext=True,
94+
newpath=str(Path(runtime.cwd).absolute())
95+
)
96+
out_ref = fname_presuffix(
97+
self.inputs.in_file,
98+
suffix='_ref',
99+
use_ext=True,
100+
newpath=str(Path(runtime.cwd).absolute())
101+
)
102+
94103
self._results['out_b0s'] = rescale_b0(
95104
self.inputs.in_file,
96-
self.inputs.mask_file,
97-
newpath=runtime.cwd
105+
self.inputs.mask_file, out_b0s
98106
)
99107
self._results['out_ref'] = median(
100108
self._results['out_b0s'],
101-
newpath=runtime.cwd
109+
out_path=out_ref
102110
)
103111
return runtime
104-
105-
106-
def rescale_b0(in_file, mask_file, newpath=None):
107-
"""Rescale the input volumes using the median signal intensity."""
108-
out_file = fname_presuffix(
109-
in_file, suffix='_rescaled_b0', newpath=newpath)
110-
111-
img = nb.load(in_file)
112-
if img.dataobj.ndim == 3:
113-
return in_file
114-
115-
data = img.get_fdata(dtype='float32')
116-
mask_img = nb.load(mask_file)
117-
mask_data = mask_img.get_fdata(dtype='float32')
118-
119-
median_signal = np.median(data[mask_data > 0, ...], axis=0)
120-
rescaled_data = 1000 * data / median_signal
121-
hdr = img.header.copy()
122-
nb.Nifti1Image(rescaled_data, img.affine, hdr).to_filename(out_file)
123-
return out_file
124-
125-
126-
def median(in_file, newpath=None):
127-
"""Average a 4D dataset across the last dimension using median."""
128-
out_file = fname_presuffix(
129-
in_file, suffix='_b0ref', newpath=newpath)
130-
131-
img = nb.load(in_file)
132-
if img.dataobj.ndim == 3:
133-
return in_file
134-
if img.shape[-1] == 1:
135-
nb.squeeze_image(img).to_filename(out_file)
136-
return out_file
137-
138-
median_data = np.median(img.get_fdata(dtype='float32'),
139-
axis=-1)
140-
141-
hdr = img.header.copy()
142-
hdr.set_xyzt_units('mm')
143-
hdr.set_data_dtype(np.float32)
144-
nb.Nifti1Image(median_data, img.affine, hdr).to_filename(out_file)
145-
return out_file

dmriprep/utils/images.py

Lines changed: 70 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,70 @@
1+
import nibabel as nb
2+
import numpy as np
3+
from nipype.utils.filemanip import fname_presuffix
4+
5+
6+
def extract_b0(in_file, b0_ixs, out_path=None):
7+
"""Extract the *b0* volumes from a DWI dataset."""
8+
if out_path is None:
9+
out_path = fname_presuffix(
10+
in_file, suffix='_b0', use_ext=True)
11+
12+
img = nb.load(in_file)
13+
data = img.get_fdata()
14+
15+
b0 = data[..., b0_ixs]
16+
17+
hdr = img.header.copy()
18+
hdr.set_data_shape(b0.shape)
19+
hdr.set_xyzt_units('mm')
20+
nb.Nifti1Image(b0.astype(hdr.get_data_dtype()), img.affine, hdr).to_filename(out_path)
21+
return out_path
22+
23+
24+
def rescale_b0(in_file, mask_file, out_path=None):
25+
"""Rescale the input volumes using the median signal intensity."""
26+
if out_path is None:
27+
out_path = fname_presuffix(
28+
in_file, suffix='_rescaled', use_ext=True)
29+
30+
img = nb.load(in_file)
31+
if img.dataobj.ndim == 3:
32+
return in_file
33+
34+
data = img.get_fdata()
35+
mask_img = nb.load(mask_file)
36+
mask_data = mask_img.get_fdata()
37+
38+
median_signal = np.median(data[mask_data > 0, ...], axis=0)
39+
rescaled_data = 1000 * data / median_signal
40+
hdr = img.header.copy()
41+
nb.Nifti1Image(
42+
rescaled_data.astype(hdr.get_data_dtype()), img.affine, hdr
43+
).to_filename(out_path)
44+
return out_path
45+
46+
47+
def median(in_file, dtype=None, out_path=None):
48+
"""Average a 4D dataset across the last dimension using median."""
49+
if out_path is None:
50+
out_path = fname_presuffix(
51+
in_file, suffix='_b0ref', use_ext=True)
52+
53+
img = nb.load(in_file)
54+
if img.dataobj.ndim == 3:
55+
return in_file
56+
if img.shape[-1] == 1:
57+
nb.squeeze_image(img).to_filename(out_path)
58+
return out_path
59+
60+
median_data = np.median(img.get_fdata(dtype=dtype),
61+
axis=-1)
62+
63+
hdr = img.header.copy()
64+
hdr.set_xyzt_units('mm')
65+
if dtype is not None:
66+
hdr.set_data_dtype(dtype)
67+
else:
68+
dtype = hdr.get_data_dtype()
69+
nb.Nifti1Image(median_data.astype(dtype), img.affine, hdr).to_filename(out_path)
70+
return out_path

0 commit comments

Comments
 (0)