Skip to content

Commit a92bc6f

Browse files
committed
RF: Calculate b-spline design matrix ourselves
1 parent c05f643 commit a92bc6f

File tree

1 file changed

+14
-14
lines changed

1 file changed

+14
-14
lines changed

sdcflows/transform.py

Lines changed: 14 additions & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -27,8 +27,8 @@
2727
import numpy as np
2828
from warnings import warn
2929
from scipy import ndimage as ndi
30-
from scipy.interpolate import BSpline
31-
from scipy.sparse import vstack as sparse_vstack, kron
30+
from scipy.signal import cubic
31+
from scipy.sparse import vstack as sparse_vstack, kron, csr_matrix, lil_matrix
3232

3333
import nibabel as nb
3434
import nitransforms as nt
@@ -405,18 +405,18 @@ def grid_bspline_weights(target_nii, ctrl_nii, dtype="float32"):
405405
coords[axis] = np.arange(sample_shape[axis], dtype=dtype)
406406

407407
# Calculate the index component of samples w.r.t. B-Spline knots along current axis
408-
x = nb.affines.apply_affine(target_to_grid, coords.T)[:, axis]
409-
pad_left = max(int(-np.rint(x.min())), 0)
410-
pad_right = max(int(np.rint(x.max()) - knots_shape[axis]), 0)
411-
412-
# BSpline.design_matrix requires all x be within -4 and 4 padding
413-
# This padding results from the B-Spline degree (3) plus one
414-
t = np.arange(-4 - pad_left, knots_shape[axis] + 4 + pad_right, dtype=dtype)
415-
416-
# Calculate K x N collocation matrix (discarding extra padding)
417-
colloc_ax = BSpline.design_matrix(x, t, 3)[:, (2 + pad_left):-(2 + pad_right)]
418-
# Design matrix returns K x N and we want N x K
419-
wd.append(colloc_ax.T.tocsr())
408+
locs = nb.affines.apply_affine(target_to_grid, coords.T)[:, axis]
409+
knots = np.arange(knots_shape[axis], dtype=dtype)
410+
411+
distance = np.abs(locs[np.newaxis, ...] - knots[..., np.newaxis])
412+
within_support = distance < 2.0
413+
d_vals, d_idxs = np.unique(distance[within_support], return_inverse=True)
414+
bs_w = cubic(d_vals)
415+
416+
colloc_ax = lil_matrix((knots_shape[axis], sample_shape[axis]), dtype=dtype)
417+
colloc_ax[within_support] = bs_w[d_idxs]
418+
419+
wd.append(csr_matrix(colloc_ax))
420420

421421
# Calculate the tensor product of the three design matrices
422422
return kron(kron(wd[0], wd[1]), wd[2]).astype(dtype)

0 commit comments

Comments
 (0)