Skip to content

Commit db1b250

Browse files
committed
fix: postpone coordinate mapping on linear array transforms
Resolves: #173.
1 parent 6e70c02 commit db1b250

File tree

1 file changed

+41
-27
lines changed

1 file changed

+41
-27
lines changed

nitransforms/linear.py

Lines changed: 41 additions & 27 deletions
Original file line numberDiff line numberDiff line change
@@ -436,6 +436,7 @@ def apply(
436436
The data imaged after resampling to reference space.
437437
438438
"""
439+
439440
if reference is not None and isinstance(reference, (str, Path)):
440441
reference = _nbload(str(reference))
441442

@@ -446,40 +447,53 @@ def apply(
446447
if isinstance(spatialimage, (str, Path)):
447448
spatialimage = _nbload(str(spatialimage))
448449

449-
data = np.squeeze(np.asanyarray(spatialimage.dataobj))
450-
output_dtype = output_dtype or data.dtype
450+
# Avoid opening the data array just yet
451+
input_dtype = spatialimage.header.get_data_dtype()
452+
output_dtype = output_dtype or input_dtype
451453

452-
ycoords = self.map(_ref.ndcoords.T)
453-
targets = ImageGrid(spatialimage).index( # data should be an image
454-
_as_homogeneous(np.vstack(ycoords), dim=_ref.ndim)
455-
)
454+
# Prepare physical coordinates of input (grid, points)
455+
xcoords = _ref.ndcoords.astype("f4")
456456

457-
if data.ndim == 4:
458-
if len(self) != data.shape[-1]:
457+
# Invert target's (moving) affine once
458+
ras2vox = ~Affine(spatialimage.affine)
459+
460+
if spatialimage.ndim == 4:
461+
if len(self) != spatialimage.shape[-1]:
459462
raise ValueError(
460463
"Attempting to apply %d transforms on a file with "
461-
"%d timepoints" % (len(self), data.shape[-1])
464+
"%d timepoints" % (len(self), spatialimage.shape[-1])
462465
)
463-
targets = targets.reshape((len(self), -1, targets.shape[-1]))
464-
resampled = np.stack(
465-
[
466-
ndi.map_coordinates(
467-
data[..., t],
468-
targets[t, ..., : _ref.ndim].T,
469-
output=output_dtype,
470-
order=order,
471-
mode=mode,
472-
cval=cval,
473-
prefilter=prefilter,
474-
)
475-
for t in range(data.shape[-1])
476-
],
477-
axis=0,
466+
467+
# Order F ensures individual volumes are contiguous in memory
468+
# Also matches NIfTI, making final save more efficient
469+
resampled = np.zeros(
470+
(xcoords.T.shape[0], ) + spatialimage.shape[-1:], dtype=output_dtype, order="F"
478471
)
479-
elif data.ndim in (2, 3):
472+
473+
for t in range(spatialimage.shape[-1]):
474+
# Map the input coordinates on to timepoint t of the target (moving)
475+
ycoords = Affine(self.matrix[t]).map(xcoords.T)[..., : _ref.ndim]
476+
477+
# Calculate corresponding voxel coordinates
478+
yvoxels = ras2vox.map(ycoords)[..., : _ref.ndim]
479+
480+
# Interpolate
481+
resampled[..., t] = ndi.map_coordinates(
482+
spatialimage.dataobj[..., t].astype(input_dtype, copy=False),
483+
yvoxels.T,
484+
output=output_dtype,
485+
order=order,
486+
mode=mode,
487+
cval=cval,
488+
prefilter=prefilter,
489+
)
490+
elif spatialimage.ndim in (2, 3):
491+
ycoords = self.map(xcoords.T)[..., : _ref.ndim]
492+
yvoxels = ras2vox.map(ycoords)[..., : _ref.ndim]
493+
480494
resampled = ndi.map_coordinates(
481-
data,
482-
targets[..., : _ref.ndim].T,
495+
spatialimage.dataobj.astype(input_dtype, copy=False),
496+
yvoxels.T,
483497
output=output_dtype,
484498
order=order,
485499
mode=mode,

0 commit comments

Comments
 (0)