Skip to content

Commit 7cc3c48

Browse files
committed
enh: update tests
1 parent 314b69f commit 7cc3c48

File tree

3 files changed

+113
-90
lines changed

3 files changed

+113
-90
lines changed

nitransforms/tests/test_io_itk.py

Lines changed: 41 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,41 @@
1+
# emacs: -*- mode: python; py-indent-offset: 4; indent-tabs-mode: nil -*-
2+
# vi: set ft=python sts=4 ts=4 sw=4 et:
3+
"""Test io module for ITK transforms."""
4+
5+
import pytest
6+
7+
import numpy as np
8+
import nibabel as nb
9+
10+
from nitransforms.base import TransformError
11+
from nitransforms.io.base import TransformFileError
12+
from nitransforms.io.itk import ITKDisplacementsField
13+
from nitransforms.nonlinear import (
14+
DenseFieldTransform,
15+
)
16+
17+
18+
@pytest.mark.parametrize("size", [(20, 20, 20), (20, 20, 20, 3)])
19+
def test_itk_disp_load(size):
20+
"""Checks field sizes."""
21+
with pytest.raises(TransformFileError):
22+
ITKDisplacementsField.from_image(
23+
nb.Nifti1Image(np.zeros(size), np.eye(4), None)
24+
)
25+
26+
27+
@pytest.mark.parametrize("size", [(20, 20, 20), (20, 20, 20, 2, 3), (20, 20, 20, 1, 4)])
28+
def test_displacements_bad_sizes(size):
29+
"""Checks field sizes."""
30+
with pytest.raises(TransformError):
31+
DenseFieldTransform(nb.Nifti1Image(np.zeros(size), np.eye(4), None))
32+
33+
34+
def test_itk_disp_load_intent():
35+
"""Checks whether the NIfTI intent is fixed."""
36+
with pytest.warns(UserWarning):
37+
field = ITKDisplacementsField.from_image(
38+
nb.Nifti1Image(np.zeros((20, 20, 20, 1, 3)), np.eye(4), None)
39+
)
40+
41+
assert field.header.get_intent()[0] == "vector"

nitransforms/tests/test_nonlinear.py

Lines changed: 23 additions & 89 deletions
Original file line numberDiff line numberDiff line change
@@ -9,39 +9,10 @@
99
import nibabel as nb
1010
from nitransforms.resampling import apply
1111
from nitransforms.base import TransformError
12-
from nitransforms.io.base import TransformFileError
1312
from nitransforms.nonlinear import (
1413
BSplineFieldTransform,
1514
DenseFieldTransform,
1615
)
17-
from nitransforms import io
18-
from ..io.itk import ITKDisplacementsField
19-
20-
21-
@pytest.mark.parametrize("size", [(20, 20, 20), (20, 20, 20, 3)])
22-
def test_itk_disp_load(size):
23-
"""Checks field sizes."""
24-
with pytest.raises(TransformFileError):
25-
ITKDisplacementsField.from_image(
26-
nb.Nifti1Image(np.zeros(size), np.eye(4), None)
27-
)
28-
29-
30-
@pytest.mark.parametrize("size", [(20, 20, 20), (20, 20, 20, 2, 3), (20, 20, 20, 1, 4)])
31-
def test_displacements_bad_sizes(size):
32-
"""Checks field sizes."""
33-
with pytest.raises(TransformError):
34-
DenseFieldTransform(nb.Nifti1Image(np.zeros(size), np.eye(4), None))
35-
36-
37-
def test_itk_disp_load_intent():
38-
"""Checks whether the NIfTI intent is fixed."""
39-
with pytest.warns(UserWarning):
40-
field = ITKDisplacementsField.from_image(
41-
nb.Nifti1Image(np.zeros((20, 20, 20, 1, 3)), np.eye(4), None)
42-
)
43-
44-
assert field.header.get_intent()[0] == "vector"
4516

4617

4718
def test_displacements_init():
@@ -96,76 +67,39 @@ def test_bsplines_references(testdata_path):
9667
)
9768

9869

70+
@pytest.mark.xfail(
71+
reason="Disable while #266 is developed.",
72+
strict=False,
73+
)
9974
def test_bspline(tmp_path, testdata_path):
75+
"""
76+
Cross-check B-Splines and deformation field.
77+
78+
This test is disabled and will be split into two separate tests.
79+
The current implementation will be moved into test_resampling.py,
80+
since that's what it actually tests.
81+
82+
In GH-266, this test will be re-implemented by testing the equivalence
83+
of the B-Spline and deformation field transforms by calling the
84+
transform's `map()` method on points.
85+
86+
"""
87+
assert True
88+
89+
90+
def test_map_bspline_vs_displacement(tmp_path, testdata_path):
10091
"""Cross-check B-Splines and deformation field."""
10192
os.chdir(str(tmp_path))
10293

10394
img_name = testdata_path / "someones_anatomy.nii.gz"
10495
disp_name = testdata_path / "someones_displacement_field.nii.gz"
10596
bs_name = testdata_path / "someones_bspline_coefficients.nii.gz"
10697

107-
bsplxfm = BSplineFieldTransform(bs_name, reference=img_name)
98+
bsplxfm = BSplineFieldTransform(bs_name, reference=img_name).to_field()
10899
dispxfm = DenseFieldTransform(disp_name)
109100

110-
out_disp = apply(dispxfm, img_name)
111-
out_bspl = apply(bsplxfm, img_name)
112-
113-
out_disp.to_filename("resampled_field.nii.gz")
114-
out_bspl.to_filename("resampled_bsplines.nii.gz")
115-
116-
assert (
117-
np.sqrt(
118-
(out_disp.get_fdata(dtype="float32") - out_bspl.get_fdata(dtype="float32"))
119-
** 2
120-
).mean()
121-
< 0.2
122-
)
123-
124-
125-
@pytest.mark.parametrize("is_deltas", [True, False])
126-
def test_densefield_x5_roundtrip(tmp_path, is_deltas):
127-
"""Ensure dense field transforms roundtrip via X5."""
128-
ref = nb.Nifti1Image(np.zeros((2, 2, 2), dtype="uint8"), np.eye(4))
129-
disp = nb.Nifti1Image(np.random.rand(2, 2, 2, 3).astype("float32"), np.eye(4))
130-
131-
xfm = DenseFieldTransform(disp, is_deltas=is_deltas, reference=ref)
132-
133-
node = xfm.to_x5(metadata={"GeneratedBy": "pytest"})
134-
assert node.type == "nonlinear"
135-
assert node.subtype == "densefield"
136-
assert node.representation == "displacements" if is_deltas else "deformations"
137-
assert node.domain.size == ref.shape
138-
assert node.metadata["GeneratedBy"] == "pytest"
139-
140-
fname = tmp_path / "test.x5"
141-
io.x5.to_filename(fname, [node])
142-
143-
xfm2 = DenseFieldTransform.from_filename(fname, fmt="X5")
144-
145-
assert xfm2.reference.shape == ref.shape
146-
assert np.allclose(xfm2.reference.affine, ref.affine)
147-
assert xfm == xfm2
148-
149-
150-
def test_bspline_to_x5(tmp_path):
151-
"""Check BSpline transforms export to X5."""
152-
coeff = nb.Nifti1Image(np.zeros((2, 2, 2, 3), dtype="float32"), np.eye(4))
153-
ref = nb.Nifti1Image(np.zeros((2, 2, 2), dtype="uint8"), np.eye(4))
154-
155-
xfm = BSplineFieldTransform(coeff, reference=ref)
156-
node = xfm.to_x5(metadata={"tool": "pytest"})
157-
assert node.type == "nonlinear"
158-
assert node.subtype == "bspline"
159-
assert node.representation == "coefficients"
160-
assert node.metadata["tool"] == "pytest"
161-
162-
fname = tmp_path / "bspline.x5"
163-
io.x5.to_filename(fname, [node])
164-
165-
xfm2 = BSplineFieldTransform.from_filename(fname, fmt="X5")
166-
assert np.allclose(xfm._coeffs, xfm2._coeffs)
167-
assert xfm2.reference.shape == ref.shape
168-
assert np.allclose(xfm2.reference.affine, ref.affine)
101+
# Interpolating the field should be reasonably similar
102+
np.testing.assert_allclose(dispxfm._field, bsplxfm._field, atol=1e-1, rtol=1e-4)
169103

170104

171105
@pytest.mark.parametrize("is_deltas", [True, False])

nitransforms/tests/test_x5.py

Lines changed: 49 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,8 +1,10 @@
11
import numpy as np
2+
import nibabel as nb
23
import pytest
34
from h5py import File as H5File
45

5-
from ..io.x5 import X5Transform, X5Domain, to_filename, from_filename
6+
from nitransforms.nonlinear import DenseFieldTransform, BSplineFieldTransform
7+
from nitransforms.io.x5 import X5Transform, X5Domain, to_filename, from_filename
68

79

810
def test_x5_transform_defaults():
@@ -75,3 +77,49 @@ def test_from_filename_invalid(tmp_path):
7577

7678
with pytest.raises(TypeError):
7779
from_filename(fname)
80+
81+
82+
@pytest.mark.parametrize("is_deltas", [True, False])
83+
def test_densefield_x5_roundtrip(tmp_path, is_deltas):
84+
"""Ensure dense field transforms roundtrip via X5."""
85+
ref = nb.Nifti1Image(np.zeros((2, 2, 2), dtype="uint8"), np.eye(4))
86+
disp = nb.Nifti1Image(np.random.rand(2, 2, 2, 3).astype("float32"), np.eye(4))
87+
88+
xfm = DenseFieldTransform(disp, is_deltas=is_deltas, reference=ref)
89+
90+
node = xfm.to_x5(metadata={"GeneratedBy": "pytest"})
91+
assert node.type == "nonlinear"
92+
assert node.subtype == "densefield"
93+
assert node.representation == "displacements" if is_deltas else "deformations"
94+
assert node.domain.size == ref.shape
95+
assert node.metadata["GeneratedBy"] == "pytest"
96+
97+
fname = tmp_path / "test.x5"
98+
to_filename(fname, [node])
99+
100+
xfm2 = DenseFieldTransform.from_filename(fname, fmt="X5")
101+
102+
assert xfm2.reference.shape == ref.shape
103+
assert np.allclose(xfm2.reference.affine, ref.affine)
104+
assert xfm == xfm2
105+
106+
107+
def test_bspline_to_x5(tmp_path):
108+
"""Check BSpline transforms export to X5."""
109+
coeff = nb.Nifti1Image(np.zeros((2, 2, 2, 3), dtype="float32"), np.eye(4))
110+
ref = nb.Nifti1Image(np.zeros((2, 2, 2), dtype="uint8"), np.eye(4))
111+
112+
xfm = BSplineFieldTransform(coeff, reference=ref)
113+
node = xfm.to_x5(metadata={"tool": "pytest"})
114+
assert node.type == "nonlinear"
115+
assert node.subtype == "bspline"
116+
assert node.representation == "coefficients"
117+
assert node.metadata["tool"] == "pytest"
118+
119+
fname = tmp_path / "bspline.x5"
120+
to_filename(fname, [node])
121+
122+
xfm2 = BSplineFieldTransform.from_filename(fname, fmt="X5")
123+
assert np.allclose(xfm._coeffs, xfm2._coeffs)
124+
assert xfm2.reference.shape == ref.shape
125+
assert np.allclose(xfm2.reference.affine, ref.affine)

0 commit comments

Comments
 (0)