Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
50 changes: 50 additions & 0 deletions nitransforms/io/x5.py
Original file line number Diff line number Diff line change
Expand Up @@ -136,3 +136,53 @@ def to_filename(fname: str | Path, x5_list: List[X5Transform]):
# "AdditionalParameters", data=node.additional_parameters
# )
return fname


def from_filename(fname: str | Path) -> List[X5Transform]:
"""Read a list of :class:`X5Transform` objects from an X5 HDF5 file."""
try:
with h5py.File(str(fname), "r") as in_file:
if in_file.attrs.get("Format") != "X5":
raise TypeError("Input file is not in X5 format")

tg = in_file["TransformGroup"]
return [
_read_x5_group(node)
for _, node in sorted(tg.items(), key=lambda kv: int(kv[0]))
]
except OSError as err:
if "file signature not found" in err.args[0]:
raise TypeError("Input file is not HDF5.")

raise # pragma: no cover


def _read_x5_group(node) -> X5Transform:
x5 = X5Transform(
type=node.attrs["Type"],
transform=np.asarray(node["Transform"]),
subtype=node.attrs.get("SubType"),
representation=node.attrs.get("Representation"),
metadata=json.loads(node.attrs["Metadata"])
if "Metadata" in node.attrs
else None,
dimension_kinds=[
k.decode() if isinstance(k, bytes) else k
for k in node["DimensionKinds"][()]
],
domain=None,
inverse=np.asarray(node["Inverse"]) if "Inverse" in node else None,
jacobian=np.asarray(node["Jacobian"]) if "Jacobian" in node else None,
array_length=int(node.attrs.get("ArrayLength", 1)),
)

if "Domain" in node:
dgrp = node["Domain"]
x5.domain = X5Domain(
grid=bool(int(np.asarray(dgrp["Grid"]))),
size=tuple(np.asarray(dgrp["Size"])),
mapping=np.asarray(dgrp["Mapping"]),
coordinates=dgrp.attrs.get("Coordinates"),
)

return x5
35 changes: 32 additions & 3 deletions nitransforms/linear.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,7 @@
"""Linear transforms."""

import warnings
from collections import namedtuple
import numpy as np
from pathlib import Path

Expand All @@ -27,7 +28,12 @@
EQUALITY_TOL,
)
from nitransforms.io import get_linear_factory, TransformFileError
from nitransforms.io.x5 import X5Transform, X5Domain, to_filename as save_x5
from nitransforms.io.x5 import (
X5Transform,
X5Domain,
to_filename as save_x5,
from_filename as load_x5,
)


class Affine(TransformBase):
Expand Down Expand Up @@ -174,8 +180,29 @@ def ndim(self):
return self._matrix.ndim + 1

@classmethod
def from_filename(cls, filename, fmt=None, reference=None, moving=None):
def from_filename(
cls, filename, fmt=None, reference=None, moving=None, x5_position=0
):
"""Create an affine from a transform file."""

if fmt and fmt.upper() == "X5":
x5_xfm = load_x5(filename)[x5_position]
Transform = cls if x5_xfm.array_length == 1 else LinearTransformsMapping
if (
x5_xfm.domain
and not x5_xfm.domain.grid
and len(x5_xfm.domain.size) == 3
): # pragma: no cover
raise NotImplementedError(
"Only 3D regularly gridded domains are supported"
)
elif x5_xfm.domain:
# Override reference
Domain = namedtuple("Domain", "affine shape")
reference = Domain(x5_xfm.domain.mapping, x5_xfm.domain.size)

return Transform(x5_xfm.transform, reference=reference)

fmtlist = [fmt] if fmt is not None else ("itk", "lta", "afni", "fsl")

if fmt is not None and not Path(filename).exists():
Expand Down Expand Up @@ -265,7 +292,9 @@ def to_filename(self, filename, fmt="X5", moving=None, x5_inverse=False):
if fmt.upper() == "X5":
return save_x5(filename, [self.to_x5(store_inverse=x5_inverse)])

writer = get_linear_factory(fmt, is_array=isinstance(self, LinearTransformsMapping))
writer = get_linear_factory(
fmt, is_array=isinstance(self, LinearTransformsMapping)
)

if fmt.lower() in ("itk", "ants", "elastix"):
writer.from_ras(self.matrix).to_filename(filename)
Expand Down
18 changes: 16 additions & 2 deletions nitransforms/tests/test_linear.py
Original file line number Diff line number Diff line change
Expand Up @@ -265,6 +265,9 @@ def test_linear_to_x5(tmpdir, store_inverse):

aff.to_filename("export1.x5", x5_inverse=store_inverse)

# Test round trip
assert aff == nitl.Affine.from_filename("export1.x5", fmt="X5")

# Test with Domain
img = nb.Nifti1Image(np.zeros((2, 2, 2), dtype="float32"), np.eye(4))
img_path = Path(tmpdir) / "ref.nii.gz"
Expand All @@ -275,21 +278,32 @@ def test_linear_to_x5(tmpdir, store_inverse):
assert node.domain.size == aff.reference.shape
aff.to_filename("export2.x5", x5_inverse=store_inverse)

# Test round trip
assert aff == nitl.Affine.from_filename("export2.x5", fmt="X5")

# Test with Jacobian
node.jacobian = np.zeros((2, 2, 2), dtype="float32")
io.x5.to_filename("export3.x5", [node])


def test_mapping_to_x5():
@pytest.mark.parametrize("store_inverse", [True, False])
def test_mapping_to_x5(tmp_path, store_inverse):
mats = [
np.eye(4),
np.array([[1, 0, 0, 1], [0, 1, 0, 2], [0, 0, 1, 3], [0, 0, 0, 1]]),
]
mapping = nitl.LinearTransformsMapping(mats)
node = mapping.to_x5()
node = mapping.to_x5(
metadata={"GeneratedBy": "FreeSurfer 8"}, store_inverse=store_inverse
)
assert node.array_length == 2
assert node.transform.shape == (2, 4, 4)

mapping.to_filename(tmp_path / "export1.x5", x5_inverse=store_inverse)

# Test round trip
assert mapping == nitl.Affine.from_filename(tmp_path / "export1.x5", fmt="X5")


def test_mulmat_operator(testdata_path):
"""Check the @ operator."""
Expand Down
38 changes: 37 additions & 1 deletion nitransforms/tests/test_x5.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,8 @@
import numpy as np
import pytest
from h5py import File as H5File

from ..io.x5 import X5Transform, X5Domain, to_filename
from ..io.x5 import X5Transform, X5Domain, to_filename, from_filename


def test_x5_transform_defaults():
Expand Down Expand Up @@ -39,3 +40,38 @@ def test_to_filename(tmp_path):
assert "0" in grp
assert grp["0"].attrs["Type"] == "linear"
assert grp["0"].attrs["ArrayLength"] == 1


def test_from_filename_roundtrip(tmp_path):
domain = X5Domain(grid=False, size=(5, 5, 5), mapping=np.eye(4))
node = X5Transform(
type="linear",
transform=np.eye(4),
dimension_kinds=("space", "space", "space", "vector"),
domain=domain,
metadata={"foo": "bar"},
inverse=np.eye(4),
)
fname = tmp_path / "test.x5"
to_filename(fname, [node])

x5_list = from_filename(fname)
assert len(x5_list) == 1
x5 = x5_list[0]
assert x5.type == node.type
assert np.allclose(x5.transform, node.transform)
assert x5.dimension_kinds == list(node.dimension_kinds)
assert x5.domain.grid == domain.grid
assert x5.domain.size == tuple(domain.size)
assert np.allclose(x5.domain.mapping, domain.mapping)
assert x5.metadata == node.metadata
assert np.allclose(x5.inverse, node.inverse)


def test_from_filename_invalid(tmp_path):
fname = tmp_path / "invalid.h5"
with H5File(fname, "w") as f:
f.attrs["Format"] = "NOTX5"

with pytest.raises(TypeError):
from_filename(fname)
Loading