14
14
15
15
from nibabel .loadsave import load as _nbload
16
16
from nibabel .affines import from_matvec
17
+ from nibabel .arrayproxy import get_obj_dtype
17
18
18
19
from nitransforms .base import (
19
20
ImageGrid ,
@@ -216,14 +217,13 @@ def from_filename(cls, filename, fmt=None, reference=None, moving=None):
216
217
is_array = cls != Affine
217
218
errors = []
218
219
for potential_fmt in fmtlist :
219
- if ( potential_fmt == "itk" and Path (filename ).suffix == ".mat" ) :
220
+ if potential_fmt == "itk" and Path (filename ).suffix == ".mat" :
220
221
is_array = False
221
222
cls = Affine
222
223
223
224
try :
224
225
struct = get_linear_factory (
225
- potential_fmt ,
226
- is_array = is_array
226
+ potential_fmt , is_array = is_array
227
227
).from_filename (filename )
228
228
except (TransformFileError , FileNotFoundError ) as err :
229
229
errors .append ((potential_fmt , err ))
@@ -316,6 +316,11 @@ def __init__(self, transforms, reference=None):
316
316
)
317
317
self ._inverse = np .linalg .inv (self ._matrix )
318
318
319
+ def __iter__ (self ):
320
+ """Enable iterating over the series of transforms."""
321
+ for _m in self .matrix :
322
+ yield Affine (_m , reference = self ._reference )
323
+
319
324
def __getitem__ (self , i ):
320
325
"""Enable indexed access to the series of matrices."""
321
326
return Affine (self .matrix [i , ...], reference = self ._reference )
@@ -436,6 +441,7 @@ def apply(
436
441
The data imaged after resampling to reference space.
437
442
438
443
"""
444
+
439
445
if reference is not None and isinstance (reference , (str , Path )):
440
446
reference = _nbload (str (reference ))
441
447
@@ -446,40 +452,49 @@ def apply(
446
452
if isinstance (spatialimage , (str , Path )):
447
453
spatialimage = _nbload (str (spatialimage ))
448
454
449
- data = np .squeeze (np .asanyarray (spatialimage .dataobj ))
450
- output_dtype = output_dtype or data .dtype
455
+ # Avoid opening the data array just yet
456
+ input_dtype = get_obj_dtype (spatialimage .dataobj )
457
+ output_dtype = output_dtype or input_dtype
451
458
452
- ycoords = self .map (_ref .ndcoords .T )
453
- targets = ImageGrid (spatialimage ).index ( # data should be an image
454
- _as_homogeneous (np .vstack (ycoords ), dim = _ref .ndim )
455
- )
459
+ # Prepare physical coordinates of input (grid, points)
460
+ xcoords = _ref .ndcoords .astype ("f4" ).T
456
461
457
- if data .ndim == 4 :
458
- if len (self ) != data .shape [- 1 ]:
459
- raise ValueError (
460
- "Attempting to apply %d transforms on a file with "
461
- "%d timepoints" % (len (self ), data .shape [- 1 ])
462
- )
463
- targets = targets .reshape ((len (self ), - 1 , targets .shape [- 1 ]))
464
- resampled = np .stack (
465
- [
466
- ndi .map_coordinates (
467
- data [..., t ],
468
- targets [t , ..., : _ref .ndim ].T ,
469
- output = output_dtype ,
470
- order = order ,
471
- mode = mode ,
472
- cval = cval ,
473
- prefilter = prefilter ,
474
- )
475
- for t in range (data .shape [- 1 ])
476
- ],
477
- axis = 0 ,
462
+ # Invert target's (moving) affine once
463
+ ras2vox = ~ Affine (spatialimage .affine )
464
+
465
+ if spatialimage .ndim == 4 and (len (self ) != spatialimage .shape [- 1 ]):
466
+ raise ValueError (
467
+ "Attempting to apply %d transforms on a file with "
468
+ "%d timepoints" % (len (self ), spatialimage .shape [- 1 ])
478
469
)
479
- elif data .ndim in (2 , 3 ):
480
- resampled = ndi .map_coordinates (
481
- data ,
482
- targets [..., : _ref .ndim ].T ,
470
+
471
+ # Order F ensures individual volumes are contiguous in memory
472
+ # Also matches NIfTI, making final save more efficient
473
+ resampled = np .zeros (
474
+ (xcoords .shape [0 ], len (self )), dtype = output_dtype , order = "F"
475
+ )
476
+
477
+ dataobj = (
478
+ np .asanyarray (spatialimage .dataobj , dtype = input_dtype )
479
+ if spatialimage .ndim in (2 , 3 )
480
+ else None
481
+ )
482
+
483
+ for t , xfm_t in enumerate (self ):
484
+ # Map the input coordinates on to timepoint t of the target (moving)
485
+ ycoords = xfm_t .map (xcoords )[..., : _ref .ndim ]
486
+
487
+ # Calculate corresponding voxel coordinates
488
+ yvoxels = ras2vox .map (ycoords )[..., : _ref .ndim ]
489
+
490
+ # Interpolate
491
+ resampled [..., t ] = ndi .map_coordinates (
492
+ (
493
+ dataobj
494
+ if dataobj is not None
495
+ else spatialimage .dataobj [..., t ].astype (input_dtype , copy = False )
496
+ ),
497
+ yvoxels .T ,
483
498
output = output_dtype ,
484
499
order = order ,
485
500
mode = mode ,
@@ -488,10 +503,8 @@ def apply(
488
503
)
489
504
490
505
if isinstance (_ref , ImageGrid ): # If reference is grid, reshape
491
- newdata = resampled .reshape ((len (self ), * _ref .shape ))
492
- moved = spatialimage .__class__ (
493
- np .moveaxis (newdata , 0 , - 1 ), _ref .affine , spatialimage .header
494
- )
506
+ newdata = resampled .reshape (_ref .shape + (len (self ),))
507
+ moved = spatialimage .__class__ (newdata , _ref .affine , spatialimage .header )
495
508
moved .header .set_data_dtype (output_dtype )
496
509
return moved
497
510
0 commit comments