Skip to content

Commit fbb0451

Browse files
committed
fix: resolve some failing tests
1 parent e0bde09 commit fbb0451

File tree

3 files changed

+19
-13
lines changed

3 files changed

+19
-13
lines changed

nitransforms/nonlinear.py

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -94,6 +94,10 @@ def __repr__(self):
9494
"""Beautify the python representation."""
9595
return f"<{self.__class__.__name__}[{self._field.shape[-1]}D] {self._field.shape[:3]}>"
9696

97+
def __len__(self):
98+
"""Enable len() -- for compatibility, only len == 1 is supported."""
99+
return 1
100+
97101
@property
98102
def ndim(self):
99103
"""Get the dimensions of the transform."""

nitransforms/resampling.py

Lines changed: 10 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -97,7 +97,6 @@ def apply(
9797

9898
# Avoid opening the data array just yet
9999
input_dtype = get_obj_dtype(spatialimage.dataobj)
100-
output_dtype = output_dtype or input_dtype
101100

102101
# Number of transformations
103102
data_nvols = 1 if spatialimage.ndim < 4 else spatialimage.shape[-1]
@@ -115,16 +114,17 @@ def apply(
115114
serialize_4d = n_resamplings >= serialize_nvols
116115

117116
targets = None
117+
ref_ndcoords = _ref.ndcoords.T
118118
if hasattr(transform, "to_field") and callable(transform.to_field):
119119
targets = ImageGrid(spatialimage).index(
120120
_as_homogeneous(
121-
transform.to_field(reference=reference).map(_ref.ndcoords.T),
121+
transform.to_field(reference=reference).map(ref_ndcoords),
122122
dim=_ref.ndim,
123123
)
124124
)
125125
elif xfm_nvols == 1:
126126
targets = ImageGrid(spatialimage).index( # data should be an image
127-
_as_homogeneous(transform.map(_ref.ndcoords.T), dim=_ref.ndim)
127+
_as_homogeneous(transform.map(ref_ndcoords), dim=_ref.ndim)
128128
)
129129

130130
if serialize_4d:
@@ -137,15 +137,15 @@ def apply(
137137
# Order F ensures individual volumes are contiguous in memory
138138
# Also matches NIfTI, making final save more efficient
139139
resampled = np.zeros(
140-
(spatialimage.size, len(transform)), dtype=output_dtype, order="F"
140+
(len(ref_ndcoords), len(transform)), dtype=input_dtype, order="F"
141141
)
142142

143143
for t in range(n_resamplings):
144144
xfm_t = transform if n_resamplings == 1 else transform[t]
145145

146146
if targets is None:
147147
targets = ImageGrid(spatialimage).index( # data should be an image
148-
_as_homogeneous(xfm_t.map(_ref.ndcoords.T), dim=_ref.ndim)
148+
_as_homogeneous(xfm_t.map(ref_ndcoords), dim=_ref.ndim)
149149
)
150150

151151
# Interpolate
@@ -156,7 +156,6 @@ def apply(
156156
else spatialimage.dataobj[..., t].astype(input_dtype, copy=False)
157157
),
158158
targets,
159-
output=output_dtype,
160159
order=order,
161160
mode=mode,
162161
cval=cval,
@@ -168,7 +167,7 @@ def apply(
168167

169168
if targets is None:
170169
targets = ImageGrid(spatialimage).index( # data should be an image
171-
_as_homogeneous(transform.map(_ref.ndcoords.T), dim=_ref.ndim)
170+
_as_homogeneous(transform.map(ref_ndcoords), dim=_ref.ndim)
172171
)
173172

174173
# Cast 3D data into 4D if 4D nonsequential transform
@@ -181,7 +180,6 @@ def apply(
181180
resampled = ndi.map_coordinates(
182181
data,
183182
targets,
184-
output=output_dtype,
185183
order=order,
186184
mode=mode,
187185
cval=cval,
@@ -190,13 +188,14 @@ def apply(
190188

191189
if isinstance(_ref, ImageGrid): # If reference is grid, reshape
192190
hdr = _ref.header.copy() if _ref.header is not None else spatialimage.header.__class__()
193-
hdr.set_data_dtype(output_dtype)
191+
hdr.set_data_dtype(output_dtype or spatialimage.header.get_data_dtype())
194192

195193
moved = spatialimage.__class__(
196-
resampled.reshape(_ref.shape if data.ndim < 4 else _ref.shape + (-1,)),
194+
resampled.reshape(_ref.shape if n_resamplings == 1 else _ref.shape + (-1,)),
197195
_ref.affine,
198196
hdr,
199197
)
200198
return moved
201199

202-
return resampled
200+
output_dtype = output_dtype or input_dtype
201+
return resampled.astype(output_dtype)

nitransforms/tests/test_base.py

Lines changed: 5 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,8 @@
11
"""Tests of the base module."""
22
import numpy as np
33
import nibabel as nb
4+
from nibabel.arrayproxy import get_obj_dtype
5+
46
import pytest
57
import h5py
68

@@ -97,7 +99,7 @@ def _to_hdf5(klass, x5_root):
9799
fname = testdata_path / "someones_anatomy.nii.gz"
98100

99101
img = nb.load(fname)
100-
imgdata = np.asanyarray(img.dataobj, dtype=img.get_data_dtype())
102+
imgdata = np.asanyarray(img.dataobj, dtype=get_obj_dtype(img.dataobj))
101103

102104
# Test identity transform - setting reference
103105
xfm = TransformBase()
@@ -111,7 +113,8 @@ def _to_hdf5(klass, x5_root):
111113
xfm = nitl.Affine()
112114
xfm.reference = fname
113115
moved = apply(xfm, fname, order=0)
114-
assert np.all(imgdata == np.asanyarray(moved.dataobj, dtype=moved.get_data_dtype()))
116+
117+
assert np.all(imgdata == np.asanyarray(moved.dataobj, dtype=get_obj_dtype(moved.dataobj)))
115118

116119
# Test ndim returned by affine
117120
assert nitl.Affine().ndim == 3

0 commit comments

Comments
 (0)