diff --git a/nitransforms/resampling.py b/nitransforms/resampling.py index 53750206..98ef4454 100644 --- a/nitransforms/resampling.py +++ b/nitransforms/resampling.py @@ -10,6 +10,7 @@ import asyncio from os import cpu_count +from contextlib import suppress from functools import partial from pathlib import Path from typing import Callable, TypeVar, Union @@ -108,12 +109,17 @@ async def _apply_serial( semaphore = asyncio.Semaphore(max_concurrent) for t in range(n_resamplings): - xfm_t = transform if (n_resamplings == 1 or transform.ndim < 4) else transform[t] + xfm_t = ( + transform if (n_resamplings == 1 or transform.ndim < 4) else transform[t] + ) - if targets is None: - targets = ImageGrid(spatialimage).index( # data should be an image + targets_t = ( + ImageGrid(spatialimage).index( _as_homogeneous(xfm_t.map(ref_ndcoords), dim=ref_ndim) ) + if targets is None + else targets[t, ...] + ) data_t = ( data @@ -127,7 +133,7 @@ async def _apply_serial( partial( ndi.map_coordinates, data_t, - targets, + targets_t, output=output[..., t], order=order, mode=mode, @@ -255,11 +261,22 @@ def apply( dim=_ref.ndim, ) ) - elif xfm_nvols == 1: - targets = ImageGrid(spatialimage).index( # data should be an image - _as_homogeneous(transform.map(ref_ndcoords), dim=_ref.ndim) + else: + # Targets' shape is (Nt, 3, Nv) with Nv = Num. voxels, Nt = Num. timepoints. + targets = ( + ImageGrid(spatialimage).index( + _as_homogeneous(transform.map(ref_ndcoords), dim=_ref.ndim) + ) + if targets is None + else targets ) + if targets.ndim == 3: + targets = np.rollaxis(targets, targets.ndim - 1, 0) + else: + assert targets.ndim == 2 + targets = targets[np.newaxis, ...] + if serialize_4d: data = ( np.asanyarray(spatialimage.dataobj, dtype=input_dtype) @@ -294,17 +311,24 @@ def apply( else: data = np.asanyarray(spatialimage.dataobj, dtype=input_dtype) - if targets is None: - targets = ImageGrid(spatialimage).index( # data should be an image - _as_homogeneous(transform.map(ref_ndcoords), dim=_ref.ndim) - ) - + if data_nvols == 1 and xfm_nvols == 1: + targets = np.squeeze(targets) + assert targets.ndim == 2 # Cast 3D data into 4D if 4D nonsequential transform - if data_nvols == 1 and xfm_nvols > 1: + elif data_nvols == 1 and xfm_nvols > 1: data = data[..., np.newaxis] - if transform.ndim == 4: - targets = _as_homogeneous(targets.reshape(-2, targets.shape[0])).T + if xfm_nvols > 1: + assert targets.ndim == 3 + n_time, n_dim, n_vox = targets.shape + # Reshape to (3, n_time x n_vox) + ijk_targets = np.rollaxis(targets, 0, 2).reshape((n_dim, -1)) + time_row = np.repeat(np.arange(n_time), n_vox)[None, :] + + # Now targets is (4, n_vox x n_time), with indexes (t, i, j, k) + # t is the slowest-changing axis, so we put it first + targets = np.vstack((time_row, ijk_targets)) + data = np.rollaxis(data, data.ndim - 1, 0) resampled = ndi.map_coordinates( data, @@ -323,11 +347,19 @@ def apply( ) hdr.set_data_dtype(output_dtype or spatialimage.header.get_data_dtype()) - moved = spatialimage.__class__( - resampled.reshape(_ref.shape if n_resamplings == 1 else _ref.shape + (-1,)), - _ref.affine, - hdr, - ) + if serialize_4d: + resampled = resampled.reshape( + _ref.shape + if n_resamplings == 1 + else _ref.shape + (resampled.shape[-1],) + ) + else: + resampled = resampled.reshape((-1, *_ref.shape)) + resampled = np.rollaxis(resampled, 0, resampled.ndim) + with suppress(ValueError): + resampled = np.squeeze(resampled, axis=3) + + moved = spatialimage.__class__(resampled, _ref.affine, hdr) return moved output_dtype = output_dtype or input_dtype diff --git a/nitransforms/tests/test_resampling.py b/nitransforms/tests/test_resampling.py index 2384ad97..0e11df5b 100644 --- a/nitransforms/tests/test_resampling.py +++ b/nitransforms/tests/test_resampling.py @@ -363,3 +363,28 @@ def test_LinearTransformsMapping_apply( reference=testdata_path / "sbref.nii.gz", serialize_nvols=2 if serialize_4d else np.inf, ) + + +@pytest.mark.parametrize("serialize_4d", [True, False]) +def test_apply_4d(serialize_4d): + """Regression test for per-volume transforms with serialized resampling.""" + nvols = 9 + shape = (10, 5, 5) + base = np.zeros(shape, dtype=np.float32) + base[9, 2, 2] = 1 + img = nb.Nifti1Image(np.stack([base] * nvols, axis=-1), np.eye(4)) + + transforms = [] + for i in range(nvols): + mat = np.eye(4) + mat[0, 3] = i + transforms.append(nitl.Affine(mat)) + + extraparams = {} if serialize_4d else {"serialize_nvols": nvols + 1} + + xfm = nitl.LinearTransformsMapping(transforms, reference=img) + + moved = apply(xfm, img, order=0, **extraparams) + data = np.asanyarray(moved.dataobj) + idxs = [tuple(np.argwhere(data[..., i])[0]) for i in range(nvols)] + assert idxs == [(9 - i, 2, 2) for i in range(nvols)]