Skip to content

Commit 6064b8c

Browse files
Julien MarabottoJulien Marabotto
authored andcommitted
enh: draft implementation of serialize 4d
1 parent b922fa5 commit 6064b8c

File tree

1 file changed

+60
-27
lines changed

1 file changed

+60
-27
lines changed

nitransforms/resampling.py

Lines changed: 60 additions & 27 deletions
Original file line numberDiff line numberDiff line change
@@ -7,12 +7,13 @@
77
#
88
### ### ### ### ### ### ### ### ### ### ### ### ### ### ### ### ### ### ### ##
99
"""Resampling utilities."""
10-
from warnings import warn
1110
from pathlib import Path
1211
import numpy as np
1312
from nibabel.loadsave import load as _nbload
13+
from nibabel.arrayproxy import get_obj_dtype
1414
from scipy import ndimage as ndi
1515

16+
from nitransforms.linear import Affine, get
1617
from nitransforms.base import (
1718
ImageGrid,
1819
TransformError,
@@ -96,45 +97,77 @@ def apply(
9697

9798
data = np.asanyarray(spatialimage.dataobj)
9899
data_nvols = 1 if data.ndim < 4 else data.shape[-1]
99-
xfm_nvols = len(transforms)
100+
xfm_nvols = len(transform)
101+
assert xfm_nvols == transform.ndim == _ref.ndim
100102

101103
if data_nvols == 1 and xfm_nvols > 1:
102104
data = data[..., np.newaxis]
103105
elif data_nvols != xfm_nvols:
104106
raise ValueError(
105-
"The fourth dimension of the data does not match the tranform's shape."
107+
"The fourth dimension of the data does not match the transform's shape."
106108
)
107109

108110
serialize_nvols = serialize_nvols if serialize_nvols and serialize_nvols > 1 else np.inf
109111
serialize_4d = max(data_nvols, xfm_nvols) > serialize_nvols
110112
if serialize_4d:
111-
warn("4D transforms serialization into 3D+t not implemented")
112-
113-
# For model-based nonlinear transforms, generate the corresponding dense field
114-
if hasattr(transform, "to_field") and callable(transform.to_field):
115-
targets = ImageGrid(spatialimage).index(
116-
_as_homogeneous(
117-
transform.to_field(reference=reference).map(_ref.ndcoords.T),
118-
dim=_ref.ndim,
113+
for t, xfm_t in enumerate(transform):
114+
ras2vox = ~Affine(spatialimage.affine)
115+
input_dtype = get_obj_dtype(spatialimage.dataobj)
116+
output_dtype = output_dtype or input_dtype
117+
118+
# Map the input coordinates on to timepoint t of the target (moving)
119+
xcoords = _ref.ndcoords.astype("f4").T
120+
ycoords = xfm_t.map(xcoords)[..., : _ref.ndim]
121+
122+
# Calculate corresponding voxel coordinates
123+
yvoxels = ras2vox.map(ycoords)[..., : _ref.ndim]
124+
125+
# Interpolate
126+
dataobj = (
127+
np.asanyarray(spatialimage.dataobj, dtype=input_dtype)
128+
if spatialimage.ndim in (2, 3)
129+
else None
119130
)
120-
)
131+
resampled[..., t] = ndi.map_coordinates(
132+
(
133+
dataobj
134+
if dataobj is not None
135+
else spatialimage.dataobj[..., t].astype(input_dtype, copy=False)
136+
),
137+
yvoxels.T,
138+
output=output_dtype,
139+
order=order,
140+
mode=mode,
141+
cval=cval,
142+
prefilter=prefilter,
143+
)
144+
121145
else:
122-
targets = ImageGrid(spatialimage).index( # data should be an image
123-
_as_homogeneous(transform.map(_ref.ndcoords.T), dim=_ref.ndim)
124-
)
146+
# For model-based nonlinear transforms, generate the corresponding dense field
147+
if hasattr(transform, "to_field") and callable(transform.to_field):
148+
targets = ImageGrid(spatialimage).index(
149+
_as_homogeneous(
150+
transform.to_field(reference=reference).map(_ref.ndcoords.T),
151+
dim=_ref.ndim,
152+
)
153+
)
154+
else:
155+
targets = ImageGrid(spatialimage).index( # data should be an image
156+
_as_homogeneous(transform.map(_ref.ndcoords.T), dim=_ref.ndim)
157+
)
125158

126-
if transform.ndim == 4:
127-
targets = _as_homogeneous(targets.reshape(-2, targets.shape[0])).T
128-
129-
resampled = ndi.map_coordinates(
130-
data,
131-
targets,
132-
output=output_dtype,
133-
order=order,
134-
mode=mode,
135-
cval=cval,
136-
prefilter=prefilter,
137-
)
159+
if transform.ndim == 4:
160+
targets = _as_homogeneous(targets.reshape(-2, targets.shape[0])).T
161+
162+
resampled = ndi.map_coordinates(
163+
data,
164+
targets,
165+
output=output_dtype,
166+
order=order,
167+
mode=mode,
168+
cval=cval,
169+
prefilter=prefilter,
170+
)
138171

139172
if isinstance(_ref, ImageGrid): # If reference is grid, reshape
140173
hdr = None

0 commit comments

Comments
 (0)