Skip to content

Commit ea8cda9

Browse files
committed
fix: revising map method and tests
1 parent 6104b47 commit ea8cda9

File tree

7 files changed

+143
-160
lines changed

7 files changed

+143
-160
lines changed

nitransforms/base.py

Lines changed: 56 additions & 18 deletions
Original file line numberDiff line numberDiff line change
@@ -78,14 +78,15 @@ def ndcoords(self):
7878
)[:3, ...]
7979
return self._coords
8080

81-
def index(self, coordinates):
82-
"""Get the image array's indexes corresponding to coordinates."""
83-
coordinates = np.array(coordinates)
84-
ncoords = coordinates.shape[-1]
85-
coordinates = np.vstack((coordinates, np.ones((1, ncoords))))
81+
def ras(self, ijk):
82+
"""Get RAS+ coordinates from input indexes."""
83+
ras = self._affine.dot(_as_homogeneous(ijk).T)[:3, ...]
84+
return ras.T
8685

87-
# Back to grid coordinates
88-
return np.tensordot(self._inverse, coordinates, axes=1)[:3, ...]
86+
def index(self, x):
87+
"""Get the image array's indexes corresponding to coordinates."""
88+
ijk = self._inverse.dot(_as_homogeneous(x).T)[:3, ...]
89+
return ijk.T
8990

9091
def _to_hdf5(self, group):
9192
group.attrs['Type'] = 'image'
@@ -114,9 +115,9 @@ def __eq__(self, other):
114115
return False
115116
return np.allclose(self.matrix, other.matrix, rtol=EQUALITY_TOL)
116117

117-
def __call__(self, x):
118+
def __call__(self, x, inverse=False, index=0):
118119
"""Apply y = f(x)."""
119-
return self.map(x)
120+
return self.map(x, inverse=inverse, index=index)
120121

121122
@property
122123
def reference(self):
@@ -171,31 +172,47 @@ def resample(self, moving, order=3, mode='constant', cval=0.0, prefilter=True,
171172
if output_dtype is None:
172173
output_dtype = moving_data.dtype
173174

174-
moving_grid = ImageGrid(moving)
175-
176-
def _map_indexes(ijk):
177-
return moving_grid.inverse.dot(self.map(self.reference.affine.dot(ijk)))
178-
179175
moved = ndi.geometric_transform(
180176
moving_data,
181-
mapping=_map_indexes,
177+
mapping=self._map_index,
182178
output_shape=self.reference.shape,
183179
output=output_dtype,
184180
order=order,
185181
mode=mode,
186182
cval=cval,
187183
prefilter=prefilter,
188-
extra_keywords={'moving': moving},
184+
extra_keywords={'moving': ImageGrid(moving)},
189185
)
190186

191187
moved_image = moving.__class__(moved, self.reference.affine, moving.header)
192188
moved_image.header.set_data_dtype(output_dtype)
193189
return moved_image
194190

195-
def map(self, x):
196-
"""Apply y = f(x)."""
191+
def map(self, x, inverse=False, index=0):
192+
r"""
193+
Apply :math:`y = f(x)`.
194+
195+
Parameters
196+
----------
197+
x : N x D numpy.ndarray
198+
Input RAS+ coordinates (i.e., physical coordinates).
199+
inverse : bool
200+
If ``True``, apply the inverse transform :math:`x = f^{-1}(y)`.
201+
index : int, optional
202+
Transformation index
203+
204+
Returns
205+
-------
206+
y : N x D numpy.ndarray
207+
Transformed (mapped) RAS+ coordinates (i.e., physical coordinates).
208+
209+
"""
197210
raise NotImplementedError
198211

212+
def _map_index(self, ijk, moving):
213+
x = self.reference.ras(_as_homogeneous(ijk))
214+
return moving.index(self.map(x))
215+
199216
def to_filename(self, filename, fmt='X5'):
200217
"""Store the transform in BIDS-Transforms HDF5 file format (.x5)."""
201218
with h5py.File(filename, 'w') as out_file:
@@ -209,3 +226,24 @@ def to_filename(self, filename, fmt='X5'):
209226
def _to_hdf5(self, x5_root):
210227
"""Serialize this object into the x5 file format."""
211228
raise NotImplementedError
229+
230+
231+
def _as_homogeneous(xyz, dtype='float32'):
232+
"""
233+
Convert 2D and 3D coordinates into homogeneous coordinates.
234+
235+
Examples
236+
--------
237+
>>> _as_homogeneous((4, 5), dtype='int8').tolist()
238+
[[4, 5, 1]]
239+
240+
>>> _as_homogeneous((4, 5, 6),dtype='int8').tolist()
241+
[[4, 5, 6, 1]]
242+
243+
>>> _as_homogeneous([(1, 2, 3), (4, 5, 6)]).tolist()
244+
[[1.0, 2.0, 3.0, 1.0], [4.0, 5.0, 6.0, 1.0]]
245+
246+
247+
"""
248+
xyz = np.atleast_2d(np.array(xyz, dtype=dtype))
249+
return np.hstack((xyz, np.ones((xyz.shape[0], 1), dtype=dtype)))

nitransforms/conftest.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -39,7 +39,7 @@ def data_path():
3939

4040

4141
@pytest.fixture
42-
def get_data():
42+
def get_testdata():
4343
"""Generate data in the requested orientation."""
4444
global _data
4545

nitransforms/linear.py

Lines changed: 18 additions & 15 deletions
Original file line numberDiff line numberDiff line change
@@ -143,7 +143,7 @@ def resample(self, moving, order=3, mode='constant', cval=0.0, prefilter=True,
143143
nvols = moving.shape[3]
144144

145145
movaff = moving.affine
146-
movingdata = moving.get_data()
146+
movingdata = np.asanyarray(moving.dataobj)
147147
if nvols == 1:
148148
movingdata = movingdata[..., np.newaxis]
149149

@@ -178,39 +178,42 @@ def resample(self, moving, order=3, mode='constant', cval=0.0, prefilter=True,
178178

179179
return moved_image
180180

181-
def map(self, coords, index=0, forward=True):
182-
"""
183-
Apply y = f(x), where x is the argument `coords`.
181+
def map(self, x, inverse=False, index=0):
182+
r"""
183+
Apply :math:`y = f(x)`.
184184
185185
Parameters
186186
----------
187-
coords : array_like
188-
RAS coordinates to map
187+
x : N x D numpy.ndarray
188+
Input RAS+ coordinates (i.e., physical coordinates).
189+
inverse : bool
190+
If ``True``, apply the inverse transform :math:`x = f^{-1}(y)`.
189191
index : int, optional
190192
Transformation index
191-
forward: bool, optional
192-
Direction of mapping. Default is set to ``True``. If ``False``,
193-
the inverse transformation is applied.
194193
195194
Returns
196195
-------
197-
out: ndarray
198-
Transformed coordinates
196+
y : N x D numpy.ndarray
197+
Transformed (mapped) RAS+ coordinates (i.e., physical coordinates).
199198
200199
Examples
201200
--------
202201
>>> xfm = Affine([[1, 0, 0, 1], [0, 1, 0, 2], [0, 0, 1, 3], [0, 0, 0, 1]])
203-
>>> xfm((0,0,0))
202+
>>> xfm.map((0,0,0))
204203
array([1, 2, 3])
205204
206-
>>> xfm((0,0,0), forward=False)
205+
>>> xfm.map((0,0,0), inverse=True)
207206
array([-1., -2., -3.])
208207
209208
"""
210-
coords = np.array(coords)
209+
coords = np.array(x)
211210
if coords.shape[0] == self._matrix[index].shape[0] - 1:
212211
coords = np.append(coords, [1])
213-
affine = self._matrix[index] if forward else np.linalg.inv(self._matrix[index])
212+
affine = self._matrix[index]
213+
214+
if inverse is True:
215+
affine = np.linalg.inv(self._matrix[index])
216+
214217
return affine.dot(coords)[:-1]
215218

216219
def _map_voxel(self, index, nindex=0, moving=None):

nitransforms/nonlinear.py

Lines changed: 34 additions & 98 deletions
Original file line numberDiff line numberDiff line change
@@ -7,6 +7,7 @@
77
#
88
### ### ### ### ### ### ### ### ### ### ### ### ### ### ### ### ### ### ### ##
99
"""Nonlinear transforms."""
10+
import sys
1011
import numpy as np
1112
from scipy import ndimage as ndi
1213
# from gridbspline.maths import cubic
@@ -21,12 +22,11 @@ class DeformationFieldTransform(TransformBase):
2122
"""Represents a dense field of displacements (one vector per voxel)."""
2223

2324
__slots__ = ['_field', '_moving', '_moving_space']
24-
__s = (slice(None), )
2525

2626
def __init__(self, field, reference=None):
2727
"""Create a dense deformation field transform."""
2828
super(DeformationFieldTransform, self).__init__()
29-
self._field = field.get_data()
29+
self._field = np.asanyarray(field.dataobj)
3030

3131
ndim = self._field.ndim - 1
3232
if len(self._field.shape[:-1]) != ndim:
@@ -51,110 +51,46 @@ def __init__(self, field, reference=None):
5151

5252
self.reference = reference
5353

54-
def _cache_moving(self, moving):
55-
# Check whether input (moving) space is cached
56-
moving_space = ImageGrid(moving)
57-
if self._moving_space == moving_space:
58-
return
54+
def map(self, x, inverse=False, index=0):
55+
r"""
56+
Apply :math:`y = f(x)`.
5957
60-
# Generate grid of pixel indexes (ijk)
61-
ndim = self._field.ndim - 1
62-
if ndim == 2:
63-
grid = np.meshgrid(
64-
np.arange(self._field.shape[0]),
65-
np.arange(self._field.shape[1]),
66-
indexing='ij')
67-
elif ndim == 3:
68-
grid = np.meshgrid(
69-
np.arange(self._field.shape[0]),
70-
np.arange(self._field.shape[1]),
71-
np.arange(self._field.shape[2]),
72-
indexing='ij')
73-
else:
74-
raise ValueError('Wrong dimensions (%d)' % ndim)
75-
76-
grid = np.array(grid)
77-
flatgrid = grid.reshape(ndim, -1)
78-
79-
# Calculate physical coords of all voxels (xyz)
80-
flatxyz = np.tensordot(
81-
self.reference.affine,
82-
np.vstack((flatgrid, np.ones((1, flatgrid.shape[1])))),
83-
axes=1
84-
)
85-
86-
# Add field
87-
newxyz = flatxyz + np.vstack((
88-
np.moveaxis(self._field, -1, 0).reshape(ndim, -1),
89-
np.zeros((1, flatgrid.shape[1]))))
90-
91-
# Back to grid coordinates
92-
newijk = np.tensordot(np.linalg.inv(moving.affine),
93-
newxyz, axes=1)
94-
95-
# Reshape as grid
96-
self._moving = np.moveaxis(
97-
newijk[0:3, :].reshape((ndim, ) + self._field.shape[:-1]),
98-
0, -1)
99-
100-
self._moving_space = moving_space
58+
Parameters
59+
----------
60+
x : N x D numpy.ndarray
61+
Input RAS+ coordinates (i.e., physical coordinates).
62+
inverse : bool
63+
If ``True``, apply the inverse transform :math:`x = f^{-1}(y)`.
64+
index : int, optional
65+
Transformation index
10166
102-
def resample(self, moving, order=3, mode='constant', cval=0.0, prefilter=True,
103-
output_dtype=None):
104-
"""
105-
Resample the ``moving`` image applying the deformation field.
67+
Returns
68+
-------
69+
y : N x D numpy.ndarray
70+
Transformed (mapped) RAS+ coordinates (i.e., physical coordinates).
10671
10772
Examples
10873
--------
109-
>>> ref = nb.load(testfile)
110-
>>> refdata = ref.get_fdata()
111-
>>> np.allclose(refdata, 0)
112-
True
113-
114-
>>> refdata[5, 5, 5] = 1 # Set a one in the middle voxel
115-
>>> moving = nb.Nifti1Image(refdata, ref.affine, ref.header)
116-
>>> field = np.zeros(tuple(list(ref.shape) + [3]))
74+
>>> field = np.zeros((10, 10, 10, 3))
11775
>>> field[..., 0] = 4.0
118-
>>> fieldimg = nb.Nifti1Image(field, ref.affine, ref.header)
76+
>>> fieldimg = nb.Nifti1Image(field, np.diag([2., 2., 2., 1.]))
11977
>>> xfm = DeformationFieldTransform(fieldimg)
120-
>>> resampled = xfm.resample(moving, order=0).get_fdata()
121-
>>> resampled[1, 5, 5]
122-
1.0
78+
>>> xfm([4.0, 4.0, 4.0]).tolist()
79+
[[8.0, 4.0, 4.0]]
12380
124-
"""
125-
self._cache_moving(moving)
126-
return super(DeformationFieldTransform, self).resample(
127-
moving, order=order, mode=mode, cval=cval, prefilter=prefilter)
128-
129-
def _map_voxel(self, index, moving=None):
130-
"""Apply ijk' = f_ijk((i, j, k)), equivalent to the above with indexes."""
131-
return tuple(self._moving[index + self.__s])
81+
>>> xfm([[4.0, 4.0, 4.0], [8, 2, 10]]).tolist()
82+
[[8.0, 4.0, 4.0], [12.0, 2.0, 10.0]]
13283
133-
def map(self, x, order=3, mode='mirror', cval=0.0, prefilter=True):
134-
"""Apply y = f(x), where x is the argument `coords`."""
135-
coordinates = np.array(x)
136-
# Extract shapes and dimensions, then flatten
137-
ndim = coordinates.shape[-1]
138-
output_shape = coordinates.shape[:-1]
139-
flatcoord = np.moveaxis(coordinates, -1, 0).reshape(ndim, -1)
140-
141-
# Convert coordinates to voxel indices
142-
ijk = np.tensordot(
143-
np.linalg.inv(self.reference.affine),
144-
np.vstack((flatcoord, np.ones((1, flatcoord.shape[1])))),
145-
axes=1)
146-
deltas = ndi.map_coordinates(
147-
self._field,
148-
ijk,
149-
order=order,
150-
mode=mode,
151-
cval=cval,
152-
prefilter=prefilter)
153-
154-
deltas = np.moveaxis(deltas[0:3, :].reshape((ndim, ) + output_shape),
155-
0, -1)
156-
157-
return coordinates + deltas
84+
"""
85+
if inverse is True:
86+
raise NotImplementedError
87+
ijk = self.reference.index(x)
88+
indexes = np.round(ijk).astype('int')
89+
if np.any(np.abs(ijk - indexes) > 0.05):
90+
print('Some coordinates are off-grid of the displacements field.',
91+
file=sys.stderr)
92+
indexes = tuple([tuple(i) for i in indexes.T])
93+
return x + self._field[indexes]
15894

15995

16096
class BSplineFieldTransform(TransformBase):
@@ -174,7 +110,7 @@ def __init__(self, reference, coefficients, order=3):
174110
'Number of components of the coefficients does '
175111
'not match the number of dimensions')
176112

177-
self._coeffs = coefficients.get_data()
113+
self._coeffs = np.asanyarray(coefficients.dataobj)
178114
self._knots = ImageGrid(four_to_three(coefficients)[0])
179115
self._cache_moving()
180116

nitransforms/tests/test_base.py

Lines changed: 12 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -6,12 +6,21 @@
66

77

88
@pytest.mark.parametrize('image_orientation', ['RAS', 'LAS', 'LPS', 'oblique'])
9-
def test_ImageGrid(get_data, image_orientation):
9+
def test_ImageGrid(get_testdata, image_orientation):
1010
"""Check the grid object."""
11-
im = get_data[image_orientation]
11+
im = get_testdata[image_orientation]
1212

1313
img = ImageGrid(im)
14-
assert np.all(img.affine == np.linalg.inv(img.inverse))
14+
assert np.allclose(img.affine, np.linalg.inv(img.inverse))
15+
16+
# Test ras2vox and vox2ras conversions
17+
ijk = [[10, 10, 10], [40, 4, 20], [0, 0, 0], [s - 1 for s in im.shape[:3]]]
18+
xyz = [img._affine.dot(idx + [1])[:-1] for idx in ijk]
19+
20+
assert np.allclose(img.ras(ijk[0]), xyz[0])
21+
assert np.allclose(np.round(img.index(xyz[0])), ijk[0])
22+
assert np.allclose(img.ras(ijk), xyz)
23+
assert np.allclose(np.round(img.index(xyz)), ijk)
1524

1625
# nd index / coords
1726
idxs = img.ndindex

0 commit comments

Comments
 (0)