Skip to content

Commit 7d1294f

Browse files
committed
TEST: Image slicing
1 parent 8749b5c commit 7d1294f

File tree

1 file changed

+107
-0
lines changed

1 file changed

+107
-0
lines changed

nibabel/tests/test_spatialimages.py

Lines changed: 107 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -411,6 +411,113 @@ def test_get_data(self):
411411
assert_false(rt_img.get_data() is out_data)
412412
assert_array_equal(rt_img.get_data(), in_data)
413413

414+
def test_slicer(self):
415+
img_klass = self.image_class
416+
in_data_template = np.arange(240, dtype=np.int16)
417+
base_affine = np.eye(4)
418+
t_axis = None
419+
for dshape in ((4, 5, 6, 2), # Time series
420+
(8, 5, 6)): # Volume
421+
in_data = in_data_template.copy().reshape(dshape)
422+
img = img_klass(in_data, base_affine.copy())
423+
424+
# Detect time axis on first loop (4D image)
425+
if t_axis is None:
426+
t_axis = 3 if img._spatial_dims.start == 0 else 0
427+
428+
assert_true(hasattr(img.slicer, '__getitem__'))
429+
if not img_klass.makeable:
430+
assert_raises(NotImplementedError, img.slicer[:])
431+
continue
432+
# Note spatial zooms are always first 3, even when
433+
spatial_zooms = img.header.get_zooms()[:3]
434+
435+
# Down-sample with [::2, ::2, ::2] along spatial dimensions
436+
sliceobj = [slice(None)] * len(dshape)
437+
sliceobj[img._spatial_dims] = [slice(None, None, 2)] * 3
438+
downsampled_img = img.slicer[tuple(sliceobj)]
439+
assert_array_equal(downsampled_img.header.get_zooms()[:3],
440+
np.array(spatial_zooms) * 2)
441+
442+
# Check newaxis errors
443+
if t_axis == 3:
444+
with assert_raises(IndexError):
445+
img.slicer[None]
446+
elif len(img.shape) == 4:
447+
with assert_raises(IndexError):
448+
img.slicer[None]
449+
# 4D Minc to 3D
450+
assert_equal(img.slicer[0].shape, img.shape[1:])
451+
else:
452+
# 3D Minc to 4D
453+
assert_equal(img.slicer[None].shape, (1,) + img.shape)
454+
# Axes 1 and 2 are always spatial
455+
with assert_raises(IndexError):
456+
img.slicer[:, None]
457+
with assert_raises(IndexError):
458+
img.slicer[:, :, None]
459+
if t_axis == 0:
460+
with assert_raises(IndexError):
461+
img.slicer[:, :, :, None]
462+
elif len(img.shape) == 4:
463+
# Reorder non-spatial axes
464+
assert_equal(img.slicer[:, :, :, None].shape, img.shape[:3] + (1,) + img.shape[3:])
465+
else:
466+
# 3D Analyze/NIfTI/MGH to 4D
467+
assert_equal(img.slicer[:, :, :, None].shape, img.shape + (1,))
468+
if len(img.shape) == 3:
469+
# Slices exceed dimensions
470+
with assert_raises(IndexError):
471+
img.slicer[:, :, :, :, None]
472+
else:
473+
assert_equal(img.slicer[:, :, :, :, None].shape, img.shape + (1,))
474+
475+
# Crop by one voxel in each dimension
476+
if len(img.shape) == 3 or t_axis == 3:
477+
sliced_i = img.slicer[1:]
478+
sliced_j = img.slicer[:, 1:]
479+
sliced_k = img.slicer[:, :, 1:]
480+
sliced_ijk = img.slicer[1:, 1:, 1:]
481+
else:
482+
# 4D Minc
483+
sliced_i = img.slicer[:, 1:]
484+
sliced_j = img.slicer[:, :, 1:]
485+
sliced_k = img.slicer[:, :, :, 1:]
486+
sliced_ijk = img.slicer[:, 1:, 1:, 1:]
487+
488+
# No scaling change
489+
assert_array_equal(sliced_i.affine[:3, :3], img.affine[:3, :3])
490+
assert_array_equal(sliced_j.affine[:3, :3], img.affine[:3, :3])
491+
assert_array_equal(sliced_k.affine[:3, :3], img.affine[:3, :3])
492+
assert_array_equal(sliced_ijk.affine[:3, :3], img.affine[:3, :3])
493+
# Translation
494+
assert_array_equal(sliced_i.affine[:, 3], [1, 0, 0, 1])
495+
assert_array_equal(sliced_j.affine[:, 3], [0, 1, 0, 1])
496+
assert_array_equal(sliced_k.affine[:, 3], [0, 0, 1, 1])
497+
assert_array_equal(sliced_ijk.affine[:, 3], [1, 1, 1, 1])
498+
499+
# No change to affines with upper-bound slices
500+
assert_array_equal(img.slicer[:1, :1, :1].affine, img.affine)
501+
502+
# Check data is consistent with slicing numpy arrays
503+
slice_elems = (None, Ellipsis, 0, 1, -1, slice(None), slice(1),
504+
slice(-1), slice(1, -1))
505+
for n_elems in range(6):
506+
for _ in range(10):
507+
sliceobj = tuple(
508+
np.random.choice(slice_elems, n_elems).tolist())
509+
try:
510+
sliced_img = img.slicer[sliceobj]
511+
except IndexError:
512+
# Only checking valid slices
513+
pass
514+
else:
515+
sliced_data = in_data[sliceobj]
516+
assert_array_equal(sliced_data, sliced_img.get_data())
517+
assert_array_equal(sliced_data, sliced_img.dataobj)
518+
assert_array_equal(sliced_data, img.dataobj[sliceobj])
519+
assert_array_equal(sliced_data, img.get_data()[sliceobj])
520+
414521
def test_api_deprecations(self):
415522

416523
class FakeImage(self.image_class):

0 commit comments

Comments
 (0)