Skip to content

Commit 3e15102

Browse files
committed
test: compare composite h5 mapping with ANTs
1 parent 0ae771d commit 3e15102

File tree

1 file changed

+55
-0
lines changed

1 file changed

+55
-0
lines changed

nitransforms/tests/test_io.py

Lines changed: 55 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -15,6 +15,7 @@
1515
from nibabel.affines import from_matvec
1616
from scipy.io import loadmat
1717
from nitransforms.linear import Affine
18+
from nitransforms import nonlinear as nitnl, linear as nitl
1819
from nitransforms.io import (
1920
afni,
2021
fsl,
@@ -778,3 +779,57 @@ def test_itk_h5_field_order_fortran(tmp_path):
778779
expected = np.moveaxis(field, 0, -1)
779780
expected[..., (0, 1)] *= -1
780781
assert np.allclose(img.get_fdata(), expected)
782+
783+
784+
def test_composite_h5_map_against_ants(tmp_path):
785+
"""Map points with NiTransforms and compare to ANTs."""
786+
shape = (2, 2, 2)
787+
disp = np.zeros(shape + (3,), dtype=float)
788+
disp += np.array([0.2, -0.3, 0.4])
789+
790+
params = np.moveaxis(disp, -1, 0).reshape(-1, order="F")
791+
fixed = np.array(
792+
list(shape) + [0, 0, 0] + [1, 1, 1] + list(np.eye(3).ravel()), dtype=float
793+
)
794+
fname = tmp_path / "test.h5"
795+
with H5File(fname, "w") as f:
796+
grp = f.create_group("TransformGroup")
797+
grp.create_group("0")["TransformType"] = np.array(
798+
[b"CompositeTransform_double_3_3"]
799+
)
800+
g1 = grp.create_group("1")
801+
g1["TransformType"] = np.array([b"DisplacementFieldTransform_float_3_3"])
802+
g1["TransformFixedParameters"] = fixed
803+
g1["TransformParameters"] = params
804+
g2 = grp.create_group("2")
805+
g2["TransformType"] = np.array([b"AffineTransform_double_3_3"])
806+
g2["TransformFixedParameters"] = np.zeros(3, dtype=float)
807+
g2["TransformParameters"] = np.array(
808+
[1, 0, 0, 0, 1, 0, 0, 0, 1, 0.1, 0.2, -0.1], dtype=float
809+
)
810+
811+
points = np.array([[0.0, 0.0, 0.0], [1.0, 2.0, 3.0]])
812+
csvin = tmp_path / "points.csv"
813+
with open(csvin, "w") as f:
814+
f.write("x,y,z\n")
815+
for row in points:
816+
f.write(",".join(map(str, row)) + "\n")
817+
818+
csvout = tmp_path / "out.csv"
819+
cmd = (
820+
f"antsApplyTransformsToPoints -d 3 -i {csvin} -o {csvout} -t {fname}"
821+
)
822+
exe = cmd.split()[0]
823+
if not shutil.which(exe):
824+
pytest.skip(f"Command {exe} not found on host")
825+
check_call(cmd, shell=True)
826+
827+
ants_res = np.genfromtxt(csvout, delimiter=",", names=True)
828+
ants_pts = np.vstack([ants_res[n] for n in ("x", "y", "z")]).T
829+
830+
xforms = itk.ITKCompositeH5.from_filename(fname)
831+
dfield = nitnl.DenseFieldTransform(xforms[0])
832+
affine = nitl.Affine(xforms[1].to_ras())
833+
mapped = (affine @ dfield).map(points)
834+
835+
assert np.allclose(mapped, ants_pts, atol=1e-6)

0 commit comments

Comments
 (0)