Skip to content

Commit f1efba1

Browse files
authored
Merge pull request #247 from nipy/codex/investigate-4d-dataset-resampling-issue
FIX: Broken 4D resampling
2 parents a94e577 + 4e159c2 commit f1efba1

File tree

2 files changed

+77
-20
lines changed

2 files changed

+77
-20
lines changed

nitransforms/resampling.py

Lines changed: 52 additions & 20 deletions
Original file line numberDiff line numberDiff line change
@@ -10,6 +10,7 @@
1010

1111
import asyncio
1212
from os import cpu_count
13+
from contextlib import suppress
1314
from functools import partial
1415
from pathlib import Path
1516
from typing import Callable, TypeVar, Union
@@ -108,12 +109,17 @@ async def _apply_serial(
108109
semaphore = asyncio.Semaphore(max_concurrent)
109110

110111
for t in range(n_resamplings):
111-
xfm_t = transform if (n_resamplings == 1 or transform.ndim < 4) else transform[t]
112+
xfm_t = (
113+
transform if (n_resamplings == 1 or transform.ndim < 4) else transform[t]
114+
)
112115

113-
if targets is None:
114-
targets = ImageGrid(spatialimage).index( # data should be an image
116+
targets_t = (
117+
ImageGrid(spatialimage).index(
115118
_as_homogeneous(xfm_t.map(ref_ndcoords), dim=ref_ndim)
116119
)
120+
if targets is None
121+
else targets[t, ...]
122+
)
117123

118124
data_t = (
119125
data
@@ -127,7 +133,7 @@ async def _apply_serial(
127133
partial(
128134
ndi.map_coordinates,
129135
data_t,
130-
targets,
136+
targets_t,
131137
output=output[..., t],
132138
order=order,
133139
mode=mode,
@@ -255,11 +261,22 @@ def apply(
255261
dim=_ref.ndim,
256262
)
257263
)
258-
elif xfm_nvols == 1:
259-
targets = ImageGrid(spatialimage).index( # data should be an image
260-
_as_homogeneous(transform.map(ref_ndcoords), dim=_ref.ndim)
264+
else:
265+
# Targets' shape is (Nt, 3, Nv) with Nv = Num. voxels, Nt = Num. timepoints.
266+
targets = (
267+
ImageGrid(spatialimage).index(
268+
_as_homogeneous(transform.map(ref_ndcoords), dim=_ref.ndim)
269+
)
270+
if targets is None
271+
else targets
261272
)
262273

274+
if targets.ndim == 3:
275+
targets = np.rollaxis(targets, targets.ndim - 1, 0)
276+
else:
277+
assert targets.ndim == 2
278+
targets = targets[np.newaxis, ...]
279+
263280
if serialize_4d:
264281
data = (
265282
np.asanyarray(spatialimage.dataobj, dtype=input_dtype)
@@ -294,17 +311,24 @@ def apply(
294311
else:
295312
data = np.asanyarray(spatialimage.dataobj, dtype=input_dtype)
296313

297-
if targets is None:
298-
targets = ImageGrid(spatialimage).index( # data should be an image
299-
_as_homogeneous(transform.map(ref_ndcoords), dim=_ref.ndim)
300-
)
301-
314+
if data_nvols == 1 and xfm_nvols == 1:
315+
targets = np.squeeze(targets)
316+
assert targets.ndim == 2
302317
# Cast 3D data into 4D if 4D nonsequential transform
303-
if data_nvols == 1 and xfm_nvols > 1:
318+
elif data_nvols == 1 and xfm_nvols > 1:
304319
data = data[..., np.newaxis]
305320

306-
if transform.ndim == 4:
307-
targets = _as_homogeneous(targets.reshape(-2, targets.shape[0])).T
321+
if xfm_nvols > 1:
322+
assert targets.ndim == 3
323+
n_time, n_dim, n_vox = targets.shape
324+
# Reshape to (3, n_time x n_vox)
325+
ijk_targets = np.rollaxis(targets, 0, 2).reshape((n_dim, -1))
326+
time_row = np.repeat(np.arange(n_time), n_vox)[None, :]
327+
328+
# Now targets is (4, n_vox x n_time), with indexes (t, i, j, k)
329+
# t is the slowest-changing axis, so we put it first
330+
targets = np.vstack((time_row, ijk_targets))
331+
data = np.rollaxis(data, data.ndim - 1, 0)
308332

309333
resampled = ndi.map_coordinates(
310334
data,
@@ -323,11 +347,19 @@ def apply(
323347
)
324348
hdr.set_data_dtype(output_dtype or spatialimage.header.get_data_dtype())
325349

326-
moved = spatialimage.__class__(
327-
resampled.reshape(_ref.shape if n_resamplings == 1 else _ref.shape + (-1,)),
328-
_ref.affine,
329-
hdr,
330-
)
350+
if serialize_4d:
351+
resampled = resampled.reshape(
352+
_ref.shape
353+
if n_resamplings == 1
354+
else _ref.shape + (resampled.shape[-1],)
355+
)
356+
else:
357+
resampled = resampled.reshape((-1, *_ref.shape))
358+
resampled = np.rollaxis(resampled, 0, resampled.ndim)
359+
with suppress(ValueError):
360+
resampled = np.squeeze(resampled, axis=3)
361+
362+
moved = spatialimage.__class__(resampled, _ref.affine, hdr)
331363
return moved
332364

333365
output_dtype = output_dtype or input_dtype

nitransforms/tests/test_resampling.py

Lines changed: 25 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -363,3 +363,28 @@ def test_LinearTransformsMapping_apply(
363363
reference=testdata_path / "sbref.nii.gz",
364364
serialize_nvols=2 if serialize_4d else np.inf,
365365
)
366+
367+
368+
@pytest.mark.parametrize("serialize_4d", [True, False])
369+
def test_apply_4d(serialize_4d):
370+
"""Regression test for per-volume transforms with serialized resampling."""
371+
nvols = 9
372+
shape = (10, 5, 5)
373+
base = np.zeros(shape, dtype=np.float32)
374+
base[9, 2, 2] = 1
375+
img = nb.Nifti1Image(np.stack([base] * nvols, axis=-1), np.eye(4))
376+
377+
transforms = []
378+
for i in range(nvols):
379+
mat = np.eye(4)
380+
mat[0, 3] = i
381+
transforms.append(nitl.Affine(mat))
382+
383+
extraparams = {} if serialize_4d else {"serialize_nvols": nvols + 1}
384+
385+
xfm = nitl.LinearTransformsMapping(transforms, reference=img)
386+
387+
moved = apply(xfm, img, order=0, **extraparams)
388+
data = np.asanyarray(moved.dataobj)
389+
idxs = [tuple(np.argwhere(data[..., i])[0]) for i in range(nvols)]
390+
assert idxs == [(9 - i, 2, 2) for i in range(nvols)]

0 commit comments

Comments
 (0)