Skip to content

Commit 4a33af1

Browse files
matthew-bretteffigies
authored andcommitted
RF+TST: drop CIFTI Matrix check, add tests
CIFTI matrix check was incorrect, metadata never required. However, we do need to have one or more MatrixIndicesMap entries referring to all the axes in the data matrix. Need to think about how to do this check, as the XML generator does not know how many data dimensions we have.
1 parent 9af9f4f commit 4a33af1

File tree

3 files changed

+70
-12
lines changed

3 files changed

+70
-12
lines changed

nibabel/cifti2/cifti2.py

Lines changed: 49 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -867,6 +867,23 @@ def _to_xml_element(self):
867867

868868

869869
class Cifti2Matrix(xml.XmlSerializable, collections.MutableSequence):
870+
""" CIFTI2 Matrix object
871+
872+
This is a list-like container where the elements are instances of
873+
:class:`Cifti2MatrixIndicesMap`.
874+
875+
* Description: contains child elements that describe the meaning of the
876+
values in the matrix.
877+
* Attributes: [NA]
878+
* Child Elements
879+
* MetaData (0 .. 1)
880+
* MatrixIndicesMap (1 .. N)
881+
* Text Content: [NA]
882+
* Parent Element: CIFTI
883+
884+
For each matrix (data) dimension, exactly one MatrixIndicesMap element must
885+
list it in the AppliesToMatrixDimension attribute.
886+
"""
870887
def __init__(self):
871888
self._mims = []
872889
self.metadata = None
@@ -909,11 +926,9 @@ def insert(self, index, value):
909926
self._mims.insert(index, value)
910927

911928
def _to_xml_element(self):
912-
if (len(self) == 0 and self.metadata is None):
913-
raise CIFTI2HeaderError(
914-
'Matrix element requires either a MatrixIndicesMap or a Metadata element'
915-
)
916-
929+
# From the spec: "For each matrix dimension, exactly one
930+
# MatrixIndicesMap element must list it in the AppliesToMatrixDimension
931+
# attribute."
917932
mat = xml.Element('Matrix')
918933
if self.metadata:
919934
mat.append(self.metadata._to_xml_element())
@@ -970,9 +985,9 @@ def __init__(self,
970985
Parameters
971986
----------
972987
dataobj : object
973-
Object containing image data. It should be some object that returns
974-
an array from ``np.asanyarray``. It should have a ``shape``
975-
attribute or property.
988+
Object containing image data. It should be some object that
989+
returns an array from ``np.asanyarray``. It should have a
990+
``shape`` attribute or property.
976991
header : Cifti2Header instance
977992
Header with data for / from XML part of CIFTI2 format.
978993
nifti_header : None or mapping or NIfTI2 header instance, optional
@@ -985,6 +1000,11 @@ def __init__(self,
9851000
super(Cifti2Image, self).__init__(dataobj, header=header,
9861001
extra=extra, file_map=file_map)
9871002
self._nifti_header = Nifti2Header.from_header(nifti_header)
1003+
# if NIfTI header not specified, get data type from input array
1004+
if nifti_header is None:
1005+
if hasattr(dataobj, 'dtype'):
1006+
self._nifti_header.set_data_dtype(dataobj.dtype)
1007+
self.update_headers()
9881008

9891009
@property
9901010
def nifti_header(self):
@@ -1055,6 +1075,7 @@ def to_file_map(self, file_map=None):
10551075
None
10561076
"""
10571077
from .parse_cifti2 import Cifti2Extension
1078+
self.update_headers()
10581079
header = self._nifti_header
10591080
extension = Cifti2Extension(content=self.header.to_xml())
10601081
header.extensions.append(extension)
@@ -1066,6 +1087,26 @@ def to_file_map(self, file_map=None):
10661087
img = Nifti2Image(data, None, header)
10671088
img.to_file_map(file_map or self.file_map)
10681089

1090+
def update_headers(self):
1091+
''' Harmonize CIFTI2 and NIfTI headers with image data
1092+
1093+
>>> import numpy as np
1094+
>>> data = np.zeros((2,3,4))
1095+
>>> img = Cifti2Image(data)
1096+
>>> img.shape == (2, 3, 4)
1097+
True
1098+
>>> img.update_headers()
1099+
>>> img.nifti_header.get_data_shape() == (2, 3, 4)
1100+
True
1101+
'''
1102+
self._nifti_header.set_data_shape(self._dataobj.shape)
1103+
1104+
def get_data_dtype(self):
1105+
return self._nifti_header.get_data_dtype()
1106+
1107+
def set_data_dtype(self, dtype):
1108+
self._nifti_header.set_data_dtype(dtype)
1109+
10691110

10701111
def load(filename):
10711112
""" Load cifti2 from `filename`

nibabel/cifti2/tests/test_cifti2.py

Lines changed: 17 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -6,6 +6,7 @@
66
import numpy as np
77

88
from nibabel import cifti2 as ci
9+
from nibabel.nifti2 import Nifti2Header
910
from nibabel.cifti2.cifti2 import _float_01
1011

1112
from nose.tools import assert_true, assert_equal, assert_raises, assert_is_none
@@ -292,3 +293,19 @@ def test_underscoring():
292293

293294
for camel, underscored in pairs:
294295
assert_equal(ci.cifti2._underscore(camel), underscored)
296+
297+
298+
class TestCifti2ImageAPI(_TDA):
299+
""" Basic validation for Cifti2Image instances
300+
"""
301+
# A callable returning an image from ``image_maker(data, header)``
302+
image_maker = ci.Cifti2Image
303+
# A callable returning a header from ``header_maker()``
304+
header_maker = ci.Cifti2Header
305+
# A callable returning a nifti header
306+
ni_header_maker = Nifti2Header
307+
example_shapes = ((2,), (2, 3), (2, 3, 4))
308+
standard_extension = '.nii'
309+
310+
def make_imaker(self, arr, header=None, ni_header=None):
311+
return lambda: self.image_maker(arr.copy(), header, ni_header)

nibabel/tests/test_filebasedimages.py

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -60,15 +60,15 @@ class TestFBImageAPI(GenericImageAPI):
6060
can_save = True
6161
standard_extension = '.npy'
6262

63+
def make_imaker(self, arr, header=None):
64+
return lambda: self.image_maker(arr, header)
65+
6366
def obj_params(self):
6467
# Create new images
65-
def make_imaker(arr, header=None):
66-
return lambda: self.image_maker(arr, header)
67-
6868
for shape, dtype in product(self.example_shapes, self.example_dtypes):
6969
arr = np.arange(np.prod(shape), dtype=dtype).reshape(shape)
7070
hdr = self.header_maker()
71-
func = make_imaker(arr.copy(), hdr)
71+
func = self.make_imaker(arr.copy(), hdr)
7272
params = dict(
7373
dtype=dtype,
7474
data=arr,

0 commit comments

Comments
 (0)