Skip to content
101 changes: 95 additions & 6 deletions nibabel/cifti2/cifti2.py
Original file line number Diff line number Diff line change
Expand Up @@ -25,6 +25,7 @@
from ..nifti1 import Nifti1Extensions
from ..nifti2 import Nifti2Image, Nifti2Header
from ..arrayproxy import reshape_dataobj
from ..volumeutils import Recoder
from warnings import warn


Expand Down Expand Up @@ -89,6 +90,53 @@ class Cifti2HeaderError(Exception):
'CIFTI_STRUCTURE_THALAMUS_LEFT',
'CIFTI_STRUCTURE_THALAMUS_RIGHT')

# "Standard CIFTI Mapping Combinations" within CIFTI-2 spec
# https://www.nitrc.org/forum/attachment.php?attachid=341&group_id=454&forum_id=1955
CIFTI_CODES = Recoder((
('.dconn.nii', 'NIFTI_INTENT_CONNECTIVITY_DENSE', (
'CIFTI_INDEX_TYPE_BRAIN_MODELS', 'CIFTI_INDEX_TYPE_BRAIN_MODELS',
)),
('.dtseries.nii', 'NIFTI_INTENT_CONNECTIVITY_DENSE_SERIES', (
'CIFTI_INDEX_TYPE_SERIES', 'CIFTI_INDEX_TYPE_BRAIN_MODELS',
)),
('.pconn.nii', 'NIFTI_INTENT_CONNECTIVITY_PARCELLATED', (
'CIFTI_INDEX_TYPE_PARCELS', 'CIFTI_INDEX_TYPE_PARCELS',
)),
('.ptseries.nii', 'NIFTI_INTENT_CONNECTIVITY_PARCELLATED_SERIES', (
'CIFTI_INDEX_TYPE_SERIES', 'CIFTI_INDEX_TYPE_PARCELS',
)),
('.dscalar.nii', 'NIFTI_INTENT_CONNECTIVITY_DENSE_SCALARS', (
'CIFTI_INDEX_TYPE_SCALARS', 'CIFTI_INDEX_TYPE_BRAIN_MODELS',
)),
('.dlabel.nii', 'NIFTI_INTENT_CONNECTIVITY_DENSE_LABELS', (
'CIFTI_INDEX_TYPE_LABELS', 'CIFTI_INDEX_TYPE_BRAIN_MODELS',
)),
('.pscalar.nii', 'NIFTI_INTENT_CONNECTIVITY_PARCELLATED_SCALAR', (
'CIFTI_INDEX_TYPE_SCALARS', 'CIFTI_INDEX_TYPE_PARCELS',
)),
('.pdconn.nii', 'NIFTI_INTENT_CONNECTIVITY_PARCELLATED_DENSE', (
'CIFTI_INDEX_TYPE_BRAIN_MODELS', 'CIFTI_INDEX_TYPE_PARCELS',
)),
('.dpconn.nii', 'NIFTI_INTENT_CONNECTIVITY_DENSE_PARCELLATED', (
'CIFTI_INDEX_TYPE_PARCELS', 'CIFTI_INDEX_TYPE_BRAIN_MODELS',
)),
('.pconnseries.nii', 'NIFTI_INTENT_CONNECTIVITY_PARCELLATED_PARCELLATED_SERIES', (
'CIFTI_INDEX_TYPE_PARCELS', 'CIFTI_INDEX_TYPE_PARCELS', 'CIFTI_INDEX_TYPE_SERIES',
)),
('.pconnscalar.nii', 'NIFTI_INTENT_CONNECTIVITY_PARCELLATED_PARCELLATED_SCALAR', (
'CIFTI_INDEX_TYPE_PARCELS', 'CIFTI_INDEX_TYPE_PARCELS', 'CIFTI_INDEX_TYPE_SCALARS',
)),
('.dfan.nii', 'NIFTI_INTENT_CONNECTIVITY_DENSE_SERIES', (
'CIFTI_INDEX_TYPE_SCALARS', 'CIFTI_INDEX_TYPE_BRAIN_MODELS',
)),
('.dfibersamp.nii', 'NIFTI_INTENT_CONNECTIVITY_UNKNOWN', (
'CIFTI_INDEX_TYPE_SCALARS', 'CIFTI_INDEX_TYPE_SCALARS', 'CIFTI_INDEX_TYPE_BRAIN_MODELS',
)),
('.dfansamp.nii', 'NIFTI_INTENT_CONNECTIVITY_UNKNOWN', (
'CIFTI_INDEX_TYPE_SCALARS', 'CIFTI_INDEX_TYPE_SCALARS', 'CIFTI_INDEX_TYPE_BRAIN_MODELS',
)),
), fields=('extension', 'niistring', 'map_types'))


def _value_if_klass(val, klass):
if val is None or isinstance(val, klass):
Expand Down Expand Up @@ -1466,11 +1514,7 @@ def to_file_map(self, file_map=None):
raise ValueError(
f"Dataobj shape {self._dataobj.shape} does not match shape "
f"expected from CIFTI-2 header {self.header.matrix.get_data_shape()}")
# if intent code is not set, default to unknown CIFTI
if header.get_intent()[0] == 'none':
header.set_intent('NIFTI_INTENT_CONNECTIVITY_UNKNOWN')
data = reshape_dataobj(self.dataobj,
(1, 1, 1, 1) + self.dataobj.shape)
data = reshape_dataobj(self.dataobj, (1, 1, 1, 1) + self.dataobj.shape)
# If qform not set, reset pixdim values so Nifti2 does not complain
if header['qform_code'] == 0:
header['pixdim'][:4] = 1
Expand Down Expand Up @@ -1501,14 +1545,59 @@ def update_headers(self):
>>> img.shape == (2, 3, 4)
True
"""
self._nifti_header.set_data_shape((1, 1, 1, 1) + self._dataobj.shape)
header = self._nifti_header
header.set_data_shape((1, 1, 1, 1) + self._dataobj.shape)
# if intent code is not set, default to unknown
if header.get_intent()[0] == 'none':
header.set_intent('NIFTI_INTENT_CONNECTIVITY_UNKNOWN')

def get_data_dtype(self):
return self._nifti_header.get_data_dtype()

def set_data_dtype(self, dtype):
self._nifti_header.set_data_dtype(dtype)

def to_filename(self, filename, validate=True):
"""
Ensures NIfTI header intent code is set prior to saving.

Parameters
----------
validate : boolean, optional
If ``True``, infer and validate CIFTI type based on MatrixIndicesMap values.
This includes the setting of the relevant intent code within the NIfTI header.
If validation fails, a UserWarning is issued and saving continues.
"""
if validate:
# Determine CIFTI type via index maps
from .parse_cifti2 import intent_codes

matrix = self.header.matrix
map_types = tuple(
matrix.get_index_map(idx).indices_map_to_data_type for idx
in sorted(matrix.mapped_indices)
)
try:
expected_intent = CIFTI_CODES.niistring[map_types]
expected_ext = CIFTI_CODES.extension[map_types]
except KeyError: # unknown
expected_intent = "NIFTI_INTENT_CONNECTIVITY_UNKNOWN"
expected_ext = None
warn(
"No information found for matrix containing the following index maps:"
f"{map_types}, defaulting to unknown."
)

orig_intent = self._nifti_header.get_intent()[0]
if expected_intent != intent_codes.niistring[orig_intent]:
warn(
f"Expected NIfTI intent: {expected_intent} has been automatically set."
)
self._nifti_header.set_intent(expected_intent)
if expected_ext is not None and not filename.endswith(expected_ext):
warn(f"Filename does not end with expected extension: {expected_ext}")
super().to_filename(filename)


load = Cifti2Image.from_filename
save = Cifti2Image.instance_to_filename
15 changes: 15 additions & 0 deletions nibabel/cifti2/tests/test_cifti2.py
Original file line number Diff line number Diff line change
Expand Up @@ -427,3 +427,18 @@ def make_imaker(self, arr, header=None, ni_header=None):
)
header.matrix.append(mim)
return lambda: self.image_maker(arr.copy(), header, ni_header)

def validate_filenames(self, imaker, params, validate=False):
super().validate_filenames(imaker, params, validate=validate)

def validate_mmap_parameter(self, imaker, params, validate=False):
super().validate_mmap_parameter(imaker, params, validate=validate)

def validate_to_bytes(self, imaker, params, validate=False):
super().validate_to_bytes(imaker, params, validate=validate)

def validate_from_bytes(self, imaker, params, validate=False):
super().validate_from_bytes(imaker, params, validate=validate)

def validate_to_from_bytes(self, imaker, params, validate=False):
super().validate_to_from_bytes(imaker, params, validate=validate)
2 changes: 1 addition & 1 deletion nibabel/cifti2/tests/test_cifti2io_axes.py
Original file line number Diff line number Diff line change
Expand Up @@ -91,7 +91,7 @@ def check_rewrite(arr, axes, extension='.nii'):
custom extension to use
"""
(fd, name) = tempfile.mkstemp(extension)
cifti2.Cifti2Image(arr, header=axes).to_filename(name)
cifti2.Cifti2Image(arr, header=axes).to_filename(name, validate=False)
img = nib.load(name)
arr2 = img.get_fdata()
assert np.allclose(arr, arr2)
Expand Down
4 changes: 2 additions & 2 deletions nibabel/cifti2/tests/test_cifti2io_header.py
Original file line number Diff line number Diff line change
Expand Up @@ -83,7 +83,7 @@ def test_readwritedata():
with InTemporaryDirectory():
for name in datafiles:
img = ci.load(name)
ci.save(img, 'test.nii')
ci.save(img, 'test.nii', validate=False)
img2 = ci.load('test.nii')
assert len(img.header.matrix) == len(img2.header.matrix)
# Order should be preserved in load/save
Expand All @@ -109,7 +109,7 @@ def test_nibabel_readwritedata():
with InTemporaryDirectory():
for name in datafiles:
img = nib.load(name)
nib.save(img, 'test.nii')
nib.save(img, 'test.nii', validate=False)
img2 = nib.load('test.nii')
assert len(img.header.matrix) == len(img2.header.matrix)
# Order should be preserved in load/save
Expand Down
57 changes: 41 additions & 16 deletions nibabel/cifti2/tests/test_new_cifti2.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,12 +7,11 @@
scratch.
"""
import numpy as np

import nibabel as nib
from nibabel import cifti2 as ci
from nibabel.tmpdirs import InTemporaryDirectory

import pytest

from ...testing import (
clear_and_catch_warnings, error_warnings, suppress_warnings, assert_array_equal)

Expand Down Expand Up @@ -237,7 +236,6 @@ def test_dtseries():
hdr = ci.Cifti2Header(matrix)
data = np.random.randn(13, 10)
img = ci.Cifti2Image(data, hdr)
img.nifti_header.set_intent('NIFTI_INTENT_CONNECTIVITY_DENSE_SERIES')

with InTemporaryDirectory():
ci.save(img, 'test.dtseries.nii')
Expand Down Expand Up @@ -281,7 +279,6 @@ def test_dlabel():
hdr = ci.Cifti2Header(matrix)
data = np.random.randn(2, 10)
img = ci.Cifti2Image(data, hdr)
img.nifti_header.set_intent('NIFTI_INTENT_CONNECTIVITY_DENSE_LABELS')

with InTemporaryDirectory():
ci.save(img, 'test.dlabel.nii')
Expand All @@ -301,7 +298,6 @@ def test_dconn():
hdr = ci.Cifti2Header(matrix)
data = np.random.randn(10, 10)
img = ci.Cifti2Image(data, hdr)
img.nifti_header.set_intent('NIFTI_INTENT_CONNECTIVITY_DENSE')

with InTemporaryDirectory():
ci.save(img, 'test.dconn.nii')
Expand All @@ -323,7 +319,6 @@ def test_ptseries():
hdr = ci.Cifti2Header(matrix)
data = np.random.randn(13, 4)
img = ci.Cifti2Image(data, hdr)
img.nifti_header.set_intent('NIFTI_INTENT_CONNECTIVITY_PARCELLATED_SERIES')

with InTemporaryDirectory():
ci.save(img, 'test.ptseries.nii')
Expand All @@ -345,7 +340,6 @@ def test_pscalar():
hdr = ci.Cifti2Header(matrix)
data = np.random.randn(2, 4)
img = ci.Cifti2Image(data, hdr)
img.nifti_header.set_intent('NIFTI_INTENT_CONNECTIVITY_PARCELLATED_SCALAR')

with InTemporaryDirectory():
ci.save(img, 'test.pscalar.nii')
Expand All @@ -367,7 +361,6 @@ def test_pdconn():
hdr = ci.Cifti2Header(matrix)
data = np.random.randn(10, 4)
img = ci.Cifti2Image(data, hdr)
img.nifti_header.set_intent('NIFTI_INTENT_CONNECTIVITY_PARCELLATED_DENSE')

with InTemporaryDirectory():
ci.save(img, 'test.pdconn.nii')
Expand All @@ -389,7 +382,6 @@ def test_dpconn():
hdr = ci.Cifti2Header(matrix)
data = np.random.randn(4, 10)
img = ci.Cifti2Image(data, hdr)
img.nifti_header.set_intent('NIFTI_INTENT_CONNECTIVITY_DENSE_PARCELLATED')

with InTemporaryDirectory():
ci.save(img, 'test.dpconn.nii')
Expand All @@ -413,7 +405,7 @@ def test_plabel():
img = ci.Cifti2Image(data, hdr)

with InTemporaryDirectory():
ci.save(img, 'test.plabel.nii')
ci.save(img, 'test.plabel.nii', validate=False)
img2 = ci.load('test.plabel.nii')
assert img.nifti_header.get_intent()[0] == 'ConnUnknown'
assert isinstance(img2, ci.Cifti2Image)
Expand All @@ -430,7 +422,6 @@ def test_pconn():
hdr = ci.Cifti2Header(matrix)
data = np.random.randn(4, 4)
img = ci.Cifti2Image(data, hdr)
img.nifti_header.set_intent('NIFTI_INTENT_CONNECTIVITY_PARCELLATED')

with InTemporaryDirectory():
ci.save(img, 'test.pconn.nii')
Expand All @@ -453,8 +444,6 @@ def test_pconnseries():
hdr = ci.Cifti2Header(matrix)
data = np.random.randn(4, 4, 13)
img = ci.Cifti2Image(data, hdr)
img.nifti_header.set_intent('NIFTI_INTENT_CONNECTIVITY_PARCELLATED_'
'PARCELLATED_SERIES')

with InTemporaryDirectory():
ci.save(img, 'test.pconnseries.nii')
Expand All @@ -478,8 +467,6 @@ def test_pconnscalar():
hdr = ci.Cifti2Header(matrix)
data = np.random.randn(4, 4, 2)
img = ci.Cifti2Image(data, hdr)
img.nifti_header.set_intent('NIFTI_INTENT_CONNECTIVITY_PARCELLATED_'
'PARCELLATED_SCALAR')

with InTemporaryDirectory():
ci.save(img, 'test.pconnscalar.nii')
Expand Down Expand Up @@ -517,7 +504,45 @@ def test_wrong_shape():
ci.Cifti2Image(data, hdr)
with suppress_warnings():
img = ci.Cifti2Image(data, hdr)

with pytest.raises(ValueError):
img.to_file_map()


def test_cifti_validation():
# flip label / brain_model index maps
geometry_map = create_geometry_map((0, ))
label_map = create_label_map((1, ))
matrix = ci.Cifti2Matrix()
matrix.append(geometry_map)
matrix.append(label_map)
hdr = ci.Cifti2Header(matrix)
data = np.random.randn(10, 2)
img = ci.Cifti2Image(data, hdr)
# flipped index maps will warn
with InTemporaryDirectory(), pytest.warns(UserWarning):
ci.save(img, 'test.dlabel.nii')

label_map = create_label_map((0, ))
geometry_map = create_geometry_map((1, ))
matrix = ci.Cifti2Matrix()
matrix.append(label_map)
matrix.append(geometry_map)
hdr = ci.Cifti2Header(matrix)
data = np.random.randn(2, 10)
img = ci.Cifti2Image(data, hdr)

with InTemporaryDirectory():
ci.save(img, 'test.validate.nii', validate=False)
ci.save(img, 'test.dlabel.nii')

img2 = nib.load('test.dlabel.nii')
img3 = nib.load('test.validate.nii')
assert img2.nifti_header.get_intent()[0] == 'ConnDenseLabel'
assert img3.nifti_header.get_intent()[0] == 'ConnUnknown'
assert isinstance(img2, ci.Cifti2Image)
assert isinstance(img3, ci.Cifti2Image)
assert_array_equal(img2.get_fdata(), data)
check_label_map(img2.header.matrix.get_index_map(0))
check_geometry_map(img2.header.matrix.get_index_map(1))
del img2, img3
6 changes: 3 additions & 3 deletions nibabel/filebasedimages.py
Original file line number Diff line number Diff line change
Expand Up @@ -315,7 +315,7 @@ def filespec_to_file_map(klass, filespec):
def filespec_to_files(klass, filespec):
return klass.filespec_to_file_map(filespec)

def to_filename(self, filename):
def to_filename(self, filename, **kwargs):
""" Write image to files implied by filename string

Parameters
Expand Down Expand Up @@ -381,7 +381,7 @@ def make_file_map(klass, mapping=None):
load = from_filename

@classmethod
def instance_to_filename(klass, img, filename):
def instance_to_filename(klass, img, filename, **kwargs):
""" Save `img` in our own format, to name implied by `filename`

This is a class method
Expand All @@ -394,7 +394,7 @@ def instance_to_filename(klass, img, filename):
Filename, implying name to which to save image.
"""
img = klass.from_image(img)
img.to_filename(filename)
img.to_filename(filename, **kwargs)

@classmethod
def from_image(klass, img):
Expand Down
6 changes: 3 additions & 3 deletions nibabel/loadsave.py
Original file line number Diff line number Diff line change
Expand Up @@ -78,7 +78,7 @@ def guessed_image_type(filename):
raise ImageFileError(f'Cannot work out file type of "{filename}"')


def save(img, filename):
def save(img, filename, **kwargs):
""" Save an image to file adapting format to `filename`

Parameters
Expand All @@ -96,7 +96,7 @@ def save(img, filename):

# Save the type as expected
try:
img.to_filename(filename)
img.to_filename(filename, **kwargs)
except ImageFileError:
pass
else:
Expand Down Expand Up @@ -144,7 +144,7 @@ def save(img, filename):
# Here, we either have a klass or a converted image.
if converted is None:
converted = klass.from_image(img)
converted.to_filename(filename)
converted.to_filename(filename, **kwargs)


@deprecate_with_version('read_img_data deprecated. '
Expand Down
Loading