12
12
import numpy as np
13
13
import h5py
14
14
import warnings
15
- from nibabel .loadsave import load
15
+ from nibabel .loadsave import load as _nbload
16
+ from nibabel import funcs as _nbfuncs
16
17
from nibabel .nifti1 import intent_codes as INTENT_CODES
17
18
from nibabel .cifti2 import Cifti2Image
18
19
from scipy import ndimage as ndi
19
20
20
21
EQUALITY_TOL = 1e-5
21
22
22
23
23
- class TransformError (ValueError ):
24
+ class TransformError (TypeError ):
24
25
"""A custom exception for transforms."""
25
26
26
27
@@ -51,7 +52,7 @@ def __init__(self, dataset):
51
52
return
52
53
53
54
if isinstance (dataset , (str , Path )):
54
- dataset = load (str (dataset ))
55
+ dataset = _nbload (str (dataset ))
55
56
56
57
if hasattr (dataset , 'numDA' ): # Looks like a Gifti file
57
58
_das = dataset .get_arrays_from_intent (INTENT_CODES ['pointset' ])
@@ -96,11 +97,15 @@ class ImageGrid(SampledSpatialData):
96
97
def __init__ (self , image ):
97
98
"""Create a gridded sampling reference."""
98
99
if isinstance (image , (str , Path )):
99
- image = load ( str (image ))
100
+ image = _nbfuncs . squeeze_image ( _nbload ( str (image ) ))
100
101
101
102
self ._affine = image .affine
102
103
self ._shape = image .shape
104
+
103
105
self ._ndim = getattr (image , 'ndim' , len (image .shape ))
106
+ if self ._ndim == 4 :
107
+ self ._shape = image .shape [:3 ]
108
+ self ._ndim = 3
104
109
105
110
self ._npoints = getattr (image , 'npoints' ,
106
111
np .prod (image .shape ))
@@ -172,9 +177,9 @@ def __init__(self):
172
177
"""Instantiate a transform."""
173
178
self ._reference = None
174
179
175
- def __call__ (self , x , inverse = False , index = 0 ):
180
+ def __call__ (self , x , inverse = False ):
176
181
"""Apply y = f(x)."""
177
- return self .map (x , inverse = inverse , index = index )
182
+ return self .map (x , inverse = inverse )
178
183
179
184
def __add__ (self , b ):
180
185
"""
@@ -246,13 +251,13 @@ def apply(self, spatialimage, reference=None,
246
251
247
252
"""
248
253
if reference is not None and isinstance (reference , (str , Path )):
249
- reference = load (str (reference ))
254
+ reference = _nbload (str (reference ))
250
255
251
256
_ref = self .reference if reference is None \
252
257
else SpatialReference .factory (reference )
253
258
254
259
if isinstance (spatialimage , (str , Path )):
255
- spatialimage = load (str (spatialimage ))
260
+ spatialimage = _nbload (str (spatialimage ))
256
261
257
262
data = np .asanyarray (spatialimage .dataobj )
258
263
output_dtype = output_dtype or data .dtype
@@ -279,7 +284,7 @@ def apply(self, spatialimage, reference=None,
279
284
280
285
return resampled
281
286
282
- def map (self , x , inverse = False , index = 0 ):
287
+ def map (self , x , inverse = False ):
283
288
r"""
284
289
Apply :math:`y = f(x)`.
285
290
@@ -291,8 +296,6 @@ def map(self, x, inverse=False, index=0):
291
296
Input RAS+ coordinates (i.e., physical coordinates).
292
297
inverse : bool
293
298
If ``True``, apply the inverse transform :math:`x = f^{-1}(y)`.
294
- index : int, optional
295
- Transformation index
296
299
297
300
Returns
298
301
-------
@@ -407,7 +410,7 @@ def insert(self, i, x):
407
410
"""
408
411
self .transforms = self .transforms [:i ] + _as_chain (x ) + self .transforms [i :]
409
412
410
- def map (self , x , inverse = False , index = 0 ):
413
+ def map (self , x , inverse = False ):
411
414
"""
412
415
Apply a succession of transforms, e.g., :math:`y = f_3(f_2(f_1(f_0(x))))`.
413
416
@@ -438,6 +441,80 @@ def map(self, x, inverse=False, index=0):
438
441
return x
439
442
440
443
444
+ class TransformMapping (TransformBase ):
445
+ """Implements a four-dimensional series of transforms."""
446
+
447
+ __slots__ = ('_transforms' , )
448
+
449
+ def __init__ (self , transforms = None ):
450
+ """Initialize a chain of transforms."""
451
+ self ._transforms = None
452
+ if transforms is not None :
453
+ self .transforms = transforms
454
+
455
+ def __getitem__ (self , i ):
456
+ """
457
+ Enable indexed access of transform chains.
458
+
459
+ Example
460
+ -------
461
+ >>> T1 = TransformBase()
462
+ >>> xfm4d = TransformMapping([T1, TransformBase(), TransformBase()])
463
+ >>> xfm4d[0] is T1
464
+ True
465
+
466
+ """
467
+ return self .transforms [i ]
468
+
469
+ def __len__ (self ):
470
+ """Enable using len()."""
471
+ return len (self .transforms )
472
+
473
+ @property
474
+ def transforms (self ):
475
+ """Get the internal list of transforms."""
476
+ return self ._transforms
477
+
478
+ @transforms .setter
479
+ def transforms (self , value ):
480
+ self ._transforms = value
481
+
482
+ def append (self , x ):
483
+ """
484
+ Concatenate one element to the chain.
485
+
486
+ Example
487
+ -------
488
+ >>> xfm4d = TransformMapping([TransformBase(), TransformBase()])
489
+ >>> xfm4d.append(TransformBase())
490
+ >>> len(xfm4d)
491
+ 3
492
+
493
+ """
494
+ self .transforms .append (x )
495
+
496
+ def insert (self , i , x ):
497
+ """
498
+ Insert an item at a given position.
499
+
500
+ Example
501
+ -------
502
+ >>> xfm4d = TransformMapping([TransformBase(), TransformBase()])
503
+ >>> xfm4d.insert(1, TransformBase())
504
+ >>> len(xfm4d)
505
+ 3
506
+
507
+ """
508
+ self .transforms .insert (i , x )
509
+
510
+ def map (self , x , inverse = False ):
511
+ """Apply a map of transforms, e.g., :math:`y_t = f_t(x_t)`."""
512
+ if not self .transforms :
513
+ raise TransformError ('Cannot apply an empty transforms mapping.' )
514
+
515
+ return [xfm (x , inverse = inverse ) for xfm in self .transforms ]
516
+
517
+
441
518
def _as_homogeneous (xyz , dtype = 'float32' , dim = 3 ):
442
519
"""
443
520
Convert 2D and 3D coordinates into homogeneous coordinates.
0 commit comments