Skip to content

Commit d092ea1

Browse files
sgiavasisoesteban
authored andcommitted
Set up I/O functions to be expandable, deferring some design decisions to the future. Started the beginning of some xfm_utils for quicker testing and validation.
(cherry picked from commit 7ba96c3)
1 parent 92e932c commit d092ea1

File tree

1 file changed

+126
-0
lines changed

1 file changed

+126
-0
lines changed

nitransforms/io/base.py

Lines changed: 126 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,11 +1,137 @@
11
"""Read/write linear transforms."""
22
from pathlib import Path
33
import numpy as np
4+
import nibabel as nb
45
from nibabel import load as loadimg
56

7+
import h5py
8+
69
from ..patched import LabeledWrapStruct
710

811

12+
def get_xfm_filetype(xfm_file):
13+
path = Path(xfm_file)
14+
ext = path.suffix
15+
if ext == '.gz' and path.name.endswith('.nii.gz'):
16+
return 'nifti'
17+
18+
file_types = {
19+
'.nii': 'nifti',
20+
'.h5': 'hdf5',
21+
'.x5': 'x5',
22+
'.txt': 'txt',
23+
'.mat': 'txt'
24+
}
25+
return file_types.get(ext, 'unknown')
26+
27+
def gather_fields(x5=None, hdf5=None, nifti=None, shape=None, affine=None, header=None):
28+
xfm_fields = {
29+
"x5": x5,
30+
"hdf5": hdf5,
31+
"nifti": nifti,
32+
"header": header,
33+
"shape": shape,
34+
"affine": affine
35+
}
36+
return xfm_fields
37+
38+
def load_nifti(nifti_file):
39+
nifti_xfm = nb.load(nifti_file)
40+
xfm_data = nifti_xfm.get_fdata()
41+
shape = nifti_xfm.shape
42+
affine = nifti_xfm.affine
43+
header = getattr(nifti_xfm, "header", None)
44+
return gather_fields(nifti=xfm_data, shape=shape, affine=affine, header=header)
45+
46+
def load_hdf5(hdf5_file):
47+
storage = {}
48+
49+
def get_hdf5_items(name, x5_root):
50+
if isinstance(x5_root, h5py.Dataset):
51+
storage[name] = {
52+
'type': 'dataset',
53+
'attrs': dict(x5_root.attrs),
54+
'shape': x5_root.shape,
55+
'data': x5_root[()]
56+
}
57+
elif isinstance(x5_root, h5py.Group):
58+
storage[name] = {
59+
'type': 'group',
60+
'attrs': dict(x5_root.attrs),
61+
'members': {}
62+
}
63+
64+
with h5py.File(hdf5_file, 'r') as f:
65+
f.visititems(get_hdf5_items)
66+
if storage:
67+
hdf5_storage = {'hdf5': storage}
68+
return hdf5_storage
69+
70+
def load_x5(x5_file):
71+
load_hdf5(x5_file)
72+
73+
def load_mat(mat_file):
74+
affine_matrix = np.loadtxt(mat_file)
75+
affine = nb.affines.from_matvec(affine_matrix[:,:3], affine_matrix[:,3])
76+
return gather_fields(affine=affine)
77+
78+
def xfm_loader(xfm_file):
79+
loaders = {
80+
'nifti': load_nifti,
81+
'hdf5': load_hdf5,
82+
'x5': load_x5,
83+
'txt': load_mat,
84+
'mat': load_mat
85+
}
86+
xfm_filetype = get_xfm_filetype(xfm_file)
87+
loader = loaders.get(xfm_filetype)
88+
if loader is None:
89+
raise ValueError(f"Unsupported file type: {xfm_filetype}")
90+
return loader(xfm_file)
91+
92+
def to_filename(self, filename, fmt="X5"):
93+
"""Store the transform in BIDS-Transforms HDF5 file format (.x5)."""
94+
with h5py.File(filename, "w") as out_file:
95+
out_file.attrs["Format"] = "X5"
96+
out_file.attrs["Version"] = np.uint16(1)
97+
root = out_file.create_group("/0")
98+
self._to_hdf5(root)
99+
100+
return filename
101+
102+
def _to_hdf5(self, x5_root):
103+
"""Serialize this object into the x5 file format."""
104+
transform_group = x5_root.create_group("TransformGroup")
105+
106+
"""Group '0' containing Affine transform"""
107+
transform_0 = transform_group.create_group("0")
108+
109+
transform_0.attrs["Type"] = "Affine"
110+
transform_0.create_dataset("Transform", data=self._affine)
111+
transform_0.create_dataset("Inverse", data=np.linalg.inv(self._affine))
112+
113+
metadata = {"key": "value"}
114+
transform_0.attrs["Metadata"] = str(metadata)
115+
116+
"""sub-group 'Domain' contained within group '0' """
117+
domain_group = transform_0.create_group("Domain")
118+
#domain_group.attrs["Grid"] = self._grid
119+
#domain_group.create_dataset("Size", data=_as_homogeneous(self._reference.shape))
120+
#domain_group.create_dataset("Mapping", data=self.mapping)
121+
122+
def _from_x5(self, x5_root):
123+
variables = {}
124+
125+
x5_root.visititems(lambda name, x5_root: loader(name, x5_root, variables))
126+
127+
_transform = variables["TransformGroup/0/Transform"]
128+
_inverse = variables["TransformGroup/0/Inverse"]
129+
_size = variables["TransformGroup/0/Domain/Size"]
130+
_mapping = variables["TransformGroup/0/Domain/Mapping"]
131+
132+
return _transform, _inverse, _size, _map
133+
134+
9135
class TransformIOError(IOError):
10136
"""General I/O exception while reading/writing transforms."""
11137

0 commit comments

Comments
 (0)