Skip to content

Commit 0b13408

Browse files
committed
enh: read X5 transform files
1 parent 33d4af8 commit 0b13408

File tree

1 file changed

+49
-0
lines changed

1 file changed

+49
-0
lines changed

nitransforms/io/x5.py

Lines changed: 49 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -25,6 +25,8 @@
2525

2626
import numpy as np
2727

28+
from .base import TransformFileError
29+
2830

2931
@dataclass
3032
class X5Domain:
@@ -136,3 +138,50 @@ def to_filename(fname: str | Path, x5_list: List[X5Transform]):
136138
# "AdditionalParameters", data=node.additional_parameters
137139
# )
138140
return fname
141+
142+
143+
def from_filename(fname: str | Path) -> List[X5Transform]:
144+
"""Read a list of :class:`X5Transform` objects from an X5 HDF5 file."""
145+
try:
146+
with h5py.File(str(fname), "r") as in_file:
147+
if in_file.attrs.get("Format") != "X5":
148+
raise TransformFileError("Input file is not in X5 format")
149+
150+
tg = in_file["TransformGroup"]
151+
return [
152+
_read_x5_group(node)
153+
for _, node in sorted(tg.items(), key=lambda kv: int(kv[0]))
154+
]
155+
except OSError as exc: # pragma: no cover - in case h5py not installed
156+
raise TransformFileError(str(exc)) from exc
157+
158+
159+
def _read_x5_group(node) -> X5Transform:
160+
x5 = X5Transform(
161+
type=node.attrs["Type"],
162+
transform=np.asarray(node["Transform"]),
163+
subtype=node.attrs.get("SubType"),
164+
representation=node.attrs.get("Representation"),
165+
metadata=json.loads(node.attrs["Metadata"])
166+
if "Metadata" in node.attrs
167+
else None,
168+
dimension_kinds=[
169+
k.decode() if isinstance(k, bytes) else k
170+
for k in node["DimensionKinds"][()]
171+
],
172+
domain=None,
173+
inverse=np.asarray(node["Inverse"]) if "Inverse" in node else None,
174+
jacobian=np.asarray(node["Jacobian"]) if "Jacobian" in node else None,
175+
array_length=int(node.attrs.get("ArrayLength", 1)),
176+
)
177+
178+
if "Domain" in node:
179+
dgrp = node["Domain"]
180+
x5.domain = X5Domain(
181+
grid=bool(int(np.asarray(dgrp["Grid"]))),
182+
size=tuple(np.asarray(dgrp["Size"])),
183+
mapping=np.asarray(dgrp["Mapping"]),
184+
coordinates=dgrp.attrs.get("Coordinates"),
185+
)
186+
187+
return x5

0 commit comments

Comments
 (0)