6
6
import nibabel as nb
7
7
from nibabel .affines import apply_affine
8
8
9
+ from nipype import logging
9
10
from nipype .utils .filemanip import fname_presuffix
10
11
from nipype .interfaces .base import (
11
12
BaseInterfaceInputSpec ,
22
23
DEFAULT_ZOOMS_MM = (40.0 , 40.0 , 20.0 ) # For human adults (mid-frequency), in mm
23
24
DEFAULT_LF_ZOOMS_MM = (100.0 , 100.0 , 40.0 ) # For human adults (low-frequency), in mm
24
25
DEFAULT_HF_ZOOMS_MM = (16.0 , 16.0 , 10.0 ) # For human adults (high-frequency), in mm
26
+ BSPLINE_SUPPORT = 2 - 1.82e-3 # Disallows weights < 1e-9
27
+ LOGGER = logging .getLogger ("nipype.interface" )
25
28
26
29
27
30
class _BSplineApproxInputSpec (BaseInterfaceInputSpec ):
@@ -96,6 +99,7 @@ class BSplineApprox(SimpleInterface):
96
99
97
100
def _run_interface (self , runtime ):
98
101
from sklearn import linear_model as lm
102
+ from scipy .sparse import vstack as sparse_vstack
99
103
100
104
# Load in the fieldmap
101
105
fmapnii = nb .load (self .inputs .in_data )
@@ -119,16 +123,18 @@ def _run_interface(self, runtime):
119
123
120
124
# Calculate the spatial location of control points
121
125
bs_levels = []
122
- w_l = []
123
126
ncoeff = []
127
+ regressors = None
124
128
for sp in bs_spacing :
125
129
level = bspline_grid (fmapnii , control_zooms_mm = sp )
126
130
bs_levels .append (level )
127
131
ncoeff .append (level .dataobj .size )
128
- w_l .append (bspline_weights (fmap_points , level ))
129
132
130
- # Compose the interpolation matrix
131
- regressors = np .vstack (w_l )
133
+ regressors = (
134
+ bspline_weights (fmap_points , level )
135
+ if regressors is None
136
+ else sparse_vstack ((regressors , bspline_weights (fmap_points , level )))
137
+ )
132
138
133
139
# Fit the model
134
140
model = lm .Ridge (alpha = self .inputs .ridge_alpha , fit_intercept = False )
@@ -170,9 +176,12 @@ def _run_interface(self, runtime):
170
176
return runtime
171
177
172
178
bg_indices = np .argwhere (~ mask )
173
- bg_points = apply_affine (fmapnii .affine .astype ("float32" ), bg_indices )
179
+ if not bg_indices .size :
180
+ self ._results ["out_extrapolated" ] = self ._results ["out_field" ]
181
+ return runtime
174
182
175
- extrapolators = np .vstack (
183
+ bg_points = apply_affine (fmapnii .affine .astype ("float32" ), bg_indices )
184
+ extrapolators = sparse_vstack (
176
185
[bspline_weights (bg_points , level ) for level in bs_levels ]
177
186
)
178
187
interp_data [~ mask ] = np .array (model .coef_ ) @ extrapolators # Extrapolation
@@ -227,7 +236,7 @@ class Coefficients2Warp(SimpleInterface):
227
236
output_spec = _Coefficients2WarpOutputSpec
228
237
229
238
def _run_interface (self , runtime ):
230
- from .. utils . misc import get_free_mem
239
+ from scipy . sparse import vstack as sparse_vstack
231
240
232
241
# Calculate the physical coordinates of target grid
233
242
targetnii = nb .load (self .inputs .in_target )
@@ -238,37 +247,18 @@ def _run_interface(self, runtime):
238
247
239
248
weights = []
240
249
coeffs = []
241
- blocksize = LOW_MEM_BLOCK_SIZE if self .inputs .low_mem else len (points )
242
250
243
251
for cname in self .inputs .in_coeff :
244
- cnii = nb .load (cname )
245
- cdata = cnii .get_fdata (dtype = "float32" )
246
- coeffs .append (cdata .reshape (- 1 ))
247
-
248
- # Try to probe the free memory
249
- _free_mem = get_free_mem ()
250
- suggested_blocksize = (
251
- int (np .round ((_free_mem * 0.80 ) / (3 * 32 * cdata .size )))
252
- if _free_mem
253
- else blocksize
254
- )
255
- blocksize = min (blocksize , suggested_blocksize )
256
-
257
- idx = 0
258
- block_w = []
259
- while True :
260
- end = idx + blocksize
261
- subsample = points [idx :end , ...]
262
- if subsample .shape [0 ] == 0 :
263
- break
264
-
265
- idx = end
266
- block_w .append (bspline_weights (subsample , cnii ))
267
-
268
- weights .append (np .hstack (block_w ))
252
+ coeff_nii = nb .load (cname )
253
+ wmat = grid_bspline_weights (targetnii , coeff_nii )
254
+ # wmat = bspline_weights(
255
+ # points, coeff_nii, mem_percent=0.1 if self.inputs.low_mem else None,
256
+ # )
257
+ weights .append (wmat )
258
+ coeffs .append (coeff_nii .get_fdata (dtype = "float32" ).reshape (- 1 ))
269
259
270
260
data = np .zeros (targetnii .shape , dtype = "float32" )
271
- data [allmask == 1 ] = np .squeeze (np .vstack (coeffs ).T ) @ np . vstack (weights )
261
+ data [allmask == 1 ] = np .squeeze (np .vstack (coeffs ).T ) @ sparse_vstack (weights )
272
262
273
263
hdr = targetnii .header .copy ()
274
264
hdr .set_data_dtype ("float32" )
@@ -411,7 +401,62 @@ def bspline_grid(img, control_zooms_mm=DEFAULT_ZOOMS_MM):
411
401
return img .__class__ (np .zeros (bs_shape , dtype = "float32" ), bs_affine )
412
402
413
403
414
- def bspline_weights (points , ctrl_nii ):
404
+ def grid_bspline_weights (target_nii , ctrl_nii ):
405
+ """Fast, gridded evaluation."""
406
+ from scipy .sparse import csr_matrix , vstack
407
+
408
+ if isinstance (target_nii , (str , bytes , Path )):
409
+ target_nii = nb .load (target_nii )
410
+ if isinstance (ctrl_nii , (str , bytes , Path )):
411
+ ctrl_nii = nb .load (ctrl_nii )
412
+
413
+ shape = target_nii .shape [:3 ]
414
+ ctrl_sp = ctrl_nii .header .get_zooms ()[:3 ]
415
+ ras2ijk = np .linalg .inv (ctrl_nii .affine )
416
+ origin = apply_affine (ras2ijk , [tuple (target_nii .affine [:3 , 3 ])])[0 ]
417
+
418
+ wd = []
419
+ for i , (o , n , sp ) in enumerate (
420
+ zip (origin , shape , target_nii .header .get_zooms ()[:3 ])
421
+ ):
422
+ locations = np .arange (0 , n , dtype = "float32" ) * sp / ctrl_sp [i ] + o
423
+ knots = np .arange (0 , ctrl_nii .shape [i ], dtype = "float32" )
424
+ distance = (locations [np .newaxis , ...] - knots [..., np .newaxis ]).astype (
425
+ "float32"
426
+ )
427
+ weights = np .zeros_like (distance , dtype = "float32" )
428
+ within_support = np .abs (distance ) < 2.0
429
+ d = np .abs (distance [within_support ])
430
+ weights [within_support ] = np .piecewise (
431
+ d ,
432
+ [d < 1.0 , d >= 1.0 ],
433
+ [
434
+ lambda d : (4.0 - 6.0 * d ** 2 + 3.0 * d ** 3 ) / 6.0 ,
435
+ lambda d : (2.0 - d ) ** 3 / 6.0 ,
436
+ ],
437
+ )
438
+ wd .append (weights )
439
+
440
+ ctrl_shape = ctrl_nii .shape [:3 ]
441
+ data_size = np .prod (shape )
442
+ wmat = None
443
+ for i in range (ctrl_shape [0 ]):
444
+ sparse_mat = (
445
+ wd [0 ][i , np .newaxis , np .newaxis , :, np .newaxis , np .newaxis ]
446
+ * wd [1 ][np .newaxis , :, np .newaxis , np .newaxis , :, np .newaxis ]
447
+ * wd [2 ][np .newaxis , np .newaxis , :, np .newaxis , np .newaxis , :]
448
+ ).reshape ((- 1 , data_size ))
449
+ sparse_mat [sparse_mat < 1e-9 ] = 0
450
+
451
+ if wmat is None :
452
+ wmat = csr_matrix (sparse_mat )
453
+ else :
454
+ wmat = vstack ((wmat , csr_matrix (sparse_mat )))
455
+
456
+ return wmat
457
+
458
+
459
+ def bspline_weights (points , ctrl_nii , blocksize = None , mem_percent = None ):
415
460
r"""
416
461
Calculate the tensor-product cubic B-Spline kernel weights for a list of 3D points.
417
462
@@ -456,29 +501,74 @@ def bspline_weights(points, ctrl_nii):
456
501
step of approximation/extrapolation.
457
502
458
503
"""
504
+ from scipy .sparse import csc_matrix , hstack
505
+ from ..utils .misc import get_free_mem
506
+
507
+ if isinstance (ctrl_nii , (str , bytes , Path )):
508
+ ctrl_nii = nb .load (ctrl_nii )
459
509
ncoeff = np .prod (ctrl_nii .shape [:3 ])
460
510
knots = np .argwhere (np .ones (ctrl_nii .shape [:3 ], dtype = "uint8" ) == 1 )
461
- ctl_points = apply_affine (np .linalg .inv (ctrl_nii .affine ).astype ("float32" ), points )
462
-
463
- weights = np .ones ((ncoeff , points .shape [0 ]), dtype = "float32" )
464
- for i in range (3 ):
465
- d = np .abs (
466
- (knots [:, np .newaxis , i ].astype ("float32" ) - ctl_points [np .newaxis , :, i ])[
467
- weights > 1e-6
468
- ]
469
- )
470
- weights [weights > 1e-6 ] *= np .piecewise (
471
- d ,
472
- [d >= 2.0 , d < 1.0 , (d >= 1.0 ) & (d < 2 )],
473
- [
474
- 0.0 ,
475
- lambda d : (4.0 - 6.0 * d ** 2 + 3.0 * d ** 3 ) / 6.0 ,
476
- lambda d : (2.0 - d ) ** 3 / 6.0 ,
477
- ],
478
- )
511
+ ras2ijk = np .linalg .inv (ctrl_nii .affine ).astype ("float32" )
512
+
513
+ if blocksize is None :
514
+ blocksize = len (points )
479
515
480
- weights [weights < 1e-6 ] = 0.0
481
- return weights
516
+ # Try to probe the free memory
517
+ _free_mem = get_free_mem ()
518
+ suggested_blocksize = (
519
+ int (np .round ((_free_mem * (mem_percent or 0.9 )) / (3 * 4 * ncoeff )))
520
+ if _free_mem
521
+ else blocksize
522
+ )
523
+ blocksize = min (blocksize , suggested_blocksize )
524
+ LOGGER .debug (
525
+ f"Determined a block size of { blocksize } , for interpolating "
526
+ f"an image of { len (points )} voxels with a grid of { ncoeff } "
527
+ f"coefficients ({ _free_mem / 1024 ** 3 :.2f} GiB free memory)."
528
+ )
529
+
530
+ idx = 0
531
+ wmatrix = None
532
+ while True :
533
+ end = idx + blocksize
534
+ subsample = points [idx :end , ...]
535
+ if subsample .shape [0 ] == 0 :
536
+ break
537
+
538
+ ctl_points = apply_affine (ras2ijk , subsample )
539
+ weights = np .ones ((ncoeff , len (subsample )), dtype = "float32" )
540
+ for i in range (3 ):
541
+ nonzeros = weights > 1e-6
542
+ distance = np .squeeze (
543
+ np .abs (
544
+ (
545
+ knots [:, np .newaxis , i ].astype ("float32" )
546
+ - ctl_points [np .newaxis , :, i ]
547
+ )[nonzeros ]
548
+ )
549
+ )
550
+ within_support = distance < BSPLINE_SUPPORT
551
+ d = distance [within_support ]
552
+ distance [~ within_support ] = 0
553
+ distance [within_support ] = np .piecewise (
554
+ d ,
555
+ [d < 1.0 , d >= 1.0 ],
556
+ [
557
+ lambda d : (4.0 - 6.0 * d ** 2 + 3.0 * d ** 3 ) / 6.0 ,
558
+ lambda d : (2.0 - d ) ** 3 / 6.0 ,
559
+ ],
560
+ )
561
+ weights [nonzeros ] *= distance
562
+
563
+ weights [weights < 1e-6 ] = 0.0
564
+
565
+ wmatrix = (
566
+ csc_matrix (weights )
567
+ if wmatrix is None
568
+ else hstack ((wmatrix , csc_matrix (weights )))
569
+ )
570
+ idx = end
571
+ return wmatrix .tocsr ()
482
572
483
573
484
574
def _move_coeff (in_coeff , fmap_ref , transform ):
0 commit comments