Skip to content

Commit 9f916e4

Browse files
committed
Add constant field comparison test
1 parent c2ec9d9 commit 9f916e4

File tree

1 file changed

+39
-0
lines changed

1 file changed

+39
-0
lines changed

nitransforms/tests/test_nonlinear.py

Lines changed: 39 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -9,6 +9,7 @@
99

1010
import numpy as np
1111
import nibabel as nb
12+
from nibabel.affines import from_matvec
1213
from nitransforms.resampling import apply
1314
from nitransforms.base import TransformError
1415
from nitransforms.io.base import TransformFileError
@@ -273,3 +274,41 @@ def test_densefield_map_against_ants(data_path, tmp_path):
273274
mapped = xfm.map(points)
274275

275276
assert np.allclose(mapped, ants_pts, atol=1e-6)
277+
278+
279+
def test_constant_field_vs_ants(tmp_path):
280+
"""Create a constant displacement field and compare mappings."""
281+
282+
# Create a reference centered at the origin
283+
shape = (5, 5, 5)
284+
ref_affine = from_matvec(np.eye(3), -(np.array(shape) - 1) / 2)
285+
286+
field = np.zeros(shape + (3,), dtype="float32")
287+
field[..., 0] = -5
288+
field[..., 1] = 0
289+
field[..., 2] = 5
290+
291+
field_img = nb.Nifti1Image(field, ref_affine)
292+
itk_img = ITKDisplacementsField.to_image(field_img)
293+
warpfile = tmp_path / "const_disp.nii.gz"
294+
itk_img.to_filename(warpfile)
295+
296+
points = np.array([[0.0, 0.0, 0.0], [1.0, 2.0, 3.0]])
297+
csvin = tmp_path / "points.csv"
298+
np.savetxt(csvin, points, delimiter=",", header="x,y,z", comments="")
299+
300+
csvout = tmp_path / "out.csv"
301+
cmd = f"antsApplyTransformsToPoints -d 3 -i {csvin} -o {csvout} -t {warpfile}"
302+
exe = cmd.split()[0]
303+
if not shutil.which(exe):
304+
pytest.skip(f"Command {exe} not found on host")
305+
check_call(cmd, shell=True)
306+
307+
ants_res = np.genfromtxt(csvout, delimiter=",", names=True)
308+
ants_pts = np.vstack([ants_res[n] for n in ("x", "y", "z")]).T
309+
310+
xfm = DenseFieldTransform(warpfile)
311+
mapped = xfm.map(points)
312+
313+
assert not np.allclose(mapped, ants_pts, atol=1e-6)
314+
assert np.allclose(mapped - ants_pts, [-10.0, 0.0, 0.0])

0 commit comments

Comments
 (0)