Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
22 changes: 13 additions & 9 deletions nitransforms/interp/bspline.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,24 +7,28 @@
#
### ### ### ### ### ### ### ### ### ### ### ### ### ### ### ### ### ### ### ##
"""Interpolate with 3D tensor-product B-Spline basis."""

import numpy as np
import nibabel as nb
from scipy.sparse import csr_matrix, kron


def _cubic_bspline(d, order=3):
"""Evaluate the cubic bspline at distance d from the center."""
"""Evaluate the cubic B-spline at distance ``d`` from the center."""

if order != 3:
raise NotImplementedError

return np.piecewise(
d,
[d < 1.0, d >= 1.0],
[
lambda d: (4.0 - 6.0 * d ** 2 + 3.0 * d ** 3) / 6.0,
lambda d: (2.0 - d) ** 3 / 6.0,
],
)
d = np.abs(d)
out = np.zeros_like(d, dtype="float32")

mask1 = d < 1.0
mask2 = (d >= 1.0) & (d < 2.0)

out[mask1] = (4.0 - 6.0 * d[mask1] ** 2 + 3.0 * d[mask1] ** 3) / 6.0
out[mask2] = (2.0 - d[mask2]) ** 3 / 6.0

return out


def grid_bspline_weights(target_grid, ctrl_grid):
Expand Down
23 changes: 14 additions & 9 deletions nitransforms/nonlinear.py
Original file line number Diff line number Diff line change
Expand Up @@ -453,10 +453,10 @@ def map(self, x, inverse=False):
>>> xfm = BSplineFieldTransform(test_dir / "someones_bspline_coefficients.nii.gz")
>>> xfm.reference = test_dir / "someones_anatomy.nii.gz"
>>> xfm.map([-6.5, -36., -19.5]).tolist() # doctest: +ELLIPSIS
[[-6.5, -31.476097418406..., -19.5]]
[[-6.5, -36.475114..., -19.5]]

>>> xfm.map([[-6.5, -36., -19.5], [-1., -41.5, -11.25]]).tolist() # doctest: +ELLIPSIS
[[-6.5, -31.4760974184..., -19.5], [-1.0, -3.807267537712..., -11.25]]
[[-6.5, -36.475114..., -19.5], [-1.0, -42.03878957..., -11.25]]

"""
vfunc = partial(
Expand Down Expand Up @@ -499,18 +499,23 @@ def _map_xyz(x, reference, knots, coeffs):
# Calculate the index coordinates of the point in the B-Spline grid
ijk = (knots.inverse @ _as_homogeneous(x).squeeze())[:ndim]

# Determine the window within distance 2.0 (where the B-Spline is nonzero)
# Determine the window within distance 2.0 (where the B-Spline is nonzero).
# Probably this will change if the order of the B-Spline is different
w_start, w_end = np.ceil(ijk - 2).astype(int), np.floor(ijk + 2).astype(int)
# Generate a grid of indexes corresponding to the window
nonzero_knots = tuple(
[np.arange(start, end + 1) for start, end in zip(w_start, w_end)]
)

# Generate a grid of indexes corresponding to the window, clipped to the
# coefficient grid boundaries
nonzero_knots = []
for start, end, size in zip(w_start, w_end, knots.shape):
start = max(start, 0)
end = min(end, size - 1)
nonzero_knots.append(np.arange(start, end + 1))
nonzero_knots = tuple(np.meshgrid(*nonzero_knots, indexing="ij"))
window = np.array(nonzero_knots).reshape((ndim, -1))

# Calculate the distance of the location w.r.t. to all voxels in window
distance = window.T - ijk
# Calculate the absolute distance of the location w.r.t. all voxels in
# the window. Distances are expressed in knot-grid voxel units
distance = np.abs(window.T - ijk)
# Since this is a grid, distance only takes a few float values
unique_d, indices = np.unique(distance.reshape(-1), return_inverse=True)
# Calculate the B-Spline weight corresponding to the distance.
Expand Down
53 changes: 53 additions & 0 deletions nitransforms/tests/test_nonlinear.py
Original file line number Diff line number Diff line change
Expand Up @@ -190,3 +190,56 @@ def test_densefield_oob_resampling(is_deltas):
assert np.allclose(mapped[0], points[0])
assert np.allclose(mapped[2], points[2])
assert np.allclose(mapped[1], points[1] + 1)


def test_bspline_map_gridpoints():
"""BSpline mapping matches dense field on grid points."""
ref = nb.Nifti1Image(np.zeros((5, 5, 5), dtype="uint8"), np.eye(4))
coeff = nb.Nifti1Image(
np.random.RandomState(0).rand(9, 9, 9, 3).astype("float32"), np.eye(4)
)

bspline = BSplineFieldTransform(coeff, reference=ref)
dense = bspline.to_field()

# Use a couple of voxel centers from the reference grid
ijk = np.array([[1, 1, 1], [2, 3, 0]])
pts = nb.affines.apply_affine(ref.affine, ijk)

assert np.allclose(bspline.map(pts), dense.map(pts), atol=1e-6)


def test_bspline_map_manual():
"""BSpline interpolation agrees with manual computation."""
ref = nb.Nifti1Image(np.zeros((5, 5, 5), dtype="uint8"), np.eye(4))
rng = np.random.RandomState(0)
coeff = nb.Nifti1Image(rng.rand(9, 9, 9, 3).astype("float32"), np.eye(4))

bspline = BSplineFieldTransform(coeff, reference=ref)

from nitransforms.base import _as_homogeneous
from nitransforms.interp.bspline import _cubic_bspline

def manual_map(x):
ijk = (bspline._knots.inverse @ _as_homogeneous(x).squeeze())[:3]
w_start = np.floor(ijk).astype(int) - 1
w_end = w_start + 3
w_start = np.maximum(w_start, 0)
w_end = np.minimum(w_end, np.array(bspline._coeffs.shape[:3]) - 1)

window = []
for i in range(w_start[0], w_end[0] + 1):
for j in range(w_start[1], w_end[1] + 1):
for k in range(w_start[2], w_end[2] + 1):
window.append([i, j, k])
window = np.array(window)

dist = np.abs(window - ijk)
weights = _cubic_bspline(dist).prod(1)
coeffs = bspline._coeffs[window[:, 0], window[:, 1], window[:, 2]]

return x + coeffs.T @ weights

pts = np.array([[1.2, 1.5, 2.0], [3.3, 1.7, 2.4]])
expected = np.vstack([manual_map(p) for p in pts])
assert np.allclose(bspline.map(pts), expected, atol=1e-6)
Loading