@@ -436,6 +436,7 @@ def apply(
436
436
The data imaged after resampling to reference space.
437
437
438
438
"""
439
+
439
440
if reference is not None and isinstance (reference , (str , Path )):
440
441
reference = _nbload (str (reference ))
441
442
@@ -446,40 +447,53 @@ def apply(
446
447
if isinstance (spatialimage , (str , Path )):
447
448
spatialimage = _nbload (str (spatialimage ))
448
449
449
- data = np .squeeze (np .asanyarray (spatialimage .dataobj ))
450
- output_dtype = output_dtype or data .dtype
450
+ # Avoid opening the data array just yet
451
+ input_dtype = spatialimage .header .get_data_dtype ()
452
+ output_dtype = output_dtype or input_dtype
451
453
452
- ycoords = self .map (_ref .ndcoords .T )
453
- targets = ImageGrid (spatialimage ).index ( # data should be an image
454
- _as_homogeneous (np .vstack (ycoords ), dim = _ref .ndim )
455
- )
454
+ # Prepare physical coordinates of input (grid, points)
455
+ xcoords = _ref .ndcoords .astype ("f4" )
456
456
457
- if data .ndim == 4 :
458
- if len (self ) != data .shape [- 1 ]:
457
+ # Invert target's (moving) affine once
458
+ ras2vox = ~ Affine (spatialimage .affine )
459
+
460
+ if spatialimage .ndim == 4 :
461
+ if len (self ) != spatialimage .shape [- 1 ]:
459
462
raise ValueError (
460
463
"Attempting to apply %d transforms on a file with "
461
- "%d timepoints" % (len (self ), data .shape [- 1 ])
464
+ "%d timepoints" % (len (self ), spatialimage .shape [- 1 ])
462
465
)
463
- targets = targets .reshape ((len (self ), - 1 , targets .shape [- 1 ]))
464
- resampled = np .stack (
465
- [
466
- ndi .map_coordinates (
467
- data [..., t ],
468
- targets [t , ..., : _ref .ndim ].T ,
469
- output = output_dtype ,
470
- order = order ,
471
- mode = mode ,
472
- cval = cval ,
473
- prefilter = prefilter ,
474
- )
475
- for t in range (data .shape [- 1 ])
476
- ],
477
- axis = 0 ,
466
+
467
+ # Order F ensures individual volumes are contiguous in memory
468
+ # Also matches NIfTI, making final save more efficient
469
+ resampled = np .zeros (
470
+ (xcoords .T .shape [0 ], ) + spatialimage .shape [- 1 :], dtype = output_dtype , order = "F"
478
471
)
479
- elif data .ndim in (2 , 3 ):
472
+
473
+ for t in range (spatialimage .shape [- 1 ]):
474
+ # Map the input coordinates on to timepoint t of the target (moving)
475
+ ycoords = Affine (self .matrix [t ]).map (xcoords .T )[..., : _ref .ndim ]
476
+
477
+ # Calculate corresponding voxel coordinates
478
+ yvoxels = ras2vox .map (ycoords )[..., : _ref .ndim ]
479
+
480
+ # Interpolate
481
+ resampled [..., t ] = ndi .map_coordinates (
482
+ spatialimage .dataobj [..., t ].astype (input_dtype , copy = False ),
483
+ yvoxels .T ,
484
+ output = output_dtype ,
485
+ order = order ,
486
+ mode = mode ,
487
+ cval = cval ,
488
+ prefilter = prefilter ,
489
+ )
490
+ elif spatialimage .ndim in (2 , 3 ):
491
+ ycoords = self .map (xcoords .T )[..., : _ref .ndim ]
492
+ yvoxels = ras2vox .map (ycoords )[..., : _ref .ndim ]
493
+
480
494
resampled = ndi .map_coordinates (
481
- data ,
482
- targets [..., : _ref . ndim ] .T ,
495
+ spatialimage . dataobj . astype ( input_dtype , copy = False ) ,
496
+ yvoxels .T ,
483
497
output = output_dtype ,
484
498
order = order ,
485
499
mode = mode ,
0 commit comments