Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
72 changes: 52 additions & 20 deletions nitransforms/resampling.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,6 +10,7 @@

import asyncio
from os import cpu_count
from contextlib import suppress
from functools import partial
from pathlib import Path
from typing import Callable, TypeVar, Union
Expand Down Expand Up @@ -108,12 +109,17 @@ async def _apply_serial(
semaphore = asyncio.Semaphore(max_concurrent)

for t in range(n_resamplings):
xfm_t = transform if (n_resamplings == 1 or transform.ndim < 4) else transform[t]
xfm_t = (
transform if (n_resamplings == 1 or transform.ndim < 4) else transform[t]
)

if targets is None:
targets = ImageGrid(spatialimage).index( # data should be an image
targets_t = (
ImageGrid(spatialimage).index(
_as_homogeneous(xfm_t.map(ref_ndcoords), dim=ref_ndim)
)
if targets is None
else targets[t, ...]
)

data_t = (
data
Expand All @@ -127,7 +133,7 @@ async def _apply_serial(
partial(
ndi.map_coordinates,
data_t,
targets,
targets_t,
output=output[..., t],
order=order,
mode=mode,
Expand Down Expand Up @@ -255,11 +261,22 @@ def apply(
dim=_ref.ndim,
)
)
elif xfm_nvols == 1:
targets = ImageGrid(spatialimage).index( # data should be an image
_as_homogeneous(transform.map(ref_ndcoords), dim=_ref.ndim)
else:
# Targets' shape is (Nt, 3, Nv) with Nv = Num. voxels, Nt = Num. timepoints.
targets = (
ImageGrid(spatialimage).index(
_as_homogeneous(transform.map(ref_ndcoords), dim=_ref.ndim)
)
if targets is None
else targets
)

if targets.ndim == 3:
targets = np.rollaxis(targets, targets.ndim - 1, 0)
else:
assert targets.ndim == 2
targets = targets[np.newaxis, ...]

if serialize_4d:
data = (
np.asanyarray(spatialimage.dataobj, dtype=input_dtype)
Expand Down Expand Up @@ -294,17 +311,24 @@ def apply(
else:
data = np.asanyarray(spatialimage.dataobj, dtype=input_dtype)

if targets is None:
targets = ImageGrid(spatialimage).index( # data should be an image
_as_homogeneous(transform.map(ref_ndcoords), dim=_ref.ndim)
)

if data_nvols == 1 and xfm_nvols == 1:
targets = np.squeeze(targets)
assert targets.ndim == 2
# Cast 3D data into 4D if 4D nonsequential transform
if data_nvols == 1 and xfm_nvols > 1:
elif data_nvols == 1 and xfm_nvols > 1:
data = data[..., np.newaxis]

if transform.ndim == 4:
targets = _as_homogeneous(targets.reshape(-2, targets.shape[0])).T
if xfm_nvols > 1:
assert targets.ndim == 3
n_time, n_dim, n_vox = targets.shape
# Reshape to (3, n_time x n_vox)
ijk_targets = np.rollaxis(targets, 0, 2).reshape((n_dim, -1))
time_row = np.repeat(np.arange(n_time), n_vox)[None, :]

# Now targets is (4, n_vox x n_time), with indexes (t, i, j, k)
# t is the slowest-changing axis, so we put it first
targets = np.vstack((time_row, ijk_targets))
data = np.rollaxis(data, data.ndim - 1, 0)

resampled = ndi.map_coordinates(
data,
Expand All @@ -323,11 +347,19 @@ def apply(
)
hdr.set_data_dtype(output_dtype or spatialimage.header.get_data_dtype())

moved = spatialimage.__class__(
resampled.reshape(_ref.shape if n_resamplings == 1 else _ref.shape + (-1,)),
_ref.affine,
hdr,
)
if serialize_4d:
resampled = resampled.reshape(
_ref.shape
if n_resamplings == 1
else _ref.shape + (resampled.shape[-1],)
)
else:
resampled = resampled.reshape((-1, *_ref.shape))
resampled = np.rollaxis(resampled, 0, resampled.ndim)
with suppress(ValueError):
resampled = np.squeeze(resampled, axis=3)

moved = spatialimage.__class__(resampled, _ref.affine, hdr)
return moved

output_dtype = output_dtype or input_dtype
Expand Down
25 changes: 25 additions & 0 deletions nitransforms/tests/test_resampling.py
Original file line number Diff line number Diff line change
Expand Up @@ -363,3 +363,28 @@ def test_LinearTransformsMapping_apply(
reference=testdata_path / "sbref.nii.gz",
serialize_nvols=2 if serialize_4d else np.inf,
)


@pytest.mark.parametrize("serialize_4d", [True, False])
def test_apply_4d(serialize_4d):
"""Regression test for per-volume transforms with serialized resampling."""
nvols = 9
shape = (10, 5, 5)
base = np.zeros(shape, dtype=np.float32)
base[9, 2, 2] = 1
img = nb.Nifti1Image(np.stack([base] * nvols, axis=-1), np.eye(4))

transforms = []
for i in range(nvols):
mat = np.eye(4)
mat[0, 3] = i
transforms.append(nitl.Affine(mat))

extraparams = {} if serialize_4d else {"serialize_nvols": nvols + 1}

xfm = nitl.LinearTransformsMapping(transforms, reference=img)

moved = apply(xfm, img, order=0, **extraparams)
data = np.asanyarray(moved.dataobj)
idxs = [tuple(np.argwhere(data[..., i])[0]) for i in range(nvols)]
assert idxs == [(9 - i, 2, 2) for i in range(nvols)]
Loading