@@ -867,6 +867,23 @@ def _to_xml_element(self):
867
867
868
868
869
869
class Cifti2Matrix (xml .XmlSerializable , collections .MutableSequence ):
870
+ """ CIFTI2 Matrix object
871
+
872
+ This is a list-like container where the elements are instances of
873
+ :class:`Cifti2MatrixIndicesMap`.
874
+
875
+ * Description: contains child elements that describe the meaning of the
876
+ values in the matrix.
877
+ * Attributes: [NA]
878
+ * Child Elements
879
+ * MetaData (0 .. 1)
880
+ * MatrixIndicesMap (1 .. N)
881
+ * Text Content: [NA]
882
+ * Parent Element: CIFTI
883
+
884
+ For each matrix (data) dimension, exactly one MatrixIndicesMap element must
885
+ list it in the AppliesToMatrixDimension attribute.
886
+ """
870
887
def __init__ (self ):
871
888
self ._mims = []
872
889
self .metadata = None
@@ -909,11 +926,9 @@ def insert(self, index, value):
909
926
self ._mims .insert (index , value )
910
927
911
928
def _to_xml_element (self ):
912
- if (len (self ) == 0 and self .metadata is None ):
913
- raise CIFTI2HeaderError (
914
- 'Matrix element requires either a MatrixIndicesMap or a Metadata element'
915
- )
916
-
929
+ # From the spec: "For each matrix dimension, exactly one
930
+ # MatrixIndicesMap element must list it in the AppliesToMatrixDimension
931
+ # attribute."
917
932
mat = xml .Element ('Matrix' )
918
933
if self .metadata :
919
934
mat .append (self .metadata ._to_xml_element ())
@@ -970,9 +985,9 @@ def __init__(self,
970
985
Parameters
971
986
----------
972
987
dataobj : object
973
- Object containing image data. It should be some object that returns
974
- an array from ``np.asanyarray``. It should have a ``shape``
975
- attribute or property.
988
+ Object containing image data. It should be some object that
989
+ returns an array from ``np.asanyarray``. It should have a
990
+ ``shape`` attribute or property.
976
991
header : Cifti2Header instance
977
992
Header with data for / from XML part of CIFTI2 format.
978
993
nifti_header : None or mapping or NIfTI2 header instance, optional
@@ -985,6 +1000,11 @@ def __init__(self,
985
1000
super (Cifti2Image , self ).__init__ (dataobj , header = header ,
986
1001
extra = extra , file_map = file_map )
987
1002
self ._nifti_header = Nifti2Header .from_header (nifti_header )
1003
+ # if NIfTI header not specified, get data type from input array
1004
+ if nifti_header is None :
1005
+ if hasattr (dataobj , 'dtype' ):
1006
+ self ._nifti_header .set_data_dtype (dataobj .dtype )
1007
+ self .update_headers ()
988
1008
989
1009
@property
990
1010
def nifti_header (self ):
@@ -1055,6 +1075,7 @@ def to_file_map(self, file_map=None):
1055
1075
None
1056
1076
"""
1057
1077
from .parse_cifti2 import Cifti2Extension
1078
+ self .update_headers ()
1058
1079
header = self ._nifti_header
1059
1080
extension = Cifti2Extension (content = self .header .to_xml ())
1060
1081
header .extensions .append (extension )
@@ -1066,6 +1087,26 @@ def to_file_map(self, file_map=None):
1066
1087
img = Nifti2Image (data , None , header )
1067
1088
img .to_file_map (file_map or self .file_map )
1068
1089
1090
+ def update_headers (self ):
1091
+ ''' Harmonize CIFTI2 and NIfTI headers with image data
1092
+
1093
+ >>> import numpy as np
1094
+ >>> data = np.zeros((2,3,4))
1095
+ >>> img = Cifti2Image(data)
1096
+ >>> img.shape == (2, 3, 4)
1097
+ True
1098
+ >>> img.update_headers()
1099
+ >>> img.nifti_header.get_data_shape() == (2, 3, 4)
1100
+ True
1101
+ '''
1102
+ self ._nifti_header .set_data_shape (self ._dataobj .shape )
1103
+
1104
+ def get_data_dtype (self ):
1105
+ return self ._nifti_header .get_data_dtype ()
1106
+
1107
+ def set_data_dtype (self , dtype ):
1108
+ self ._nifti_header .set_data_dtype (dtype )
1109
+
1069
1110
1070
1111
def load (filename ):
1071
1112
""" Load cifti2 from `filename`
0 commit comments