Skip to content

Commit c2c8444

Browse files
committed
ENH: Optimize tensor-product B-Spline kernel evaluation
We're still having memory issues when interpolating with TOPUP-generated coefficients.
1 parent 7c81f25 commit c2c8444

File tree

1 file changed

+78
-53
lines changed

1 file changed

+78
-53
lines changed

sdcflows/interfaces/bspline.py

Lines changed: 78 additions & 53 deletions
Original file line numberDiff line numberDiff line change
@@ -22,6 +22,7 @@
2222
DEFAULT_ZOOMS_MM = (40.0, 40.0, 20.0) # For human adults (mid-frequency), in mm
2323
DEFAULT_LF_ZOOMS_MM = (100.0, 100.0, 40.0) # For human adults (low-frequency), in mm
2424
DEFAULT_HF_ZOOMS_MM = (16.0, 16.0, 10.0) # For human adults (high-frequency), in mm
25+
BSPLINE_SUPPORT = 2 - 1.82e-2 # Disallows weights < 1e-6
2526

2627

2728
class _BSplineApproxInputSpec(BaseInterfaceInputSpec):
@@ -119,16 +120,19 @@ def _run_interface(self, runtime):
119120

120121
# Calculate the spatial location of control points
121122
bs_levels = []
122-
w_l = []
123123
ncoeff = []
124+
regressors = None
124125
for sp in bs_spacing:
125126
level = bspline_grid(fmapnii, control_zooms_mm=sp)
126127
bs_levels.append(level)
127128
ncoeff.append(level.dataobj.size)
128-
w_l.append(bspline_weights(fmap_points, level))
129129

130-
# Compose the interpolation matrix
131-
regressors = np.vstack(w_l)
130+
if regressors is None:
131+
regressors = bspline_weights(fmap_points, level)
132+
else:
133+
regressors = np.vstack(
134+
(regressors, bspline_weights(fmap_points, level))
135+
)
132136

133137
# Fit the model
134138
model = lm.Ridge(alpha=self.inputs.ridge_alpha, fit_intercept=False)
@@ -170,8 +174,11 @@ def _run_interface(self, runtime):
170174
return runtime
171175

172176
bg_indices = np.argwhere(~mask)
173-
bg_points = apply_affine(fmapnii.affine.astype("float32"), bg_indices)
177+
if not bg_indices.size:
178+
self._results["out_extrapolated"] = self._results["out_field"]
179+
return runtime
174180

181+
bg_points = apply_affine(fmapnii.affine.astype("float32"), bg_indices)
175182
extrapolators = np.vstack(
176183
[bspline_weights(bg_points, level) for level in bs_levels]
177184
)
@@ -227,8 +234,6 @@ class Coefficients2Warp(SimpleInterface):
227234
output_spec = _Coefficients2WarpOutputSpec
228235

229236
def _run_interface(self, runtime):
230-
from ..utils.misc import get_free_mem
231-
232237
# Calculate the physical coordinates of target grid
233238
targetnii = nb.load(self.inputs.in_target)
234239
targetaff = targetnii.affine
@@ -238,34 +243,16 @@ def _run_interface(self, runtime):
238243

239244
weights = []
240245
coeffs = []
241-
blocksize = LOW_MEM_BLOCK_SIZE if self.inputs.low_mem else len(points)
242246

243247
for cname in self.inputs.in_coeff:
244-
cnii = nb.load(cname)
245-
cdata = cnii.get_fdata(dtype="float32")
246-
coeffs.append(cdata.reshape(-1))
247-
248-
# Try to probe the free memory
249-
_free_mem = get_free_mem()
250-
suggested_blocksize = (
251-
int(np.round((_free_mem * 0.80) / (3 * 32 * cdata.size)))
252-
if _free_mem
253-
else blocksize
248+
coeff_nii = nb.load(cname)
249+
wmat = bspline_weights(
250+
points,
251+
coeff_nii,
252+
blocksize=LOW_MEM_BLOCK_SIZE if self.inputs.low_mem else None,
254253
)
255-
blocksize = min(blocksize, suggested_blocksize)
256-
257-
idx = 0
258-
block_w = []
259-
while True:
260-
end = idx + blocksize
261-
subsample = points[idx:end, ...]
262-
if subsample.shape[0] == 0:
263-
break
264-
265-
idx = end
266-
block_w.append(bspline_weights(subsample, cnii))
267-
268-
weights.append(np.hstack(block_w))
254+
weights.append(wmat)
255+
coeffs.append(coeff_nii.get_fdata(dtype="float32").reshape(-1))
269256

270257
data = np.zeros(targetnii.shape, dtype="float32")
271258
data[allmask == 1] = np.squeeze(np.vstack(coeffs).T) @ np.vstack(weights)
@@ -411,7 +398,7 @@ def bspline_grid(img, control_zooms_mm=DEFAULT_ZOOMS_MM):
411398
return img.__class__(np.zeros(bs_shape, dtype="float32"), bs_affine)
412399

413400

414-
def bspline_weights(points, ctrl_nii):
401+
def bspline_weights(points, ctrl_nii, blocksize=None):
415402
r"""
416403
Calculate the tensor-product cubic B-Spline kernel weights for a list of 3D points.
417404
@@ -456,29 +443,67 @@ def bspline_weights(points, ctrl_nii):
456443
step of approximation/extrapolation.
457444
458445
"""
446+
from ..utils.misc import get_free_mem
447+
448+
if isinstance(ctrl_nii, (str, bytes, Path)):
449+
ctrl_nii = nb.load(ctrl_nii)
459450
ncoeff = np.prod(ctrl_nii.shape[:3])
460451
knots = np.argwhere(np.ones(ctrl_nii.shape[:3], dtype="uint8") == 1)
461-
ctl_points = apply_affine(np.linalg.inv(ctrl_nii.affine).astype("float32"), points)
462-
463-
weights = np.ones((ncoeff, points.shape[0]), dtype="float32")
464-
for i in range(3):
465-
d = np.abs(
466-
(knots[:, np.newaxis, i].astype("float32") - ctl_points[np.newaxis, :, i])[
467-
weights > 1e-6
468-
]
469-
)
470-
weights[weights > 1e-6] *= np.piecewise(
471-
d,
472-
[d >= 2.0, d < 1.0, (d >= 1.0) & (d < 2)],
473-
[
474-
0.0,
475-
lambda d: (4.0 - 6.0 * d ** 2 + 3.0 * d ** 3) / 6.0,
476-
lambda d: (2.0 - d) ** 3 / 6.0,
477-
],
478-
)
452+
ras2ijk = np.linalg.inv(ctrl_nii.affine).astype("float32")
453+
454+
if blocksize is None:
455+
blocksize = len(points)
456+
457+
# Try to probe the free memory
458+
_free_mem = get_free_mem()
459+
suggested_blocksize = (
460+
int(np.round((_free_mem * 0.80) / (3 * 32 * ncoeff)))
461+
if _free_mem
462+
else blocksize
463+
)
464+
blocksize = min(blocksize, suggested_blocksize)
465+
idx = 0
466+
wmatrix = None
467+
while True:
468+
end = idx + blocksize
469+
subsample = points[idx:end, ...]
470+
if subsample.shape[0] == 0:
471+
break
472+
473+
ctl_points = apply_affine(ras2ijk, subsample)
474+
weights = np.ones((ncoeff, len(subsample)), dtype="float32")
475+
for i in range(3):
476+
nonzeros = weights > 1e-6
477+
distance = np.squeeze(
478+
np.abs(
479+
(
480+
knots[:, np.newaxis, i].astype("float32")
481+
- ctl_points[np.newaxis, :, i]
482+
)[nonzeros]
483+
)
484+
)
485+
within_support = distance < BSPLINE_SUPPORT
486+
d = distance[within_support]
487+
distance[~within_support] = 0
488+
distance[within_support] = np.piecewise(
489+
d,
490+
[d < 1.0, d >= 1.0],
491+
[
492+
lambda d: (4.0 - 6.0 * d ** 2 + 3.0 * d ** 3) / 6.0,
493+
lambda d: (2.0 - d) ** 3 / 6.0,
494+
],
495+
)
496+
weights[nonzeros] *= distance
497+
498+
weights[weights < 1e-6] = 0.0
499+
500+
if idx == 0:
501+
wmatrix = weights
502+
else:
503+
wmatrix = np.hstack((wmatrix, weights))
504+
idx = end
479505

480-
weights[weights < 1e-6] = 0.0
481-
return weights
506+
return wmatrix
482507

483508

484509
def _move_coeff(in_coeff, fmap_ref, transform):

0 commit comments

Comments
 (0)