Skip to content

Commit 79e5cad

Browse files
committed
enh: integrating @jmarabotto's code
1 parent 23daabb commit 79e5cad

File tree

1 file changed

+45
-39
lines changed

1 file changed

+45
-39
lines changed

nitransforms/resampling.py

Lines changed: 45 additions & 39 deletions
Original file line numberDiff line numberDiff line change
@@ -7,22 +7,21 @@
77
#
88
### ### ### ### ### ### ### ### ### ### ### ### ### ### ### ### ### ### ### ##
99
"""Resampling utilities."""
10-
from warnings import warn
10+
1111
from pathlib import Path
1212
import numpy as np
1313
from nibabel.loadsave import load as _nbload
1414
from nibabel.arrayproxy import get_obj_dtype
1515
from scipy import ndimage as ndi
1616

17-
from nitransforms.linear import Affine, LinearTransformsMapping
1817
from nitransforms.base import (
1918
ImageGrid,
2019
TransformError,
2120
SpatialReference,
2221
_as_homogeneous,
2322
)
2423

25-
SERIALIZE_VOLUME_WINDOW_WIDTH : int = 8
24+
SERIALIZE_VOLUME_WINDOW_WIDTH: int = 8
2625
"""Minimum number of volumes to automatically serialize 4D transforms."""
2726

2827

@@ -96,58 +95,67 @@ def apply(
9695
if isinstance(spatialimage, (str, Path)):
9796
spatialimage = _nbload(str(spatialimage))
9897

99-
data = np.asanyarray(spatialimage.dataobj)
100-
data_nvols = 1 if data.ndim < 4 else data.shape[-1]
98+
# Avoid opening the data array just yet
99+
input_dtype = get_obj_dtype(spatialimage.dataobj)
100+
output_dtype = output_dtype or input_dtype
101101

102+
# Number of transformations
103+
data_nvols = 1 if spatialimage.ndim < 4 else spatialimage.shape[-1]
102104
xfm_nvols = len(transform)
103105

104-
if data_nvols == 1 and xfm_nvols > 1:
105-
data = data[..., np.newaxis]
106-
elif data_nvols != xfm_nvols:
106+
if data_nvols != xfm_nvols and min(data_nvols, xfm_nvols) > 1:
107107
raise ValueError(
108108
"The fourth dimension of the data does not match the transform's shape."
109109
)
110110

111-
serialize_nvols = serialize_nvols if serialize_nvols and serialize_nvols > 1 else np.inf
112-
serialize_4d = max(data_nvols, xfm_nvols) >= serialize_nvols
111+
serialize_nvols = (
112+
serialize_nvols if serialize_nvols and serialize_nvols > 1 else np.inf
113+
)
114+
n_resamplings = max(data_nvols, xfm_nvols)
115+
serialize_4d = n_resamplings >= serialize_nvols
116+
117+
targets = None
118+
if hasattr(transform, "to_field") and callable(transform.to_field):
119+
targets = ImageGrid(spatialimage).index(
120+
_as_homogeneous(
121+
transform.to_field(reference=reference).map(_ref.ndcoords.T),
122+
dim=_ref.ndim,
123+
)
124+
)
125+
elif xfm_nvols == 1:
126+
targets = ImageGrid(spatialimage).index( # data should be an image
127+
_as_homogeneous(transform.map(_ref.ndcoords.T), dim=_ref.ndim)
128+
)
113129

114130
if serialize_4d:
115-
# Avoid opening the data array just yet
116-
input_dtype = get_obj_dtype(spatialimage.dataobj)
117-
output_dtype = output_dtype or input_dtype
118-
119-
# Prepare physical coordinates of input (grid, points)
120-
xcoords = _ref.ndcoords.astype("f4").T
121-
122-
# Invert target's (moving) affine once
123-
ras2vox = ~Affine(spatialimage.affine)
124-
dataobj = (
131+
data = (
125132
np.asanyarray(spatialimage.dataobj, dtype=input_dtype)
126-
if spatialimage.ndim in (2, 3)
133+
if data_nvols == 1
127134
else None
128135
)
129136

130137
# Order F ensures individual volumes are contiguous in memory
131138
# Also matches NIfTI, making final save more efficient
132139
resampled = np.zeros(
133-
(xcoords.shape[0], len(transform)), dtype=output_dtype, order="F"
140+
(spatialimage.size, len(transform)), dtype=output_dtype, order="F"
134141
)
135142

136-
for t, xfm_t in enumerate(transform):
137-
# Map the input coordinates on to timepoint t of the target (moving)
138-
ycoords = xfm_t.map(xcoords)[..., : _ref.ndim]
143+
for t in range(n_resamplings):
144+
xfm_t = transform if n_resamplings == 1 else transform[t]
139145

140-
# Calculate corresponding voxel coordinates
141-
yvoxels = ras2vox.map(ycoords)[..., : _ref.ndim]
146+
if targets is None:
147+
targets = ImageGrid(spatialimage).index( # data should be an image
148+
_as_homogeneous(xfm_t.map(_ref.ndcoords.T), dim=_ref.ndim)
149+
)
142150

143151
# Interpolate
144152
resampled[..., t] = ndi.map_coordinates(
145153
(
146-
dataobj
147-
if dataobj is not None
154+
data
155+
if data is not None
148156
else spatialimage.dataobj[..., t].astype(input_dtype, copy=False)
149157
),
150-
yvoxels.T,
158+
targets,
151159
output=output_dtype,
152160
order=order,
153161
mode=mode,
@@ -156,19 +164,17 @@ def apply(
156164
)
157165

158166
else:
159-
# For model-based nonlinear transforms, generate the corresponding dense field
160-
if hasattr(transform, "to_field") and callable(transform.to_field):
161-
targets = ImageGrid(spatialimage).index(
162-
_as_homogeneous(
163-
transform.to_field(reference=reference).map(_ref.ndcoords.T),
164-
dim=_ref.ndim,
165-
)
166-
)
167-
else:
167+
data = np.asanyarray(spatialimage.dataobj, dtype=input_dtype)
168+
169+
if targets is None:
168170
targets = ImageGrid(spatialimage).index( # data should be an image
169171
_as_homogeneous(transform.map(_ref.ndcoords.T), dim=_ref.ndim)
170172
)
171173

174+
# Cast 3D data into 4D if 4D nonsequential transform
175+
if data_nvols == 1 and xfm_nvols > 1:
176+
data = data[..., np.newaxis]
177+
172178
if transform.ndim == 4:
173179
targets = _as_homogeneous(targets.reshape(-2, targets.shape[0])).T
174180

0 commit comments

Comments
 (0)