Skip to content

Commit 0dcec02

Browse files
committed
ENH: Add option to change data dtype
1 parent b6f810e commit 0dcec02

File tree

2 files changed

+59
-43
lines changed

2 files changed

+59
-43
lines changed

niworkflows/interfaces/bids.py

Lines changed: 35 additions & 22 deletions
Original file line numberDiff line numberDiff line change
@@ -7,6 +7,7 @@
77
from shutil import copytree, rmtree
88

99
import nibabel as nb
10+
import numpy as np
1011

1112
from nipype import logging
1213
from nipype.interfaces.base import (
@@ -238,6 +239,7 @@ class _DerivativesDataSinkInputSpec(DynamicTraitedSpec, BaseInterfaceInputSpec):
238239
in_file = InputMultiObject(File(exists=True), mandatory=True,
239240
desc='the object to be saved')
240241
keep_dtype = traits.Bool(False, usedefault=True, desc='keep datatype suffix')
242+
data_dtype = Str(desc='NumPy datatype to coerce NIfTI data to')
241243
meta_dict = traits.DictStrAny(desc='an input dictionary containing metadata')
242244
source_file = File(exists=False, mandatory=True, desc='the input func file')
243245
space = Str('', usedefault=True, desc='Label for space field')
@@ -464,34 +466,45 @@ def _run_interface(self, runtime):
464466
self._results['out_file'].append(out_file)
465467
self._results['compression'].append(_copy_any(fname, out_file))
466468

467-
is_nii = out_file.endswith('.nii') or out_file.endswith('.nii.gz')
468-
if self.inputs.check_hdr and is_nii:
469+
is_nii = out_file.endswith(('.nii', '.nii.gz'))
470+
if is_nii and any((self.inputs.check_hdr, self.inputs.data_dtype)):
469471
# Do not use mmap; if we need to access the data at all, it will be to
470472
# rewrite, risking a BusError
471473
nii = nb.load(out_file, mmap=False)
472474
if not isinstance(nii, (nb.Nifti1Image, nb.Nifti2Image)):
473475
# .dtseries.nii are CIfTI2, therefore skip check
474476
return runtime
475-
hdr = nii.header
476-
curr_units = tuple([None if u == 'unknown' else u
477-
for u in hdr.get_xyzt_units()])
478-
curr_codes = (int(hdr['qform_code']), int(hdr['sform_code']))
479-
480-
# Default to mm, use sec if data type is bold
481-
units = (curr_units[0] or 'mm', 'sec' if dtype == '_bold' else None)
482-
xcodes = (1, 1) # Derivative in its original scanner space
483-
if self.inputs.space:
484-
xcodes = (4, 4) if self.inputs.space in STANDARD_SPACES \
485-
else (2, 2)
486-
487-
if curr_codes != xcodes or curr_units != units:
488-
self._results['fixed_hdr'][i] = True
489-
hdr.set_qform(nii.affine, xcodes[0])
490-
hdr.set_sform(nii.affine, xcodes[1])
491-
hdr.set_xyzt_units(*units)
492-
493-
# Rewrite file with new header
494-
overwrite_header(nii, out_file)
477+
478+
if self.inputs.check_hdr:
479+
hdr = nii.header
480+
curr_units = tuple([None if u == 'unknown' else u
481+
for u in hdr.get_xyzt_units()])
482+
curr_codes = (int(hdr['qform_code']), int(hdr['sform_code']))
483+
484+
# Default to mm, use sec if data type is bold
485+
units = (curr_units[0] or 'mm', 'sec' if dtype == '_bold' else None)
486+
xcodes = (1, 1) # Derivative in its original scanner space
487+
if self.inputs.space:
488+
xcodes = (4, 4) if self.inputs.space in STANDARD_SPACES \
489+
else (2, 2)
490+
491+
if curr_codes != xcodes or curr_units != units:
492+
self._results['fixed_hdr'][i] = True
493+
hdr.set_qform(nii.affine, xcodes[0])
494+
hdr.set_sform(nii.affine, xcodes[1])
495+
hdr.set_xyzt_units(*units)
496+
497+
# Rewrite file with new header
498+
overwrite_header(nii, out_file)
499+
500+
if self.inputs.data_dtype:
501+
if self.inputs.check_hdr:
502+
# load updated NIfTI
503+
nii = nb.load(out_file, mmap=False)
504+
data_dtype = np.dtype(self.inputs.data_dtype)
505+
if nii.get_data_dtype() != data_dtype:
506+
nii.set_data_dtype(data_dtype)
507+
nii.to_filename(out_file)
495508

496509
if len(self._results['out_file']) == 1:
497510
meta_fields = self.inputs.copyable_trait_names()

niworkflows/interfaces/tests/test_bids.py

Lines changed: 24 additions & 21 deletions
Original file line numberDiff line numberDiff line change
@@ -19,28 +19,28 @@
1919
BOLD_PATH = 'ds054/sub-100185/func/sub-100185_task-machinegame_run-01_bold.nii.gz'
2020

2121

22-
@pytest.mark.parametrize('space, size, units, xcodes, zipped, fixed', [
23-
('T1w', (30, 30, 30, 10), ('mm', 'sec'), (2, 2), True, [False]),
24-
('T1w', (30, 30, 30, 10), ('mm', 'sec'), (0, 2), True, [True]),
25-
('T1w', (30, 30, 30, 10), ('mm', 'sec'), (0, 0), True, [True]),
26-
('T1w', (30, 30, 30, 10), ('mm', None), (2, 2), True, [True]),
27-
('T1w', (30, 30, 30, 10), (None, None), (0, 2), True, [True]),
28-
('T1w', (30, 30, 30, 10), (None, 'sec'), (0, 0), True, [True]),
29-
('MNI152Lin', (30, 30, 30, 10), ('mm', 'sec'), (4, 4), True, [False]),
30-
('MNI152Lin', (30, 30, 30, 10), ('mm', 'sec'), (0, 2), True, [True]),
31-
('MNI152Lin', (30, 30, 30, 10), ('mm', 'sec'), (0, 0), True, [True]),
32-
('MNI152Lin', (30, 30, 30, 10), ('mm', None), (4, 4), True, [True]),
33-
('MNI152Lin', (30, 30, 30, 10), (None, None), (0, 2), True, [True]),
34-
('MNI152Lin', (30, 30, 30, 10), (None, 'sec'), (0, 0), True, [True]),
35-
(None, (30, 30, 30, 10), ('mm', 'sec'), (1, 1), True, [False]),
36-
(None, (30, 30, 30, 10), ('mm', 'sec'), (0, 0), True, [True]),
37-
(None, (30, 30, 30, 10), ('mm', 'sec'), (0, 2), True, [True]),
38-
(None, (30, 30, 30, 10), ('mm', None), (1, 1), True, [True]),
39-
(None, (30, 30, 30, 10), (None, None), (0, 2), True, [True]),
40-
(None, (30, 30, 30, 10), (None, 'sec'), (0, 0), True, [True]),
41-
(None, (30, 30, 30, 10), (None, 'sec'), (0, 0), False, [True]),
22+
@pytest.mark.parametrize('space, size, units, xcodes, zipped, fixed, data_dtype', [
23+
('T1w', (30, 30, 30, 10), ('mm', 'sec'), (2, 2), True, [False], None),
24+
('T1w', (30, 30, 30, 10), ('mm', 'sec'), (0, 2), True, [True], None),
25+
('T1w', (30, 30, 30, 10), ('mm', 'sec'), (0, 0), True, [True], '<i4'),
26+
('T1w', (30, 30, 30, 10), ('mm', None), (2, 2), True, [True], '<f4'),
27+
('T1w', (30, 30, 30, 10), (None, None), (0, 2), True, [True], None),
28+
('T1w', (30, 30, 30, 10), (None, 'sec'), (0, 0), True, [True], None),
29+
('MNI152Lin', (30, 30, 30, 10), ('mm', 'sec'), (4, 4), True, [False], None),
30+
('MNI152Lin', (30, 30, 30, 10), ('mm', 'sec'), (0, 2), True, [True], None),
31+
('MNI152Lin', (30, 30, 30, 10), ('mm', 'sec'), (0, 0), True, [True], None),
32+
('MNI152Lin', (30, 30, 30, 10), ('mm', None), (4, 4), True, [True], None),
33+
('MNI152Lin', (30, 30, 30, 10), (None, None), (0, 2), True, [True], None),
34+
('MNI152Lin', (30, 30, 30, 10), (None, 'sec'), (0, 0), True, [True], None),
35+
(None, (30, 30, 30, 10), ('mm', 'sec'), (1, 1), True, [False], None),
36+
(None, (30, 30, 30, 10), ('mm', 'sec'), (0, 0), True, [True], None),
37+
(None, (30, 30, 30, 10), ('mm', 'sec'), (0, 2), True, [True], None),
38+
(None, (30, 30, 30, 10), ('mm', None), (1, 1), True, [True], None),
39+
(None, (30, 30, 30, 10), (None, None), (0, 2), True, [True], None),
40+
(None, (30, 30, 30, 10), (None, 'sec'), (0, 0), True, [True], None),
41+
(None, (30, 30, 30, 10), (None, 'sec'), (0, 0), False, [True], None),
4242
])
43-
def test_DerivativesDataSink_bold(tmp_path, space, size, units, xcodes, zipped, fixed):
43+
def test_DerivativesDataSink_bold(tmp_path, space, size, units, xcodes, zipped, fixed, data_dtype):
4444
fname = str(tmp_path / 'source.nii') + ('.gz' if zipped else '')
4545

4646
hdr = nb.Nifti1Header()
@@ -53,6 +53,7 @@ def test_DerivativesDataSink_bold(tmp_path, space, size, units, xcodes, zipped,
5353
dds = bintfs.DerivativesDataSink(
5454
base_directory=str(tmp_path),
5555
keep_dtype=True,
56+
data_dtype=data_dtype or Undefined,
5657
desc='preproc',
5758
source_file=BOLD_PATH,
5859
space=space or Undefined,
@@ -61,6 +62,8 @@ def test_DerivativesDataSink_bold(tmp_path, space, size, units, xcodes, zipped,
6162

6263
nii = nb.load(dds.outputs.out_file)
6364
assert dds.outputs.fixed_hdr == fixed
65+
if data_dtype:
66+
assert nii.get_data_dtype() == np.dtype(data_dtype)
6467
assert int(nii.header['qform_code']) == XFORM_CODES[space]
6568
assert int(nii.header['sform_code']) == XFORM_CODES[space]
6669
assert nii.header.get_xyzt_units() == ('mm', 'sec')

0 commit comments

Comments
 (0)