Skip to content

Commit 9772710

Browse files
committed
fix: shape and order of resampled array
1 parent 2c36d08 commit 9772710

File tree

1 file changed

+34
-34
lines changed

1 file changed

+34
-34
lines changed

nitransforms/linear.py

Lines changed: 34 additions & 34 deletions
Original file line numberDiff line numberDiff line change
@@ -317,6 +317,11 @@ def __init__(self, transforms, reference=None):
317317
)
318318
self._inverse = np.linalg.inv(self._matrix)
319319

320+
def __iter__(self):
321+
"""Enable iterating over the series of transforms."""
322+
for _m in self.matrix:
323+
yield Affine(_m, reference=self._reference)
324+
320325
def __getitem__(self, i):
321326
"""Enable indexed access to the series of matrices."""
322327
return Affine(self.matrix[i, ...], reference=self._reference)
@@ -458,42 +463,37 @@ def apply(
458463
# Invert target's (moving) affine once
459464
ras2vox = ~Affine(spatialimage.affine)
460465

461-
if spatialimage.ndim == 4:
462-
if len(self) != spatialimage.shape[-1]:
463-
raise ValueError(
464-
"Attempting to apply %d transforms on a file with "
465-
"%d timepoints" % (len(self), spatialimage.shape[-1])
466-
)
467-
468-
# Order F ensures individual volumes are contiguous in memory
469-
# Also matches NIfTI, making final save more efficient
470-
resampled = np.zeros(
471-
(xcoords.T.shape[0], ) + spatialimage.shape[-1:], dtype=output_dtype, order="F"
466+
if spatialimage.ndim == 4 and (len(self) != spatialimage.shape[-1]):
467+
raise ValueError(
468+
"Attempting to apply %d transforms on a file with "
469+
"%d timepoints" % (len(self), spatialimage.shape[-1])
472470
)
473471

474-
for t in range(spatialimage.shape[-1]):
475-
# Map the input coordinates on to timepoint t of the target (moving)
476-
ycoords = Affine(self.matrix[t]).map(xcoords.T)[..., : _ref.ndim]
477-
478-
# Calculate corresponding voxel coordinates
479-
yvoxels = ras2vox.map(ycoords)[..., : _ref.ndim]
480-
481-
# Interpolate
482-
resampled[..., t] = ndi.map_coordinates(
483-
spatialimage.dataobj[..., t].astype(input_dtype, copy=False),
484-
yvoxels.T,
485-
output=output_dtype,
486-
order=order,
487-
mode=mode,
488-
cval=cval,
489-
prefilter=prefilter,
490-
)
491-
elif spatialimage.ndim in (2, 3):
492-
ycoords = self.map(xcoords.T)[..., : _ref.ndim]
472+
# Order F ensures individual volumes are contiguous in memory
473+
# Also matches NIfTI, making final save more efficient
474+
resampled = np.zeros(
475+
(xcoords.T.shape[0], len(self)), dtype=output_dtype, order="F"
476+
)
477+
478+
dataobj = (
479+
np.asanyarray(spatialimage.dataobj, dtype=input_dtype)
480+
if spatialimage.ndim in (2, 3)
481+
else None
482+
)
483+
484+
for t, xfm_t in enumerate(self):
485+
# Map the input coordinates on to timepoint t of the target (moving)
486+
ycoords = xfm_t.map(xcoords.T)[..., : _ref.ndim]
487+
488+
# Calculate corresponding voxel coordinates
493489
yvoxels = ras2vox.map(ycoords)[..., : _ref.ndim]
494490

495-
resampled = ndi.map_coordinates(
496-
spatialimage.dataobj.astype(input_dtype, copy=False),
491+
# Interpolate
492+
resampled[..., t] = ndi.map_coordinates(
493+
(
494+
dataobj if dataobj is not None
495+
else np.asanyarray(spatialimage.dataobj[..., t], dtype=input_dtype)
496+
),
497497
yvoxels.T,
498498
output=output_dtype,
499499
order=order,
@@ -503,9 +503,9 @@ def apply(
503503
)
504504

505505
if isinstance(_ref, ImageGrid): # If reference is grid, reshape
506-
newdata = resampled.reshape((len(self), *_ref.shape))
506+
newdata = resampled.reshape(_ref.shape + (len(self), ))
507507
moved = spatialimage.__class__(
508-
np.moveaxis(newdata, 0, -1), _ref.affine, spatialimage.header
508+
newdata, _ref.affine, spatialimage.header
509509
)
510510
moved.header.set_data_dtype(output_dtype)
511511
return moved

0 commit comments

Comments
 (0)