|
27 | 27 | import numpy as np
|
28 | 28 | from warnings import warn
|
29 | 29 | from scipy import ndimage as ndi
|
30 |
| -from scipy.interpolate import BSpline |
31 |
| -from scipy.sparse import vstack as sparse_vstack, kron |
| 30 | +from scipy.signal import cubic |
| 31 | +from scipy.sparse import vstack as sparse_vstack, kron, lil_array |
32 | 32 |
|
33 | 33 | import nibabel as nb
|
34 | 34 | import nitransforms as nt
|
@@ -405,18 +405,18 @@ def grid_bspline_weights(target_nii, ctrl_nii, dtype="float32"):
|
405 | 405 | coords[axis] = np.arange(sample_shape[axis], dtype=dtype)
|
406 | 406 |
|
407 | 407 | # Calculate the index component of samples w.r.t. B-Spline knots along current axis
|
408 |
| - x = nb.affines.apply_affine(target_to_grid, coords.T)[:, axis] |
409 |
| - pad_left = max(int(-np.rint(x.min())), 0) |
410 |
| - pad_right = max(int(np.rint(x.max()) - knots_shape[axis]), 0) |
411 |
| - |
412 |
| - # BSpline.design_matrix requires all x be within -4 and 4 padding |
413 |
| - # This padding results from the B-Spline degree (3) plus one |
414 |
| - t = np.arange(-4 - pad_left, knots_shape[axis] + 4 + pad_right, dtype=dtype) |
415 |
| - |
416 |
| - # Calculate K x N collocation matrix (discarding extra padding) |
417 |
| - colloc_ax = BSpline.design_matrix(x, t, 3)[:, (2 + pad_left):-(2 + pad_right)] |
418 |
| - # Design matrix returns K x N and we want N x K |
419 |
| - wd.append(colloc_ax.T.tocsr()) |
| 408 | + locs = nb.affines.apply_affine(target_to_grid, coords.T)[:, axis] |
| 409 | + knots = np.arange(knots_shape[axis], dtype=dtype) |
| 410 | + |
| 411 | + distance = np.abs(locs[np.newaxis, ...] - knots[..., np.newaxis]) |
| 412 | + within_support = distance < 2.0 |
| 413 | + d_vals, d_idxs = np.unique(distance[within_support], return_inverse=True) |
| 414 | + bs_w = cubic(d_vals) |
| 415 | + |
| 416 | + colloc_ax = lil_array((knots_shape[axis], sample_shape[axis]), dtype=dtype) |
| 417 | + colloc_ax[within_support] = bs_w[d_idxs] |
| 418 | + |
| 419 | + wd.append(colloc_ax) |
420 | 420 |
|
421 | 421 | # Calculate the tensor product of the three design matrices
|
422 | 422 | return kron(kron(wd[0], wd[1]), wd[2]).astype(dtype)
|
|
0 commit comments