diff --git a/nitransforms/io/x5.py b/nitransforms/io/x5.py index 463a1336..a86a8554 100644 --- a/nitransforms/io/x5.py +++ b/nitransforms/io/x5.py @@ -136,3 +136,53 @@ def to_filename(fname: str | Path, x5_list: List[X5Transform]): # "AdditionalParameters", data=node.additional_parameters # ) return fname + + +def from_filename(fname: str | Path) -> List[X5Transform]: + """Read a list of :class:`X5Transform` objects from an X5 HDF5 file.""" + try: + with h5py.File(str(fname), "r") as in_file: + if in_file.attrs.get("Format") != "X5": + raise TypeError("Input file is not in X5 format") + + tg = in_file["TransformGroup"] + return [ + _read_x5_group(node) + for _, node in sorted(tg.items(), key=lambda kv: int(kv[0])) + ] + except OSError as err: + if "file signature not found" in err.args[0]: + raise TypeError("Input file is not HDF5.") + + raise # pragma: no cover + + +def _read_x5_group(node) -> X5Transform: + x5 = X5Transform( + type=node.attrs["Type"], + transform=np.asarray(node["Transform"]), + subtype=node.attrs.get("SubType"), + representation=node.attrs.get("Representation"), + metadata=json.loads(node.attrs["Metadata"]) + if "Metadata" in node.attrs + else None, + dimension_kinds=[ + k.decode() if isinstance(k, bytes) else k + for k in node["DimensionKinds"][()] + ], + domain=None, + inverse=np.asarray(node["Inverse"]) if "Inverse" in node else None, + jacobian=np.asarray(node["Jacobian"]) if "Jacobian" in node else None, + array_length=int(node.attrs.get("ArrayLength", 1)), + ) + + if "Domain" in node: + dgrp = node["Domain"] + x5.domain = X5Domain( + grid=bool(int(np.asarray(dgrp["Grid"]))), + size=tuple(np.asarray(dgrp["Size"])), + mapping=np.asarray(dgrp["Mapping"]), + coordinates=dgrp.attrs.get("Coordinates"), + ) + + return x5 diff --git a/nitransforms/linear.py b/nitransforms/linear.py index cf8f8465..8797a1c8 100644 --- a/nitransforms/linear.py +++ b/nitransforms/linear.py @@ -9,6 +9,7 @@ """Linear transforms.""" import warnings +from collections import namedtuple import numpy as np from pathlib import Path @@ -27,7 +28,12 @@ EQUALITY_TOL, ) from nitransforms.io import get_linear_factory, TransformFileError -from nitransforms.io.x5 import X5Transform, X5Domain, to_filename as save_x5 +from nitransforms.io.x5 import ( + X5Transform, + X5Domain, + to_filename as save_x5, + from_filename as load_x5, +) class Affine(TransformBase): @@ -174,8 +180,29 @@ def ndim(self): return self._matrix.ndim + 1 @classmethod - def from_filename(cls, filename, fmt=None, reference=None, moving=None): + def from_filename( + cls, filename, fmt=None, reference=None, moving=None, x5_position=0 + ): """Create an affine from a transform file.""" + + if fmt and fmt.upper() == "X5": + x5_xfm = load_x5(filename)[x5_position] + Transform = cls if x5_xfm.array_length == 1 else LinearTransformsMapping + if ( + x5_xfm.domain + and not x5_xfm.domain.grid + and len(x5_xfm.domain.size) == 3 + ): # pragma: no cover + raise NotImplementedError( + "Only 3D regularly gridded domains are supported" + ) + elif x5_xfm.domain: + # Override reference + Domain = namedtuple("Domain", "affine shape") + reference = Domain(x5_xfm.domain.mapping, x5_xfm.domain.size) + + return Transform(x5_xfm.transform, reference=reference) + fmtlist = [fmt] if fmt is not None else ("itk", "lta", "afni", "fsl") if fmt is not None and not Path(filename).exists(): @@ -265,7 +292,9 @@ def to_filename(self, filename, fmt="X5", moving=None, x5_inverse=False): if fmt.upper() == "X5": return save_x5(filename, [self.to_x5(store_inverse=x5_inverse)]) - writer = get_linear_factory(fmt, is_array=isinstance(self, LinearTransformsMapping)) + writer = get_linear_factory( + fmt, is_array=isinstance(self, LinearTransformsMapping) + ) if fmt.lower() in ("itk", "ants", "elastix"): writer.from_ras(self.matrix).to_filename(filename) diff --git a/nitransforms/tests/test_linear.py b/nitransforms/tests/test_linear.py index 32634c61..d1e5e47e 100644 --- a/nitransforms/tests/test_linear.py +++ b/nitransforms/tests/test_linear.py @@ -265,6 +265,9 @@ def test_linear_to_x5(tmpdir, store_inverse): aff.to_filename("export1.x5", x5_inverse=store_inverse) + # Test round trip + assert aff == nitl.Affine.from_filename("export1.x5", fmt="X5") + # Test with Domain img = nb.Nifti1Image(np.zeros((2, 2, 2), dtype="float32"), np.eye(4)) img_path = Path(tmpdir) / "ref.nii.gz" @@ -275,21 +278,32 @@ def test_linear_to_x5(tmpdir, store_inverse): assert node.domain.size == aff.reference.shape aff.to_filename("export2.x5", x5_inverse=store_inverse) + # Test round trip + assert aff == nitl.Affine.from_filename("export2.x5", fmt="X5") + # Test with Jacobian node.jacobian = np.zeros((2, 2, 2), dtype="float32") io.x5.to_filename("export3.x5", [node]) -def test_mapping_to_x5(): +@pytest.mark.parametrize("store_inverse", [True, False]) +def test_mapping_to_x5(tmp_path, store_inverse): mats = [ np.eye(4), np.array([[1, 0, 0, 1], [0, 1, 0, 2], [0, 0, 1, 3], [0, 0, 0, 1]]), ] mapping = nitl.LinearTransformsMapping(mats) - node = mapping.to_x5() + node = mapping.to_x5( + metadata={"GeneratedBy": "FreeSurfer 8"}, store_inverse=store_inverse + ) assert node.array_length == 2 assert node.transform.shape == (2, 4, 4) + mapping.to_filename(tmp_path / "export1.x5", x5_inverse=store_inverse) + + # Test round trip + assert mapping == nitl.Affine.from_filename(tmp_path / "export1.x5", fmt="X5") + def test_mulmat_operator(testdata_path): """Check the @ operator.""" diff --git a/nitransforms/tests/test_x5.py b/nitransforms/tests/test_x5.py index 8502a387..89b49e06 100644 --- a/nitransforms/tests/test_x5.py +++ b/nitransforms/tests/test_x5.py @@ -1,7 +1,8 @@ import numpy as np +import pytest from h5py import File as H5File -from ..io.x5 import X5Transform, X5Domain, to_filename +from ..io.x5 import X5Transform, X5Domain, to_filename, from_filename def test_x5_transform_defaults(): @@ -39,3 +40,38 @@ def test_to_filename(tmp_path): assert "0" in grp assert grp["0"].attrs["Type"] == "linear" assert grp["0"].attrs["ArrayLength"] == 1 + + +def test_from_filename_roundtrip(tmp_path): + domain = X5Domain(grid=False, size=(5, 5, 5), mapping=np.eye(4)) + node = X5Transform( + type="linear", + transform=np.eye(4), + dimension_kinds=("space", "space", "space", "vector"), + domain=domain, + metadata={"foo": "bar"}, + inverse=np.eye(4), + ) + fname = tmp_path / "test.x5" + to_filename(fname, [node]) + + x5_list = from_filename(fname) + assert len(x5_list) == 1 + x5 = x5_list[0] + assert x5.type == node.type + assert np.allclose(x5.transform, node.transform) + assert x5.dimension_kinds == list(node.dimension_kinds) + assert x5.domain.grid == domain.grid + assert x5.domain.size == tuple(domain.size) + assert np.allclose(x5.domain.mapping, domain.mapping) + assert x5.metadata == node.metadata + assert np.allclose(x5.inverse, node.inverse) + + +def test_from_filename_invalid(tmp_path): + fname = tmp_path / "invalid.h5" + with H5File(fname, "w") as f: + f.attrs["Format"] = "NOTX5" + + with pytest.raises(TypeError): + from_filename(fname)