Skip to content

Commit ccbd585

Browse files
committed
introduce reduction arg and move utils around
1 parent 55d4688 commit ccbd585

File tree

3 files changed

+27
-16
lines changed

3 files changed

+27
-16
lines changed

stochman/manifold.py

Lines changed: 13 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,3 @@
1-
#!/usr/bin/env python3
21
from abc import ABC, abstractmethod
32
from typing import Optional, Tuple, Union
43

@@ -8,7 +7,7 @@
87

98
from stochman.curves import BasicCurve, CubicSpline
109
from stochman.geodesic import geodesic_minimizing_energy, shooting_geodesic
11-
from stochman.utilities import squared_manifold_distance
10+
from stochman.utils import squared_manifold_distance, tensor_reduction
1211

1312

1413
class Manifold(ABC):
@@ -17,13 +16,14 @@ class Manifold(ABC):
1716
from this abstract base class abstraction.
1817
"""
1918

20-
def curve_energy(self, curve: BasicCurve) -> torch.Tensor:
19+
def curve_energy(self, curve: BasicCurve, reduction: Optional[str] = "sum") -> torch.Tensor:
2120
"""
2221
Compute the discrete energy of a given curve.
2322
24-
Input:
25-
curve: a Nx(d) torch Tensor representing a curve or
26-
a BxNx(d) torch Tensor representing B curves.
23+
Args:
24+
curve: a Nx(d) torch Tensor representing a curve or a BxNx(d) torch Tensor representing B curves.
25+
reduction: how to reduce the curve energy over the batch dimension. Choose between
26+
`'sum'`, `'mean'`, `'none'` or `None` (where the last two will return the individual scores)
2727
2828
Output:
2929
energy: a scalar corresponding to the energy of
@@ -44,7 +44,7 @@ def curve_energy(self, curve: BasicCurve) -> torch.Tensor:
4444
delta = curve[:, 1:] - curve[:, :-1] # Bx(N-1)x(d)
4545
flat_delta = delta.view(-1, d) # (B*(N-1))x(d)
4646
energy = self.inner(curve[:, :-1].reshape(-1, d), flat_delta, flat_delta) # B*(N-1)
47-
return energy.sum() # scalar
47+
return tensor_reduction(energy, reduction)
4848

4949
def curve_length(self, curve: BasicCurve) -> torch.Tensor:
5050
"""
@@ -417,13 +417,14 @@ class EmbeddedManifold(Manifold, ABC):
417417
should inherit from this abstract base class abstraction.
418418
"""
419419

420-
def curve_energy(self, curve: BasicCurve, dt=None):
420+
def curve_energy(self, curve: BasicCurve, reduction: Optional[str] = "sum", dt=None):
421421
"""
422422
Compute the discrete energy of a given curve.
423423
424-
Input:
425-
curve: a Nx(d) torch Tensor representing a curve or
426-
a BxNx(d) torch Tensor representing B curves.
424+
Args:
425+
curve: a Nx(d) torch Tensor representing a curve or a BxNx(d) torch Tensor representing B curves.
426+
reduction: how to reduce the curve energy over the batch dimension. Choose between
427+
`'sum'`, `'mean'`, `'none'` or `None` (where the last two will return the individual scores)
427428
428429
Output:
429430
energy: a scalar corresponding to the energy of
@@ -444,8 +445,7 @@ def curve_energy(self, curve: BasicCurve, dt=None):
444445
B, N, D = emb_curve.shape
445446
delta = emb_curve[:, 1:, :] - emb_curve[:, :-1, :] # Bx(N-1)xD
446447
energy = (delta ** 2).sum((1, 2)) * dt # B
447-
448-
return energy
448+
return tensor_reduction(energy, reduction)
449449

450450
def curve_length(self, curve: BasicCurve, dt=None):
451451
"""

stochman/utilities/__init__.py

Lines changed: 0 additions & 2 deletions
This file was deleted.

stochman/utilities/distance.py renamed to stochman/utils.py

Lines changed: 14 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,4 @@
1-
#!/usr/bin/env python3
1+
from multiprocessing.sharedctypes import Value
22
import torch
33

44

@@ -51,3 +51,16 @@ def squared_manifold_distance(manifold, p0: torch.Tensor, p1: torch.Tensor):
5151
"""
5252
distance_op = __Dist2__()
5353
return distance_op.apply(manifold, p0, p1)
54+
55+
56+
def tensor_reduction(x: torch.Tensor, reduction: str):
57+
if reduction == 'sum':
58+
return x.sum()
59+
elif reduction == 'mean':
60+
return x.mean()
61+
elif reduction is None or reduction == "none":
62+
return x
63+
else:
64+
raise ValueError(
65+
f"Expected `reduction` to either be `'mean'`, `'sum'`, `'none'` or `None` but got {reduction}"
66+
)

0 commit comments

Comments
 (0)