Skip to content

Commit 961a020

Browse files
committed
wip
1 parent 2c62fd9 commit 961a020

File tree

1 file changed

+79
-77
lines changed

1 file changed

+79
-77
lines changed

nitransforms/tests/test_nonlinear.py

Lines changed: 79 additions & 77 deletions
Original file line numberDiff line numberDiff line change
@@ -219,84 +219,8 @@ def test_densefield_map(get_testdata, image_orientation, ongrid):
219219
) < 0.5
220220

221221

222-
@pytest.mark.parametrize("is_deltas", [True, False])
223-
def test_densefield_oob_resampling(is_deltas):
224-
"""Ensure mapping outside the field returns input coordinates."""
225-
ref = nb.Nifti1Image(np.zeros((2, 2, 2), dtype="uint8"), np.eye(4))
226-
227-
if is_deltas:
228-
field = nb.Nifti1Image(np.ones((2, 2, 2, 3), dtype="float32"), np.eye(4))
229-
else:
230-
grid = np.stack(
231-
np.meshgrid(*[np.arange(2) for _ in range(3)], indexing="ij"),
232-
axis=-1,
233-
).astype("float32")
234-
field = nb.Nifti1Image(grid + 1.0, np.eye(4))
235-
236-
xfm = DenseFieldTransform(field, is_deltas=is_deltas, reference=ref)
237-
238-
points = np.array([[-1.0, -1.0, -1.0], [0.5, 0.5, 0.5], [3.0, 3.0, 3.0]])
239-
mapped = xfm.map(points)
240-
241-
assert np.allclose(mapped[0], points[0])
242-
assert np.allclose(mapped[2], points[2])
243-
assert np.allclose(mapped[1], points[1] + 1)
244-
245-
246-
def test_bspline_map_gridpoints():
247-
"""BSpline mapping matches dense field on grid points."""
248-
ref = nb.Nifti1Image(np.zeros((5, 5, 5), dtype="uint8"), np.eye(4))
249-
coeff = nb.Nifti1Image(
250-
np.random.RandomState(0).rand(9, 9, 9, 3).astype("float32"), np.eye(4)
251-
)
252-
253-
bspline = BSplineFieldTransform(coeff, reference=ref)
254-
dense = bspline.to_field()
255-
256-
# Use a couple of voxel centers from the reference grid
257-
ijk = np.array([[1, 1, 1], [2, 3, 0]])
258-
pts = nb.affines.apply_affine(ref.affine, ijk)
259-
260-
assert np.allclose(bspline.map(pts), dense.map(pts), atol=1e-6)
261222

262-
263-
def test_bspline_map_manual():
264-
"""BSpline interpolation agrees with manual computation."""
265-
ref = nb.Nifti1Image(np.zeros((5, 5, 5), dtype="uint8"), np.eye(4))
266-
rng = np.random.RandomState(0)
267-
coeff = nb.Nifti1Image(rng.rand(9, 9, 9, 3).astype("float32"), np.eye(4))
268-
269-
bspline = BSplineFieldTransform(coeff, reference=ref)
270-
271-
from nitransforms.base import _as_homogeneous
272-
from nitransforms.interp.bspline import _cubic_bspline
273-
274-
def manual_map(x):
275-
ijk = (bspline._knots.inverse @ _as_homogeneous(x).squeeze())[:3]
276-
w_start = np.floor(ijk).astype(int) - 1
277-
w_end = w_start + 3
278-
w_start = np.maximum(w_start, 0)
279-
w_end = np.minimum(w_end, np.array(bspline._coeffs.shape[:3]) - 1)
280-
281-
window = []
282-
for i in range(w_start[0], w_end[0] + 1):
283-
for j in range(w_start[1], w_end[1] + 1):
284-
for k in range(w_start[2], w_end[2] + 1):
285-
window.append([i, j, k])
286-
window = np.array(window)
287-
288-
dist = np.abs(window - ijk)
289-
weights = _cubic_bspline(dist).prod(1)
290-
coeffs = bspline._coeffs[window[:, 0], window[:, 1], window[:, 2]]
291-
292-
return x + coeffs.T @ weights
293-
294-
pts = np.array([[1.2, 1.5, 2.0], [3.3, 1.7, 2.4]])
295-
expected = np.vstack([manual_map(p) for p in pts])
296-
assert np.allclose(bspline.map(pts), expected, atol=1e-6)
297-
298-
299-
def test_densefield_map_against_ants(testdata_path, tmp_path):
223+
def test_densefield_map_vs_ants(testdata_path, tmp_path):
300224
"""Map points with DenseFieldTransform and compare to ANTs."""
301225
warpfile = (
302226
testdata_path
@@ -419,3 +343,81 @@ def test_constant_field_vs_ants(tmp_path, get_testdata, image_orientation, gridp
419343
mask = np.argwhere(np.abs(diff) > 1e-2)[:, 0]
420344

421345
assert len(mask) == 0, f"A total of {len(mask)}/{ants_pts.shape[0]} contained errors:\n{diff[mask]}"
346+
347+
348+
349+
@pytest.mark.parametrize("is_deltas", [True, False])
350+
def test_densefield_oob_resampling(is_deltas):
351+
"""Ensure mapping outside the field returns input coordinates."""
352+
ref = nb.Nifti1Image(np.zeros((2, 2, 2), dtype="uint8"), np.eye(4))
353+
354+
if is_deltas:
355+
field = nb.Nifti1Image(np.ones((2, 2, 2, 3), dtype="float32"), np.eye(4))
356+
else:
357+
grid = np.stack(
358+
np.meshgrid(*[np.arange(2) for _ in range(3)], indexing="ij"),
359+
axis=-1,
360+
).astype("float32")
361+
field = nb.Nifti1Image(grid + 1.0, np.eye(4))
362+
363+
xfm = DenseFieldTransform(field, is_deltas=is_deltas, reference=ref)
364+
365+
points = np.array([[-1.0, -1.0, -1.0], [0.5, 0.5, 0.5], [3.0, 3.0, 3.0]])
366+
mapped = xfm.map(points)
367+
368+
assert np.allclose(mapped[0], points[0])
369+
assert np.allclose(mapped[2], points[2])
370+
assert np.allclose(mapped[1], points[1] + 1)
371+
372+
373+
def test_bspline_map_gridpoints():
374+
"""BSpline mapping matches dense field on grid points."""
375+
ref = nb.Nifti1Image(np.zeros((5, 5, 5), dtype="uint8"), np.eye(4))
376+
coeff = nb.Nifti1Image(
377+
np.random.RandomState(0).rand(9, 9, 9, 3).astype("float32"), np.eye(4)
378+
)
379+
380+
bspline = BSplineFieldTransform(coeff, reference=ref)
381+
dense = bspline.to_field()
382+
383+
# Use a couple of voxel centers from the reference grid
384+
ijk = np.array([[1, 1, 1], [2, 3, 0]])
385+
pts = nb.affines.apply_affine(ref.affine, ijk)
386+
387+
assert np.allclose(bspline.map(pts), dense.map(pts), atol=1e-6)
388+
389+
390+
def test_bspline_map_manual():
391+
"""BSpline interpolation agrees with manual computation."""
392+
ref = nb.Nifti1Image(np.zeros((5, 5, 5), dtype="uint8"), np.eye(4))
393+
rng = np.random.RandomState(0)
394+
coeff = nb.Nifti1Image(rng.rand(9, 9, 9, 3).astype("float32"), np.eye(4))
395+
396+
bspline = BSplineFieldTransform(coeff, reference=ref)
397+
398+
from nitransforms.base import _as_homogeneous
399+
from nitransforms.interp.bspline import _cubic_bspline
400+
401+
def manual_map(x):
402+
ijk = (bspline._knots.inverse @ _as_homogeneous(x).squeeze())[:3]
403+
w_start = np.floor(ijk).astype(int) - 1
404+
w_end = w_start + 3
405+
w_start = np.maximum(w_start, 0)
406+
w_end = np.minimum(w_end, np.array(bspline._coeffs.shape[:3]) - 1)
407+
408+
window = []
409+
for i in range(w_start[0], w_end[0] + 1):
410+
for j in range(w_start[1], w_end[1] + 1):
411+
for k in range(w_start[2], w_end[2] + 1):
412+
window.append([i, j, k])
413+
window = np.array(window)
414+
415+
dist = np.abs(window - ijk)
416+
weights = _cubic_bspline(dist).prod(1)
417+
coeffs = bspline._coeffs[window[:, 0], window[:, 1], window[:, 2]]
418+
419+
return x + coeffs.T @ weights
420+
421+
pts = np.array([[1.2, 1.5, 2.0], [3.3, 1.7, 2.4]])
422+
expected = np.vstack([manual_map(p) for p in pts])
423+
assert np.allclose(bspline.map(pts), expected, atol=1e-6)

0 commit comments

Comments
 (0)