Skip to content

Commit 5031627

Browse files
authored
Merge pull request #276 from nipy/fix/re-enable-tests
FIX: ``ImageGrid._coords`` was somehow overwritten + re-enable tests
2 parents 8185fb4 + 434f526 commit 5031627

File tree

6 files changed

+78
-35
lines changed

6 files changed

+78
-35
lines changed

nitransforms/base.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -206,14 +206,14 @@ def ndindex(self):
206206
0:self._shape[0], 0:self._shape[1], 0:self._shape[2]
207207
]
208208
self._ndindex = indexes.reshape((indexes.shape[0], -1)).T
209-
return self._ndindex
209+
return self._ndindex.copy() # Return copies to disallow alteration
210210

211211
@property
212212
def ndcoords(self):
213213
"""List the physical coordinates of this gridded space samples."""
214214
if self._coords is None:
215215
self._coords = self.ras(self.ndindex)
216-
return self._coords
216+
return self._coords.copy() # Return copies to disallow alteration
217217

218218
def ras(self, ijk):
219219
"""Get RAS+ coordinates from input indexes."""

nitransforms/nonlinear.py

Lines changed: 28 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -257,6 +257,34 @@ def __eq__(self, other):
257257
warnings.warn("Fields are equal, but references do not match.")
258258
return _eq
259259

260+
def to_filename(self, filename, fmt="X5", moving=None, x5_inverse=False):
261+
"""Store the transform in the designated format."""
262+
263+
if fmt.upper() == "X5":
264+
raise TypeError("Please use .to_x5()")
265+
266+
field = nb.Nifti1Image(
267+
self._deltas if self.is_deltas else self._field,
268+
self.reference.affine,
269+
None,
270+
)
271+
272+
if fmt.lower() == "afni":
273+
from nitransforms.io.afni import AFNIDisplacementsField as FieldIOType
274+
275+
elif fmt.lower() in ("itk", "ants", "elastix"):
276+
from nitransforms.io.itk import ITKDisplacementsField as FieldIOType
277+
278+
elif fmt.lower() == "fsl":
279+
from nitransforms.io.fsl import FSLDisplacementsField as FieldIOType
280+
281+
else:
282+
raise NotImplementedError(
283+
f"Dense field of type '{fmt}' cannot be converted."
284+
)
285+
286+
FieldIOType.to_image(field).to_filename(filename)
287+
260288
def to_x5(self, metadata=None):
261289
"""Return an :class:`~nitransforms.io.x5.X5Transform` representation."""
262290
metadata = {"WrittenBy": f"NiTransforms {__version__}"} | (metadata or {})

nitransforms/resampling.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -254,6 +254,7 @@ def apply(
254254

255255
targets = None
256256
ref_ndcoords = _ref.ndcoords
257+
257258
if hasattr(transform, "to_field") and callable(transform.to_field):
258259
targets = ImageGrid(spatialimage).index(
259260
_as_homogeneous(

nitransforms/tests/test_base.py

Lines changed: 5 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -65,14 +65,15 @@ def test_ImageGrid(get_testdata, image_orientation):
6565
assert idxs.shape[1] == coords.shape[1] == img.ndim == 3
6666
assert idxs.shape[0] == coords.shape[0] == img.npoints == np.prod(im.shape)
6767

68+
# Test indexing round trip
69+
np.testing.assert_allclose(coords, img.ras(idxs))
70+
np.testing.assert_allclose(idxs, img.index(coords), rtol=1e-3, atol=1e-3)
71+
72+
# Test equality
6873
img2 = ImageGrid(img)
6974
assert img2 == img
7075
assert (img2 != img) is False
7176

72-
# Test indexing round trip
73-
np.testing.assert_allclose(img.ndcoords, img.ras(img.ndindex))
74-
np.testing.assert_allclose(img.ndindex, np.round(img.index(img.ndcoords)))
75-
7677

7778
def test_ImageGrid_utils(tmpdir, testdata_path, get_testdata):
7879
"""Check that images can be objects or paths and equality."""

nitransforms/tests/test_nonlinear.py

Lines changed: 24 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -40,6 +40,30 @@ def test_displacements_init():
4040
)
4141

4242

43+
@pytest.mark.parametrize("image_orientation", ["RAS", "LAS", "LPS", "oblique"])
44+
@pytest.mark.parametrize("axis", [0, 1, 2, (0, 1), (1, 2), (0, 1, 2)])
45+
def test_displacements_to_filename(tmp_path, get_testdata, image_orientation, axis):
46+
"""Exercise to_filename."""
47+
48+
nii = get_testdata[image_orientation]
49+
fieldmap = np.zeros((*nii.shape[:3], 3), dtype="float32")
50+
fieldmap[..., axis] = -10.0
51+
52+
xfm = DenseFieldTransform(
53+
fieldmap,
54+
reference=nii,
55+
)
56+
xfm.to_filename(tmp_path / "warp_itk.nii.gz", fmt="itk")
57+
xfm.to_filename(tmp_path / "warp_afni.nii.gz", fmt="afni")
58+
xfm.to_filename(tmp_path / "warp_fsl.nii.gz", fmt="fsl")
59+
60+
with pytest.raises(NotImplementedError):
61+
xfm.to_filename(tmp_path / "warp_freesurfer.nii.gz", fmt="fs")
62+
63+
with pytest.raises(TypeError):
64+
xfm.to_filename(tmp_path / "warp.x5", fmt="X5")
65+
66+
4367
@pytest.mark.parametrize("size", [(20, 20, 20), (20, 20, 20, 2, 3), (20, 20, 20, 1, 4)])
4468
def test_displacements_bad_sizes(size):
4569
"""Checks field sizes."""

nitransforms/tests/test_resampling.py

Lines changed: 18 additions & 29 deletions
Original file line numberDiff line numberDiff line change
@@ -149,10 +149,6 @@ def test_apply_linear_transform(
149149
assert np.sqrt((diff[brainmask] ** 2).mean()) < RMSE_TOL_LINEAR
150150

151151

152-
@pytest.mark.xfail(
153-
reason="Disable while #266 is developed.",
154-
strict=False,
155-
)
156152
@pytest.mark.parametrize("image_orientation", ["RAS", "LAS", "LPS", "oblique"])
157153
@pytest.mark.parametrize("sw_tool", ["itk", "afni"])
158154
@pytest.mark.parametrize("axis", [0, 1, 2, (0, 1), (1, 2), (0, 1, 2)])
@@ -174,29 +170,24 @@ def test_displacements_field1(
174170
nii.to_filename("reference.nii.gz")
175171
msk.to_filename("mask.nii.gz")
176172

177-
fieldmap = np.zeros(
178-
(*nii.shape[:3], 1, 3) if sw_tool != "fsl" else (*nii.shape[:3], 3),
179-
dtype="float32",
180-
)
173+
fieldmap = np.zeros((*nii.shape[:3], 3), dtype="float32")
181174
fieldmap[..., axis] = -10.0
182175

183-
_hdr = nii.header.copy()
184-
if sw_tool in ("itk",):
185-
_hdr.set_intent("vector")
186-
_hdr.set_data_dtype("float32")
187-
176+
# Generate a transform file for the particular software
188177
xfm_fname = "warp.nii.gz"
189-
field = nb.Nifti1Image(fieldmap, nii.affine, _hdr)
190-
field.to_filename(xfm_fname)
191-
192-
xfm = nitnl.load(xfm_fname, fmt=sw_tool)
178+
xfm = nitnl.DenseFieldTransform(
179+
fieldmap,
180+
reference=nii,
181+
)
182+
xfm.to_filename(xfm_fname, fmt=sw_tool)
193183

184+
tool_output = tmp_path / f"{sw_tool}_brainmask.nii.gz"
194185
# Then apply the transform and cross-check with software
195186
cmd = APPLY_NONLINEAR_CMD[sw_tool](
196187
transform=os.path.abspath(xfm_fname),
197188
reference=tmp_path / "mask.nii.gz",
198189
moving=tmp_path / "mask.nii.gz",
199-
output=tmp_path / "resampled_brainmask.nii.gz",
190+
output=tool_output,
200191
extra="--output-data-type uchar" if sw_tool == "itk" else "",
201192
)
202193

@@ -208,26 +199,28 @@ def test_displacements_field1(
208199
# resample mask
209200
exit_code = check_call([cmd], shell=True)
210201
assert exit_code == 0
211-
sw_moved_mask = nb.load("resampled_brainmask.nii.gz")
202+
sw_moved_mask = np.asanyarray(nb.load(tool_output).dataobj, dtype=bool)
212203
nt_moved_mask = apply(xfm, msk, order=0)
213-
nt_moved_mask.set_data_dtype(msk.get_data_dtype())
214-
diff = np.asanyarray(sw_moved_mask.dataobj) - np.asanyarray(nt_moved_mask.dataobj)
215-
216-
assert np.sqrt((diff**2).mean()) < RMSE_TOL_LINEAR
204+
nt_moved_mask.to_filename(tmp_path / "nit_brainmask.nii.gz")
217205
brainmask = np.asanyarray(nt_moved_mask.dataobj, dtype=bool)
206+
percent_diff = (sw_moved_mask != brainmask)[5:-5, 5:-5, 5:-5].sum() / brainmask.size
207+
208+
assert percent_diff < 1e-8, (
209+
f"Resampled masks differed by {percent_diff * 100:0.2f}%."
210+
)
218211

219212
# Then apply the transform and cross-check with software
220213
cmd = APPLY_NONLINEAR_CMD[sw_tool](
221214
transform=os.path.abspath(xfm_fname),
222215
reference=tmp_path / "reference.nii.gz",
223216
moving=tmp_path / "reference.nii.gz",
224-
output=tmp_path / "resampled.nii.gz",
217+
output=tmp_path / f"{sw_tool}_resampled.nii.gz",
225218
extra="--output-data-type uchar" if sw_tool == "itk" else "",
226219
)
227220

228221
exit_code = check_call([cmd], shell=True)
229222
assert exit_code == 0
230-
sw_moved = nb.load("resampled.nii.gz")
223+
sw_moved = nb.load(f"{sw_tool}_resampled.nii.gz")
231224

232225
nt_moved = apply(xfm, nii, order=0)
233226
nt_moved.set_data_dtype(nii.get_data_dtype())
@@ -240,10 +233,6 @@ def test_displacements_field1(
240233
assert np.sqrt((diff[brainmask] ** 2).mean()) < RMSE_TOL_LINEAR
241234

242235

243-
@pytest.mark.xfail(
244-
reason="Disable while #266 is developed.",
245-
strict=False,
246-
)
247236
@pytest.mark.parametrize("sw_tool", ["itk", "afni"])
248237
def test_displacements_field2(tmp_path, testdata_path, sw_tool):
249238
"""Check a translation-only field on one or more axes, different image orientations."""

0 commit comments

Comments
 (0)