Skip to content

Commit 8e83ccd

Browse files
committed
Merge branch 'fix/267' into fix/itk-displacements-field-round-trip
2 parents d9b1e12 + 664552b commit 8e83ccd

File tree

7 files changed

+242
-99
lines changed

7 files changed

+242
-99
lines changed

nitransforms/base.py

Lines changed: 7 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -202,30 +202,26 @@ def inverse(self):
202202
def ndindex(self):
203203
"""List the indexes corresponding to the space grid."""
204204
if self._ndindex is None:
205-
indexes = tuple([np.arange(s) for s in self._shape])
206-
self._ndindex = np.array(np.meshgrid(*indexes, indexing="ij")).reshape(
207-
self._ndim, self._npoints
208-
)
205+
indexes = np.mgrid[
206+
0:self._shape[0], 0:self._shape[1], 0:self._shape[2]
207+
]
208+
self._ndindex = indexes.reshape((indexes.shape[0], -1)).T
209209
return self._ndindex
210210

211211
@property
212212
def ndcoords(self):
213213
"""List the physical coordinates of this gridded space samples."""
214214
if self._coords is None:
215-
self._coords = np.tensordot(
216-
self._affine,
217-
np.vstack((self.ndindex, np.ones((1, self._npoints)))),
218-
axes=1,
219-
)[:3, ...]
215+
self._coords = self.ras(self.ndindex)
220216
return self._coords
221217

222218
def ras(self, ijk):
223219
"""Get RAS+ coordinates from input indexes."""
224-
return _apply_affine(ijk, self._affine, self._ndim)
220+
return _apply_affine(ijk, self._affine, self._ndim).T
225221

226222
def index(self, x):
227223
"""Get the image array's indexes corresponding to coordinates."""
228-
return _apply_affine(x, self._inverse, self._ndim)
224+
return _apply_affine(x, self._inverse, self._ndim).T
229225

230226
def _to_hdf5(self, group):
231227
group.attrs["Type"] = "image"

nitransforms/nonlinear.py

Lines changed: 36 additions & 39 deletions
Original file line numberDiff line numberDiff line change
@@ -65,50 +65,47 @@ def __init__(self, field=None, is_deltas=True, reference=None):
6565
<DenseFieldTransform[3D] (57, 67, 56)>
6666
6767
"""
68+
6869
if field is None and reference is None:
69-
raise TransformError("DenseFieldTransforms require a spatial reference")
70+
raise TransformError("cannot initialize field")
7071

7172
super().__init__()
7273

73-
self._is_deltas = is_deltas
74+
if field is not None:
75+
field = _ensure_image(field)
76+
# Extract data if nibabel object otherwise assume numpy array
77+
_data = np.squeeze(
78+
np.asanyarray(field.dataobj)
79+
if hasattr(field, "dataobj")
80+
else field.copy()
81+
)
7482

7583
try:
7684
self.reference = ImageGrid(reference if reference is not None else field)
7785
except AttributeError:
7886
raise TransformError(
79-
"Field must be a spatial image if reference is not provided"
87+
"field must be a spatial image if reference is not provided"
8088
if reference is None
81-
else "Reference is not a spatial image"
89+
else "reference is not a spatial image"
8290
)
8391

8492
fieldshape = (*self.reference.shape, self.reference.ndim)
85-
if field is not None:
86-
field = _ensure_image(field)
87-
self._field = np.squeeze(
88-
np.asanyarray(field.dataobj) if hasattr(field, "dataobj") else field
89-
)
90-
if fieldshape != self._field.shape:
91-
raise TransformError(
92-
f"Shape of the field ({'x'.join(str(i) for i in self._field.shape)}) "
93-
f"doesn't match that of the reference({'x'.join(str(i) for i in fieldshape)})"
94-
)
95-
else:
96-
self._field = np.zeros(fieldshape, dtype="float32")
97-
self._is_deltas = True
98-
99-
if self._field.shape[-1] != self.ndim:
93+
if field is None:
94+
_data = np.zeros(fieldshape)
95+
elif fieldshape != _data.shape:
10096
raise TransformError(
101-
"The number of components of the field (%d) does not match "
102-
"the number of dimensions (%d)" % (self._field.shape[-1], self.ndim)
97+
f"Shape of the field ({'x'.join(str(i) for i in _data.shape)}) "
98+
f"doesn't match that of the reference({'x'.join(str(i) for i in fieldshape)})"
10399
)
104100

101+
self._is_deltas = is_deltas
102+
self._field = self.reference.ndcoords.reshape(fieldshape)
103+
105104
if self.is_deltas:
106-
self._deltas = (
107-
self._field.copy()
108-
) # IMPORTANT: you don't want to update deltas
109-
# Convert from displacements (deltas) to deformations fields
110-
# (just add its origin to each delta vector)
111-
self._field += self.reference.ndcoords.T.reshape(fieldshape)
105+
self._deltas = _data.copy()
106+
self._field += self._deltas
107+
else:
108+
self._field = _data.copy()
112109

113110
def __repr__(self):
114111
"""Beautify the python representation."""
@@ -153,7 +150,7 @@ def map(self, x, inverse=False):
153150
... test_dir / "someones_displacement_field.nii.gz",
154151
... is_deltas=False,
155152
... )
156-
>>> xfm.map([-6.5, -36., -19.5]).tolist()
153+
>>> xfm.map([[-6.5, -36., -19.5]]).tolist()
157154
[[0.0, -0.47516798973083496, 0.0]]
158155
159156
>>> xfm.map([[-6.5, -36., -19.5], [-1., -41.5, -11.25]]).tolist()
@@ -170,8 +167,8 @@ def map(self, x, inverse=False):
170167
... test_dir / "someones_displacement_field.nii.gz",
171168
... is_deltas=True,
172169
... )
173-
>>> xfm.map([[-6.5, -36., -19.5], [-1., -41.5, -11.25]]).tolist()
174-
[[-6.5, -36.47516632080078, -19.5], [-1.0, -42.03835678100586, -11.25]]
170+
>>> xfm.map([[-6.5, -36., -19.5], [-1., -41.5, -11.25]]).tolist() # doctest: +ELLIPSIS
171+
[[-6.5, -36.475..., -19.5], [-1.0, -42.038..., -11.25]]
175172
176173
>>> np.array_str(
177174
... xfm.map([[-6.7, -36.3, -19.2], [-1., -41.5, -11.25]]),
@@ -185,19 +182,19 @@ def map(self, x, inverse=False):
185182
if inverse is True:
186183
raise NotImplementedError
187184

188-
ijk = self.reference.index(x)
185+
ijk = self.reference.index(np.array(x, dtype="float32"))
189186
indexes = np.round(ijk).astype("int")
187+
ongrid = np.where(np.linalg.norm(ijk - indexes, axis=1) < 1e-3)[0]
190188

191-
import pdb; pdb.set_trace()
192-
if np.all(np.abs(ijk - indexes) < 1e-3):
193-
indexes = tuple(tuple(i) for i in indexes)
194-
return self._field[indexes]
189+
if ongrid.size == np.shape(x)[0]:
190+
# return self._field[*indexes.T, :] # From Python 3.11
191+
return self._field[tuple(indexes.T) + (np.s_[:],)]
195192

196-
new_map = np.vstack(
193+
mapped_coords = np.vstack(
197194
tuple(
198195
map_coordinates(
199196
self._field[..., i],
200-
ijk,
197+
ijk.T,
201198
order=3,
202199
mode="constant",
203200
cval=np.nan,
@@ -208,8 +205,8 @@ def map(self, x, inverse=False):
208205
).T
209206

210207
# Set NaN values back to the original coordinates value = no displacement
211-
new_map[np.isnan(new_map)] = np.array(x)[np.isnan(new_map)]
212-
return new_map
208+
mapped_coords[np.isnan(mapped_coords)] = np.array(x)[np.isnan(mapped_coords)]
209+
return mapped_coords
213210

214211
def __matmul__(self, b):
215212
"""

nitransforms/resampling.py

Lines changed: 15 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -253,7 +253,7 @@ def apply(
253253
serialize_4d = n_resamplings >= serialize_nvols
254254

255255
targets = None
256-
ref_ndcoords = _ref.ndcoords.T
256+
ref_ndcoords = _ref.ndcoords
257257
if hasattr(transform, "to_field") and callable(transform.to_field):
258258
targets = ImageGrid(spatialimage).index(
259259
_as_homogeneous(
@@ -271,11 +271,8 @@ def apply(
271271
else targets
272272
)
273273

274-
if targets.ndim == 3:
275-
targets = np.rollaxis(targets, targets.ndim - 1, 0)
276-
else:
277-
assert targets.ndim == 2
278-
targets = targets[np.newaxis, ...]
274+
if targets.ndim == 2:
275+
targets = targets.T[np.newaxis, ...]
279276

280277
if serialize_4d:
281278
data = (
@@ -290,6 +287,9 @@ def apply(
290287
(len(ref_ndcoords), n_resamplings), dtype=input_dtype, order="F"
291288
)
292289

290+
if targets.ndim == 3:
291+
targets = np.rollaxis(targets, targets.ndim - 1, 1)
292+
293293
resampled = asyncio.run(
294294
_apply_serial(
295295
data,
@@ -311,6 +311,9 @@ def apply(
311311
else:
312312
data = np.asanyarray(spatialimage.dataobj, dtype=input_dtype)
313313

314+
if targets.ndim == 3:
315+
targets = np.rollaxis(targets, targets.ndim - 1, 0)
316+
314317
if data_nvols == 1 and xfm_nvols == 1:
315318
targets = np.squeeze(targets)
316319
assert targets.ndim == 2
@@ -320,15 +323,19 @@ def apply(
320323

321324
if xfm_nvols > 1:
322325
assert targets.ndim == 3
323-
n_time, n_dim, n_vox = targets.shape
326+
327+
# Targets must have shape (n_dim x n_time x n_vox)
328+
n_dim, n_time, n_vox = targets.shape
324329
# Reshape to (3, n_time x n_vox)
325-
ijk_targets = np.rollaxis(targets, 0, 2).reshape((n_dim, -1))
330+
ijk_targets = targets.reshape((n_dim, -1))
326331
time_row = np.repeat(np.arange(n_time), n_vox)[None, :]
327332

328333
# Now targets is (4, n_vox x n_time), with indexes (t, i, j, k)
329334
# t is the slowest-changing axis, so we put it first
330335
targets = np.vstack((time_row, ijk_targets))
331336
data = np.rollaxis(data, data.ndim - 1, 0)
337+
else:
338+
targets = targets.T
332339

333340
resampled = ndi.map_coordinates(
334341
data,

nitransforms/tests/test_base.py

Lines changed: 8 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -55,20 +55,24 @@ def test_ImageGrid(get_testdata, image_orientation):
5555

5656
assert np.allclose(np.squeeze(img.ras(ijk[0])), xyz[0])
5757
assert np.allclose(np.round(img.index(xyz[0])), ijk[0])
58-
assert np.allclose(img.ras(ijk).T, xyz)
59-
assert np.allclose(np.round(img.index(xyz)).T, ijk)
58+
assert np.allclose(img.ras(ijk), xyz)
59+
assert np.allclose(np.round(img.index(xyz)), ijk)
6060

6161
# nd index / coords
6262
idxs = img.ndindex
6363
coords = img.ndcoords
6464
assert len(idxs.shape) == len(coords.shape) == 2
65-
assert idxs.shape[0] == coords.shape[0] == img.ndim == 3
66-
assert idxs.shape[1] == coords.shape[1] == img.npoints == np.prod(im.shape)
65+
assert idxs.shape[1] == coords.shape[1] == img.ndim == 3
66+
assert idxs.shape[0] == coords.shape[0] == img.npoints == np.prod(im.shape)
6767

6868
img2 = ImageGrid(img)
6969
assert img2 == img
7070
assert (img2 != img) is False
7171

72+
# Test indexing round trip
73+
np.testing.assert_allclose(img.ndcoords, img.ras(img.ndindex))
74+
np.testing.assert_allclose(img.ndindex, np.round(img.index(img.ndcoords)))
75+
7276

7377
def test_ImageGrid_utils(tmpdir, testdata_path, get_testdata):
7478
"""Check that images can be objects or paths and equality."""

nitransforms/tests/test_io.py

Lines changed: 6 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -17,13 +17,13 @@
1717
from nibabel.affines import from_matvec
1818
from scipy.io import loadmat
1919
from nitransforms.linear import Affine
20-
from nitransforms import nonlinear as nitnl
20+
from nitransforms.nonlinear import DenseFieldTransform, BSplineFieldTransform
2121
from nitransforms.io import (
2222
afni,
2323
fsl,
2424
lta as fs,
2525
itk,
26-
x5
26+
x5,
2727
)
2828
from nitransforms.io.lta import (
2929
VolumeGeometry as VG,
@@ -773,7 +773,7 @@ def test_densefield_x5_roundtrip(tmp_path, is_deltas):
773773
ref = nb.Nifti1Image(np.zeros((2, 2, 2), dtype="uint8"), np.eye(4))
774774
disp = nb.Nifti1Image(np.random.rand(2, 2, 2, 3).astype("float32"), np.eye(4))
775775

776-
xfm = nitnl.DenseFieldTransform(disp, is_deltas=is_deltas, reference=ref)
776+
xfm = DenseFieldTransform(disp, is_deltas=is_deltas, reference=ref)
777777

778778
node = xfm.to_x5(metadata={"GeneratedBy": "pytest"})
779779
assert node.type == "nonlinear"
@@ -785,7 +785,7 @@ def test_densefield_x5_roundtrip(tmp_path, is_deltas):
785785
fname = tmp_path / "test.x5"
786786
x5.to_filename(fname, [node])
787787

788-
xfm2 = nitnl.DenseFieldTransform.from_filename(fname, fmt="X5")
788+
xfm2 = DenseFieldTransform.from_filename(fname, fmt="X5")
789789

790790
assert xfm2.reference.shape == ref.shape
791791
assert np.allclose(xfm2.reference.affine, ref.affine)
@@ -797,7 +797,7 @@ def test_bspline_to_x5(tmp_path):
797797
coeff = nb.Nifti1Image(np.zeros((2, 2, 2, 3), dtype="float32"), np.eye(4))
798798
ref = nb.Nifti1Image(np.zeros((2, 2, 2), dtype="uint8"), np.eye(4))
799799

800-
xfm = nitnl.BSplineFieldTransform(coeff, reference=ref)
800+
xfm = BSplineFieldTransform(coeff, reference=ref)
801801
node = xfm.to_x5(metadata={"tool": "pytest"})
802802
assert node.type == "nonlinear"
803803
assert node.subtype == "bspline"
@@ -807,7 +807,7 @@ def test_bspline_to_x5(tmp_path):
807807
fname = tmp_path / "bspline.x5"
808808
x5.to_filename(fname, [node])
809809

810-
xfm2 = nitnl.BSplineFieldTransform.from_filename(fname, fmt="X5")
810+
xfm2 = BSplineFieldTransform.from_filename(fname, fmt="X5")
811811
assert np.allclose(xfm._coeffs, xfm2._coeffs)
812812
assert xfm2.reference.shape == ref.shape
813813
assert np.allclose(xfm2.reference.affine, ref.affine)

0 commit comments

Comments
 (0)