Skip to content

Commit 82f58c1

Browse files
authored
Merge pull request #243 from nipy/codex/add-support-for-affine-and-lineartransformsmapping
ENH: Loading of X5 (linear) transforms
2 parents 33d91ad + b02077e commit 82f58c1

File tree

4 files changed

+135
-6
lines changed

4 files changed

+135
-6
lines changed

nitransforms/io/x5.py

Lines changed: 50 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -136,3 +136,53 @@ def to_filename(fname: str | Path, x5_list: List[X5Transform]):
136136
# "AdditionalParameters", data=node.additional_parameters
137137
# )
138138
return fname
139+
140+
141+
def from_filename(fname: str | Path) -> List[X5Transform]:
142+
"""Read a list of :class:`X5Transform` objects from an X5 HDF5 file."""
143+
try:
144+
with h5py.File(str(fname), "r") as in_file:
145+
if in_file.attrs.get("Format") != "X5":
146+
raise TypeError("Input file is not in X5 format")
147+
148+
tg = in_file["TransformGroup"]
149+
return [
150+
_read_x5_group(node)
151+
for _, node in sorted(tg.items(), key=lambda kv: int(kv[0]))
152+
]
153+
except OSError as err:
154+
if "file signature not found" in err.args[0]:
155+
raise TypeError("Input file is not HDF5.")
156+
157+
raise # pragma: no cover
158+
159+
160+
def _read_x5_group(node) -> X5Transform:
161+
x5 = X5Transform(
162+
type=node.attrs["Type"],
163+
transform=np.asarray(node["Transform"]),
164+
subtype=node.attrs.get("SubType"),
165+
representation=node.attrs.get("Representation"),
166+
metadata=json.loads(node.attrs["Metadata"])
167+
if "Metadata" in node.attrs
168+
else None,
169+
dimension_kinds=[
170+
k.decode() if isinstance(k, bytes) else k
171+
for k in node["DimensionKinds"][()]
172+
],
173+
domain=None,
174+
inverse=np.asarray(node["Inverse"]) if "Inverse" in node else None,
175+
jacobian=np.asarray(node["Jacobian"]) if "Jacobian" in node else None,
176+
array_length=int(node.attrs.get("ArrayLength", 1)),
177+
)
178+
179+
if "Domain" in node:
180+
dgrp = node["Domain"]
181+
x5.domain = X5Domain(
182+
grid=bool(int(np.asarray(dgrp["Grid"]))),
183+
size=tuple(np.asarray(dgrp["Size"])),
184+
mapping=np.asarray(dgrp["Mapping"]),
185+
coordinates=dgrp.attrs.get("Coordinates"),
186+
)
187+
188+
return x5

nitransforms/linear.py

Lines changed: 32 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -9,6 +9,7 @@
99
"""Linear transforms."""
1010

1111
import warnings
12+
from collections import namedtuple
1213
import numpy as np
1314
from pathlib import Path
1415

@@ -27,7 +28,12 @@
2728
EQUALITY_TOL,
2829
)
2930
from nitransforms.io import get_linear_factory, TransformFileError
30-
from nitransforms.io.x5 import X5Transform, X5Domain, to_filename as save_x5
31+
from nitransforms.io.x5 import (
32+
X5Transform,
33+
X5Domain,
34+
to_filename as save_x5,
35+
from_filename as load_x5,
36+
)
3137

3238

3339
class Affine(TransformBase):
@@ -174,8 +180,29 @@ def ndim(self):
174180
return self._matrix.ndim + 1
175181

176182
@classmethod
177-
def from_filename(cls, filename, fmt=None, reference=None, moving=None):
183+
def from_filename(
184+
cls, filename, fmt=None, reference=None, moving=None, x5_position=0
185+
):
178186
"""Create an affine from a transform file."""
187+
188+
if fmt and fmt.upper() == "X5":
189+
x5_xfm = load_x5(filename)[x5_position]
190+
Transform = cls if x5_xfm.array_length == 1 else LinearTransformsMapping
191+
if (
192+
x5_xfm.domain
193+
and not x5_xfm.domain.grid
194+
and len(x5_xfm.domain.size) == 3
195+
): # pragma: no cover
196+
raise NotImplementedError(
197+
"Only 3D regularly gridded domains are supported"
198+
)
199+
elif x5_xfm.domain:
200+
# Override reference
201+
Domain = namedtuple("Domain", "affine shape")
202+
reference = Domain(x5_xfm.domain.mapping, x5_xfm.domain.size)
203+
204+
return Transform(x5_xfm.transform, reference=reference)
205+
179206
fmtlist = [fmt] if fmt is not None else ("itk", "lta", "afni", "fsl")
180207

181208
if fmt is not None and not Path(filename).exists():
@@ -265,7 +292,9 @@ def to_filename(self, filename, fmt="X5", moving=None, x5_inverse=False):
265292
if fmt.upper() == "X5":
266293
return save_x5(filename, [self.to_x5(store_inverse=x5_inverse)])
267294

268-
writer = get_linear_factory(fmt, is_array=isinstance(self, LinearTransformsMapping))
295+
writer = get_linear_factory(
296+
fmt, is_array=isinstance(self, LinearTransformsMapping)
297+
)
269298

270299
if fmt.lower() in ("itk", "ants", "elastix"):
271300
writer.from_ras(self.matrix).to_filename(filename)

nitransforms/tests/test_linear.py

Lines changed: 16 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -265,6 +265,9 @@ def test_linear_to_x5(tmpdir, store_inverse):
265265

266266
aff.to_filename("export1.x5", x5_inverse=store_inverse)
267267

268+
# Test round trip
269+
assert aff == nitl.Affine.from_filename("export1.x5", fmt="X5")
270+
268271
# Test with Domain
269272
img = nb.Nifti1Image(np.zeros((2, 2, 2), dtype="float32"), np.eye(4))
270273
img_path = Path(tmpdir) / "ref.nii.gz"
@@ -275,21 +278,32 @@ def test_linear_to_x5(tmpdir, store_inverse):
275278
assert node.domain.size == aff.reference.shape
276279
aff.to_filename("export2.x5", x5_inverse=store_inverse)
277280

281+
# Test round trip
282+
assert aff == nitl.Affine.from_filename("export2.x5", fmt="X5")
283+
278284
# Test with Jacobian
279285
node.jacobian = np.zeros((2, 2, 2), dtype="float32")
280286
io.x5.to_filename("export3.x5", [node])
281287

282288

283-
def test_mapping_to_x5():
289+
@pytest.mark.parametrize("store_inverse", [True, False])
290+
def test_mapping_to_x5(tmp_path, store_inverse):
284291
mats = [
285292
np.eye(4),
286293
np.array([[1, 0, 0, 1], [0, 1, 0, 2], [0, 0, 1, 3], [0, 0, 0, 1]]),
287294
]
288295
mapping = nitl.LinearTransformsMapping(mats)
289-
node = mapping.to_x5()
296+
node = mapping.to_x5(
297+
metadata={"GeneratedBy": "FreeSurfer 8"}, store_inverse=store_inverse
298+
)
290299
assert node.array_length == 2
291300
assert node.transform.shape == (2, 4, 4)
292301

302+
mapping.to_filename(tmp_path / "export1.x5", x5_inverse=store_inverse)
303+
304+
# Test round trip
305+
assert mapping == nitl.Affine.from_filename(tmp_path / "export1.x5", fmt="X5")
306+
293307

294308
def test_mulmat_operator(testdata_path):
295309
"""Check the @ operator."""

nitransforms/tests/test_x5.py

Lines changed: 37 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,8 @@
11
import numpy as np
2+
import pytest
23
from h5py import File as H5File
34

4-
from ..io.x5 import X5Transform, X5Domain, to_filename
5+
from ..io.x5 import X5Transform, X5Domain, to_filename, from_filename
56

67

78
def test_x5_transform_defaults():
@@ -39,3 +40,38 @@ def test_to_filename(tmp_path):
3940
assert "0" in grp
4041
assert grp["0"].attrs["Type"] == "linear"
4142
assert grp["0"].attrs["ArrayLength"] == 1
43+
44+
45+
def test_from_filename_roundtrip(tmp_path):
46+
domain = X5Domain(grid=False, size=(5, 5, 5), mapping=np.eye(4))
47+
node = X5Transform(
48+
type="linear",
49+
transform=np.eye(4),
50+
dimension_kinds=("space", "space", "space", "vector"),
51+
domain=domain,
52+
metadata={"foo": "bar"},
53+
inverse=np.eye(4),
54+
)
55+
fname = tmp_path / "test.x5"
56+
to_filename(fname, [node])
57+
58+
x5_list = from_filename(fname)
59+
assert len(x5_list) == 1
60+
x5 = x5_list[0]
61+
assert x5.type == node.type
62+
assert np.allclose(x5.transform, node.transform)
63+
assert x5.dimension_kinds == list(node.dimension_kinds)
64+
assert x5.domain.grid == domain.grid
65+
assert x5.domain.size == tuple(domain.size)
66+
assert np.allclose(x5.domain.mapping, domain.mapping)
67+
assert x5.metadata == node.metadata
68+
assert np.allclose(x5.inverse, node.inverse)
69+
70+
71+
def test_from_filename_invalid(tmp_path):
72+
fname = tmp_path / "invalid.h5"
73+
with H5File(fname, "w") as f:
74+
f.attrs["Format"] = "NOTX5"
75+
76+
with pytest.raises(TypeError):
77+
from_filename(fname)

0 commit comments

Comments
 (0)