Skip to content

Commit 71a9ae3

Browse files
authored
Merge pull request #393 from effigies/rf/scipy-bspline-take-3
RF: Use scipy.interpolate.BSpline to construct spline basis
2 parents c90c4ed + 1d715a9 commit 71a9ae3

File tree

3 files changed

+33
-31
lines changed

3 files changed

+33
-31
lines changed

sdcflows/interfaces/bspline.py

Lines changed: 7 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -130,7 +130,7 @@ class BSplineApprox(SimpleInterface):
130130

131131
def _run_interface(self, runtime):
132132
from sklearn import linear_model as lm
133-
from scipy.sparse import vstack as sparse_vstack
133+
from scipy.sparse import hstack as sparse_hstack
134134

135135
# Output name baseline
136136
out_name = fname_presuffix(
@@ -197,9 +197,9 @@ def _run_interface(self, runtime):
197197
data -= center
198198

199199
# Calculate collocation matrix from (possibly resized) image and knot grids
200-
colmat = sparse_vstack(
201-
grid_bspline_weights(fmapnii, grid) for grid in bs_grids
202-
).T.tocsr()
200+
colmat = sparse_hstack(
201+
[grid_bspline_weights(fmapnii, grid) for grid in bs_grids]
202+
).tocsr()
203203

204204
bs_grids_str = ["x".join(str(s) for s in grid.shape) for grid in bs_grids]
205205
bs_grids_str[-1] = f"and {bs_grids_str[-1]}"
@@ -254,9 +254,9 @@ def _run_interface(self, runtime):
254254
mask = np.asanyarray(masknii.dataobj) > 1e-4
255255
else:
256256
mask = np.ones_like(fmapnii.dataobj, dtype=bool)
257-
colmat = sparse_vstack(
258-
grid_bspline_weights(fmapnii, grid) for grid in bs_grids
259-
).T.tocsr()
257+
colmat = sparse_hstack(
258+
[grid_bspline_weights(fmapnii, grid) for grid in bs_grids]
259+
).tocsr()
260260

261261
regressors = colmat[mask.reshape(-1), :]
262262
interp_data = np.zeros_like(data)

sdcflows/tests/test_transform.py

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -348,11 +348,11 @@ def test_grid_bspline_weights():
348348
nb.Nifti1Image(np.zeros(target_shape), target_aff),
349349
nb.Nifti1Image(np.zeros(ctrl_shape), ctrl_aff),
350350
).tocsr()
351-
assert weights.shape == (64, 1000)
351+
assert weights.shape == (1000, 64)
352352
# Empirically determined numbers intended to indicate that something
353353
# significant has changed. If it turns out we've been doing this wrong,
354354
# these numbers will probably change.
355355
assert np.isclose(weights[0, 0], 0.00089725334)
356356
assert np.isclose(weights[-1, -1], 0.18919244)
357-
assert np.isclose(weights.sum(axis=1).max(), 129.3907)
358-
assert np.isclose(weights.sum(axis=1).min(), 0.0052327816)
357+
assert np.isclose(weights.sum(axis=0).max(), 129.3907)
358+
assert np.isclose(weights.sum(axis=0).min(), 0.0052327816)

sdcflows/transform.py

Lines changed: 23 additions & 21 deletions
Original file line numberDiff line numberDiff line change
@@ -56,8 +56,8 @@
5656
import numpy as np
5757
from warnings import warn
5858
from scipy import ndimage as ndi
59-
from scipy.signal import cubic
60-
from scipy.sparse import vstack as sparse_vstack, kron, lil_array
59+
from scipy.interpolate import BSpline
60+
from scipy.sparse import hstack as sparse_hstack, kron, lil_array
6161

6262
import nibabel as nb
6363
import nitransforms as nt
@@ -309,7 +309,6 @@ def fit(
309309
atol=1e-3,
310310
)
311311

312-
weights = []
313312
if approx:
314313
from sdcflows.utils.tools import deoblique_and_zooms
315314

@@ -321,17 +320,15 @@ def fit(
321320
)
322321

323322
# Generate tensor-product B-Spline weights
324-
coeffs_data = []
325-
for level in coeffs:
326-
wmat = grid_bspline_weights(target_reference, level)
327-
weights.append(wmat)
328-
coeffs_data.append(level.get_fdata(dtype="float32").reshape(-1))
323+
colmat = sparse_hstack(
324+
[grid_bspline_weights(projected_reference, level) for level in coeffs]
325+
).tocsr()
326+
coefficients = np.hstack(
327+
[level.get_fdata(dtype="float32").reshape(-1) for level in coeffs]
328+
)
329329

330330
# Reconstruct the fieldmap (in Hz) from coefficients
331-
fmap = np.zeros(projected_reference.shape[:3], dtype="float32")
332-
fmap = (np.squeeze(np.hstack(coeffs_data).T) @ sparse_vstack(weights)).reshape(
333-
fmap.shape
334-
)
331+
fmap = np.reshape(colmat @ coefficients, projected_reference.shape[:3])
335332

336333
# Generate a NIfTI object
337334
hdr = target_reference.header.copy()
@@ -703,7 +700,7 @@ def grid_bspline_weights(target_nii, ctrl_nii, dtype="float32"):
703700
704701
Returns
705702
-------
706-
weights : :obj:`numpy.ndarray` (:math:`K \times N`)
703+
weights : :obj:`numpy.ndarray` (:math:`N \times K`)
707704
A sparse matrix of interpolating weights :math:`\Psi^3(\mathbf{k}, \mathbf{s})`
708705
for the *N* voxels of the target EPI, for each of the total *K* knots.
709706
This sparse matrix can be directly used as design matrix for the fitting
@@ -732,21 +729,26 @@ def grid_bspline_weights(target_nii, ctrl_nii, dtype="float32"):
732729
coords[axis] = np.arange(sample_shape[axis], dtype=dtype)
733730

734731
# Calculate the index component of samples w.r.t. B-Spline knots along current axis
732+
# Size of locations is L
735733
locs = nb.affines.apply_affine(target_to_grid, coords.T)[:, axis]
736-
knots = np.arange(knots_shape[axis], dtype=dtype)
737734

738-
distance = np.abs(locs[np.newaxis, ...] - knots[..., np.newaxis])
735+
# Size of knots is K + 6 so that all locations are fully covered by basis
736+
knots = np.arange(-3, knots_shape[axis] + 3, dtype=dtype)
737+
738+
bspl = BSpline(knots, np.eye(len(knots) - 3 - 1), 3)
739+
740+
# Construct a sparse design matrix (L, K)
741+
distance = np.abs(locs[..., np.newaxis] - knots[np.newaxis, 3:-3])
739742
within_support = distance < 2.0
740-
d_vals, d_idxs = np.unique(distance[within_support], return_inverse=True)
741-
bs_w = cubic(d_vals)
742743

743-
colloc_ax = lil_array((knots_shape[axis], sample_shape[axis]), dtype=dtype)
744-
colloc_ax[within_support] = bs_w[d_idxs]
744+
colloc_ax = lil_array(distance.shape, dtype=dtype)
745+
colloc_ax[within_support] = bspl(locs)[:, 1:-1][within_support]
745746

746-
wd.append(colloc_ax)
747+
# Convert to CSR for efficient multiplication
748+
wd.append(colloc_ax.tocsr())
747749

748750
# Calculate the tensor product of the three design matrices
749-
return kron(kron(wd[0], wd[1]), wd[2]).astype(dtype)
751+
return kron(kron(wd[0], wd[1]), wd[2])
750752

751753

752754
def _move_coeff(in_coeff, fmap_ref, transform, fmap_target=None):

0 commit comments

Comments
 (0)