5656import numpy as np
5757from warnings import warn
5858from scipy import ndimage as ndi
59- from scipy .signal import cubic
60- from scipy .sparse import vstack as sparse_vstack , kron , lil_array
59+ from scipy .interpolate import BSpline
60+ from scipy .sparse import hstack as sparse_hstack , kron , lil_array
6161
6262import nibabel as nb
6363import nitransforms as nt
@@ -309,7 +309,6 @@ def fit(
309309 atol = 1e-3 ,
310310 )
311311
312- weights = []
313312 if approx :
314313 from sdcflows .utils .tools import deoblique_and_zooms
315314
@@ -321,17 +320,15 @@ def fit(
321320 )
322321
323322 # Generate tensor-product B-Spline weights
324- coeffs_data = []
325- for level in coeffs :
326- wmat = grid_bspline_weights (target_reference , level )
327- weights .append (wmat )
328- coeffs_data .append (level .get_fdata (dtype = "float32" ).reshape (- 1 ))
323+ colmat = sparse_hstack (
324+ [grid_bspline_weights (projected_reference , level ) for level in coeffs ]
325+ ).tocsr ()
326+ coefficients = np .hstack (
327+ [level .get_fdata (dtype = "float32" ).reshape (- 1 ) for level in coeffs ]
328+ )
329329
330330 # Reconstruct the fieldmap (in Hz) from coefficients
331- fmap = np .zeros (projected_reference .shape [:3 ], dtype = "float32" )
332- fmap = (np .squeeze (np .hstack (coeffs_data ).T ) @ sparse_vstack (weights )).reshape (
333- fmap .shape
334- )
331+ fmap = np .reshape (colmat @ coefficients , projected_reference .shape [:3 ])
335332
336333 # Generate a NIfTI object
337334 hdr = target_reference .header .copy ()
@@ -703,7 +700,7 @@ def grid_bspline_weights(target_nii, ctrl_nii, dtype="float32"):
703700
704701 Returns
705702 -------
706- weights : :obj:`numpy.ndarray` (:math:`K \times N `)
703+ weights : :obj:`numpy.ndarray` (:math:`N \times K `)
707704 A sparse matrix of interpolating weights :math:`\Psi^3(\mathbf{k}, \mathbf{s})`
708705 for the *N* voxels of the target EPI, for each of the total *K* knots.
709706 This sparse matrix can be directly used as design matrix for the fitting
@@ -732,21 +729,26 @@ def grid_bspline_weights(target_nii, ctrl_nii, dtype="float32"):
732729 coords [axis ] = np .arange (sample_shape [axis ], dtype = dtype )
733730
734731 # Calculate the index component of samples w.r.t. B-Spline knots along current axis
732+ # Size of locations is L
735733 locs = nb .affines .apply_affine (target_to_grid , coords .T )[:, axis ]
736- knots = np .arange (knots_shape [axis ], dtype = dtype )
737734
738- distance = np .abs (locs [np .newaxis , ...] - knots [..., np .newaxis ])
735+ # Size of knots is K + 6 so that all locations are fully covered by basis
736+ knots = np .arange (- 3 , knots_shape [axis ] + 3 , dtype = dtype )
737+
738+ bspl = BSpline (knots , np .eye (len (knots ) - 3 - 1 ), 3 )
739+
740+ # Construct a sparse design matrix (L, K)
741+ distance = np .abs (locs [..., np .newaxis ] - knots [np .newaxis , 3 :- 3 ])
739742 within_support = distance < 2.0
740- d_vals , d_idxs = np .unique (distance [within_support ], return_inverse = True )
741- bs_w = cubic (d_vals )
742743
743- colloc_ax = lil_array (( knots_shape [ axis ], sample_shape [ axis ]) , dtype = dtype )
744- colloc_ax [within_support ] = bs_w [ d_idxs ]
744+ colloc_ax = lil_array (distance . shape , dtype = dtype )
745+ colloc_ax [within_support ] = bspl ( locs )[:, 1 : - 1 ][ within_support ]
745746
746- wd .append (colloc_ax )
747+ # Convert to CSR for efficient multiplication
748+ wd .append (colloc_ax .tocsr ())
747749
748750 # Calculate the tensor product of the three design matrices
749- return kron (kron (wd [0 ], wd [1 ]), wd [2 ]). astype ( dtype )
751+ return kron (kron (wd [0 ], wd [1 ]), wd [2 ])
750752
751753
752754def _move_coeff (in_coeff , fmap_ref , transform , fmap_target = None ):
0 commit comments