Skip to content

Commit d9b1e12

Browse files
committed
wip
1 parent 4df1194 commit d9b1e12

File tree

5 files changed

+179
-92
lines changed

5 files changed

+179
-92
lines changed

nitransforms/io/itk.py

Lines changed: 6 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -347,7 +347,9 @@ def from_image(cls, imgobj):
347347
warnings.warn("Incorrect intent identified.")
348348
hdr.set_intent("vector")
349349

350-
field = np.squeeze(np.asanyarray(imgobj.dataobj)).transpose(2, 1, 0, 3)
350+
field = np.squeeze(np.asanyarray(imgobj.dataobj))
351+
field[..., (0, 1)] *= 1.0
352+
field = field.transpose(2, 1, 0, 3)
351353
return imgobj.__class__(field, LPS @ imgobj.affine, hdr)
352354

353355
@classmethod
@@ -357,7 +359,9 @@ def to_image(cls, imgobj):
357359
hdr = imgobj.header.copy()
358360
hdr.set_intent("vector")
359361

360-
field = imgobj.get_fdata().transpose(2, 1, 0, 3)[..., None, :]
362+
field = imgobj.get_fdata()
363+
field = field.transpose(2, 1, 0, 3)[..., None, :]
364+
field[..., (0, 1)] *= 1.0
361365
return imgobj.__class__(field, LPS @ imgobj.affine, hdr)
362366

363367

nitransforms/nonlinear.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -188,6 +188,7 @@ def map(self, x, inverse=False):
188188
ijk = self.reference.index(x)
189189
indexes = np.round(ijk).astype("int")
190190

191+
import pdb; pdb.set_trace()
191192
if np.all(np.abs(ijk - indexes) < 1e-3):
192193
indexes = tuple(tuple(i) for i in indexes)
193194
return self._field[indexes]

nitransforms/tests/test_io.py

Lines changed: 13 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -710,7 +710,8 @@ def test_itk_disp_load_intent():
710710

711711
# Added tests for displacements fields orientations (ANTs/ITK)
712712
@pytest.mark.parametrize("image_orientation", ["RAS", "LAS", "LPS", "oblique"])
713-
def test_itk_displacements(tmp_path, get_testdata, image_orientation):
713+
@pytest.mark.parametrize("field_is_random", [False, True])
714+
def test_itk_displacements(tmp_path, get_testdata, image_orientation, field_is_random):
714715
"""Exercise I/O of ITK displacements fields."""
715716

716717
nii = get_testdata[image_orientation]
@@ -719,13 +720,17 @@ def test_itk_displacements(tmp_path, get_testdata, image_orientation):
719720
shape = nii.shape
720721
ref_affine = nii.affine.copy()
721722

722-
field = np.hstack(
723-
(
724-
np.linspace(-50, 50, num=np.prod(shape)),
725-
np.linspace(-80, 80, num=np.prod(shape)),
726-
np.zeros(np.prod(shape)),
727-
)
728-
).reshape(shape + (3,))
723+
field = (
724+
np.hstack(
725+
(
726+
np.linspace(-50, 50, num=np.prod(shape)),
727+
np.linspace(-80, 80, num=np.prod(shape)),
728+
np.zeros(np.prod(shape)),
729+
)
730+
).reshape(shape + (3,))
731+
if not field_is_random
732+
else np.random.normal(size=shape + (3,))
733+
)
729734

730735
nit_nii = itk.ITKDisplacementsField.to_image(
731736
nb.Nifti1Image(field, ref_affine, None)

nitransforms/tests/test_nonlinear.py

Lines changed: 155 additions & 81 deletions
Original file line numberDiff line numberDiff line change
@@ -3,10 +3,15 @@
33
"""Tests of nonlinear transforms."""
44

55
import os
6+
from subprocess import check_call
7+
import shutil
8+
9+
import SimpleITK as sitk
610
import pytest
711

812
import numpy as np
913
import nibabel as nb
14+
from nibabel.affines import from_matvec
1015
from nitransforms.resampling import apply
1116
from nitransforms.base import TransformError
1217
from nitransforms.io.base import TransformFileError
@@ -15,7 +20,7 @@
1520
DenseFieldTransform,
1621
)
1722
from nitransforms import io
18-
from ..io.itk import ITKDisplacementsField
23+
from nitransforms.io.itk import ITKDisplacementsField
1924

2025

2126
@pytest.mark.parametrize("size", [(20, 20, 20), (20, 20, 20, 3)])
@@ -34,16 +39,6 @@ def test_displacements_bad_sizes(size):
3439
DenseFieldTransform(nb.Nifti1Image(np.zeros(size), np.eye(4), None))
3540

3641

37-
def test_itk_disp_load_intent():
38-
"""Checks whether the NIfTI intent is fixed."""
39-
with pytest.warns(UserWarning):
40-
field = ITKDisplacementsField.from_image(
41-
nb.Nifti1Image(np.zeros((20, 20, 20, 1, 3)), np.eye(4), None)
42-
)
43-
44-
assert field.header.get_intent()[0] == "vector"
45-
46-
4742
def test_displacements_init():
4843
identity1 = DenseFieldTransform(
4944
np.zeros((10, 10, 10, 3)),
@@ -67,6 +62,30 @@ def test_displacements_init():
6762
)
6863

6964

65+
@pytest.mark.parametrize("is_deltas", [True, False])
66+
def test_densefield_oob_resampling(is_deltas):
67+
"""Ensure mapping outside the field returns input coordinates."""
68+
ref = nb.Nifti1Image(np.zeros((2, 2, 2), dtype="uint8"), np.eye(4))
69+
70+
if is_deltas:
71+
field = nb.Nifti1Image(np.ones((2, 2, 2, 3), dtype="float32"), np.eye(4))
72+
else:
73+
grid = np.stack(
74+
np.meshgrid(*[np.arange(2) for _ in range(3)], indexing="ij"),
75+
axis=-1,
76+
).astype("float32")
77+
field = nb.Nifti1Image(grid + 1.0, np.eye(4))
78+
79+
xfm = DenseFieldTransform(field, is_deltas=is_deltas, reference=ref)
80+
81+
points = np.array([[-1.0, -1.0, -1.0], [0.5, 0.5, 0.5], [3.0, 3.0, 3.0]])
82+
mapped = xfm.map(points)
83+
84+
assert np.allclose(mapped[0], points[0])
85+
assert np.allclose(mapped[2], points[2])
86+
assert np.allclose(mapped[1], points[1] + 1)
87+
88+
7089
def test_bsplines_init():
7190
with pytest.raises(TransformError):
7291
BSplineFieldTransform(
@@ -122,76 +141,6 @@ def test_bspline(tmp_path, testdata_path):
122141
)
123142

124143

125-
@pytest.mark.parametrize("is_deltas", [True, False])
126-
def test_densefield_x5_roundtrip(tmp_path, is_deltas):
127-
"""Ensure dense field transforms roundtrip via X5."""
128-
ref = nb.Nifti1Image(np.zeros((2, 2, 2), dtype="uint8"), np.eye(4))
129-
disp = nb.Nifti1Image(np.random.rand(2, 2, 2, 3).astype("float32"), np.eye(4))
130-
131-
xfm = DenseFieldTransform(disp, is_deltas=is_deltas, reference=ref)
132-
133-
node = xfm.to_x5(metadata={"GeneratedBy": "pytest"})
134-
assert node.type == "nonlinear"
135-
assert node.subtype == "densefield"
136-
assert node.representation == "displacements" if is_deltas else "deformations"
137-
assert node.domain.size == ref.shape
138-
assert node.metadata["GeneratedBy"] == "pytest"
139-
140-
fname = tmp_path / "test.x5"
141-
io.x5.to_filename(fname, [node])
142-
143-
xfm2 = DenseFieldTransform.from_filename(fname, fmt="X5")
144-
145-
assert xfm2.reference.shape == ref.shape
146-
assert np.allclose(xfm2.reference.affine, ref.affine)
147-
assert xfm == xfm2
148-
149-
150-
def test_bspline_to_x5(tmp_path):
151-
"""Check BSpline transforms export to X5."""
152-
coeff = nb.Nifti1Image(np.zeros((2, 2, 2, 3), dtype="float32"), np.eye(4))
153-
ref = nb.Nifti1Image(np.zeros((2, 2, 2), dtype="uint8"), np.eye(4))
154-
155-
xfm = BSplineFieldTransform(coeff, reference=ref)
156-
node = xfm.to_x5(metadata={"tool": "pytest"})
157-
assert node.type == "nonlinear"
158-
assert node.subtype == "bspline"
159-
assert node.representation == "coefficients"
160-
assert node.metadata["tool"] == "pytest"
161-
162-
fname = tmp_path / "bspline.x5"
163-
io.x5.to_filename(fname, [node])
164-
165-
xfm2 = BSplineFieldTransform.from_filename(fname, fmt="X5")
166-
assert np.allclose(xfm._coeffs, xfm2._coeffs)
167-
assert xfm2.reference.shape == ref.shape
168-
assert np.allclose(xfm2.reference.affine, ref.affine)
169-
170-
171-
@pytest.mark.parametrize("is_deltas", [True, False])
172-
def test_densefield_oob_resampling(is_deltas):
173-
"""Ensure mapping outside the field returns input coordinates."""
174-
ref = nb.Nifti1Image(np.zeros((2, 2, 2), dtype="uint8"), np.eye(4))
175-
176-
if is_deltas:
177-
field = nb.Nifti1Image(np.ones((2, 2, 2, 3), dtype="float32"), np.eye(4))
178-
else:
179-
grid = np.stack(
180-
np.meshgrid(*[np.arange(2) for _ in range(3)], indexing="ij"),
181-
axis=-1,
182-
).astype("float32")
183-
field = nb.Nifti1Image(grid + 1.0, np.eye(4))
184-
185-
xfm = DenseFieldTransform(field, is_deltas=is_deltas, reference=ref)
186-
187-
points = np.array([[-1.0, -1.0, -1.0], [0.5, 0.5, 0.5], [3.0, 3.0, 3.0]])
188-
mapped = xfm.map(points)
189-
190-
assert np.allclose(mapped[0], points[0])
191-
assert np.allclose(mapped[2], points[2])
192-
assert np.allclose(mapped[1], points[1] + 1)
193-
194-
195144
def test_bspline_map_gridpoints():
196145
"""BSpline mapping matches dense field on grid points."""
197146
ref = nb.Nifti1Image(np.zeros((5, 5, 5), dtype="uint8"), np.eye(4))
@@ -243,3 +192,128 @@ def manual_map(x):
243192
pts = np.array([[1.2, 1.5, 2.0], [3.3, 1.7, 2.4]])
244193
expected = np.vstack([manual_map(p) for p in pts])
245194
assert np.allclose(bspline.map(pts), expected, atol=1e-6)
195+
196+
197+
def test_densefield_map_against_ants(testdata_path, tmp_path):
198+
"""Map points with DenseFieldTransform and compare to ANTs."""
199+
warpfile = (
200+
testdata_path
201+
/ "regressions"
202+
/ ("01_ants_t1_to_mniComposite_DisplacementFieldTransform.nii.gz")
203+
)
204+
if not warpfile.exists():
205+
pytest.skip("Composite transform test data not available")
206+
207+
points = np.array(
208+
[
209+
[0.0, 0.0, 0.0],
210+
[1.0, 2.0, 3.0],
211+
[10.0, -10.0, 5.0],
212+
[-5.0, 7.0, -2.0],
213+
[-12.0, 12.0, 0.0],
214+
]
215+
)
216+
csvin = tmp_path / "points.csv"
217+
np.savetxt(csvin, points, delimiter=",", header="x,y,z", comments="")
218+
219+
csvout = tmp_path / "out.csv"
220+
cmd = f"antsApplyTransformsToPoints -d 3 -i {csvin} -o {csvout} -t {warpfile}"
221+
exe = cmd.split()[0]
222+
if not shutil.which(exe):
223+
pytest.skip(f"Command {exe} not found on host")
224+
check_call(cmd, shell=True)
225+
226+
ants_res = np.genfromtxt(csvout, delimiter=",", names=True)
227+
ants_pts = np.vstack([ants_res[n] for n in ("x", "y", "z")]).T
228+
229+
xfm = DenseFieldTransform(ITKDisplacementsField.from_filename(warpfile))
230+
mapped = xfm.map(points)
231+
232+
assert np.allclose(mapped, ants_pts, atol=1e-6)
233+
234+
235+
@pytest.mark.parametrize("image_orientation", ["RAS", "LAS", "LPS", "oblique"])
236+
@pytest.mark.parametrize("gridpoints", [True, False])
237+
def test_constant_field_vs_ants(tmp_path, get_testdata, image_orientation, gridpoints):
238+
"""Create a constant displacement field and compare mappings."""
239+
240+
nii = get_testdata[image_orientation]
241+
242+
# Create a reference centered at the origin with various axis orders/flips
243+
shape = nii.shape
244+
ref_affine = nii.affine.copy()
245+
246+
field = np.hstack((
247+
np.zeros(np.prod(shape)),
248+
np.linspace(-80, 80, num=np.prod(shape)),
249+
np.linspace(-50, 50, num=np.prod(shape)),
250+
)).reshape(shape + (3, ))
251+
fieldnii = nb.Nifti1Image(field, ref_affine, None)
252+
253+
warpfile = tmp_path / "itk_transform.nii.gz"
254+
ITKDisplacementsField.to_filename(fieldnii, warpfile)
255+
256+
# Ensure direct (xfm) and ITK roundtrip (itk_xfm) are equivalent
257+
xfm = DenseFieldTransform(fieldnii)
258+
itk_xfm = DenseFieldTransform(ITKDisplacementsField.from_filename(warpfile))
259+
260+
assert xfm == itk_xfm
261+
np.testing.assert_allclose(xfm.reference.affine, itk_xfm.reference.affine)
262+
np.testing.assert_allclose(ref_affine, itk_xfm.reference.affine)
263+
np.testing.assert_allclose(xfm.reference.shape, itk_xfm.reference.shape)
264+
np.testing.assert_allclose(xfm._field, itk_xfm._field)
265+
266+
points = (
267+
xfm.reference.ndcoords.T if gridpoints
268+
else np.array(
269+
[
270+
[0.0, 0.0, 0.0],
271+
[1.0, 2.0, 3.0],
272+
[10.0, -10.0, 5.0],
273+
[-5.0, 7.0, -2.0],
274+
[12.0, 0.0, -11.0],
275+
]
276+
)
277+
)
278+
279+
mapped = xfm.map(points)
280+
nit_deltas = mapped - points
281+
282+
if gridpoints:
283+
np.testing.assert_array_equal(field, nit_deltas.reshape(*shape, -1))
284+
285+
csvin = tmp_path / "points.csv"
286+
np.savetxt(csvin, points, delimiter=",", header="x,y,z", comments="")
287+
288+
csvout = tmp_path / "out.csv"
289+
cmd = f"antsApplyTransformsToPoints -d 3 -i {csvin} -o {csvout} -t {warpfile}"
290+
exe = cmd.split()[0]
291+
if not shutil.which(exe):
292+
pytest.skip(f"Command {exe} not found on host")
293+
check_call(cmd, shell=True)
294+
295+
ants_res = np.genfromtxt(csvout, delimiter=",", names=True)
296+
ants_pts = np.vstack([ants_res[n] for n in ("x", "y", "z")]).T
297+
298+
# if gridpoints:
299+
# ants_field = ants_pts.reshape(shape + (3, ))
300+
# diff = xfm._field[..., 0] - ants_field[..., 0]
301+
# mask = np.argwhere(np.abs(diff) > 1e-2)[:, 0]
302+
# assert len(mask) == 0, f"A total of {len(mask)}/{ants_pts.shape[0]} contained errors:\n{diff[mask]}"
303+
304+
# diff = xfm._field[..., 1] - ants_field[..., 1]
305+
# mask = np.argwhere(np.abs(diff) > 1e-2)[:, 0]
306+
# assert len(mask) == 0, f"A total of {len(mask)}/{ants_pts.shape[0]} contained errors:\n{diff[mask]}"
307+
308+
# diff = xfm._field[..., 2] - ants_field[..., 2]
309+
# mask = np.argwhere(np.abs(diff) > 1e-2)[:, 0]
310+
# assert len(mask) == 0, f"A total of {len(mask)}/{ants_pts.shape[0]} contained errors:\n{diff[mask]}"
311+
312+
ants_deltas = ants_pts - points
313+
np.testing.assert_array_equal(nit_deltas, ants_deltas)
314+
np.testing.assert_array_equal(mapped, ants_pts)
315+
316+
diff = mapped - ants_pts
317+
mask = np.argwhere(np.abs(diff) > 1e-2)[:, 0]
318+
319+
assert len(mask) == 0, f"A total of {len(mask)}/{ants_pts.shape[0]} contained errors:\n{diff[mask]}"

nitransforms/tests/test_resampling.py

Lines changed: 4 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -188,6 +188,8 @@ def test_displacements_field1(
188188

189189
xfm = nitnl.load(xfm_fname, fmt=sw_tool)
190190

191+
import pdb; pdb.set_trace()
192+
191193
# Then apply the transform and cross-check with software
192194
cmd = APPLY_NONLINEAR_CMD[sw_tool](
193195
transform=os.path.abspath(xfm_fname),
@@ -243,7 +245,7 @@ def test_displacements_field1(
243245
assert np.sqrt((diff[5:-5, 5:-5, 5:-5] ** 2).mean()) < 1e-6
244246

245247

246-
@pytest.mark.parametrize("sw_tool", ["itk", "afni"])
248+
@pytest.mark.parametrize("sw_tool", ["afni"])
247249
def test_displacements_field2(tmp_path, testdata_path, sw_tool):
248250
"""Check a translation-only field on one or more axes, different image orientations."""
249251
os.chdir(str(tmp_path))
@@ -275,6 +277,7 @@ def test_displacements_field2(tmp_path, testdata_path, sw_tool):
275277
nt_moved = apply(xfm, img_fname, order=0)
276278
nt_moved.to_filename("nt_resampled.nii.gz")
277279
sw_moved.set_data_dtype(nt_moved.get_data_dtype())
280+
278281
diff = np.asanyarray(
279282
sw_moved.dataobj, dtype=sw_moved.get_data_dtype()
280283
) - np.asanyarray(nt_moved.dataobj, dtype=nt_moved.get_data_dtype())

0 commit comments

Comments
 (0)