Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
90 changes: 75 additions & 15 deletions nitransforms/tests/test_nonlinear.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,6 +13,9 @@
BSplineFieldTransform,
DenseFieldTransform,
)
from nitransforms.tests.utils import get_points

rng = np.random.default_rng()


def test_displacements_init():
Expand Down Expand Up @@ -74,24 +77,81 @@ def test_bsplines_references(testdata_path):
)


@pytest.mark.xfail(
reason="Disable while #266 is developed.",
strict=False,
)
def test_bspline(tmp_path, testdata_path):
"""
Cross-check B-Splines and deformation field.
@pytest.mark.parametrize("image_orientation", ["RAS", "LAS", "LPS", "oblique"])
@pytest.mark.parametrize("ongrid", [True, False])
def test_densefield_map(get_testdata, image_orientation, ongrid):
"""Create a constant displacement field and compare mappings."""

nii = get_testdata[image_orientation]

# Get sampling indices
coords_xyz, points_ijk, grid_xyz, shape, ref_affine, reference, subsample = (
get_points(nii, ongrid, rng=rng)
)

coords_map = grid_xyz.reshape(*shape, 3)
deltas = np.stack(
(
np.zeros(np.prod(shape), dtype="float32").reshape(shape),
np.linspace(-80, 80, num=np.prod(shape), dtype="float32").reshape(shape),
np.linspace(-50, 50, num=np.prod(shape), dtype="float32").reshape(shape),
),
axis=-1,
)

if ongrid:
atol = 1e-3 if image_orientation == "oblique" or not ongrid else 1e-7
# Build an identity transform (deltas)
id_xfm_deltas = DenseFieldTransform(reference=reference)
np.testing.assert_array_equal(coords_map, id_xfm_deltas._field)
np.testing.assert_allclose(coords_xyz, id_xfm_deltas.map(coords_xyz))

# Build an identity transform (deformation)
id_xfm_field = DenseFieldTransform(
coords_map, is_deltas=False, reference=reference
)
np.testing.assert_array_equal(coords_map, id_xfm_field._field)
np.testing.assert_allclose(coords_xyz, id_xfm_field.map(coords_xyz), atol=atol)

This test is disabled and will be split into two separate tests.
The current implementation will be moved into test_resampling.py,
since that's what it actually tests.
# Collapse to zero transform (deltas)
zero_xfm_deltas = DenseFieldTransform(-coords_map, reference=reference)
np.testing.assert_array_equal(
np.zeros_like(zero_xfm_deltas._field), zero_xfm_deltas._field
)
np.testing.assert_allclose(
np.zeros_like(coords_xyz), zero_xfm_deltas.map(coords_xyz), atol=atol
)

In GH-266, this test will be re-implemented by testing the equivalence
of the B-Spline and deformation field transforms by calling the
transform's `map()` method on points.
# Collapse to zero transform (deformation)
zero_xfm_field = DenseFieldTransform(
np.zeros_like(deltas), is_deltas=False, reference=reference
)
np.testing.assert_array_equal(
np.zeros_like(zero_xfm_field._field), zero_xfm_field._field
)
np.testing.assert_allclose(
np.zeros_like(coords_xyz), zero_xfm_field.map(coords_xyz), atol=atol
)

# Now let's apply a transform
xfm = DenseFieldTransform(deltas, reference=reference)
np.testing.assert_array_equal(deltas, xfm._deltas)
np.testing.assert_array_equal(coords_map + deltas, xfm._field)

"""
assert True
mapped = xfm.map(coords_xyz)
nit_deltas = mapped - coords_xyz

if ongrid:
mapped_image = mapped.reshape(*shape, 3)
np.testing.assert_allclose(deltas + coords_map, mapped_image)
np.testing.assert_allclose(deltas, nit_deltas.reshape(*shape, 3), atol=1e-4)
np.testing.assert_allclose(xfm._field, mapped_image)
else:
ongrid_xyz = xfm.map(grid_xyz[subsample])
assert (
(np.linalg.norm(ongrid_xyz - mapped, axis=1) > 2).sum()
/ ongrid_xyz.shape[0]
) < 0.5


def test_map_bspline_vs_displacement(tmp_path, testdata_path):
Expand Down
27 changes: 26 additions & 1 deletion nitransforms/tests/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,8 +2,10 @@

from pathlib import Path
import numpy as np
import nibabel as nb

from .. import linear as nbl
from nitransforms import linear as nbl
from nitransforms.base import ImageGrid


def assert_affines_by_filename(affine1, affine2):
Expand All @@ -26,3 +28,26 @@ def assert_affines_by_filename(affine1, affine2):
xfm1 = np.loadtxt(str(affine1))
xfm2 = np.loadtxt(str(affine2))
assert np.allclose(xfm1, xfm2, atol=1e-04)


def get_points(reference_nii, ongrid, npoints=5000, rng=None):
"""Get points in RAS space."""
if rng is None:
rng = np.random.default_rng()

# Get sampling indices
shape = reference_nii.shape[:3]
ref_affine = reference_nii.affine.copy()
reference = ImageGrid(nb.Nifti1Image(np.zeros(shape), ref_affine, None))
grid_ijk = reference.ndindex
grid_xyz = reference.ras(grid_ijk)

subsample = rng.choice(grid_ijk.shape[0], npoints)
points_ijk = grid_ijk.copy() if ongrid else grid_ijk[subsample]
coords_xyz = (
grid_xyz
if ongrid
else reference.ras(points_ijk) + rng.normal(size=points_ijk.shape)
)

return coords_xyz, points_ijk, grid_xyz, shape, ref_affine, reference, subsample
Loading