Skip to content

Commit d35f658

Browse files
committed
ENH: Support for transforms mappings (e.g., head-motion correction)
Implements two types of transforms mappings: 1. a general one for any internal transform, and 2. an optimized mapping for lineart transforms (head-motion correction). Since scipy's interpn only accepts linear and nearest interpolations, these mappings will not support for simultaneous slice-timing correction (STC) for the time being. The addition of 4D tensor B-Spline basis for interpolation would allow for simultaneous STC, HMC and SDC of functional time-series. Not sure we should keep supporting Lanczos interpolation, very much less if we want to do all three corrections at the same time. Closes #46
1 parent f8b725f commit d35f658

File tree

4 files changed

+396
-50
lines changed

4 files changed

+396
-50
lines changed

nitransforms/base.py

Lines changed: 89 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -12,15 +12,16 @@
1212
import numpy as np
1313
import h5py
1414
import warnings
15-
from nibabel.loadsave import load
15+
from nibabel.loadsave import load as _nbload
16+
from nibabel import funcs as _nbfuncs
1617
from nibabel.nifti1 import intent_codes as INTENT_CODES
1718
from nibabel.cifti2 import Cifti2Image
1819
from scipy import ndimage as ndi
1920

2021
EQUALITY_TOL = 1e-5
2122

2223

23-
class TransformError(ValueError):
24+
class TransformError(TypeError):
2425
"""A custom exception for transforms."""
2526

2627

@@ -51,7 +52,7 @@ def __init__(self, dataset):
5152
return
5253

5354
if isinstance(dataset, (str, Path)):
54-
dataset = load(str(dataset))
55+
dataset = _nbload(str(dataset))
5556

5657
if hasattr(dataset, 'numDA'): # Looks like a Gifti file
5758
_das = dataset.get_arrays_from_intent(INTENT_CODES['pointset'])
@@ -96,11 +97,15 @@ class ImageGrid(SampledSpatialData):
9697
def __init__(self, image):
9798
"""Create a gridded sampling reference."""
9899
if isinstance(image, (str, Path)):
99-
image = load(str(image))
100+
image = _nbfuncs.squeeze_image(_nbload(str(image)))
100101

101102
self._affine = image.affine
102103
self._shape = image.shape
104+
103105
self._ndim = getattr(image, 'ndim', len(image.shape))
106+
if self._ndim == 4:
107+
self._shape = image.shape[:3]
108+
self._ndim = 3
104109

105110
self._npoints = getattr(image, 'npoints',
106111
np.prod(image.shape))
@@ -172,9 +177,9 @@ def __init__(self):
172177
"""Instantiate a transform."""
173178
self._reference = None
174179

175-
def __call__(self, x, inverse=False, index=0):
180+
def __call__(self, x, inverse=False):
176181
"""Apply y = f(x)."""
177-
return self.map(x, inverse=inverse, index=index)
182+
return self.map(x, inverse=inverse)
178183

179184
def __add__(self, b):
180185
"""
@@ -246,13 +251,13 @@ def apply(self, spatialimage, reference=None,
246251
247252
"""
248253
if reference is not None and isinstance(reference, (str, Path)):
249-
reference = load(str(reference))
254+
reference = _nbload(str(reference))
250255

251256
_ref = self.reference if reference is None \
252257
else SpatialReference.factory(reference)
253258

254259
if isinstance(spatialimage, (str, Path)):
255-
spatialimage = load(str(spatialimage))
260+
spatialimage = _nbload(str(spatialimage))
256261

257262
data = np.asanyarray(spatialimage.dataobj)
258263
output_dtype = output_dtype or data.dtype
@@ -279,7 +284,7 @@ def apply(self, spatialimage, reference=None,
279284

280285
return resampled
281286

282-
def map(self, x, inverse=False, index=0):
287+
def map(self, x, inverse=False):
283288
r"""
284289
Apply :math:`y = f(x)`.
285290
@@ -291,8 +296,6 @@ def map(self, x, inverse=False, index=0):
291296
Input RAS+ coordinates (i.e., physical coordinates).
292297
inverse : bool
293298
If ``True``, apply the inverse transform :math:`x = f^{-1}(y)`.
294-
index : int, optional
295-
Transformation index
296299
297300
Returns
298301
-------
@@ -407,7 +410,7 @@ def insert(self, i, x):
407410
"""
408411
self.transforms = self.transforms[:i] + _as_chain(x) + self.transforms[i:]
409412

410-
def map(self, x, inverse=False, index=0):
413+
def map(self, x, inverse=False):
411414
"""
412415
Apply a succession of transforms, e.g., :math:`y = f_3(f_2(f_1(f_0(x))))`.
413416
@@ -438,6 +441,80 @@ def map(self, x, inverse=False, index=0):
438441
return x
439442

440443

444+
class TransformMapping(TransformBase):
445+
"""Implements a four-dimensional series of transforms."""
446+
447+
__slots__ = ('_transforms', )
448+
449+
def __init__(self, transforms=None):
450+
"""Initialize a chain of transforms."""
451+
self._transforms = None
452+
if transforms is not None:
453+
self.transforms = transforms
454+
455+
def __getitem__(self, i):
456+
"""
457+
Enable indexed access of transform chains.
458+
459+
Example
460+
-------
461+
>>> T1 = TransformBase()
462+
>>> xfm4d = TransformMapping([T1, TransformBase(), TransformBase()])
463+
>>> xfm4d[0] is T1
464+
True
465+
466+
"""
467+
return self.transforms[i]
468+
469+
def __len__(self):
470+
"""Enable using len()."""
471+
return len(self.transforms)
472+
473+
@property
474+
def transforms(self):
475+
"""Get the internal list of transforms."""
476+
return self._transforms
477+
478+
@transforms.setter
479+
def transforms(self, value):
480+
self._transforms = value
481+
482+
def append(self, x):
483+
"""
484+
Concatenate one element to the chain.
485+
486+
Example
487+
-------
488+
>>> xfm4d = TransformMapping([TransformBase(), TransformBase()])
489+
>>> xfm4d.append(TransformBase())
490+
>>> len(xfm4d)
491+
3
492+
493+
"""
494+
self.transforms.append(x)
495+
496+
def insert(self, i, x):
497+
"""
498+
Insert an item at a given position.
499+
500+
Example
501+
-------
502+
>>> xfm4d = TransformMapping([TransformBase(), TransformBase()])
503+
>>> xfm4d.insert(1, TransformBase())
504+
>>> len(xfm4d)
505+
3
506+
507+
"""
508+
self.transforms.insert(i, x)
509+
510+
def map(self, x, inverse=False):
511+
"""Apply a map of transforms, e.g., :math:`y_t = f_t(x_t)`."""
512+
if not self.transforms:
513+
raise TransformError('Cannot apply an empty transforms mapping.')
514+
515+
return [xfm(x, inverse=inverse) for xfm in self.transforms]
516+
517+
441518
def _as_homogeneous(xyz, dtype='float32', dim=3):
442519
"""
443520
Convert 2D and 3D coordinates into homogeneous coordinates.

0 commit comments

Comments
 (0)