Skip to content

Commit 49343a0

Browse files
authored
Merge pull request #492 from mgxd/enh/derivative-dtype
ENH: Add option to change data dtype
2 parents 9e41877 + f57e2a1 commit 49343a0

File tree

2 files changed

+73
-43
lines changed

2 files changed

+73
-43
lines changed

niworkflows/interfaces/bids.py

Lines changed: 49 additions & 22 deletions
Original file line numberDiff line numberDiff line change
@@ -2,11 +2,13 @@
22
# vi: set ft=python sts=4 ts=4 sw=4 et:
33
"""Interfaces for handling BIDS-like neuroimaging structures."""
44

5+
from collections import defaultdict
56
from json import dumps
67
from pathlib import Path
78
from shutil import copytree, rmtree
89

910
import nibabel as nb
11+
import numpy as np
1012

1113
from nipype import logging
1214
from nipype.interfaces.base import (
@@ -26,6 +28,18 @@
2628
LOGGER = logging.getLogger('nipype.interface')
2729

2830

31+
def _none():
32+
return None
33+
34+
35+
# Automatically coerce certain suffixes (DerivativesDataSink)
36+
DEFAULT_DTYPES = defaultdict(_none, (
37+
("mask", "uint8"),
38+
("dseg", "int16"),
39+
("probseg", "float32"))
40+
)
41+
42+
2943
class _BIDSBaseInputSpec(BaseInterfaceInputSpec):
3044
bids_dir = traits.Either(
3145
(None, Directory(exists=True)), usedefault=True,
@@ -238,6 +252,7 @@ class _DerivativesDataSinkInputSpec(DynamicTraitedSpec, BaseInterfaceInputSpec):
238252
in_file = InputMultiObject(File(exists=True), mandatory=True,
239253
desc='the object to be saved')
240254
keep_dtype = traits.Bool(False, usedefault=True, desc='keep datatype suffix')
255+
data_dtype = Str(desc='NumPy datatype to coerce NIfTI data to')
241256
meta_dict = traits.DictStrAny(desc='an input dictionary containing metadata')
242257
source_file = File(exists=False, mandatory=True, desc='the input func file')
243258
space = Str('', usedefault=True, desc='Label for space field')
@@ -464,34 +479,46 @@ def _run_interface(self, runtime):
464479
self._results['out_file'].append(out_file)
465480
self._results['compression'].append(_copy_any(fname, out_file))
466481

467-
is_nii = out_file.endswith('.nii') or out_file.endswith('.nii.gz')
468-
if self.inputs.check_hdr and is_nii:
482+
is_nii = out_file.endswith(('.nii', '.nii.gz'))
483+
data_dtype = self.inputs.data_dtype or DEFAULT_DTYPES[self.inputs.suffix]
484+
if is_nii and any((self.inputs.check_hdr, data_dtype)):
469485
# Do not use mmap; if we need to access the data at all, it will be to
470486
# rewrite, risking a BusError
471487
nii = nb.load(out_file, mmap=False)
472488
if not isinstance(nii, (nb.Nifti1Image, nb.Nifti2Image)):
473489
# .dtseries.nii are CIfTI2, therefore skip check
474490
return runtime
475-
hdr = nii.header
476-
curr_units = tuple([None if u == 'unknown' else u
477-
for u in hdr.get_xyzt_units()])
478-
curr_codes = (int(hdr['qform_code']), int(hdr['sform_code']))
479-
480-
# Default to mm, use sec if data type is bold
481-
units = (curr_units[0] or 'mm', 'sec' if dtype == '_bold' else None)
482-
xcodes = (1, 1) # Derivative in its original scanner space
483-
if self.inputs.space:
484-
xcodes = (4, 4) if self.inputs.space in STANDARD_SPACES \
485-
else (2, 2)
486-
487-
if curr_codes != xcodes or curr_units != units:
488-
self._results['fixed_hdr'][i] = True
489-
hdr.set_qform(nii.affine, xcodes[0])
490-
hdr.set_sform(nii.affine, xcodes[1])
491-
hdr.set_xyzt_units(*units)
492-
493-
# Rewrite file with new header
494-
overwrite_header(nii, out_file)
491+
492+
if self.inputs.check_hdr:
493+
hdr = nii.header
494+
curr_units = tuple([None if u == 'unknown' else u
495+
for u in hdr.get_xyzt_units()])
496+
curr_codes = (int(hdr['qform_code']), int(hdr['sform_code']))
497+
498+
# Default to mm, use sec if data type is bold
499+
units = (curr_units[0] or 'mm', 'sec' if dtype == '_bold' else None)
500+
xcodes = (1, 1) # Derivative in its original scanner space
501+
if self.inputs.space:
502+
xcodes = (4, 4) if self.inputs.space in STANDARD_SPACES \
503+
else (2, 2)
504+
505+
if curr_codes != xcodes or curr_units != units:
506+
self._results['fixed_hdr'][i] = True
507+
hdr.set_qform(nii.affine, xcodes[0])
508+
hdr.set_sform(nii.affine, xcodes[1])
509+
hdr.set_xyzt_units(*units)
510+
511+
# Rewrite file with new header
512+
overwrite_header(nii, out_file)
513+
514+
if data_dtype:
515+
if self.inputs.check_hdr:
516+
# load updated NIfTI
517+
nii = nb.load(out_file, mmap=False)
518+
data_dtype = np.dtype(data_dtype)
519+
if nii.get_data_dtype() != data_dtype:
520+
nii.set_data_dtype(data_dtype)
521+
nii.to_filename(out_file)
495522

496523
if len(self._results['out_file']) == 1:
497524
meta_fields = self.inputs.copyable_trait_names()

niworkflows/interfaces/tests/test_bids.py

Lines changed: 24 additions & 21 deletions
Original file line numberDiff line numberDiff line change
@@ -19,28 +19,28 @@
1919
BOLD_PATH = 'ds054/sub-100185/func/sub-100185_task-machinegame_run-01_bold.nii.gz'
2020

2121

22-
@pytest.mark.parametrize('space, size, units, xcodes, zipped, fixed', [
23-
('T1w', (30, 30, 30, 10), ('mm', 'sec'), (2, 2), True, [False]),
24-
('T1w', (30, 30, 30, 10), ('mm', 'sec'), (0, 2), True, [True]),
25-
('T1w', (30, 30, 30, 10), ('mm', 'sec'), (0, 0), True, [True]),
26-
('T1w', (30, 30, 30, 10), ('mm', None), (2, 2), True, [True]),
27-
('T1w', (30, 30, 30, 10), (None, None), (0, 2), True, [True]),
28-
('T1w', (30, 30, 30, 10), (None, 'sec'), (0, 0), True, [True]),
29-
('MNI152Lin', (30, 30, 30, 10), ('mm', 'sec'), (4, 4), True, [False]),
30-
('MNI152Lin', (30, 30, 30, 10), ('mm', 'sec'), (0, 2), True, [True]),
31-
('MNI152Lin', (30, 30, 30, 10), ('mm', 'sec'), (0, 0), True, [True]),
32-
('MNI152Lin', (30, 30, 30, 10), ('mm', None), (4, 4), True, [True]),
33-
('MNI152Lin', (30, 30, 30, 10), (None, None), (0, 2), True, [True]),
34-
('MNI152Lin', (30, 30, 30, 10), (None, 'sec'), (0, 0), True, [True]),
35-
(None, (30, 30, 30, 10), ('mm', 'sec'), (1, 1), True, [False]),
36-
(None, (30, 30, 30, 10), ('mm', 'sec'), (0, 0), True, [True]),
37-
(None, (30, 30, 30, 10), ('mm', 'sec'), (0, 2), True, [True]),
38-
(None, (30, 30, 30, 10), ('mm', None), (1, 1), True, [True]),
39-
(None, (30, 30, 30, 10), (None, None), (0, 2), True, [True]),
40-
(None, (30, 30, 30, 10), (None, 'sec'), (0, 0), True, [True]),
41-
(None, (30, 30, 30, 10), (None, 'sec'), (0, 0), False, [True]),
22+
@pytest.mark.parametrize('space, size, units, xcodes, zipped, fixed, data_dtype', [
23+
('T1w', (30, 30, 30, 10), ('mm', 'sec'), (2, 2), True, [False], None),
24+
('T1w', (30, 30, 30, 10), ('mm', 'sec'), (0, 2), True, [True], None),
25+
('T1w', (30, 30, 30, 10), ('mm', 'sec'), (0, 0), True, [True], '<i4'),
26+
('T1w', (30, 30, 30, 10), ('mm', None), (2, 2), True, [True], '<f4'),
27+
('T1w', (30, 30, 30, 10), (None, None), (0, 2), True, [True], None),
28+
('T1w', (30, 30, 30, 10), (None, 'sec'), (0, 0), True, [True], None),
29+
('MNI152Lin', (30, 30, 30, 10), ('mm', 'sec'), (4, 4), True, [False], None),
30+
('MNI152Lin', (30, 30, 30, 10), ('mm', 'sec'), (0, 2), True, [True], None),
31+
('MNI152Lin', (30, 30, 30, 10), ('mm', 'sec'), (0, 0), True, [True], None),
32+
('MNI152Lin', (30, 30, 30, 10), ('mm', None), (4, 4), True, [True], None),
33+
('MNI152Lin', (30, 30, 30, 10), (None, None), (0, 2), True, [True], None),
34+
('MNI152Lin', (30, 30, 30, 10), (None, 'sec'), (0, 0), True, [True], None),
35+
(None, (30, 30, 30, 10), ('mm', 'sec'), (1, 1), True, [False], None),
36+
(None, (30, 30, 30, 10), ('mm', 'sec'), (0, 0), True, [True], None),
37+
(None, (30, 30, 30, 10), ('mm', 'sec'), (0, 2), True, [True], None),
38+
(None, (30, 30, 30, 10), ('mm', None), (1, 1), True, [True], None),
39+
(None, (30, 30, 30, 10), (None, None), (0, 2), True, [True], None),
40+
(None, (30, 30, 30, 10), (None, 'sec'), (0, 0), True, [True], None),
41+
(None, (30, 30, 30, 10), (None, 'sec'), (0, 0), False, [True], None),
4242
])
43-
def test_DerivativesDataSink_bold(tmp_path, space, size, units, xcodes, zipped, fixed):
43+
def test_DerivativesDataSink_bold(tmp_path, space, size, units, xcodes, zipped, fixed, data_dtype):
4444
fname = str(tmp_path / 'source.nii') + ('.gz' if zipped else '')
4545

4646
hdr = nb.Nifti1Header()
@@ -53,6 +53,7 @@ def test_DerivativesDataSink_bold(tmp_path, space, size, units, xcodes, zipped,
5353
dds = bintfs.DerivativesDataSink(
5454
base_directory=str(tmp_path),
5555
keep_dtype=True,
56+
data_dtype=data_dtype or Undefined,
5657
desc='preproc',
5758
source_file=BOLD_PATH,
5859
space=space or Undefined,
@@ -61,6 +62,8 @@ def test_DerivativesDataSink_bold(tmp_path, space, size, units, xcodes, zipped,
6162

6263
nii = nb.load(dds.outputs.out_file)
6364
assert dds.outputs.fixed_hdr == fixed
65+
if data_dtype:
66+
assert nii.get_data_dtype() == np.dtype(data_dtype)
6467
assert int(nii.header['qform_code']) == XFORM_CODES[space]
6568
assert int(nii.header['sform_code']) == XFORM_CODES[space]
6669
assert nii.header.get_xyzt_units() == ('mm', 'sec')

0 commit comments

Comments
 (0)