1
- #!/usr/bin/env python3
2
1
from abc import ABC , abstractmethod
3
2
from typing import Optional , Tuple , Union
4
3
8
7
9
8
from stochman .curves import BasicCurve , CubicSpline
10
9
from stochman .geodesic import geodesic_minimizing_energy , shooting_geodesic
11
- from stochman .utilities import squared_manifold_distance
10
+ from stochman .utils import squared_manifold_distance , tensor_reduction
12
11
13
12
14
13
class Manifold (ABC ):
@@ -17,13 +16,14 @@ class Manifold(ABC):
17
16
from this abstract base class abstraction.
18
17
"""
19
18
20
- def curve_energy (self , curve : BasicCurve ) -> torch .Tensor :
19
+ def curve_energy (self , curve : BasicCurve , reduction : Optional [ str ] = "sum" ) -> torch .Tensor :
21
20
"""
22
21
Compute the discrete energy of a given curve.
23
22
24
- Input:
25
- curve: a Nx(d) torch Tensor representing a curve or
26
- a BxNx(d) torch Tensor representing B curves.
23
+ Args:
24
+ curve: a Nx(d) torch Tensor representing a curve or a BxNx(d) torch Tensor representing B curves.
25
+ reduction: how to reduce the curve energy over the batch dimension. Choose between
26
+ `'sum'`, `'mean'`, `'none'` or `None` (where the last two will return the individual scores)
27
27
28
28
Output:
29
29
energy: a scalar corresponding to the energy of
@@ -44,7 +44,7 @@ def curve_energy(self, curve: BasicCurve) -> torch.Tensor:
44
44
delta = curve [:, 1 :] - curve [:, :- 1 ] # Bx(N-1)x(d)
45
45
flat_delta = delta .view (- 1 , d ) # (B*(N-1))x(d)
46
46
energy = self .inner (curve [:, :- 1 ].reshape (- 1 , d ), flat_delta , flat_delta ) # B*(N-1)
47
- return energy . sum () # scalar
47
+ return tensor_reduction ( energy , reduction )
48
48
49
49
def curve_length (self , curve : BasicCurve ) -> torch .Tensor :
50
50
"""
@@ -417,13 +417,14 @@ class EmbeddedManifold(Manifold, ABC):
417
417
should inherit from this abstract base class abstraction.
418
418
"""
419
419
420
- def curve_energy (self , curve : BasicCurve , dt = None ):
420
+ def curve_energy (self , curve : BasicCurve , reduction : Optional [ str ] = "sum" , dt = None ):
421
421
"""
422
422
Compute the discrete energy of a given curve.
423
423
424
- Input:
425
- curve: a Nx(d) torch Tensor representing a curve or
426
- a BxNx(d) torch Tensor representing B curves.
424
+ Args:
425
+ curve: a Nx(d) torch Tensor representing a curve or a BxNx(d) torch Tensor representing B curves.
426
+ reduction: how to reduce the curve energy over the batch dimension. Choose between
427
+ `'sum'`, `'mean'`, `'none'` or `None` (where the last two will return the individual scores)
427
428
428
429
Output:
429
430
energy: a scalar corresponding to the energy of
@@ -444,8 +445,7 @@ def curve_energy(self, curve: BasicCurve, dt=None):
444
445
B , N , D = emb_curve .shape
445
446
delta = emb_curve [:, 1 :, :] - emb_curve [:, :- 1 , :] # Bx(N-1)xD
446
447
energy = (delta ** 2 ).sum ((1 , 2 )) * dt # B
447
-
448
- return energy
448
+ return tensor_reduction (energy , reduction )
449
449
450
450
def curve_length (self , curve : BasicCurve , dt = None ):
451
451
"""
0 commit comments