diff --git a/napari_nibabel/_tests/test_nibabel.py b/napari_nibabel/_tests/test_nibabel.py index 688c724..30ba190 100644 --- a/napari_nibabel/_tests/test_nibabel.py +++ b/napari_nibabel/_tests/test_nibabel.py @@ -1,5 +1,7 @@ +import atexit import os import shutil +import tempfile import nibabel as nib import numpy as np @@ -15,8 +17,11 @@ def test_reader(tmp_path): # write some fake data in NIFTI-1 format my_test_file = str(tmp_path / "myfile.nii") - original_data = np.random.rand(20, 20) - nii = nib.Nifti1Image(original_data, affine=np.eye(4)) + original_data = np.random.rand(20, 20, 1) + + # Set affine to an LPS affine here so internal reorientation will not be + # needed. + nii = nib.Nifti1Image(original_data, affine=np.diag((-1, -1, 1, 1))) nii.to_filename(my_test_file) np.save(my_test_file, original_data) @@ -168,3 +173,38 @@ def test_analyze_hdr_only(): filename = os.path.join(data_path, 'analyze.hdr') with pytest.raises(FileNotFoundError): _test_basic_read(filename) + + +def test_read_filelist(): + filename = os.path.join(data_path, 'example4d.nii.gz') + n_files = 3 + data = _test_basic_read([filename,] * n_files) + assert data.ndim == 5 + assert data.shape[0] == n_files + + +def test_read_filelist_mismatched_shape(): + # cannot stack multiple files when the shapes are different + filename = os.path.join(data_path, 'example_nifti2.nii.gz') + filename2 = os.path.join(data_path, 'example4d.nii.gz') + with pytest.raises(ValueError): + _test_basic_read([filename, filename2]) + + +def test_read_filelist_mismatched_affine(): + # cannot stack multiple files when the shapes are different + tmp_dir = tempfile.mkdtemp() + atexit.register(shutil.rmtree, tmp_dir) + + filename = os.path.join(data_path, 'anatomical.nii') + nii1 = nib.load(filename) + data = nii1.get_fdata() + affine2 = nii1.affine.copy() + affine2[0, 0] *= 2 + affine2[1, 1] *= -1 + nii2 = nib.Nifti1Image(data, affine=affine2, header=nii1.header) + filename2 = os.path.join(tmp_dir, 'anatomical_affine2.nii') + nii2.to_filename(filename2) + + with pytest.raises(ValueError): + _test_basic_read([filename, filename2]) diff --git a/napari_nibabel/nibabel.py b/napari_nibabel/nibabel.py index 10dd78d..f944a0a 100644 --- a/napari_nibabel/nibabel.py +++ b/napari_nibabel/nibabel.py @@ -17,12 +17,60 @@ from napari_plugin_engine import napari_hook_implementation +from nibabel import orientations from nibabel.imageclasses import all_image_classes from nibabel.filename_parser import splitext_addext +valid_volume_exts = {klass.valid_exts for klass in all_image_classes} +valid_volume_exts = set(functools.reduce(operator.add, valid_volume_exts)) + + +def get_transform_ornt(affine, target=('L', 'P', 'S')): + current_ornt = orientations.io_orientation(affine) + target_ornt = orientations.axcodes2ornt(('L', 'P', 'S')) + return orientations.ornt_transform(current_ornt, target_ornt) + + +def adjust_translation(affine, affine_plumb, data_shape): + """Adjust translation vector of affine_plumb. + + The goal is to have affine_plumb result in the same data center + point in world coordinates as the original affine. + + Parameters + ---------- + affine : ndarray + The shape (4, 4) affine matrix read in by nibabel. + affine_plumb: ndarray + The affine after permutation to RAS+ space followed by discarding + of any rotation/shear elements. + data_shape : tuple of int + The shape of the data array + + Returns + ------- + affine_plumb : ndarray + A copy of affine_plumb with the 3 translation elements updated. + """ + data_shape = data_shape[-3:] + if len(data_shape) < 3: + # TODO: prepend or append? + data_shape = data_shape + (1,) * (3 - data.ndim) + + # get center in world coordinates for the original RAS+ affine + center_ijk = (np.array(data_shape) - 1) / 2 + center_world = np.dot(affine[:3, :3], center_ijk) + affine[:3, 3] + + # make a copy to avoid in-place modification of affine_plumb + affine_plumb = affine_plumb.copy() + + # center in world coordinates with the current affine_plumb + center_world_plumb = np.dot(affine_plumb[:3, :3], center_ijk) + + # adjust the translation elements + affine_plumb[:3, 3] = center_world - center_world_plumb + return affine_plumb -all_valid_exts = {klass.valid_exts for klass in all_image_classes} -all_valid_exts = set(functools.reduce(operator.add, all_valid_exts)) @napari_hook_implementation def napari_get_reader(path): @@ -48,7 +96,7 @@ def napari_get_reader(path): froot, ext, addext = splitext_addext(path) # if we know we cannot read the file, we immediately return None. - if not ext.lower() in all_valid_exts: + if not ext.lower() in valid_volume_exts: return None # otherwise we return the *function* that can read ``path``. @@ -82,22 +130,33 @@ def reader_function(path): paths = [path] if isinstance(path, str) else path n_spatial = 3 + # note: we don't squeeze the data below, so 2D data will be 3D with 1 slice if len(paths) > 1: # load all files into a single array objects = [nib.load(_path) for _path in paths] - header = objects[0].header affine = objects[0].affine - if not all([_obj.shape == _obj[0].shape for _obj in objects]): + header = objects[0].header + if not all([_obj.shape == objects[0].shape for _obj in objects]): raise ValueError( "all selected files must contain data of the same shape") - + if not all(np.allclose(affine, _obj.affine) for _obj in objects): + raise ValueError( + "all selected files must share a common affine") + # reorient volumes to the desired orientation + transform_ornt = get_transform_ornt(affine, target=('L', 'P', 'S')) + objects = [_obj.as_reoriented(transform_ornt) for _obj in objects] arrays = [_obj.get_fdata() for _obj in objects] + affine = objects[0].affine + header = objects[0].header # stack arrays into single array data = np.stack(arrays) else: img = nib.load(paths[0]) + # reorient volume to the desired orientation + transform_ornt = get_transform_ornt(img.affine, target=('L', 'P', 'S')) + img = img.as_reoriented(transform_ornt) header = img.header affine = img.affine data = img.get_fdata() # keep this as dataobj or use get_fdata()? @@ -114,43 +173,32 @@ def reader_function(path): if spatial_axis_order != (0, 1, 2): data = data.transpose(spatial_axis_order[:data.ndim]) - try: - # only get zooms for the spatial axes - zooms = np.asarray(header.get_zooms())[:n_spatial] - if np.any(zooms == 0): - raise ValueError("invalid zoom = 0 found in header") - # normalize so values are all >= 1.0 (not strictly necessary) - # zooms = zooms / zooms.min() - zooms = tuple(zooms) - if data.ndim > 3: - zooms = (1.0, ) * (data.ndim - n_spatial) + zooms - except (AttributeError, ValueError): - zooms = (1.0, ) * data.ndim - - apply_translation = False - if apply_translation: - translate = tuple(affine[:n_spatial, 3]) - if data.ndim > 3: - # set translate = 0.0 on non-spatial dimensions - translate = (0.0,) * (data.ndim - n_spatial) + translate + if np.all(affine[:3, :3] == (np.eye(3) * affine[:3, :3])): + # no rotation or shear components + affine_plumb = affine else: - translate = (0.0,) * data.ndim + # Set any remaining non-diagonal elements of the affine to 0 + # (napari currently cannot display with rotate/shear) + affine_plumb = np.diag(np.diag(affine)) + + # Set translation elements of affine_plumb to get the center of the + # data cube in the same position in world coordinates + affine_plumb = adjust_translation(affine, affine_plumb, data.shape) + + # Note: The translate, scale, rotate, shear kwargs correspond to the + # 'data2physical' component of a composite affine transform. + # https://github.com/napari/napari/blob/v0.4.11/napari/layers/base/base.py#L254-L268 #noqa + # However, the affine kwarg corresponds instead to the 'physical2world' + # affine. Here, we will extract the scale and translate components from + # affine_plumb so that we are specifying 'data2physical' to napari. - # optional kwargs for the corresponding viewer.add_* method - # https://napari.org/docs/api/napari.components.html#module-napari.components.add_layers_mixin - # see also: https://napari.org/tutorials/fundamentals/image add_kwargs = dict( metadata=dict(affine=affine, header=header), rgb=False, - scale=zooms, - translate=translate, - # contrast_limits=, + scale=np.diag(affine_plumb[:3, :3]), + translate=affine_plumb[:3, 3], + affine=None, + channel_axis=None, ) - # TODO: potential kwargs to set for viewer.add_image - # contrast_limits kwarg based on info in image header? - # e.g. for NIFTI: nii.header._structarr['cal_min'] - # nii.header._structarr['cal_max'] - - layer_type = "image" # optional, default is "image" - return [(data, add_kwargs, layer_type)] + return [(data, add_kwargs, "image")]