Skip to content
Merged
Show file tree
Hide file tree
Changes from 8 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion nibabel/batteryrunners.py
Original file line number Diff line number Diff line change
Expand Up @@ -141,7 +141,7 @@ def check_only(self, obj):
-------
reports : sequence
sequence of report objects reporting on result of running
checks (withou fixes) on `obj`
checks (without fixes) on `obj`
'''
reports = []
for check in self._checks:
Expand Down
49 changes: 46 additions & 3 deletions nibabel/cifti2/cifti2.py
Original file line number Diff line number Diff line change
Expand Up @@ -30,6 +30,7 @@
from ..nifti2 import Nifti2Image, Nifti2Header
from ..arrayproxy import reshape_dataobj
from ..keywordonly import kw_only_meth
from warnings import warn


def _float_01(val):
Expand Down Expand Up @@ -1209,6 +1210,38 @@ def _to_xml_element(self):
mat.append(mim._to_xml_element())
return mat

def get_axis(self, index):
'''
Generates the Cifti2 axis for a given dimension

Parameters
----------
index : int
Dimension for which we want to obtain the mapping.

Returns
-------
axis : :class:`.cifti2_axes.Axis`
'''
from . import cifti2_axes
return cifti2_axes.from_index_mapping(self.get_index_map(index))

def get_data_shape(self):
"""
Returns data shape expected based on the CIFTI-2 header

Any dimensions omitted in the CIFIT-2 header will be given a default size of None.
"""
from . import cifti2_axes
if len(self.mapped_indices) == 0:
return ()
base_shape = [None] * (max(self.mapped_indices) + 1)
for mim in self:
size = len(cifti2_axes.from_index_mapping(mim))
for idx in mim.applies_to_matrix_dimension:
base_shape[idx] = size
return tuple(base_shape)


class Cifti2Header(FileBasedHeader, xml.XmlSerializable):
''' Class for CIFTI-2 header extension '''
Expand Down Expand Up @@ -1279,8 +1312,7 @@ def get_axis(self, index):
-------
axis : :class:`.cifti2_axes.Axis`
'''
from . import cifti2_axes
return cifti2_axes.from_index_mapping(self.matrix.get_index_map(index))
return self.matrix.get_axis(index)

@classmethod
def from_axes(cls, axes):
Expand Down Expand Up @@ -1345,12 +1377,18 @@ def __init__(self,
super(Cifti2Image, self).__init__(dataobj, header=header,
extra=extra, file_map=file_map)
self._nifti_header = Nifti2Header.from_header(nifti_header)

# if NIfTI header not specified, get data type from input array
if nifti_header is None:
if hasattr(dataobj, 'dtype'):
self._nifti_header.set_data_dtype(dataobj.dtype)
self.update_headers()

if self._dataobj.shape != self.header.matrix.get_data_shape():
warn("Dataobj shape {} does not match shape expected from CIFTI-2 header {}".format(
self._dataobj.shape, self.header.matrix.get_data_shape()
))

@property
def nifti_header(self):
return self._nifti_header
Expand Down Expand Up @@ -1426,6 +1464,11 @@ def to_file_map(self, file_map=None):
header = self._nifti_header
extension = Cifti2Extension(content=self.header.to_xml())
header.extensions.append(extension)
if self._dataobj.shape != self.header.matrix.get_data_shape():
raise ValueError(
"Dataobj shape {} does not match shape expected from CIFTI-2 header {}".format(
self._dataobj.shape, self.header.matrix.get_data_shape()
))
# if intent code is not set, default to unknown CIFTI
if header.get_intent()[0] == 'none':
header.set_intent('NIFTI_INTENT_CONNECTIVITY_UNKNOWN')
Expand All @@ -1438,7 +1481,7 @@ def to_file_map(self, file_map=None):
img.to_file_map(file_map or self.file_map)

def update_headers(self):
''' Harmonize CIFTI-2 and NIfTI headers with image data
''' Harmonize NIfTI headers with image data

>>> import numpy as np
>>> data = np.zeros((2,3,4))
Expand Down
6 changes: 6 additions & 0 deletions nibabel/cifti2/tests/test_cifti2.py
Original file line number Diff line number Diff line change
Expand Up @@ -358,4 +358,10 @@ class TestCifti2ImageAPI(_TDA):
standard_extension = '.nii'

def make_imaker(self, arr, header=None, ni_header=None):
for idx, sz in enumerate(arr.shape):
maps = [ci.Cifti2NamedMap(str(value)) for value in range(sz)]
mim = ci.Cifti2MatrixIndicesMap(
(idx, ), 'CIFTI_INDEX_TYPE_SCALARS', maps=maps
)
header.matrix.append(mim)
return lambda: self.image_maker(arr.copy(), header, ni_header)
56 changes: 42 additions & 14 deletions nibabel/cifti2/tests/test_new_cifti2.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,11 +12,13 @@
from nibabel import cifti2 as ci
from nibabel.tmpdirs import InTemporaryDirectory

from nose.tools import assert_true, assert_equal
from nose.tools import assert_true, assert_equal, assert_raises
from nibabel.testing import clear_and_catch_warnings, error_warnings, suppress_warnings

affine = [[-1.5, 0, 0, 90],
[0, 1.5, 0, -85],
[0, 0, 1.5, -71]]
[0, 0, 1.5, -71],
[0, 0, 0, 1.]]

dimensions = (120, 83, 78)

Expand Down Expand Up @@ -234,7 +236,7 @@ def test_dtseries():
matrix.append(series_map)
matrix.append(geometry_map)
hdr = ci.Cifti2Header(matrix)
data = np.random.randn(13, 9)
data = np.random.randn(13, 10)
img = ci.Cifti2Image(data, hdr)
img.nifti_header.set_intent('NIFTI_INTENT_CONNECTIVITY_DENSE_SERIES')

Expand All @@ -257,7 +259,7 @@ def test_dscalar():
matrix.append(scalar_map)
matrix.append(geometry_map)
hdr = ci.Cifti2Header(matrix)
data = np.random.randn(2, 9)
data = np.random.randn(2, 10)
img = ci.Cifti2Image(data, hdr)
img.nifti_header.set_intent('NIFTI_INTENT_CONNECTIVITY_DENSE_SCALARS')

Expand All @@ -279,7 +281,7 @@ def test_dlabel():
matrix.append(label_map)
matrix.append(geometry_map)
hdr = ci.Cifti2Header(matrix)
data = np.random.randn(2, 9)
data = np.random.randn(2, 10)
img = ci.Cifti2Image(data, hdr)
img.nifti_header.set_intent('NIFTI_INTENT_CONNECTIVITY_DENSE_LABELS')

Expand All @@ -299,7 +301,7 @@ def test_dconn():
matrix = ci.Cifti2Matrix()
matrix.append(mapping)
hdr = ci.Cifti2Header(matrix)
data = np.random.randn(9, 9)
data = np.random.randn(10, 10)
img = ci.Cifti2Image(data, hdr)
img.nifti_header.set_intent('NIFTI_INTENT_CONNECTIVITY_DENSE')

Expand All @@ -322,7 +324,7 @@ def test_ptseries():
matrix.append(series_map)
matrix.append(parcel_map)
hdr = ci.Cifti2Header(matrix)
data = np.random.randn(13, 3)
data = np.random.randn(13, 4)
img = ci.Cifti2Image(data, hdr)
img.nifti_header.set_intent('NIFTI_INTENT_CONNECTIVITY_PARCELLATED_SERIES')

Expand All @@ -344,7 +346,7 @@ def test_pscalar():
matrix.append(scalar_map)
matrix.append(parcel_map)
hdr = ci.Cifti2Header(matrix)
data = np.random.randn(2, 3)
data = np.random.randn(2, 4)
img = ci.Cifti2Image(data, hdr)
img.nifti_header.set_intent('NIFTI_INTENT_CONNECTIVITY_PARCELLATED_SCALAR')

Expand All @@ -366,7 +368,7 @@ def test_pdconn():
matrix.append(geometry_map)
matrix.append(parcel_map)
hdr = ci.Cifti2Header(matrix)
data = np.random.randn(2, 3)
data = np.random.randn(10, 4)
img = ci.Cifti2Image(data, hdr)
img.nifti_header.set_intent('NIFTI_INTENT_CONNECTIVITY_PARCELLATED_DENSE')

Expand All @@ -388,7 +390,7 @@ def test_dpconn():
matrix.append(parcel_map)
matrix.append(geometry_map)
hdr = ci.Cifti2Header(matrix)
data = np.random.randn(2, 3)
data = np.random.randn(4, 10)
img = ci.Cifti2Image(data, hdr)
img.nifti_header.set_intent('NIFTI_INTENT_CONNECTIVITY_DENSE_PARCELLATED')

Expand All @@ -410,7 +412,7 @@ def test_plabel():
matrix.append(label_map)
matrix.append(parcel_map)
hdr = ci.Cifti2Header(matrix)
data = np.random.randn(2, 3)
data = np.random.randn(2, 4)
img = ci.Cifti2Image(data, hdr)

with InTemporaryDirectory():
Expand All @@ -429,7 +431,7 @@ def test_pconn():
matrix = ci.Cifti2Matrix()
matrix.append(mapping)
hdr = ci.Cifti2Header(matrix)
data = np.random.randn(3, 3)
data = np.random.randn(4, 4)
img = ci.Cifti2Image(data, hdr)
img.nifti_header.set_intent('NIFTI_INTENT_CONNECTIVITY_PARCELLATED')

Expand All @@ -453,7 +455,7 @@ def test_pconnseries():
matrix.append(parcel_map)
matrix.append(series_map)
hdr = ci.Cifti2Header(matrix)
data = np.random.randn(3, 3, 13)
data = np.random.randn(4, 4, 13)
img = ci.Cifti2Image(data, hdr)
img.nifti_header.set_intent('NIFTI_INTENT_CONNECTIVITY_PARCELLATED_'
'PARCELLATED_SERIES')
Expand All @@ -479,7 +481,7 @@ def test_pconnscalar():
matrix.append(parcel_map)
matrix.append(scalar_map)
hdr = ci.Cifti2Header(matrix)
data = np.random.randn(3, 3, 13)
data = np.random.randn(4, 4, 2)
img = ci.Cifti2Image(data, hdr)
img.nifti_header.set_intent('NIFTI_INTENT_CONNECTIVITY_PARCELLATED_'
'PARCELLATED_SCALAR')
Expand All @@ -496,3 +498,29 @@ def test_pconnscalar():
check_parcel_map(img2.header.matrix.get_index_map(0))
check_scalar_map(img2.header.matrix.get_index_map(2))
del img2


def test_wrong_shape():
scalar_map = create_scalar_map((0, ))
brain_model_map = create_geometry_map((1, ))

matrix = ci.Cifti2Matrix()
matrix.append(scalar_map)
matrix.append(brain_model_map)
hdr = ci.Cifti2Header(matrix)

# correct shape is (2, 10)
for data in (
np.random.randn(1, 11),
np.random.randn(2, 10, 1),
np.random.randn(1, 2, 10),
np.random.randn(3, 10),
np.random.randn(2, 9),
):
with clear_and_catch_warnings():
with error_warnings():
assert_raises(UserWarning, ci.Cifti2Image, data, hdr)
with suppress_warnings():
img = ci.Cifti2Image(data, hdr)
assert_raises(ValueError, img.to_file_map)