@@ -317,6 +317,11 @@ def __init__(self, transforms, reference=None):
317
317
)
318
318
self ._inverse = np .linalg .inv (self ._matrix )
319
319
320
+ def __iter__ (self ):
321
+ """Enable iterating over the series of transforms."""
322
+ for _m in self .matrix :
323
+ yield Affine (_m , reference = self ._reference )
324
+
320
325
def __getitem__ (self , i ):
321
326
"""Enable indexed access to the series of matrices."""
322
327
return Affine (self .matrix [i , ...], reference = self ._reference )
@@ -458,42 +463,37 @@ def apply(
458
463
# Invert target's (moving) affine once
459
464
ras2vox = ~ Affine (spatialimage .affine )
460
465
461
- if spatialimage .ndim == 4 :
462
- if len (self ) != spatialimage .shape [- 1 ]:
463
- raise ValueError (
464
- "Attempting to apply %d transforms on a file with "
465
- "%d timepoints" % (len (self ), spatialimage .shape [- 1 ])
466
- )
467
-
468
- # Order F ensures individual volumes are contiguous in memory
469
- # Also matches NIfTI, making final save more efficient
470
- resampled = np .zeros (
471
- (xcoords .T .shape [0 ], ) + spatialimage .shape [- 1 :], dtype = output_dtype , order = "F"
466
+ if spatialimage .ndim == 4 and (len (self ) != spatialimage .shape [- 1 ]):
467
+ raise ValueError (
468
+ "Attempting to apply %d transforms on a file with "
469
+ "%d timepoints" % (len (self ), spatialimage .shape [- 1 ])
472
470
)
473
471
474
- for t in range (spatialimage .shape [- 1 ]):
475
- # Map the input coordinates on to timepoint t of the target (moving)
476
- ycoords = Affine (self .matrix [t ]).map (xcoords .T )[..., : _ref .ndim ]
477
-
478
- # Calculate corresponding voxel coordinates
479
- yvoxels = ras2vox .map (ycoords )[..., : _ref .ndim ]
480
-
481
- # Interpolate
482
- resampled [..., t ] = ndi .map_coordinates (
483
- spatialimage .dataobj [..., t ].astype (input_dtype , copy = False ),
484
- yvoxels .T ,
485
- output = output_dtype ,
486
- order = order ,
487
- mode = mode ,
488
- cval = cval ,
489
- prefilter = prefilter ,
490
- )
491
- elif spatialimage .ndim in (2 , 3 ):
492
- ycoords = self .map (xcoords .T )[..., : _ref .ndim ]
472
+ # Order F ensures individual volumes are contiguous in memory
473
+ # Also matches NIfTI, making final save more efficient
474
+ resampled = np .zeros (
475
+ (xcoords .T .shape [0 ], len (self )), dtype = output_dtype , order = "F"
476
+ )
477
+
478
+ dataobj = (
479
+ np .asanyarray (spatialimage .dataobj , dtype = input_dtype )
480
+ if spatialimage .ndim in (2 , 3 )
481
+ else None
482
+ )
483
+
484
+ for t , xfm_t in enumerate (self ):
485
+ # Map the input coordinates on to timepoint t of the target (moving)
486
+ ycoords = xfm_t .map (xcoords .T )[..., : _ref .ndim ]
487
+
488
+ # Calculate corresponding voxel coordinates
493
489
yvoxels = ras2vox .map (ycoords )[..., : _ref .ndim ]
494
490
495
- resampled = ndi .map_coordinates (
496
- spatialimage .dataobj .astype (input_dtype , copy = False ),
491
+ # Interpolate
492
+ resampled [..., t ] = ndi .map_coordinates (
493
+ (
494
+ dataobj if dataobj is not None
495
+ else np .asanyarray (spatialimage .dataobj [..., t ], dtype = input_dtype )
496
+ ),
497
497
yvoxels .T ,
498
498
output = output_dtype ,
499
499
order = order ,
@@ -503,9 +503,9 @@ def apply(
503
503
)
504
504
505
505
if isinstance (_ref , ImageGrid ): # If reference is grid, reshape
506
- newdata = resampled .reshape ((len (self ), * _ref . shape ))
506
+ newdata = resampled .reshape (_ref . shape + (len (self ), ))
507
507
moved = spatialimage .__class__ (
508
- np . moveaxis ( newdata , 0 , - 1 ) , _ref .affine , spatialimage .header
508
+ newdata , _ref .affine , spatialimage .header
509
509
)
510
510
moved .header .set_data_dtype (output_dtype )
511
511
return moved
0 commit comments