22
22
DEFAULT_ZOOMS_MM = (40.0 , 40.0 , 20.0 ) # For human adults (mid-frequency), in mm
23
23
DEFAULT_LF_ZOOMS_MM = (100.0 , 100.0 , 40.0 ) # For human adults (low-frequency), in mm
24
24
DEFAULT_HF_ZOOMS_MM = (16.0 , 16.0 , 10.0 ) # For human adults (high-frequency), in mm
25
+ BSPLINE_SUPPORT = 2 - 1.82e-2 # Disallows weights < 1e-6
25
26
26
27
27
28
class _BSplineApproxInputSpec (BaseInterfaceInputSpec ):
@@ -119,16 +120,19 @@ def _run_interface(self, runtime):
119
120
120
121
# Calculate the spatial location of control points
121
122
bs_levels = []
122
- w_l = []
123
123
ncoeff = []
124
+ regressors = None
124
125
for sp in bs_spacing :
125
126
level = bspline_grid (fmapnii , control_zooms_mm = sp )
126
127
bs_levels .append (level )
127
128
ncoeff .append (level .dataobj .size )
128
- w_l .append (bspline_weights (fmap_points , level ))
129
129
130
- # Compose the interpolation matrix
131
- regressors = np .vstack (w_l )
130
+ if regressors is None :
131
+ regressors = bspline_weights (fmap_points , level )
132
+ else :
133
+ regressors = np .vstack (
134
+ (regressors , bspline_weights (fmap_points , level ))
135
+ )
132
136
133
137
# Fit the model
134
138
model = lm .Ridge (alpha = self .inputs .ridge_alpha , fit_intercept = False )
@@ -170,8 +174,11 @@ def _run_interface(self, runtime):
170
174
return runtime
171
175
172
176
bg_indices = np .argwhere (~ mask )
173
- bg_points = apply_affine (fmapnii .affine .astype ("float32" ), bg_indices )
177
+ if not bg_indices .size :
178
+ self ._results ["out_extrapolated" ] = self ._results ["out_field" ]
179
+ return runtime
174
180
181
+ bg_points = apply_affine (fmapnii .affine .astype ("float32" ), bg_indices )
175
182
extrapolators = np .vstack (
176
183
[bspline_weights (bg_points , level ) for level in bs_levels ]
177
184
)
@@ -227,8 +234,6 @@ class Coefficients2Warp(SimpleInterface):
227
234
output_spec = _Coefficients2WarpOutputSpec
228
235
229
236
def _run_interface (self , runtime ):
230
- from ..utils .misc import get_free_mem
231
-
232
237
# Calculate the physical coordinates of target grid
233
238
targetnii = nb .load (self .inputs .in_target )
234
239
targetaff = targetnii .affine
@@ -238,34 +243,16 @@ def _run_interface(self, runtime):
238
243
239
244
weights = []
240
245
coeffs = []
241
- blocksize = LOW_MEM_BLOCK_SIZE if self .inputs .low_mem else len (points )
242
246
243
247
for cname in self .inputs .in_coeff :
244
- cnii = nb .load (cname )
245
- cdata = cnii .get_fdata (dtype = "float32" )
246
- coeffs .append (cdata .reshape (- 1 ))
247
-
248
- # Try to probe the free memory
249
- _free_mem = get_free_mem ()
250
- suggested_blocksize = (
251
- int (np .round ((_free_mem * 0.80 ) / (3 * 32 * cdata .size )))
252
- if _free_mem
253
- else blocksize
248
+ coeff_nii = nb .load (cname )
249
+ wmat = bspline_weights (
250
+ points ,
251
+ coeff_nii ,
252
+ blocksize = LOW_MEM_BLOCK_SIZE if self .inputs .low_mem else None ,
254
253
)
255
- blocksize = min (blocksize , suggested_blocksize )
256
-
257
- idx = 0
258
- block_w = []
259
- while True :
260
- end = idx + blocksize
261
- subsample = points [idx :end , ...]
262
- if subsample .shape [0 ] == 0 :
263
- break
264
-
265
- idx = end
266
- block_w .append (bspline_weights (subsample , cnii ))
267
-
268
- weights .append (np .hstack (block_w ))
254
+ weights .append (wmat )
255
+ coeffs .append (coeff_nii .get_fdata (dtype = "float32" ).reshape (- 1 ))
269
256
270
257
data = np .zeros (targetnii .shape , dtype = "float32" )
271
258
data [allmask == 1 ] = np .squeeze (np .vstack (coeffs ).T ) @ np .vstack (weights )
@@ -411,7 +398,7 @@ def bspline_grid(img, control_zooms_mm=DEFAULT_ZOOMS_MM):
411
398
return img .__class__ (np .zeros (bs_shape , dtype = "float32" ), bs_affine )
412
399
413
400
414
- def bspline_weights (points , ctrl_nii ):
401
+ def bspline_weights (points , ctrl_nii , blocksize = None ):
415
402
r"""
416
403
Calculate the tensor-product cubic B-Spline kernel weights for a list of 3D points.
417
404
@@ -456,29 +443,67 @@ def bspline_weights(points, ctrl_nii):
456
443
step of approximation/extrapolation.
457
444
458
445
"""
446
+ from ..utils .misc import get_free_mem
447
+
448
+ if isinstance (ctrl_nii , (str , bytes , Path )):
449
+ ctrl_nii = nb .load (ctrl_nii )
459
450
ncoeff = np .prod (ctrl_nii .shape [:3 ])
460
451
knots = np .argwhere (np .ones (ctrl_nii .shape [:3 ], dtype = "uint8" ) == 1 )
461
- ctl_points = apply_affine (np .linalg .inv (ctrl_nii .affine ).astype ("float32" ), points )
462
-
463
- weights = np .ones ((ncoeff , points .shape [0 ]), dtype = "float32" )
464
- for i in range (3 ):
465
- d = np .abs (
466
- (knots [:, np .newaxis , i ].astype ("float32" ) - ctl_points [np .newaxis , :, i ])[
467
- weights > 1e-6
468
- ]
469
- )
470
- weights [weights > 1e-6 ] *= np .piecewise (
471
- d ,
472
- [d >= 2.0 , d < 1.0 , (d >= 1.0 ) & (d < 2 )],
473
- [
474
- 0.0 ,
475
- lambda d : (4.0 - 6.0 * d ** 2 + 3.0 * d ** 3 ) / 6.0 ,
476
- lambda d : (2.0 - d ) ** 3 / 6.0 ,
477
- ],
478
- )
452
+ ras2ijk = np .linalg .inv (ctrl_nii .affine ).astype ("float32" )
453
+
454
+ if blocksize is None :
455
+ blocksize = len (points )
456
+
457
+ # Try to probe the free memory
458
+ _free_mem = get_free_mem ()
459
+ suggested_blocksize = (
460
+ int (np .round ((_free_mem * 0.80 ) / (3 * 32 * ncoeff )))
461
+ if _free_mem
462
+ else blocksize
463
+ )
464
+ blocksize = min (blocksize , suggested_blocksize )
465
+ idx = 0
466
+ wmatrix = None
467
+ while True :
468
+ end = idx + blocksize
469
+ subsample = points [idx :end , ...]
470
+ if subsample .shape [0 ] == 0 :
471
+ break
472
+
473
+ ctl_points = apply_affine (ras2ijk , subsample )
474
+ weights = np .ones ((ncoeff , len (subsample )), dtype = "float32" )
475
+ for i in range (3 ):
476
+ nonzeros = weights > 1e-6
477
+ distance = np .squeeze (
478
+ np .abs (
479
+ (
480
+ knots [:, np .newaxis , i ].astype ("float32" )
481
+ - ctl_points [np .newaxis , :, i ]
482
+ )[nonzeros ]
483
+ )
484
+ )
485
+ within_support = distance < BSPLINE_SUPPORT
486
+ d = distance [within_support ]
487
+ distance [~ within_support ] = 0
488
+ distance [within_support ] = np .piecewise (
489
+ d ,
490
+ [d < 1.0 , d >= 1.0 ],
491
+ [
492
+ lambda d : (4.0 - 6.0 * d ** 2 + 3.0 * d ** 3 ) / 6.0 ,
493
+ lambda d : (2.0 - d ) ** 3 / 6.0 ,
494
+ ],
495
+ )
496
+ weights [nonzeros ] *= distance
497
+
498
+ weights [weights < 1e-6 ] = 0.0
499
+
500
+ if idx == 0 :
501
+ wmatrix = weights
502
+ else :
503
+ wmatrix = np .hstack ((wmatrix , weights ))
504
+ idx = end
479
505
480
- weights [weights < 1e-6 ] = 0.0
481
- return weights
506
+ return wmatrix
482
507
483
508
484
509
def _move_coeff (in_coeff , fmap_ref , transform ):
0 commit comments