Skip to content

Commit 9bc71b6

Browse files
committed
FIX/TEST: MGHImages must be 3D or 4D
1 parent 31a9a9f commit 9bc71b6

File tree

3 files changed

+28
-8
lines changed

3 files changed

+28
-8
lines changed

nibabel/freesurfer/mghformat.py

Lines changed: 9 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -17,7 +17,7 @@
1717
from ..volumeutils import (array_to_file, array_from_file, Recoder)
1818
from ..spatialimages import HeaderDataError, SpatialImage
1919
from ..fileholders import FileHolder
20-
from ..arrayproxy import ArrayProxy
20+
from ..arrayproxy import ArrayProxy, reshape_dataobj
2121
from ..keywordonly import kw_only_meth
2222
from ..openers import ImageOpener
2323
from ..wrapstruct import LabeledWrapStruct
@@ -390,6 +390,14 @@ class MGHImage(SpatialImage):
390390

391391
ImageArrayProxy = ArrayProxy
392392

393+
def __init__(self, dataobj, affine, header=None,
394+
extra=None, file_map=None):
395+
shape = dataobj.shape
396+
if len(shape) < 3:
397+
dataobj = reshape_dataobj(dataobj, shape + (1,) * (3 - len(shape)))
398+
super(MGHImage, self).__init__(dataobj, affine, header=header,
399+
extra=extra, file_map=file_map)
400+
393401
@classmethod
394402
def filespec_to_file_map(klass, filespec):
395403
""" Check for compressed .mgz format, then .mgh format """

nibabel/spatialimages.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -354,7 +354,7 @@ def __init__(self, dataobj, affine, header=None,
354354
'''
355355
super(SpatialImage, self).__init__(dataobj, header=header, extra=extra,
356356
file_map=file_map)
357-
if not affine is None:
357+
if affine is not None:
358358
# Check that affine is array-like 4,4. Maybe this is too strict at
359359
# this abstract level, but so far I think all image formats we know
360360
# do need 4,4.

nibabel/tests/test_spatialimages.py

Lines changed: 18 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -195,7 +195,7 @@ class DataLike(object):
195195
shape = (3,)
196196

197197
def __array__(self):
198-
return np.arange(3)
198+
return np.arange(3, dtype=np.int16)
199199

200200

201201
class TestSpatialImage(TestCase):
@@ -249,8 +249,11 @@ def test_default_header(self):
249249
def test_data_api(self):
250250
# Test minimal api data object can initialize
251251
img = self.image_class(DataLike(), None)
252-
assert_array_equal(img.get_data(), np.arange(3))
253-
assert_equal(img.shape, (3,))
252+
# Shape may be promoted to higher dimension, but may not reorder or
253+
# change size
254+
assert_array_equal(img.get_data().flatten(), np.arange(3))
255+
assert_equal(img.get_shape()[:1], (3,))
256+
assert_equal(np.prod(img.get_shape()), 3)
254257

255258
def check_dtypes(self, expected, actual):
256259
# Some images will want dtypes to be equal including endianness,
@@ -278,7 +281,10 @@ def test_data_shape(self):
278281
# See https://github.com/nipy/nibabel/issues/58
279282
arr = np.arange(4, dtype=np.int16)
280283
img = img_klass(arr, np.eye(4))
281-
assert_equal(img.shape, (4,))
284+
# Shape may be promoted to higher dimension, but may not reorder or
285+
# change size
286+
assert_equal(img.get_shape()[:1], (4,))
287+
assert_equal(np.prod(img.get_shape()), 4)
282288
img = img_klass(np.zeros((2, 3, 4), dtype=np.float32), np.eye(4))
283289
assert_equal(img.shape, (2, 3, 4))
284290

@@ -290,7 +296,10 @@ def test_str(self):
290296
arr = np.arange(5, dtype=np.int16)
291297
img = img_klass(arr, np.eye(4))
292298
assert_true(len(str(img)) > 0)
293-
assert_equal(img.shape, (5,))
299+
# Shape may be promoted to higher dimension, but may not reorder or
300+
# change size
301+
assert_equal(img.shape[:1], (5,))
302+
assert_equal(np.prod(img.shape), 5)
294303
img = img_klass(np.zeros((2, 3, 4), dtype=np.int16), np.eye(4))
295304
assert_true(len(str(img)) > 0)
296305

@@ -302,7 +311,10 @@ def test_get_shape(self):
302311
# See https://github.com/nipy/nibabel/issues/58
303312
img = img_klass(np.arange(1, dtype=np.int16), np.eye(4))
304313
with suppress_warnings():
305-
assert_equal(img.get_shape(), (1,))
314+
# Shape may be promoted to higher dimension, but may not reorder or
315+
# change size
316+
assert_equal(img.get_shape()[:1], (1,))
317+
assert_equal(np.prod(img.get_shape()), 1)
306318
img = img_klass(np.zeros((2, 3, 4), np.int16), np.eye(4))
307319
assert_equal(img.get_shape(), (2, 3, 4))
308320

0 commit comments

Comments
 (0)