@@ -30,6 +30,11 @@ class DenseFieldTransform(TransformBase):
30
30
31
31
__slots__ = ("_field" , "_deltas" )
32
32
33
+ @property
34
+ def ndim (self ):
35
+ """Access the dimensions of this Desne Field Transform."""
36
+ return self ._field .ndim - 1
37
+
33
38
def __init__ (self , field = None , is_deltas = True , reference = None ):
34
39
"""
35
40
Create a dense field transform.
@@ -82,11 +87,10 @@ def __init__(self, field=None, is_deltas=True, reference=None):
82
87
"Reference is not a spatial image"
83
88
)
84
89
85
- ndim = self ._field .ndim - 1
86
- if self ._field .shape [- 1 ] != ndim :
90
+ if self ._field .shape [- 1 ] != self .ndim :
87
91
raise TransformError (
88
92
"The number of components of the field (%d) does not match "
89
- "the number of dimensions (%d)" % (self ._field .shape [- 1 ], ndim )
93
+ "the number of dimensions (%d)" % (self ._field .shape [- 1 ], self . ndim )
90
94
)
91
95
92
96
if is_deltas :
@@ -245,6 +249,12 @@ class BSplineFieldTransform(TransformBase):
245
249
246
250
__slots__ = ['_coeffs' , '_knots' , '_weights' , '_order' , '_moving' ]
247
251
252
+ @property
253
+ def ndim (self ):
254
+ """Access the dimensions of this BSpline."""
255
+ #return ndim = self._coeffs.shape[-1]
256
+ return self ._coeffs .ndim - 1
257
+
248
258
def __init__ (self , coefficients , reference = None , order = 3 ):
249
259
"""Create a smooth deformation field using B-Spline basis."""
250
260
super ().__init__ ()
@@ -277,66 +287,19 @@ def to_field(self, reference=None, dtype="float32"):
277
287
if _ref is None :
278
288
raise TransformError ("A reference must be defined" )
279
289
280
- ndim = self ._coeffs .shape [- 1 ]
281
-
282
290
if self ._weights is None :
283
291
self ._weights = grid_bspline_weights (_ref , self ._knots )
284
292
285
- field = np .zeros ((_ref .npoints , ndim ))
293
+ field = np .zeros ((_ref .npoints , self . ndim ))
286
294
287
- for d in range (ndim ):
295
+ for d in range (self . ndim ):
288
296
# 1 x Nvox : (1 x K) @ (K x Nvox)
289
297
field [:, d ] = self ._coeffs [..., d ].reshape (- 1 ) @ self ._weights
290
298
291
299
return DenseFieldTransform (
292
300
field .astype (dtype ).reshape (* _ref .shape , - 1 ), reference = _ref
293
301
)
294
302
295
- def apply (
296
- self ,
297
- spatialimage ,
298
- reference = None ,
299
- order = 3 ,
300
- mode = "constant" ,
301
- cval = 0.0 ,
302
- prefilter = True ,
303
- output_dtype = None ,
304
- ):
305
- """Apply a B-Spline transform on input data."""
306
-
307
- _ref = (
308
- self .reference if reference is None else
309
- SpatialReference .factory (_ensure_image (reference ))
310
- )
311
- spatialimage = _ensure_image (spatialimage )
312
-
313
- # If locations to be interpolated are not on a grid, run map()
314
- #import pdb; pdb.set_trace()
315
- if not isinstance (_ref , ImageGrid ):
316
- return apply (
317
- super (),
318
- spatialimage ,
319
- reference = _ref ,
320
- output_dtype = output_dtype ,
321
- order = order ,
322
- mode = mode ,
323
- cval = cval ,
324
- prefilter = prefilter ,
325
-
326
- )
327
-
328
- # If locations to be interpolated are on a grid, generate a displacements field
329
- return apply (
330
- self .to_field (reference = reference ),
331
- spatialimage ,
332
- reference = reference ,
333
- order = order ,
334
- mode = mode ,
335
- cval = cval ,
336
- prefilter = prefilter ,
337
- output_dtype = output_dtype ,
338
- )
339
-
340
303
def map (self , x , inverse = False ):
341
304
r"""
342
305
Apply the transformation to a list of physical coordinate points.
0 commit comments