Skip to content

Commit b6a9eec

Browse files
authored
Merge pull request #2 from sgiavasis/enh/x5-implementation
🔧 Loaders and util for H5 and NIfTI transforms
2 parents f51ca65 + e0e5e14 commit b6a9eec

File tree

3 files changed

+209
-65
lines changed

3 files changed

+209
-65
lines changed

nitransforms/base.py

Lines changed: 43 additions & 62 deletions
Original file line numberDiff line numberDiff line change
@@ -178,14 +178,48 @@ def __ne__(self, other):
178178
class TransformBase:
179179
"""Abstract image class to represent transforms."""
180180

181-
__slots__ = ("_reference", "_ndim",)
182-
183-
def __init__(self, reference=None):
181+
__slots__ = ("_reference", "_ndim", "_affine", "_shape", "_header",
182+
"_grid", "_mapping", "_hdf5_dct", "_x5_dct")
183+
184+
x5_struct = {
185+
'TransformGroup/0': {
186+
'Type': None,
187+
'Transform': None,
188+
'Metadata': None,
189+
'Inverse': None
190+
},
191+
'TransformGroup/0/Domain': {
192+
'Grid': None,
193+
'Size': None,
194+
'Mapping': None
195+
},
196+
'TransformGroup/1': {},
197+
'TransformChain': {}
198+
}
199+
200+
def __init__(self, x5=None, hdf5=None, nifti=None, shape=None, affine=None,
201+
header=None, reference=None):
184202
"""Instantiate a transform."""
203+
185204
self._reference = None
186205
if reference:
187206
self.reference = reference
188207

208+
if nifti is not None:
209+
self._x5_dct = self.init_x5_structure(nifti)
210+
elif hdf5:
211+
self.update_x5_structure(hdf5)
212+
elif x5:
213+
self.update_x5_structure(x5)
214+
215+
self._shape = shape
216+
self._affine = affine
217+
self._header = header
218+
219+
# TO-DO
220+
self._grid = None
221+
self._mapping = None
222+
189223
def __call__(self, x, inverse=False):
190224
"""Apply y = f(x)."""
191225
return self.map(x, inverse=inverse)
@@ -222,6 +256,12 @@ def ndim(self):
222256
"""Access the dimensions of the reference space."""
223257
raise TypeError("TransformBase has no dimensions")
224258

259+
def init_x5_structure(self, xfm_data=None):
260+
self.x5_struct['TransformGroup/0/Transform'] = xfm_data
261+
262+
def update_x5_structure(self, hdf5_struct=None):
263+
self.x5_struct.update(hdf5_struct)
264+
225265
def apply(
226266
self,
227267
spatialimage,
@@ -338,65 +378,6 @@ def map(self, x, inverse=False):
338378
"""
339379
return x
340380

341-
def to_filename(self, filename, fmt="X5"):
342-
"""Store the transform in BIDS-Transforms HDF5 file format (.x5)."""
343-
with h5py.File(filename, "w") as out_file:
344-
out_file.attrs["Format"] = "X5"
345-
out_file.attrs["Version"] = np.uint16(1)
346-
root = out_file.create_group("/0")
347-
self._to_hdf5(root)
348-
349-
return filename
350-
351-
def _to_hdf5(self, x5_root):
352-
"""Serialize this object into the x5 file format."""
353-
transform_group = x5_root.create_group("TransformGroup")
354-
355-
"""Group '0' containing Affine transform"""
356-
transform_0 = transform_group.create_group("0")
357-
358-
transform_0.attrs["Type"] = "Affine"
359-
transform_0.create_dataset("Transform", data=self._matrix)
360-
transform_0.create_dataset("Inverse", data=np.linalg.inv(self._matrix))
361-
362-
metadata = {"key": "value"}
363-
transform_0.attrs["Metadata"] = str(metadata)
364-
365-
"""sub-group 'Domain' contained within group '0' """
366-
domain_group = transform_0.create_group("Domain")
367-
domain_group.attrs["Grid"] = self.grid
368-
domain_group.create_dataset("Size", data=_as_homogeneous(self._reference.shape))
369-
domain_group.create_dataset("Mapping", data=self.map)
370-
371-
raise NotImplementedError
372-
373-
def read_x5(self, x5_root):
374-
variables = {}
375-
with h5py.File(x5_root, "r") as f:
376-
f.visititems(lambda filename, x5_root: self._from_hdf5(filename, x5_root, variables))
377-
378-
_transform = variables["TransformGroup/0/Transform"]
379-
_inverse = variables["TransformGroup/0/Inverse"]
380-
_size = variables["TransformGroup/0/Domain/Size"]
381-
_map = variables["TransformGroup/0/Domain/Mapping"]
382-
383-
return _transform, _inverse, _size, _map
384-
385-
def _from_hdf5(self, name, x5_root, storage):
386-
if isinstance(x5_root, h5py.Dataset):
387-
storage[name] = {
388-
'type': 'dataset',
389-
'attrs': dict(x5_root.attrs),
390-
'shape': x5_root.shape,
391-
'data': x5_root[()] # Read the data
392-
}
393-
elif isinstance(x5_root, h5py.Group):
394-
storage[name] = {
395-
'type': 'group',
396-
'attrs': dict(x5_root.attrs),
397-
'members': {}
398-
}
399-
400381

401382
def _as_homogeneous(xyz, dtype="float32", dim=3):
402383
"""

nitransforms/cli.py

Lines changed: 40 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -2,10 +2,13 @@
22
import os
33
from textwrap import dedent
44

5+
from nitransforms.io.base import xfm_loader
6+
from nitransforms.linear import load as linload
7+
from nitransforms.nonlinear import load as nlinload
58

6-
from .linear import load as linload
7-
from .nonlinear import load as nlinload
9+
from nitransforms.base import TransformBase
810

11+
import pprint
912

1013
def cli_apply(pargs):
1114
"""
@@ -45,9 +48,27 @@ def cli_apply(pargs):
4548
cval=pargs.cval,
4649
prefilter=pargs.prefilter,
4750
)
48-
moved.to_filename(pargs.out or f"nt_{os.path.basename(pargs.moving)}")
51+
#moved.to_filename(pargs.out or f"nt_{os.path.basename(pargs.moving)}")
4952

5053

54+
def cli_xfm_util(pargs):
55+
"""
56+
"""
57+
58+
xfm_data = xfm_loader(pargs.transform)
59+
xfm_x5 = TransformBase(**xfm_data)
60+
61+
if pargs.info:
62+
pprint.pprint(xfm_x5.x5_struct)
63+
print(f"Shape:\n{xfm_x5._shape}")
64+
print(f"Affine:\n{xfm_x5._affine}")
65+
66+
if pargs.x5:
67+
filename = f"{os.path.basename(pargs.transform).split('.')[0]}.x5"
68+
xfm_x5.to_filename(filename)
69+
print(f"Writing out {filename}")
70+
71+
5172
def get_parser():
5273
desc = dedent(
5374
"""
@@ -56,6 +77,7 @@ def get_parser():
5677
Commands:
5778
5879
apply Apply a transformation to an image
80+
xfm_util Assorted transform utilities
5981
6082
For command specific information, use 'nt <command> -h'.
6183
"""
@@ -120,6 +142,17 @@ def _add_subparser(name, description):
120142
help="Determines if the image's data array is prefiltered with a spline filter before "
121143
"interpolation (default: True)",
122144
)
145+
146+
xfm_util = _add_subparser("xfm_util", cli_xfm_util.__doc__)
147+
xfm_util.set_defaults(func=cli_xfm_util)
148+
xfm_util.add_argument("transform", help="The transform file")
149+
xfm_util.add_argument("--info",
150+
action="store_true",
151+
help="Get information about the transform")
152+
xfm_util.add_argument("--x5",
153+
action="store_true",
154+
help="Convert transform to .x5 file format.")
155+
123156
return parser, subparsers
124157

125158

@@ -133,3 +166,7 @@ def main(pargs=None):
133166
subparser = subparsers.choices[pargs.command]
134167
subparser.print_help()
135168
raise (e)
169+
170+
171+
if __name__ == "__main__":
172+
main()

nitransforms/io/base.py

Lines changed: 126 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,11 +1,137 @@
11
"""Read/write linear transforms."""
22
from pathlib import Path
33
import numpy as np
4+
import nibabel as nb
45
from nibabel import load as loadimg
56

7+
import h5py
8+
69
from ..patched import LabeledWrapStruct
710

811

12+
def get_xfm_filetype(xfm_file):
13+
path = Path(xfm_file)
14+
ext = path.suffix
15+
if ext == '.gz' and path.name.endswith('.nii.gz'):
16+
return 'nifti'
17+
18+
file_types = {
19+
'.nii': 'nifti',
20+
'.h5': 'hdf5',
21+
'.x5': 'x5',
22+
'.txt': 'txt',
23+
'.mat': 'txt'
24+
}
25+
return file_types.get(ext, 'unknown')
26+
27+
def gather_fields(x5=None, hdf5=None, nifti=None, shape=None, affine=None, header=None):
28+
xfm_fields = {
29+
"x5": x5,
30+
"hdf5": hdf5,
31+
"nifti": nifti,
32+
"header": header,
33+
"shape": shape,
34+
"affine": affine
35+
}
36+
return xfm_fields
37+
38+
def load_nifti(nifti_file):
39+
nifti_xfm = nb.load(nifti_file)
40+
xfm_data = nifti_xfm.get_fdata()
41+
shape = nifti_xfm.shape
42+
affine = nifti_xfm.affine
43+
header = getattr(nifti_xfm, "header", None)
44+
return gather_fields(nifti=xfm_data, shape=shape, affine=affine, header=header)
45+
46+
def load_hdf5(hdf5_file):
47+
storage = {}
48+
49+
def get_hdf5_items(name, x5_root):
50+
if isinstance(x5_root, h5py.Dataset):
51+
storage[name] = {
52+
'type': 'dataset',
53+
'attrs': dict(x5_root.attrs),
54+
'shape': x5_root.shape,
55+
'data': x5_root[()]
56+
}
57+
elif isinstance(x5_root, h5py.Group):
58+
storage[name] = {
59+
'type': 'group',
60+
'attrs': dict(x5_root.attrs),
61+
'members': {}
62+
}
63+
64+
with h5py.File(hdf5_file, 'r') as f:
65+
f.visititems(get_hdf5_items)
66+
if storage:
67+
hdf5_storage = {'hdf5': storage}
68+
return hdf5_storage
69+
70+
def load_x5(x5_file):
71+
load_hdf5(x5_file)
72+
73+
def load_mat(mat_file):
74+
affine_matrix = np.loadtxt(mat_file)
75+
affine = nb.affines.from_matvec(affine_matrix[:,:3], affine_matrix[:,3])
76+
return gather_fields(affine=affine)
77+
78+
def xfm_loader(xfm_file):
79+
loaders = {
80+
'nifti': load_nifti,
81+
'hdf5': load_hdf5,
82+
'x5': load_x5,
83+
'txt': load_mat,
84+
'mat': load_mat
85+
}
86+
xfm_filetype = get_xfm_filetype(xfm_file)
87+
loader = loaders.get(xfm_filetype)
88+
if loader is None:
89+
raise ValueError(f"Unsupported file type: {xfm_filetype}")
90+
return loader(xfm_file)
91+
92+
def to_filename(self, filename, fmt="X5"):
93+
"""Store the transform in BIDS-Transforms HDF5 file format (.x5)."""
94+
with h5py.File(filename, "w") as out_file:
95+
out_file.attrs["Format"] = "X5"
96+
out_file.attrs["Version"] = np.uint16(1)
97+
root = out_file.create_group("/0")
98+
self._to_hdf5(root)
99+
100+
return filename
101+
102+
def _to_hdf5(self, x5_root):
103+
"""Serialize this object into the x5 file format."""
104+
transform_group = x5_root.create_group("TransformGroup")
105+
106+
"""Group '0' containing Affine transform"""
107+
transform_0 = transform_group.create_group("0")
108+
109+
transform_0.attrs["Type"] = "Affine"
110+
transform_0.create_dataset("Transform", data=self._affine)
111+
transform_0.create_dataset("Inverse", data=np.linalg.inv(self._affine))
112+
113+
metadata = {"key": "value"}
114+
transform_0.attrs["Metadata"] = str(metadata)
115+
116+
"""sub-group 'Domain' contained within group '0' """
117+
domain_group = transform_0.create_group("Domain")
118+
#domain_group.attrs["Grid"] = self._grid
119+
#domain_group.create_dataset("Size", data=_as_homogeneous(self._reference.shape))
120+
#domain_group.create_dataset("Mapping", data=self.mapping)
121+
122+
def _from_x5(self, x5_root):
123+
variables = {}
124+
125+
x5_root.visititems(lambda name, x5_root: loader(name, x5_root, variables))
126+
127+
_transform = variables["TransformGroup/0/Transform"]
128+
_inverse = variables["TransformGroup/0/Inverse"]
129+
_size = variables["TransformGroup/0/Domain/Size"]
130+
_mapping = variables["TransformGroup/0/Domain/Mapping"]
131+
132+
return _transform, _inverse, _size, _map
133+
134+
9135
class TransformIOError(IOError):
10136
"""General I/O exception while reading/writing transforms."""
11137

0 commit comments

Comments
 (0)