Skip to content

Commit 896d46b

Browse files
committed
Add nonlinear X5 tests
1 parent 4662d29 commit 896d46b

File tree

3 files changed

+120
-2
lines changed

3 files changed

+120
-2
lines changed

CHANGES.rst

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -4,12 +4,14 @@ A new major release with critical updates.
44
The new release includes a critical hotfix for 4D resamplings.
55
The second major improvement is the inclusion of a first implementation of the X5 format (BIDS).
66
The X5 implementation is currently restricted to reading/writing of linear transforms.
7+
It now supports nonlinear transforms as well.
78

89
CHANGES
910
-------
1011
* FIX: Broken 4D resampling by @oesteban in https://github.com/nipy/nitransforms/pull/247
1112
* ENH: Loading of X5 (linear) transforms by @oesteban in https://github.com/nipy/nitransforms/pull/243
1213
* ENH: Implement X5 representation and output to filesystem by @oesteban in https://github.com/nipy/nitransforms/pull/241
14+
* ENH: Support reading and writing of nonlinear transforms in X5
1315
* DOC: Fix references to ``os.PathLike`` by @oesteban in https://github.com/nipy/nitransforms/pull/242
1416
* MNT: Increase coverage by testing edge cases and adding docstrings by @oesteban in https://github.com/nipy/nitransforms/pull/248
1517
* MNT: Refactor io/lta to reduce one partial line by @oesteban in https://github.com/nipy/nitransforms/pull/246

nitransforms/nonlinear.py

Lines changed: 73 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -9,6 +9,7 @@
99
"""Nonlinear transforms."""
1010
import warnings
1111
from functools import partial
12+
from collections import namedtuple
1213
import numpy as np
1314

1415
from nitransforms import io
@@ -227,17 +228,54 @@ def __eq__(self, other):
227228
warnings.warn("Fields are equal, but references do not match.")
228229
return _eq
229230

231+
def to_x5(self, metadata=None):
232+
"""Return an :class:`~nitransforms.io.x5.X5Transform` representation."""
233+
from ._version import __version__
234+
from .io.x5 import X5Domain, X5Transform
235+
236+
metadata = {"WrittenBy": f"NiTransforms {__version__}"} | (metadata or {})
237+
238+
domain = None
239+
if (reference := self.reference) is not None:
240+
domain = X5Domain(
241+
grid=True,
242+
size=getattr(reference, "shape", (0, 0, 0)),
243+
mapping=reference.affine,
244+
coordinates="cartesian",
245+
)
246+
247+
kinds = tuple("space" for _ in range(self.ndim)) + ("vector",)
248+
249+
return X5Transform(
250+
type="nonlinear",
251+
subtype="densefield",
252+
representation="displacements",
253+
metadata=metadata,
254+
transform=self._deltas,
255+
dimension_kinds=kinds,
256+
domain=domain,
257+
)
258+
230259
@classmethod
231260
def from_filename(cls, filename, fmt="X5"):
232261
_factory = {
233262
"afni": io.afni.AFNIDisplacementsField,
234263
"itk": io.itk.ITKDisplacementsField,
235264
"fsl": io.fsl.FSLDisplacementsField,
265+
"X5": None,
236266
}
237-
if fmt not in _factory:
267+
fmt = fmt.upper()
268+
if fmt not in {k.upper() for k in _factory}:
238269
raise NotImplementedError(f"Unsupported format <{fmt}>")
239270

240-
return cls(_factory[fmt].from_filename(filename))
271+
if fmt == "X5":
272+
from .io.x5 import from_filename as load_x5
273+
x5_xfm = load_x5(filename)[0]
274+
Domain = namedtuple("Domain", "affine shape")
275+
reference = Domain(x5_xfm.domain.mapping, x5_xfm.domain.size)
276+
return cls(x5_xfm.transform, is_deltas=True, reference=reference)
277+
278+
return cls(_factory[fmt.lower()].from_filename(filename))
241279

242280

243281
load = DenseFieldTransform.from_filename
@@ -293,6 +331,39 @@ def to_field(self, reference=None, dtype="float32"):
293331
field.astype(dtype).reshape(*_ref.shape, -1), reference=_ref
294332
)
295333

334+
def to_x5(self, metadata=None):
335+
"""Return an :class:`~nitransforms.io.x5.X5Transform` representation."""
336+
from ._version import __version__
337+
from .io.x5 import X5Transform, X5Domain
338+
339+
metadata = {"WrittenBy": f"NiTransforms {__version__}"} | (metadata or {})
340+
341+
domain = None
342+
if (reference := self.reference) is not None:
343+
domain = X5Domain(
344+
grid=True,
345+
size=getattr(reference, "shape", (0, 0, 0)),
346+
mapping=reference.affine,
347+
coordinates="cartesian",
348+
)
349+
350+
meta = metadata | {
351+
"KnotsAffine": self._knots.affine.tolist(),
352+
"KnotsShape": self._knots.shape,
353+
}
354+
355+
kinds = tuple("space" for _ in range(self.ndim)) + ("vector",)
356+
357+
return X5Transform(
358+
type="nonlinear",
359+
subtype="bspline",
360+
representation="coefficients",
361+
metadata=meta,
362+
transform=self._coeffs,
363+
dimension_kinds=kinds,
364+
domain=domain,
365+
)
366+
296367
def map(self, x, inverse=False):
297368
r"""
298369
Apply the transformation to a list of physical coordinate points.

nitransforms/tests/test_nonlinear.py

Lines changed: 45 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -14,6 +14,7 @@
1414
BSplineFieldTransform,
1515
DenseFieldTransform,
1616
)
17+
from nitransforms import io
1718
from ..io.itk import ITKDisplacementsField
1819

1920

@@ -119,3 +120,47 @@ def test_bspline(tmp_path, testdata_path):
119120
).mean()
120121
< 0.2
121122
)
123+
124+
125+
def test_densefield_x5_roundtrip(tmp_path):
126+
"""Ensure dense field transforms roundtrip via X5."""
127+
ref = nb.Nifti1Image(np.zeros((2, 2, 2), dtype="uint8"), np.eye(4))
128+
disp = nb.Nifti1Image(np.random.rand(2, 2, 2, 3).astype("float32"), np.eye(4))
129+
130+
xfm = DenseFieldTransform(disp, reference=ref)
131+
132+
node = xfm.to_x5(metadata={"GeneratedBy": "pytest"})
133+
assert node.type == "nonlinear"
134+
assert node.subtype == "densefield"
135+
assert node.representation == "displacements"
136+
assert node.domain.size == ref.shape
137+
assert node.metadata["GeneratedBy"] == "pytest"
138+
139+
fname = tmp_path / "test.x5"
140+
io.x5.to_filename(fname, [node])
141+
142+
xfm2 = DenseFieldTransform.from_filename(fname, fmt="X5")
143+
diff = xfm2._deltas - xfm._deltas
144+
coords = xfm.reference.ndcoords.T.reshape(xfm._deltas.shape)
145+
assert np.allclose(diff, coords)
146+
assert xfm2.reference.shape == ref.shape
147+
assert np.allclose(xfm2.reference.affine, ref.affine)
148+
149+
150+
def test_bspline_to_x5(tmp_path):
151+
"""Check BSpline transforms export to X5."""
152+
coeff = nb.Nifti1Image(np.zeros((2, 2, 2, 3), dtype="float32"), np.eye(4))
153+
ref = nb.Nifti1Image(np.zeros((2, 2, 2), dtype="uint8"), np.eye(4))
154+
155+
xfm = BSplineFieldTransform(coeff, reference=ref)
156+
node = xfm.to_x5(metadata={"tool": "pytest"})
157+
assert node.type == "nonlinear"
158+
assert node.subtype == "bspline"
159+
assert node.representation == "coefficients"
160+
assert node.metadata["tool"] == "pytest"
161+
162+
fname = tmp_path / "bspline.x5"
163+
io.x5.to_filename(fname, [node])
164+
node2 = io.x5.from_filename(fname)[0]
165+
assert np.allclose(node2.transform, node.transform)
166+
assert node2.metadata["tool"] == "pytest"

0 commit comments

Comments
 (0)