diff --git a/nitransforms/interp/bspline.py b/nitransforms/interp/bspline.py index d8590bcc..399a506d 100644 --- a/nitransforms/interp/bspline.py +++ b/nitransforms/interp/bspline.py @@ -7,24 +7,28 @@ # ### ### ### ### ### ### ### ### ### ### ### ### ### ### ### ### ### ### ### ## """Interpolate with 3D tensor-product B-Spline basis.""" + import numpy as np import nibabel as nb from scipy.sparse import csr_matrix, kron def _cubic_bspline(d, order=3): - """Evaluate the cubic bspline at distance d from the center.""" + """Evaluate the cubic B-spline at distance ``d`` from the center.""" + if order != 3: raise NotImplementedError - return np.piecewise( - d, - [d < 1.0, d >= 1.0], - [ - lambda d: (4.0 - 6.0 * d ** 2 + 3.0 * d ** 3) / 6.0, - lambda d: (2.0 - d) ** 3 / 6.0, - ], - ) + d = np.abs(d) + out = np.zeros_like(d, dtype="float32") + + mask1 = d < 1.0 + mask2 = (d >= 1.0) & (d < 2.0) + + out[mask1] = (4.0 - 6.0 * d[mask1] ** 2 + 3.0 * d[mask1] ** 3) / 6.0 + out[mask2] = (2.0 - d[mask2]) ** 3 / 6.0 + + return out def grid_bspline_weights(target_grid, ctrl_grid): diff --git a/nitransforms/nonlinear.py b/nitransforms/nonlinear.py index 0869b9af..24e043c2 100644 --- a/nitransforms/nonlinear.py +++ b/nitransforms/nonlinear.py @@ -453,10 +453,10 @@ def map(self, x, inverse=False): >>> xfm = BSplineFieldTransform(test_dir / "someones_bspline_coefficients.nii.gz") >>> xfm.reference = test_dir / "someones_anatomy.nii.gz" >>> xfm.map([-6.5, -36., -19.5]).tolist() # doctest: +ELLIPSIS - [[-6.5, -31.476097418406..., -19.5]] + [[-6.5, -36.475114..., -19.5]] >>> xfm.map([[-6.5, -36., -19.5], [-1., -41.5, -11.25]]).tolist() # doctest: +ELLIPSIS - [[-6.5, -31.4760974184..., -19.5], [-1.0, -3.807267537712..., -11.25]] + [[-6.5, -36.475114..., -19.5], [-1.0, -42.03878957..., -11.25]] """ vfunc = partial( @@ -499,18 +499,23 @@ def _map_xyz(x, reference, knots, coeffs): # Calculate the index coordinates of the point in the B-Spline grid ijk = (knots.inverse @ _as_homogeneous(x).squeeze())[:ndim] - # Determine the window within distance 2.0 (where the B-Spline is nonzero) + # Determine the window within distance 2.0 (where the B-Spline is nonzero). # Probably this will change if the order of the B-Spline is different w_start, w_end = np.ceil(ijk - 2).astype(int), np.floor(ijk + 2).astype(int) - # Generate a grid of indexes corresponding to the window - nonzero_knots = tuple( - [np.arange(start, end + 1) for start, end in zip(w_start, w_end)] - ) + + # Generate a grid of indexes corresponding to the window, clipped to the + # coefficient grid boundaries + nonzero_knots = [] + for start, end, size in zip(w_start, w_end, knots.shape): + start = max(start, 0) + end = min(end, size - 1) + nonzero_knots.append(np.arange(start, end + 1)) nonzero_knots = tuple(np.meshgrid(*nonzero_knots, indexing="ij")) window = np.array(nonzero_knots).reshape((ndim, -1)) - # Calculate the distance of the location w.r.t. to all voxels in window - distance = window.T - ijk + # Calculate the absolute distance of the location w.r.t. all voxels in + # the window. Distances are expressed in knot-grid voxel units + distance = np.abs(window.T - ijk) # Since this is a grid, distance only takes a few float values unique_d, indices = np.unique(distance.reshape(-1), return_inverse=True) # Calculate the B-Spline weight corresponding to the distance. diff --git a/nitransforms/tests/test_nonlinear.py b/nitransforms/tests/test_nonlinear.py index c879704f..936a62f6 100644 --- a/nitransforms/tests/test_nonlinear.py +++ b/nitransforms/tests/test_nonlinear.py @@ -190,3 +190,56 @@ def test_densefield_oob_resampling(is_deltas): assert np.allclose(mapped[0], points[0]) assert np.allclose(mapped[2], points[2]) assert np.allclose(mapped[1], points[1] + 1) + + +def test_bspline_map_gridpoints(): + """BSpline mapping matches dense field on grid points.""" + ref = nb.Nifti1Image(np.zeros((5, 5, 5), dtype="uint8"), np.eye(4)) + coeff = nb.Nifti1Image( + np.random.RandomState(0).rand(9, 9, 9, 3).astype("float32"), np.eye(4) + ) + + bspline = BSplineFieldTransform(coeff, reference=ref) + dense = bspline.to_field() + + # Use a couple of voxel centers from the reference grid + ijk = np.array([[1, 1, 1], [2, 3, 0]]) + pts = nb.affines.apply_affine(ref.affine, ijk) + + assert np.allclose(bspline.map(pts), dense.map(pts), atol=1e-6) + + +def test_bspline_map_manual(): + """BSpline interpolation agrees with manual computation.""" + ref = nb.Nifti1Image(np.zeros((5, 5, 5), dtype="uint8"), np.eye(4)) + rng = np.random.RandomState(0) + coeff = nb.Nifti1Image(rng.rand(9, 9, 9, 3).astype("float32"), np.eye(4)) + + bspline = BSplineFieldTransform(coeff, reference=ref) + + from nitransforms.base import _as_homogeneous + from nitransforms.interp.bspline import _cubic_bspline + + def manual_map(x): + ijk = (bspline._knots.inverse @ _as_homogeneous(x).squeeze())[:3] + w_start = np.floor(ijk).astype(int) - 1 + w_end = w_start + 3 + w_start = np.maximum(w_start, 0) + w_end = np.minimum(w_end, np.array(bspline._coeffs.shape[:3]) - 1) + + window = [] + for i in range(w_start[0], w_end[0] + 1): + for j in range(w_start[1], w_end[1] + 1): + for k in range(w_start[2], w_end[2] + 1): + window.append([i, j, k]) + window = np.array(window) + + dist = np.abs(window - ijk) + weights = _cubic_bspline(dist).prod(1) + coeffs = bspline._coeffs[window[:, 0], window[:, 1], window[:, 2]] + + return x + coeffs.T @ weights + + pts = np.array([[1.2, 1.5, 2.0], [3.3, 1.7, 2.4]]) + expected = np.vstack([manual_map(p) for p in pts]) + assert np.allclose(bspline.map(pts), expected, atol=1e-6)