diff --git a/nitransforms/tests/test_nonlinear.py b/nitransforms/tests/test_nonlinear.py index 8510d993..0a208a96 100644 --- a/nitransforms/tests/test_nonlinear.py +++ b/nitransforms/tests/test_nonlinear.py @@ -13,6 +13,9 @@ BSplineFieldTransform, DenseFieldTransform, ) +from nitransforms.tests.utils import get_points + +rng = np.random.default_rng() def test_displacements_init(): @@ -74,24 +77,81 @@ def test_bsplines_references(testdata_path): ) -@pytest.mark.xfail( - reason="Disable while #266 is developed.", - strict=False, -) -def test_bspline(tmp_path, testdata_path): - """ - Cross-check B-Splines and deformation field. +@pytest.mark.parametrize("image_orientation", ["RAS", "LAS", "LPS", "oblique"]) +@pytest.mark.parametrize("ongrid", [True, False]) +def test_densefield_map(get_testdata, image_orientation, ongrid): + """Create a constant displacement field and compare mappings.""" + + nii = get_testdata[image_orientation] + + # Get sampling indices + coords_xyz, points_ijk, grid_xyz, shape, ref_affine, reference, subsample = ( + get_points(nii, ongrid, rng=rng) + ) + + coords_map = grid_xyz.reshape(*shape, 3) + deltas = np.stack( + ( + np.zeros(np.prod(shape), dtype="float32").reshape(shape), + np.linspace(-80, 80, num=np.prod(shape), dtype="float32").reshape(shape), + np.linspace(-50, 50, num=np.prod(shape), dtype="float32").reshape(shape), + ), + axis=-1, + ) + + if ongrid: + atol = 1e-3 if image_orientation == "oblique" or not ongrid else 1e-7 + # Build an identity transform (deltas) + id_xfm_deltas = DenseFieldTransform(reference=reference) + np.testing.assert_array_equal(coords_map, id_xfm_deltas._field) + np.testing.assert_allclose(coords_xyz, id_xfm_deltas.map(coords_xyz)) + + # Build an identity transform (deformation) + id_xfm_field = DenseFieldTransform( + coords_map, is_deltas=False, reference=reference + ) + np.testing.assert_array_equal(coords_map, id_xfm_field._field) + np.testing.assert_allclose(coords_xyz, id_xfm_field.map(coords_xyz), atol=atol) - This test is disabled and will be split into two separate tests. - The current implementation will be moved into test_resampling.py, - since that's what it actually tests. + # Collapse to zero transform (deltas) + zero_xfm_deltas = DenseFieldTransform(-coords_map, reference=reference) + np.testing.assert_array_equal( + np.zeros_like(zero_xfm_deltas._field), zero_xfm_deltas._field + ) + np.testing.assert_allclose( + np.zeros_like(coords_xyz), zero_xfm_deltas.map(coords_xyz), atol=atol + ) - In GH-266, this test will be re-implemented by testing the equivalence - of the B-Spline and deformation field transforms by calling the - transform's `map()` method on points. + # Collapse to zero transform (deformation) + zero_xfm_field = DenseFieldTransform( + np.zeros_like(deltas), is_deltas=False, reference=reference + ) + np.testing.assert_array_equal( + np.zeros_like(zero_xfm_field._field), zero_xfm_field._field + ) + np.testing.assert_allclose( + np.zeros_like(coords_xyz), zero_xfm_field.map(coords_xyz), atol=atol + ) + + # Now let's apply a transform + xfm = DenseFieldTransform(deltas, reference=reference) + np.testing.assert_array_equal(deltas, xfm._deltas) + np.testing.assert_array_equal(coords_map + deltas, xfm._field) - """ - assert True + mapped = xfm.map(coords_xyz) + nit_deltas = mapped - coords_xyz + + if ongrid: + mapped_image = mapped.reshape(*shape, 3) + np.testing.assert_allclose(deltas + coords_map, mapped_image) + np.testing.assert_allclose(deltas, nit_deltas.reshape(*shape, 3), atol=1e-4) + np.testing.assert_allclose(xfm._field, mapped_image) + else: + ongrid_xyz = xfm.map(grid_xyz[subsample]) + assert ( + (np.linalg.norm(ongrid_xyz - mapped, axis=1) > 2).sum() + / ongrid_xyz.shape[0] + ) < 0.5 def test_map_bspline_vs_displacement(tmp_path, testdata_path): diff --git a/nitransforms/tests/utils.py b/nitransforms/tests/utils.py index e653113e..e3e8e4d9 100644 --- a/nitransforms/tests/utils.py +++ b/nitransforms/tests/utils.py @@ -2,8 +2,10 @@ from pathlib import Path import numpy as np +import nibabel as nb -from .. import linear as nbl +from nitransforms import linear as nbl +from nitransforms.base import ImageGrid def assert_affines_by_filename(affine1, affine2): @@ -26,3 +28,26 @@ def assert_affines_by_filename(affine1, affine2): xfm1 = np.loadtxt(str(affine1)) xfm2 = np.loadtxt(str(affine2)) assert np.allclose(xfm1, xfm2, atol=1e-04) + + +def get_points(reference_nii, ongrid, npoints=5000, rng=None): + """Get points in RAS space.""" + if rng is None: + rng = np.random.default_rng() + + # Get sampling indices + shape = reference_nii.shape[:3] + ref_affine = reference_nii.affine.copy() + reference = ImageGrid(nb.Nifti1Image(np.zeros(shape), ref_affine, None)) + grid_ijk = reference.ndindex + grid_xyz = reference.ras(grid_ijk) + + subsample = rng.choice(grid_ijk.shape[0], npoints) + points_ijk = grid_ijk.copy() if ongrid else grid_ijk[subsample] + coords_xyz = ( + grid_xyz + if ongrid + else reference.ras(points_ijk) + rng.normal(size=points_ijk.shape) + ) + + return coords_xyz, points_ijk, grid_xyz, shape, ref_affine, reference, subsample