Skip to content

Commit 3e04551

Browse files
committed
Add nonlinear X5 tests
1 parent 5f051df commit 3e04551

File tree

3 files changed

+120
-2
lines changed

3 files changed

+120
-2
lines changed

CHANGES.rst

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -4,12 +4,14 @@ A new major release with critical updates.
44
The new release includes a critical hotfix for 4D resamplings.
55
The second major improvement is the inclusion of a first implementation of the X5 format (BIDS).
66
The X5 implementation is currently restricted to reading/writing of linear transforms.
7+
It now supports nonlinear transforms as well.
78

89
CHANGES
910
-------
1011
* FIX: Broken 4D resampling by @oesteban in https://github.com/nipy/nitransforms/pull/247
1112
* ENH: Loading of X5 (linear) transforms by @oesteban in https://github.com/nipy/nitransforms/pull/243
1213
* ENH: Implement X5 representation and output to filesystem by @oesteban in https://github.com/nipy/nitransforms/pull/241
14+
* ENH: Support reading and writing of nonlinear transforms in X5
1315
* DOC: Fix references to ``os.PathLike`` by @oesteban in https://github.com/nipy/nitransforms/pull/242
1416
* MNT: Increase coverage by testing edge cases and adding docstrings by @oesteban in https://github.com/nipy/nitransforms/pull/248
1517
* MNT: Refactor io/lta to reduce one partial line by @oesteban in https://github.com/nipy/nitransforms/pull/246

nitransforms/nonlinear.py

Lines changed: 73 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -9,6 +9,7 @@
99
"""Nonlinear transforms."""
1010
import warnings
1111
from functools import partial
12+
from collections import namedtuple
1213
import numpy as np
1314

1415
from nitransforms import io
@@ -229,17 +230,54 @@ def __eq__(self, other):
229230
warnings.warn("Fields are equal, but references do not match.")
230231
return _eq
231232

233+
def to_x5(self, metadata=None):
234+
"""Return an :class:`~nitransforms.io.x5.X5Transform` representation."""
235+
from ._version import __version__
236+
from .io.x5 import X5Domain, X5Transform
237+
238+
metadata = {"WrittenBy": f"NiTransforms {__version__}"} | (metadata or {})
239+
240+
domain = None
241+
if (reference := self.reference) is not None:
242+
domain = X5Domain(
243+
grid=True,
244+
size=getattr(reference, "shape", (0, 0, 0)),
245+
mapping=reference.affine,
246+
coordinates="cartesian",
247+
)
248+
249+
kinds = tuple("space" for _ in range(self.ndim)) + ("vector",)
250+
251+
return X5Transform(
252+
type="nonlinear",
253+
subtype="densefield",
254+
representation="displacements",
255+
metadata=metadata,
256+
transform=self._deltas,
257+
dimension_kinds=kinds,
258+
domain=domain,
259+
)
260+
232261
@classmethod
233262
def from_filename(cls, filename, fmt="X5"):
234263
_factory = {
235264
"afni": io.afni.AFNIDisplacementsField,
236265
"itk": io.itk.ITKDisplacementsField,
237266
"fsl": io.fsl.FSLDisplacementsField,
267+
"X5": None,
238268
}
239-
if fmt not in _factory:
269+
fmt = fmt.upper()
270+
if fmt not in {k.upper() for k in _factory}:
240271
raise NotImplementedError(f"Unsupported format <{fmt}>")
241272

242-
return cls(_factory[fmt].from_filename(filename))
273+
if fmt == "X5":
274+
from .io.x5 import from_filename as load_x5
275+
x5_xfm = load_x5(filename)[0]
276+
Domain = namedtuple("Domain", "affine shape")
277+
reference = Domain(x5_xfm.domain.mapping, x5_xfm.domain.size)
278+
return cls(x5_xfm.transform, is_deltas=True, reference=reference)
279+
280+
return cls(_factory[fmt.lower()].from_filename(filename))
243281

244282

245283
load = DenseFieldTransform.from_filename
@@ -295,6 +333,39 @@ def to_field(self, reference=None, dtype="float32"):
295333
field.astype(dtype).reshape(*_ref.shape, -1), reference=_ref
296334
)
297335

336+
def to_x5(self, metadata=None):
337+
"""Return an :class:`~nitransforms.io.x5.X5Transform` representation."""
338+
from ._version import __version__
339+
from .io.x5 import X5Transform, X5Domain
340+
341+
metadata = {"WrittenBy": f"NiTransforms {__version__}"} | (metadata or {})
342+
343+
domain = None
344+
if (reference := self.reference) is not None:
345+
domain = X5Domain(
346+
grid=True,
347+
size=getattr(reference, "shape", (0, 0, 0)),
348+
mapping=reference.affine,
349+
coordinates="cartesian",
350+
)
351+
352+
meta = metadata | {
353+
"KnotsAffine": self._knots.affine.tolist(),
354+
"KnotsShape": self._knots.shape,
355+
}
356+
357+
kinds = tuple("space" for _ in range(self.ndim)) + ("vector",)
358+
359+
return X5Transform(
360+
type="nonlinear",
361+
subtype="bspline",
362+
representation="coefficients",
363+
metadata=meta,
364+
transform=self._coeffs,
365+
dimension_kinds=kinds,
366+
domain=domain,
367+
)
368+
298369
def map(self, x, inverse=False):
299370
r"""
300371
Apply the transformation to a list of physical coordinate points.

nitransforms/tests/test_nonlinear.py

Lines changed: 45 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -14,6 +14,7 @@
1414
BSplineFieldTransform,
1515
DenseFieldTransform,
1616
)
17+
from nitransforms import io
1718
from ..io.itk import ITKDisplacementsField
1819

1920

@@ -119,3 +120,47 @@ def test_bspline(tmp_path, testdata_path):
119120
).mean()
120121
< 0.2
121122
)
123+
124+
125+
def test_densefield_x5_roundtrip(tmp_path):
126+
"""Ensure dense field transforms roundtrip via X5."""
127+
ref = nb.Nifti1Image(np.zeros((2, 2, 2), dtype="uint8"), np.eye(4))
128+
disp = nb.Nifti1Image(np.random.rand(2, 2, 2, 3).astype("float32"), np.eye(4))
129+
130+
xfm = DenseFieldTransform(disp, reference=ref)
131+
132+
node = xfm.to_x5(metadata={"GeneratedBy": "pytest"})
133+
assert node.type == "nonlinear"
134+
assert node.subtype == "densefield"
135+
assert node.representation == "displacements"
136+
assert node.domain.size == ref.shape
137+
assert node.metadata["GeneratedBy"] == "pytest"
138+
139+
fname = tmp_path / "test.x5"
140+
io.x5.to_filename(fname, [node])
141+
142+
xfm2 = DenseFieldTransform.from_filename(fname, fmt="X5")
143+
diff = xfm2._deltas - xfm._deltas
144+
coords = xfm.reference.ndcoords.T.reshape(xfm._deltas.shape)
145+
assert np.allclose(diff, coords)
146+
assert xfm2.reference.shape == ref.shape
147+
assert np.allclose(xfm2.reference.affine, ref.affine)
148+
149+
150+
def test_bspline_to_x5(tmp_path):
151+
"""Check BSpline transforms export to X5."""
152+
coeff = nb.Nifti1Image(np.zeros((2, 2, 2, 3), dtype="float32"), np.eye(4))
153+
ref = nb.Nifti1Image(np.zeros((2, 2, 2), dtype="uint8"), np.eye(4))
154+
155+
xfm = BSplineFieldTransform(coeff, reference=ref)
156+
node = xfm.to_x5(metadata={"tool": "pytest"})
157+
assert node.type == "nonlinear"
158+
assert node.subtype == "bspline"
159+
assert node.representation == "coefficients"
160+
assert node.metadata["tool"] == "pytest"
161+
162+
fname = tmp_path / "bspline.x5"
163+
io.x5.to_filename(fname, [node])
164+
node2 = io.x5.from_filename(fname)[0]
165+
assert np.allclose(node2.transform, node.transform)
166+
assert node2.metadata["tool"] == "pytest"

0 commit comments

Comments
 (0)