Skip to content

Commit 28737f4

Browse files
authored
Merge pull request #187 from nipy/fix/memory-issues-173
FIX: Postpone coordinate mapping on linear array transforms
2 parents 6e70c02 + d148e85 commit 28737f4

File tree

1 file changed

+51
-38
lines changed

1 file changed

+51
-38
lines changed

nitransforms/linear.py

Lines changed: 51 additions & 38 deletions
Original file line numberDiff line numberDiff line change
@@ -14,6 +14,7 @@
1414

1515
from nibabel.loadsave import load as _nbload
1616
from nibabel.affines import from_matvec
17+
from nibabel.arrayproxy import get_obj_dtype
1718

1819
from nitransforms.base import (
1920
ImageGrid,
@@ -216,14 +217,13 @@ def from_filename(cls, filename, fmt=None, reference=None, moving=None):
216217
is_array = cls != Affine
217218
errors = []
218219
for potential_fmt in fmtlist:
219-
if (potential_fmt == "itk" and Path(filename).suffix == ".mat"):
220+
if potential_fmt == "itk" and Path(filename).suffix == ".mat":
220221
is_array = False
221222
cls = Affine
222223

223224
try:
224225
struct = get_linear_factory(
225-
potential_fmt,
226-
is_array=is_array
226+
potential_fmt, is_array=is_array
227227
).from_filename(filename)
228228
except (TransformFileError, FileNotFoundError) as err:
229229
errors.append((potential_fmt, err))
@@ -316,6 +316,11 @@ def __init__(self, transforms, reference=None):
316316
)
317317
self._inverse = np.linalg.inv(self._matrix)
318318

319+
def __iter__(self):
320+
"""Enable iterating over the series of transforms."""
321+
for _m in self.matrix:
322+
yield Affine(_m, reference=self._reference)
323+
319324
def __getitem__(self, i):
320325
"""Enable indexed access to the series of matrices."""
321326
return Affine(self.matrix[i, ...], reference=self._reference)
@@ -436,6 +441,7 @@ def apply(
436441
The data imaged after resampling to reference space.
437442
438443
"""
444+
439445
if reference is not None and isinstance(reference, (str, Path)):
440446
reference = _nbload(str(reference))
441447

@@ -446,40 +452,49 @@ def apply(
446452
if isinstance(spatialimage, (str, Path)):
447453
spatialimage = _nbload(str(spatialimage))
448454

449-
data = np.squeeze(np.asanyarray(spatialimage.dataobj))
450-
output_dtype = output_dtype or data.dtype
455+
# Avoid opening the data array just yet
456+
input_dtype = get_obj_dtype(spatialimage.dataobj)
457+
output_dtype = output_dtype or input_dtype
451458

452-
ycoords = self.map(_ref.ndcoords.T)
453-
targets = ImageGrid(spatialimage).index( # data should be an image
454-
_as_homogeneous(np.vstack(ycoords), dim=_ref.ndim)
455-
)
459+
# Prepare physical coordinates of input (grid, points)
460+
xcoords = _ref.ndcoords.astype("f4").T
456461

457-
if data.ndim == 4:
458-
if len(self) != data.shape[-1]:
459-
raise ValueError(
460-
"Attempting to apply %d transforms on a file with "
461-
"%d timepoints" % (len(self), data.shape[-1])
462-
)
463-
targets = targets.reshape((len(self), -1, targets.shape[-1]))
464-
resampled = np.stack(
465-
[
466-
ndi.map_coordinates(
467-
data[..., t],
468-
targets[t, ..., : _ref.ndim].T,
469-
output=output_dtype,
470-
order=order,
471-
mode=mode,
472-
cval=cval,
473-
prefilter=prefilter,
474-
)
475-
for t in range(data.shape[-1])
476-
],
477-
axis=0,
462+
# Invert target's (moving) affine once
463+
ras2vox = ~Affine(spatialimage.affine)
464+
465+
if spatialimage.ndim == 4 and (len(self) != spatialimage.shape[-1]):
466+
raise ValueError(
467+
"Attempting to apply %d transforms on a file with "
468+
"%d timepoints" % (len(self), spatialimage.shape[-1])
478469
)
479-
elif data.ndim in (2, 3):
480-
resampled = ndi.map_coordinates(
481-
data,
482-
targets[..., : _ref.ndim].T,
470+
471+
# Order F ensures individual volumes are contiguous in memory
472+
# Also matches NIfTI, making final save more efficient
473+
resampled = np.zeros(
474+
(xcoords.shape[0], len(self)), dtype=output_dtype, order="F"
475+
)
476+
477+
dataobj = (
478+
np.asanyarray(spatialimage.dataobj, dtype=input_dtype)
479+
if spatialimage.ndim in (2, 3)
480+
else None
481+
)
482+
483+
for t, xfm_t in enumerate(self):
484+
# Map the input coordinates on to timepoint t of the target (moving)
485+
ycoords = xfm_t.map(xcoords)[..., : _ref.ndim]
486+
487+
# Calculate corresponding voxel coordinates
488+
yvoxels = ras2vox.map(ycoords)[..., : _ref.ndim]
489+
490+
# Interpolate
491+
resampled[..., t] = ndi.map_coordinates(
492+
(
493+
dataobj
494+
if dataobj is not None
495+
else spatialimage.dataobj[..., t].astype(input_dtype, copy=False)
496+
),
497+
yvoxels.T,
483498
output=output_dtype,
484499
order=order,
485500
mode=mode,
@@ -488,10 +503,8 @@ def apply(
488503
)
489504

490505
if isinstance(_ref, ImageGrid): # If reference is grid, reshape
491-
newdata = resampled.reshape((len(self), *_ref.shape))
492-
moved = spatialimage.__class__(
493-
np.moveaxis(newdata, 0, -1), _ref.affine, spatialimage.header
494-
)
506+
newdata = resampled.reshape(_ref.shape + (len(self),))
507+
moved = spatialimage.__class__(newdata, _ref.affine, spatialimage.header)
495508
moved.header.set_data_dtype(output_dtype)
496509
return moved
497510

0 commit comments

Comments
 (0)