Skip to content

Commit e4b93aa

Browse files
committed
ENH: More comprehensive implementation of ITK affines I/O
1 parent 45a0a2f commit e4b93aa

File tree

6 files changed

+257
-61
lines changed

6 files changed

+257
-61
lines changed

nitransforms/io/base.py

Lines changed: 24 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -1,23 +1,40 @@
11
"""Read/write linear transforms."""
2-
import numpy as np
3-
from nibabel.wrapstruct import LabeledWrapStruct as LWS
2+
from scipy.io.matlab.miobase import get_matfile_version
3+
from scipy.io.matlab.mio4 import MatFile4Reader # , MatFile4Writer
4+
from scipy.io.matlab.mio5 import MatFile5Reader # , MatFile5Writer
45

6+
from ..patched import LabeledWrapStruct
57

6-
class LabeledWrapStruct(LWS):
7-
def __setitem__(self, item, value):
8-
self._structarr[item] = np.asanyarray(value)
8+
9+
class TransformFileError(Exception):
10+
"""A custom exception for transform files."""
911

1012

1113
class StringBasedStruct(LabeledWrapStruct):
14+
"""File data structure from text files."""
15+
1216
def __init__(self,
1317
binaryblock=None,
1418
endianness=None,
1519
check=True):
16-
if binaryblock is not None and getattr(binaryblock, 'dtype',
17-
None) == self.dtype:
20+
"""Create a data structure based off of a string."""
21+
_dtype = getattr(binaryblock, 'dtype', None)
22+
if binaryblock is not None and _dtype == self.dtype:
1823
self._structarr = binaryblock.copy()
1924
return
2025
super(StringBasedStruct, self).__init__(binaryblock, endianness, check)
2126

2227
def __array__(self):
28+
"""Return the internal structure array."""
2329
return self._structarr
30+
31+
32+
def _read_mat(byte_stream):
33+
mjv, _ = get_matfile_version(byte_stream)
34+
if mjv == 0:
35+
reader = MatFile4Reader(byte_stream)
36+
elif mjv == 1:
37+
reader = MatFile5Reader(byte_stream)
38+
elif mjv == 2:
39+
raise TransformFileError('Please use HDF reader for matlab v7.3 files')
40+
return reader.get_variables()

nitransforms/io/itk.py

Lines changed: 138 additions & 28 deletions
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,10 @@
11
"""Read/write ITK transforms."""
22
import numpy as np
3-
from .base import StringBasedStruct
3+
from scipy.io import savemat as _save_mat
4+
from nibabel.affines import from_matvec
5+
from .base import StringBasedStruct, _read_mat
6+
7+
LPS = np.diag([-1, -1, 1, 1])
48

59

610
class ITKLinearTransform(StringBasedStruct):
@@ -13,20 +17,24 @@ class ITKLinearTransform(StringBasedStruct):
1317
('offset', 'f4', 3), # Center of rotation
1418
])
1519
dtype = template_dtype
20+
# files_types = (('string', '.tfm'), ('binary', '.mat'))
21+
# valid_exts = ('.tfm', '.mat')
1622

17-
def __init__(self):
23+
def __init__(self, parameters=None, offset=None):
1824
"""Initialize with default offset and index."""
1925
super().__init__()
20-
self.structarr['offset'] = [0, 0, 0]
2126
self.structarr['index'] = 1
27+
self.structarr['offset'] = offset or [0, 0, 0]
2228
self.structarr['parameters'] = np.eye(4)
29+
if parameters is not None:
30+
self.structarr['parameters'] = parameters
2331

2432
def __str__(self):
2533
"""Generate a string representation."""
2634
sa = self.structarr
2735
lines = [
2836
'#Transform {:d}'.format(sa['index']),
29-
'Transform: MatrixOffsetTransformBase_double_3_3',
37+
'Transform: AffineTransform_float_3_3',
3038
'Parameters: {}'.format(' '.join(
3139
['%g' % p
3240
for p in sa['parameters'][:3, :3].reshape(-1).tolist() +
@@ -36,6 +44,33 @@ def __str__(self):
3644
]
3745
return '\n'.join(lines)
3846

47+
def to_filename(self, filename):
48+
"""Store this transform to a file with the appropriate format."""
49+
if str(filename).endswith('.mat'):
50+
sa = self.structarr
51+
affine = np.array(np.hstack((
52+
sa['parameters'][:3, :3].reshape(-1),
53+
sa['parameters'][:3, 3]))[..., np.newaxis], dtype='f4')
54+
fixed = np.array(sa['offset'][..., np.newaxis], dtype='f4')
55+
mdict = {
56+
'AffineTransform_float_3_3': affine,
57+
'fixed': fixed,
58+
}
59+
_save_mat(str(filename), mdict, format='4')
60+
return
61+
62+
with open(str(filename), 'w') as f:
63+
f.write(self.to_string())
64+
65+
def to_ras(self):
66+
"""Return a nitransforms' internal RAS matrix."""
67+
sa = self.structarr
68+
matrix = sa['parameters']
69+
offset = sa['offset']
70+
c_neg = from_matvec(np.eye(3), offset * -1.0)
71+
c_pos = from_matvec(np.eye(3), offset)
72+
return LPS.dot(c_pos.dot(matrix.dot(c_neg.dot(LPS))))
73+
3974
def to_string(self, banner=None):
4075
"""Convert to a string directly writeable to file."""
4176
string = '%s'
@@ -48,9 +83,47 @@ def to_string(self, banner=None):
4883
return string % self
4984

5085
@classmethod
51-
def from_string(klass, string):
86+
def from_binary(cls, byte_stream, index=None):
87+
"""Read the struct from a matlab binary file."""
88+
mdict = _read_mat(byte_stream)
89+
return cls.from_matlab_dict(mdict, index=index)
90+
91+
@classmethod
92+
def from_fileobj(cls, fileobj, check=True):
93+
"""Read the struct from a file object."""
94+
if fileobj.name.endswith('.mat'):
95+
return cls.from_binary(fileobj)
96+
return cls.from_string(fileobj.read())
97+
98+
@classmethod
99+
def from_matlab_dict(cls, mdict, index=None):
100+
"""Read the struct from a matlab dictionary."""
101+
tf = cls()
102+
sa = tf.structarr
103+
if index is not None:
104+
raise NotImplementedError
105+
106+
sa['index'] = 1
107+
parameters = np.eye(4, dtype='f4')
108+
parameters[:3, :3] = mdict['AffineTransform_float_3_3'][:-3].reshape((3, 3))
109+
parameters[:3, 3] = mdict['AffineTransform_float_3_3'][-3:].flatten()
110+
sa['parameters'] = parameters
111+
sa['offset'] = mdict['fixed'].flatten()
112+
return tf
113+
114+
@classmethod
115+
def from_ras(cls, ras, index=0):
116+
"""Create an ITK affine from a nitransform's RAS+ matrix."""
117+
tf = cls()
118+
sa = tf.structarr
119+
sa['index'] = index + 1
120+
sa['parameters'] = LPS.dot(ras.dot(LPS))
121+
return tf
122+
123+
@classmethod
124+
def from_string(cls, string):
52125
"""Read the struct from string."""
53-
tf = klass()
126+
tf = cls()
54127
sa = tf.structarr
55128
lines = [l for l in string.splitlines()
56129
if l.strip()]
@@ -61,19 +134,14 @@ def from_string(klass, string):
61134
parameters = np.eye(4, dtype='f4')
62135
sa['index'] = int(lines[0][lines[0].index('T'):].split()[1])
63136
sa['offset'] = np.genfromtxt([lines[3].split(':')[-1].encode()],
64-
dtype=klass.dtype['offset'])
137+
dtype=cls.dtype['offset'])
65138
vals = np.genfromtxt([lines[2].split(':')[-1].encode()],
66139
dtype='f4')
67140
parameters[:3, :3] = vals[:-3].reshape((3, 3))
68141
parameters[:3, 3] = vals[-3:]
69142
sa['parameters'] = parameters
70143
return tf
71144

72-
@classmethod
73-
def from_fileobj(klass, fileobj, check=True):
74-
"""Read the struct from a file object."""
75-
return klass.from_string(fileobj.read())
76-
77145

78146
class ITKLinearTransformArray(StringBasedStruct):
79147
"""A string-based structure for series of ITK linear transforms."""
@@ -89,33 +157,80 @@ def __init__(self,
89157
check=True):
90158
"""Initialize with (optionally) a list of transforms."""
91159
super().__init__(binaryblock, endianness, check)
92-
self._xforms = []
93-
for i, mat in enumerate(xforms or []):
94-
xfm = ITKLinearTransform()
95-
xfm['parameters'] = mat
96-
xfm['index'] = i + 1
97-
self._xforms.append(xfm)
160+
self.xforms = [ITKLinearTransform(parameters=mat)
161+
for mat in xforms or []]
162+
163+
@property
164+
def xforms(self):
165+
"""Get the list of internal ITKLinearTransforms."""
166+
return self._xforms
167+
168+
@xforms.setter
169+
def xforms(self, value):
170+
self._xforms = value
171+
172+
# Update indexes
173+
for i, val in enumerate(self._xforms):
174+
val['index'] = i + 1
98175

99176
def __getitem__(self, idx):
100177
"""Allow dictionary access to the transforms."""
101178
if idx == 'xforms':
102179
return self._xforms
103180
if idx == 'nxforms':
104181
return len(self._xforms)
105-
return super().__getitem__(idx)
182+
raise KeyError(idx)
183+
184+
def to_filename(self, filename):
185+
"""Store this transform to a file with the appropriate format."""
186+
if str(filename).endswith('.mat'):
187+
raise NotImplementedError
188+
189+
with open(str(filename), 'w') as f:
190+
f.write(self.to_string())
191+
192+
def to_ras(self):
193+
"""Return a nitransforms' internal RAS matrix."""
194+
return np.stack([xfm.to_ras() for xfm in self._xforms])
106195

107196
def to_string(self):
108197
"""Convert to a string directly writeable to file."""
109198
strings = []
110-
for i, xfm in enumerate(self._xforms):
199+
for i, xfm in enumerate(self.xforms):
111200
xfm.structarr['index'] = i + 1
112201
strings.append(xfm.to_string())
113202
return '\n'.join(strings)
114203

115204
@classmethod
116-
def from_string(klass, string):
205+
def from_binary(cls, byte_stream):
206+
"""Read the struct from a matlab binary file."""
207+
mdict = _read_mat(byte_stream)
208+
nxforms = mdict['fixed'].shape[0]
209+
210+
_self = cls()
211+
_self.xforms = [ITKLinearTransform.from_matlab_dict(mdict, i)
212+
for i in range(nxforms)]
213+
return _self
214+
215+
@classmethod
216+
def from_fileobj(cls, fileobj, check=True):
217+
"""Read the struct from a file object."""
218+
if fileobj.name.endswith('.mat'):
219+
return cls.from_binary(fileobj)
220+
return cls.from_string(fileobj.read())
221+
222+
@classmethod
223+
def from_ras(cls, ras):
224+
"""Create an ITK affine from a nitransform's RAS+ matrix."""
225+
_self = cls()
226+
_self.xforms = [ITKLinearTransform.from_ras(ras[i, ...], i)
227+
for i in range(ras.shape[0])]
228+
return _self
229+
230+
@classmethod
231+
def from_string(cls, string):
117232
"""Read the struct from string."""
118-
_self = klass()
233+
_self = cls()
119234
lines = [l.strip() for l in string.splitlines()
120235
if l.strip()]
121236

@@ -124,11 +239,6 @@ def from_string(klass, string):
124239

125240
string = '\n'.join(lines[1:])
126241
for xfm in string.split('#')[1:]:
127-
_self._xforms.append(ITKLinearTransform.from_string(
242+
_self.xforms.append(ITKLinearTransform.from_string(
128243
'#%s' % xfm))
129244
return _self
130-
131-
@classmethod
132-
def from_fileobj(klass, fileobj, check=True):
133-
"""Read the struct from a file object."""
134-
return klass.from_string(fileobj.read())

nitransforms/linear.py

Lines changed: 5 additions & 15 deletions
Original file line numberDiff line numberDiff line change
@@ -12,7 +12,7 @@
1212
import numpy as np
1313

1414
from nibabel.loadsave import load as loadimg
15-
from nibabel.affines import from_matvec, voxel_sizes, obliquity
15+
from nibabel.affines import voxel_sizes, obliquity
1616
from .base import TransformBase, _as_homogeneous, EQUALITY_TOL
1717
from .patched import shape_zoom_affine
1818
from . import io
@@ -140,10 +140,8 @@ def _to_hdf5(self, x5_root):
140140
def to_filename(self, filename, fmt='X5', moving=None):
141141
"""Store the transform in BIDS-Transforms HDF5 file format (.x5)."""
142142
if fmt.lower() in ['itk', 'ants', 'elastix']:
143-
itkobj = io.itk.ITKLinearTransformArray(
144-
xforms=[LPS.dot(m.dot(LPS)) for m in self.matrix])
145-
with open(filename, 'w') as f:
146-
f.write(itkobj.to_string())
143+
itkobj = io.itk.ITKLinearTransformArray.from_ras(self.matrix)
144+
itkobj.to_filename(filename)
147145
return filename
148146

149147
if fmt.lower() == 'afni':
@@ -235,19 +233,11 @@ def to_filename(self, filename, fmt='X5', moving=None):
235233

236234
def load(filename, fmt='X5', reference=None):
237235
"""Load a linear transform."""
238-
if fmt.lower() in ['itk', 'ants', 'elastix', 'nifty']:
236+
if fmt.lower() in ('itk', 'ants', 'elastix'):
239237
with open(filename) as itkfile:
240238
itkxfm = io.itk.ITKLinearTransformArray.from_fileobj(
241239
itkfile)
242-
243-
matlist = []
244-
for xfm in itkxfm['xforms']:
245-
matrix = xfm['parameters']
246-
offset = xfm['offset']
247-
c_neg = from_matvec(np.eye(3), offset * -1.0)
248-
c_pos = from_matvec(np.eye(3), offset)
249-
matlist.append(LPS.dot(c_pos.dot(matrix.dot(c_neg.dot(LPS)))))
250-
matrix = np.stack(matlist)
240+
matrix = itkxfm.to_ras()
251241
# elif fmt.lower() == 'afni':
252242
# parameters = LPS.dot(self.matrix.dot(LPS))
253243
# parameters = parameters[:3, :].reshape(-1).tolist()

nitransforms/patched.py

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,5 @@
11
import numpy as np
2+
from nibabel.wrapstruct import LabeledWrapStruct as LWS
23

34

45
def shape_zoom_affine(shape, zooms, x_flip=True, y_flip=False):
@@ -63,3 +64,8 @@ def shape_zoom_affine(shape, zooms, x_flip=True, y_flip=False):
6364
aff[:3, :3] = np.diag(zooms)
6465
aff[:3, -1] = -origin * zooms
6566
return aff
67+
68+
69+
class LabeledWrapStruct(LWS):
70+
def __setitem__(self, item, value):
71+
self._structarr[item] = np.asanyarray(value)

0 commit comments

Comments
 (0)