Skip to content

Commit 838d385

Browse files
authored
Merge pull request #79 from oesteban/enh/4-read-itk-h5
ENH: Read (and apply) ITK/ANTs' composite HDF5 transforms
2 parents 439354a + 3185b1d commit 838d385

File tree

6 files changed

+315
-149
lines changed

6 files changed

+315
-149
lines changed

nitransforms/base.py

Lines changed: 3 additions & 136 deletions
Original file line numberDiff line numberDiff line change
@@ -8,7 +8,6 @@
88
### ### ### ### ### ### ### ### ### ### ### ### ### ### ### ### ### ### ### ##
99
"""Common interface for transforms."""
1010
from pathlib import Path
11-
from collections.abc import Iterable
1211
import numpy as np
1312
import h5py
1413
import warnings
@@ -168,10 +167,10 @@ def __ne__(self, other):
168167
return not self == other
169168

170169

171-
class TransformBase(object):
170+
class TransformBase:
172171
"""Abstract image class to represent transforms."""
173172

174-
__slots__ = ['_reference']
173+
__slots__ = ('_reference', )
175174

176175
def __init__(self, reference=None):
177176
"""Instantiate a transform."""
@@ -191,13 +190,11 @@ def __add__(self, b):
191190
-------
192191
>>> T1 = TransformBase()
193192
>>> added = T1 + TransformBase()
194-
>>> isinstance(added, TransformChain)
195-
True
196-
197193
>>> len(added.transforms)
198194
2
199195
200196
"""
197+
from .manip import TransformChain
201198
return TransformChain(transforms=[self, b])
202199

203200
@property
@@ -322,127 +319,6 @@ def _to_hdf5(self, x5_root):
322319
raise NotImplementedError
323320

324321

325-
class TransformChain(TransformBase):
326-
"""Implements the concatenation of transforms."""
327-
328-
__slots__ = ['_transforms']
329-
330-
def __init__(self, transforms=None):
331-
"""Initialize a chain of transforms."""
332-
self._transforms = None
333-
if transforms is not None:
334-
self.transforms = transforms
335-
336-
def __add__(self, b):
337-
"""
338-
Compose this and other transforms.
339-
340-
Example
341-
-------
342-
>>> T1 = TransformBase()
343-
>>> added = T1 + TransformBase() + TransformBase()
344-
>>> isinstance(added, TransformChain)
345-
True
346-
347-
>>> len(added.transforms)
348-
3
349-
350-
"""
351-
self.append(b)
352-
return self
353-
354-
def __getitem__(self, i):
355-
"""
356-
Enable indexed access of transform chains.
357-
358-
Example
359-
-------
360-
>>> T1 = TransformBase()
361-
>>> chain = T1 + TransformBase()
362-
>>> chain[0] is T1
363-
True
364-
365-
"""
366-
return self.transforms[i]
367-
368-
def __len__(self):
369-
"""Enable using len()."""
370-
return len(self.transforms)
371-
372-
@property
373-
def transforms(self):
374-
"""Get the internal list of transforms."""
375-
return self._transforms
376-
377-
@transforms.setter
378-
def transforms(self, value):
379-
self._transforms = _as_chain(value)
380-
if self.transforms[0].reference:
381-
self.reference = self.transforms[0].reference
382-
383-
def append(self, x):
384-
"""
385-
Concatenate one element to the chain.
386-
387-
Example
388-
-------
389-
>>> chain = TransformChain(transforms=TransformBase())
390-
>>> chain.append((TransformBase(), TransformBase()))
391-
>>> len(chain)
392-
3
393-
394-
"""
395-
self.transforms += _as_chain(x)
396-
397-
def insert(self, i, x):
398-
"""
399-
Insert an item at a given position.
400-
401-
Example
402-
-------
403-
>>> chain = TransformChain(transforms=[TransformBase(), TransformBase()])
404-
>>> chain.insert(1, TransformBase())
405-
>>> len(chain)
406-
3
407-
408-
>>> chain.insert(1, TransformChain(chain))
409-
>>> len(chain)
410-
6
411-
412-
"""
413-
self.transforms = self.transforms[:i] + _as_chain(x) + self.transforms[i:]
414-
415-
def map(self, x, inverse=False):
416-
"""
417-
Apply a succession of transforms, e.g., :math:`y = f_3(f_2(f_1(f_0(x))))`.
418-
419-
Example
420-
-------
421-
>>> chain = TransformChain(transforms=[TransformBase(), TransformBase()])
422-
>>> chain([(0., 0., 0.), (1., 1., 1.), (-1., -1., -1.)])
423-
[(0.0, 0.0, 0.0), (1.0, 1.0, 1.0), (-1.0, -1.0, -1.0)]
424-
425-
>>> chain([(0., 0., 0.), (1., 1., 1.), (-1., -1., -1.)], inverse=True)
426-
[(0.0, 0.0, 0.0), (1.0, 1.0, 1.0), (-1.0, -1.0, -1.0)]
427-
428-
>>> TransformChain()((0., 0., 0.)) # doctest: +IGNORE_EXCEPTION_DETAIL
429-
Traceback (most recent call last):
430-
TransformError:
431-
432-
"""
433-
if not self.transforms:
434-
raise TransformError('Cannot apply an empty transforms chain.')
435-
436-
transforms = self.transforms
437-
if inverse:
438-
transforms = reversed(self.transforms)
439-
440-
for xfm in transforms:
441-
x = xfm(x, inverse=inverse)
442-
443-
return x
444-
445-
446322
def _as_homogeneous(xyz, dtype='float32', dim=3):
447323
"""
448324
Convert 2D and 3D coordinates into homogeneous coordinates.
@@ -473,12 +349,3 @@ def _as_homogeneous(xyz, dtype='float32', dim=3):
473349
def _apply_affine(x, affine, dim):
474350
"""Get the image array's indexes corresponding to coordinates."""
475351
return affine.dot(_as_homogeneous(x, dim=dim).T)[:dim, ...].T
476-
477-
478-
def _as_chain(x):
479-
"""Convert a value into a transform chain."""
480-
if isinstance(x, TransformChain):
481-
return x.transforms
482-
if isinstance(x, Iterable):
483-
return list(x)
484-
return [x]

nitransforms/io/itk.py

Lines changed: 66 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -2,6 +2,7 @@
22
import warnings
33
import numpy as np
44
from scipy.io import savemat as _save_mat
5+
from nibabel import Nifti1Header, Nifti1Image
56
from nibabel.affines import from_matvec
67
from .base import (
78
BaseLinearTransformList,
@@ -29,7 +30,9 @@ def __init__(self, parameters=None, offset=None):
2930
"""Initialize with default offset and index."""
3031
super().__init__()
3132
self.structarr['index'] = 0
32-
self.structarr['offset'] = offset or [0, 0, 0]
33+
if offset is None:
34+
offset = np.zeros((3,), dtype='float')
35+
self.structarr['offset'] = offset
3336
self.structarr['parameters'] = np.eye(4)
3437
if parameters is not None:
3538
self.structarr['parameters'] = parameters
@@ -280,3 +283,65 @@ def from_image(cls, imgobj):
280283
field[..., (0, 1)] *= -1.0
281284

282285
return imgobj.__class__(field, imgobj.affine, hdr)
286+
287+
288+
class ITKCompositeH5:
289+
"""A data structure for ITK's HDF5 files."""
290+
291+
@classmethod
292+
def from_filename(cls, filename):
293+
"""Read the struct from a file given its path."""
294+
from h5py import File as H5File
295+
if not str(filename).endswith('.h5'):
296+
raise RuntimeError("Extension is not .h5")
297+
298+
with H5File(str(filename)) as f:
299+
return cls.from_h5obj(f)
300+
301+
@classmethod
302+
def from_h5obj(cls, fileobj, check=True):
303+
"""Read the struct from a file object."""
304+
xfm_list = []
305+
h5group = fileobj["TransformGroup"]
306+
typo_fallback = "Transform"
307+
try:
308+
h5group['1'][f"{typo_fallback}Parameters"]
309+
except KeyError:
310+
typo_fallback = "Tranform"
311+
312+
for xfm in list(h5group.values())[1:]:
313+
if xfm["TransformType"][0].startswith(b"AffineTransform"):
314+
_params = np.asanyarray(xfm[f"{typo_fallback}Parameters"])
315+
xfm_list.append(
316+
ITKLinearTransform(
317+
parameters=from_matvec(_params[:-3].reshape(3, 3), _params[-3:]),
318+
offset=np.asanyarray(xfm[f"{typo_fallback}FixedParameters"])
319+
)
320+
)
321+
continue
322+
if xfm["TransformType"][0].startswith(b"DisplacementFieldTransform"):
323+
_fixed = np.asanyarray(xfm[f"{typo_fallback}FixedParameters"])
324+
shape = _fixed[:3].astype('uint16').tolist()
325+
offset = _fixed[3:6].astype('uint16')
326+
zooms = _fixed[6:9].astype('float')
327+
directions = _fixed[9:].astype('float').reshape((3, 3))
328+
affine = from_matvec(directions * zooms, offset)
329+
field = np.asanyarray(xfm[f"{typo_fallback}Parameters"]).reshape(
330+
tuple(shape + [1, -1])
331+
)
332+
hdr = Nifti1Header()
333+
hdr.set_intent("vector")
334+
hdr.set_data_dtype("float")
335+
336+
xfm_list.append(
337+
ITKDisplacementsField.from_image(
338+
Nifti1Image(field.astype("float"), affine, hdr)
339+
)
340+
)
341+
continue
342+
343+
raise NotImplementedError(
344+
f"Unsupported transform type {xfm['TransformType'][0]}"
345+
)
346+
347+
return xfm_list

0 commit comments

Comments
 (0)