Skip to content

Commit a03c51f

Browse files
authored
Merge pull request #253 from nipy/codex/add-support-for-x5-transform-chains
ENH: X5 read/write support of ``TransformChain``
2 parents fc6fbbd + 73f3db1 commit a03c51f

File tree

6 files changed

+260
-71
lines changed

6 files changed

+260
-71
lines changed

nitransforms/io/x5.py

Lines changed: 29 additions & 30 deletions
Original file line numberDiff line numberDiff line change
@@ -26,7 +26,7 @@
2626
import numpy as np
2727

2828

29-
@dataclass
29+
@dataclass(eq=True)
3030
class X5Domain:
3131
"""Domain information of a transform representing reference/moving spaces."""
3232

@@ -105,35 +105,7 @@ def to_filename(fname: str | Path, x5_list: List[X5Transform]):
105105
tg = out_file.create_group("TransformGroup")
106106
for i, node in enumerate(x5_list):
107107
g = tg.create_group(str(i))
108-
g.attrs["Type"] = node.type
109-
g.attrs["ArrayLength"] = node.array_length
110-
if node.subtype is not None:
111-
g.attrs["SubType"] = node.subtype
112-
if node.representation is not None:
113-
g.attrs["Representation"] = node.representation
114-
if node.metadata is not None:
115-
g.attrs["Metadata"] = json.dumps(node.metadata)
116-
g.create_dataset("Transform", data=node.transform)
117-
g.create_dataset(
118-
"DimensionKinds",
119-
data=np.asarray(node.dimension_kinds, dtype="S"),
120-
)
121-
if node.domain is not None:
122-
dgrp = g.create_group("Domain")
123-
dgrp.create_dataset("Grid", data=np.uint8(1 if node.domain.grid else 0))
124-
dgrp.create_dataset("Size", data=np.asarray(node.domain.size))
125-
dgrp.create_dataset("Mapping", data=node.domain.mapping)
126-
if node.domain.coordinates is not None:
127-
dgrp.attrs["Coordinates"] = node.domain.coordinates
128-
129-
if node.inverse is not None:
130-
g.create_dataset("Inverse", data=node.inverse)
131-
if node.jacobian is not None:
132-
g.create_dataset("Jacobian", data=node.jacobian)
133-
if node.additional_parameters is not None:
134-
g.create_dataset(
135-
"AdditionalParameters", data=node.additional_parameters
136-
)
108+
_write_x5_group(g, node)
137109
return fname
138110

139111

@@ -188,3 +160,30 @@ def _read_x5_group(node) -> X5Transform:
188160
)
189161

190162
return x5
163+
164+
165+
def _write_x5_group(g, node: X5Transform):
166+
"""Write one :class:`X5Transform` element into an opened HDF5 group."""
167+
g.attrs["Type"] = node.type
168+
g.attrs["ArrayLength"] = node.array_length
169+
if node.subtype is not None:
170+
g.attrs["SubType"] = node.subtype
171+
if node.representation is not None:
172+
g.attrs["Representation"] = node.representation
173+
if node.metadata is not None:
174+
g.attrs["Metadata"] = json.dumps(node.metadata)
175+
g.create_dataset("Transform", data=node.transform)
176+
g.create_dataset("DimensionKinds", data=np.asarray(node.dimension_kinds, dtype="S"))
177+
if node.domain is not None:
178+
dgrp = g.create_group("Domain")
179+
dgrp.create_dataset("Grid", data=np.uint8(1 if node.domain.grid else 0))
180+
dgrp.create_dataset("Size", data=np.asarray(node.domain.size))
181+
dgrp.create_dataset("Mapping", data=node.domain.mapping)
182+
if node.domain.coordinates is not None:
183+
dgrp.attrs["Coordinates"] = node.domain.coordinates
184+
if node.inverse is not None:
185+
g.create_dataset("Inverse", data=node.inverse)
186+
if node.jacobian is not None:
187+
g.create_dataset("Jacobian", data=node.jacobian)
188+
if node.additional_parameters is not None:
189+
g.create_dataset("AdditionalParameters", data=node.additional_parameters)

nitransforms/linear.py

Lines changed: 27 additions & 16 deletions
Original file line numberDiff line numberDiff line change
@@ -120,8 +120,15 @@ def __eq__(self, other):
120120
>>> xfm2 = Affine(xfm1.matrix)
121121
>>> xfm1 == xfm2
122122
True
123+
>>> xfm1 == Affine()
124+
False
125+
>>> xfm1 == TransformBase()
126+
False
123127
124128
"""
129+
if not hasattr(other, "matrix"):
130+
return False
131+
125132
_eq = np.allclose(self.matrix, other.matrix, rtol=EQUALITY_TOL)
126133
if _eq and self._reference != other._reference:
127134
warnings.warn("Affines are equal, but references do not match.")
@@ -186,22 +193,9 @@ def from_filename(
186193
"""Create an affine from a transform file."""
187194

188195
if fmt and fmt.upper() == "X5":
189-
x5_xfm = load_x5(filename)[x5_position]
190-
Transform = cls if x5_xfm.array_length == 1 else LinearTransformsMapping
191-
if (
192-
x5_xfm.domain
193-
and not x5_xfm.domain.grid
194-
and len(x5_xfm.domain.size) == 3
195-
): # pragma: no cover
196-
raise NotImplementedError(
197-
"Only 3D regularly gridded domains are supported"
198-
)
199-
elif x5_xfm.domain:
200-
# Override reference
201-
Domain = namedtuple("Domain", "affine shape")
202-
reference = Domain(x5_xfm.domain.mapping, x5_xfm.domain.size)
203-
204-
return Transform(x5_xfm.transform, reference=reference)
196+
return from_x5(
197+
load_x5(filename), reference=reference, x5_position=x5_position
198+
)
205199

206200
fmtlist = [fmt] if fmt is not None else ("itk", "lta", "afni", "fsl")
207201

@@ -458,3 +452,20 @@ def load(filename, fmt=None, reference=None, moving=None):
458452
xfm = xfm[0]
459453

460454
return xfm
455+
456+
457+
def from_x5(x5_list, reference=None, x5_position=0):
458+
"""Create an affine from a list of :class:`~nitransforms.io.x5.X5Transform` objects."""
459+
460+
x5_xfm = x5_list[x5_position]
461+
Transform = Affine if x5_xfm.array_length == 1 else LinearTransformsMapping
462+
if (
463+
x5_xfm.domain and not x5_xfm.domain.grid and len(x5_xfm.domain.size) == 3
464+
): # pragma: no cover
465+
raise NotImplementedError("Only 3D regularly gridded domains are supported")
466+
elif x5_xfm.domain:
467+
# Override reference
468+
Domain = namedtuple("Domain", "affine shape")
469+
reference = Domain(x5_xfm.domain.mapping, x5_xfm.domain.size)
470+
471+
return Transform(x5_xfm.transform, reference=reference)

nitransforms/manip.py

Lines changed: 83 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -7,15 +7,26 @@
77
#
88
### ### ### ### ### ### ### ### ### ### ### ### ### ### ### ### ### ### ### ##
99
"""Common interface for transforms."""
10+
11+
import os
1012
from collections.abc import Iterable
1113
import numpy as np
1214

13-
from .base import (
15+
import h5py
16+
from nitransforms.base import (
1417
TransformBase,
1518
TransformError,
1619
)
17-
from .linear import Affine
18-
from .nonlinear import DenseFieldTransform
20+
from nitransforms.io import itk, x5 as x5io
21+
from nitransforms.io.x5 import from_filename as load_x5
22+
from nitransforms.linear import ( # noqa: F401
23+
Affine,
24+
from_x5 as linear_from_x5,
25+
)
26+
from nitransforms.nonlinear import ( # noqa: F401
27+
DenseFieldTransform,
28+
from_x5 as nonlinear_from_x5,
29+
)
1930

2031

2132
class TransformChain(TransformBase):
@@ -183,18 +194,42 @@ def asaffine(self, indices=None):
183194
The indices of the values to extract.
184195
185196
"""
186-
affines = self.transforms if indices is None else np.take(self.transforms, indices)
197+
affines = (
198+
self.transforms if indices is None else np.take(self.transforms, indices)
199+
)
187200
retval = affines[0]
188201
for xfm in affines[1:]:
189202
retval = xfm @ retval
190203
return retval
191204

192205
@classmethod
193-
def from_filename(cls, filename, fmt="X5", reference=None, moving=None):
206+
def from_filename(cls, filename, fmt="X5", reference=None, moving=None, x5_chain=0):
194207
"""Load a transform file."""
195-
from .io import itk
196208

197209
retval = []
210+
if fmt and fmt.upper() == "X5":
211+
# Get list of X5 nodes and generate transforms
212+
xfm_list = [
213+
globals()[f"{node.type}_from_x5"]([node]) for node in load_x5(filename)
214+
]
215+
if not xfm_list:
216+
raise TransformError("Empty transform group")
217+
218+
if x5_chain is None:
219+
return xfm_list
220+
221+
with h5py.File(str(filename), "r") as f:
222+
chain_grp = f.get("TransformChain")
223+
if chain_grp is None:
224+
raise TransformError("X5 file contains no TransformChain")
225+
226+
chain_path = chain_grp[str(x5_chain)][()]
227+
chain_path = (
228+
chain_path.decode() if isinstance(chain_path, bytes) else chain_path
229+
)
230+
231+
return TransformChain([xfm_list[int(idx)] for idx in chain_path.split("/")])
232+
198233
if str(filename).endswith(".h5"):
199234
reference = None
200235
xforms = itk.ITKCompositeH5.from_filename(filename)
@@ -208,6 +243,48 @@ def from_filename(cls, filename, fmt="X5", reference=None, moving=None):
208243

209244
raise NotImplementedError
210245

246+
def to_filename(self, filename, fmt="X5"):
247+
"""Store the transform chain in X5 format."""
248+
249+
if fmt.upper() != "X5":
250+
raise NotImplementedError("Only X5 format is supported for chains")
251+
252+
existing = (
253+
self.from_filename(filename, x5_chain=None)
254+
if os.path.exists(filename)
255+
else []
256+
)
257+
258+
xfm_chain = []
259+
new_xfms = []
260+
next_xfm_index = len(existing)
261+
for xfm in self.transforms:
262+
for eidx, existing_xfm in enumerate(existing):
263+
if xfm == existing_xfm:
264+
xfm_chain.append(eidx)
265+
break
266+
else:
267+
xfm_chain.append(next_xfm_index)
268+
new_xfms.append((next_xfm_index, xfm))
269+
existing.append(xfm)
270+
next_xfm_index += 1
271+
272+
mode = "r+" if os.path.exists(filename) else "w"
273+
with h5py.File(str(filename), mode) as f:
274+
if "Format" not in f.attrs:
275+
f.attrs["Format"] = "X5"
276+
f.attrs["Version"] = np.uint16(1)
277+
278+
tg = f.require_group("TransformGroup")
279+
for idx, node in new_xfms:
280+
g = tg.create_group(str(idx))
281+
x5io._write_x5_group(g, node.to_x5())
282+
283+
cg = f.require_group("TransformChain")
284+
cg.create_dataset(str(len(cg)), data="/".join(str(i) for i in xfm_chain))
285+
286+
return filename
287+
211288

212289
def _as_chain(x):
213290
"""Convert a value into a transform chain."""

0 commit comments

Comments
 (0)