Skip to content

Commit b7ed49c

Browse files
committed
enh: optimize implementation using tensor-product of 1D vectors
This implementation hopes to be more optimal than the previous one. Instead of calculating the tensor-product B-Spline weights of every point w.r.t. every control point, it calculates the weights along each axis and then calculates the tensor-product for the whole grid. In principle, the number of calls to ``np.piecewise`` has to have dramatically dropped. Memory utilization should be also more optimal, as there're only one very short-lived and large array (before it is converted to sparse matrix).
1 parent 33b6290 commit b7ed49c

File tree

1 file changed

+52
-5
lines changed

1 file changed

+52
-5
lines changed

sdcflows/interfaces/bspline.py

Lines changed: 52 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -251,9 +251,10 @@ def _run_interface(self, runtime):
251251

252252
for cname in self.inputs.in_coeff:
253253
coeff_nii = nb.load(cname)
254-
wmat = bspline_weights(
255-
points, coeff_nii, mem_percent=0.1 if self.inputs.low_mem else None,
256-
)
254+
wmat = grid_bspline_weights(targetnii, coeff_nii)
255+
# wmat = bspline_weights(
256+
# points, coeff_nii, mem_percent=0.1 if self.inputs.low_mem else None,
257+
# )
257258
weights.append(wmat)
258259
coeffs.append(coeff_nii.get_fdata(dtype="float32").reshape(-1))
259260

@@ -401,6 +402,53 @@ def bspline_grid(img, control_zooms_mm=DEFAULT_ZOOMS_MM):
401402
return img.__class__(np.zeros(bs_shape, dtype="float32"), bs_affine)
402403

403404

405+
def grid_bspline_weights(target_nii, ctrl_nii):
406+
"""Fast, gridded evaluation."""
407+
from scipy.sparse import csr_matrix
408+
409+
if isinstance(target_nii, (str, bytes, Path)):
410+
target_nii = nb.load(target_nii)
411+
if isinstance(ctrl_nii, (str, bytes, Path)):
412+
ctrl_nii = nb.load(ctrl_nii)
413+
414+
shape = target_nii.shape[:3]
415+
ctrl_sp = ctrl_nii.header.get_zooms()[:3]
416+
ras2ijk = np.linalg.inv(ctrl_nii.affine)
417+
origin = apply_affine(ras2ijk, [tuple(target_nii.affine[:3, 3])])[0]
418+
419+
wd = []
420+
for i, (o, n, sp) in enumerate(
421+
zip(origin, shape, target_nii.header.get_zooms()[:3])
422+
):
423+
locations = np.arange(0, n, dtype="float32") * sp / ctrl_sp[i] + o
424+
knots = np.arange(0, ctrl_nii.shape[i], dtype="float32")
425+
distance = (locations[np.newaxis, ...] - knots[..., np.newaxis]).astype(
426+
"float32"
427+
)
428+
weights = np.zeros_like(distance, dtype="float32")
429+
within_support = np.abs(distance) < BSPLINE_SUPPORT
430+
d = np.abs(distance[within_support])
431+
weights[within_support] = np.piecewise(
432+
d,
433+
[d < 1.0, d >= 1.0],
434+
[
435+
lambda d: (4.0 - 6.0 * d ** 2 + 3.0 * d ** 3) / 6.0,
436+
lambda d: (2.0 - d) ** 3 / 6.0,
437+
],
438+
)
439+
wd.append(weights)
440+
441+
wmat = csr_matrix(
442+
(
443+
wd[0][:, np.newaxis, np.newaxis, :, np.newaxis, np.newaxis]
444+
* wd[1][np.newaxis, :, np.newaxis, np.newaxis, :, np.newaxis]
445+
* wd[2][np.newaxis, np.newaxis, :, np.newaxis, np.newaxis, :]
446+
).reshape((np.prod(ctrl_nii.shape[:3]), np.prod(shape))),
447+
dtype="float32",
448+
)
449+
return wmat
450+
451+
404452
def bspline_weights(points, ctrl_nii, blocksize=None, mem_percent=None):
405453
r"""
406454
Calculate the tensor-product cubic B-Spline kernel weights for a list of 3D points.
@@ -466,7 +514,7 @@ def bspline_weights(points, ctrl_nii, blocksize=None, mem_percent=None):
466514
else blocksize
467515
)
468516
blocksize = min(blocksize, suggested_blocksize)
469-
LOGGER.info(
517+
LOGGER.debug(
470518
f"Determined a block size of {blocksize}, for interpolating "
471519
f"an image of {len(points)} voxels with a grid of {ncoeff} "
472520
f"coefficients ({_free_mem / 1024**3:.2f} GiB free memory)."
@@ -512,7 +560,6 @@ def bspline_weights(points, ctrl_nii, blocksize=None, mem_percent=None):
512560
else:
513561
wmatrix = hstack((wmatrix, csc_matrix(weights)))
514562
idx = end
515-
516563
return wmatrix.tocsr()
517564

518565

0 commit comments

Comments
 (0)