Skip to content

Commit ec1f5f3

Browse files
committed
Remove outdated code for length and energy on LocalVarMetric (instead fall back to the default implementation which is better)
1 parent 44b045e commit ec1f5f3

File tree

1 file changed

+2
-40
lines changed

1 file changed

+2
-40
lines changed

stochman/manifold.py

Lines changed: 2 additions & 40 deletions
Original file line numberDiff line numberDiff line change
@@ -43,7 +43,7 @@ def curve_energy(self, curve: BasicCurve) -> torch.Tensor:
4343
d = curve.shape[2]
4444
delta = curve[:, 1:] - curve[:, :-1] # Bx(N-1)x(d)
4545
flat_delta = delta.view(-1, d) # (B*(N-1))x(d)
46-
energy = self.inner(curve[:, :-1].view(-1, d), flat_delta, flat_delta) # B*(N-1)
46+
energy = self.inner(curve[:, :-1].reshape(-1, d), flat_delta, flat_delta) # B*(N-1)
4747
return energy.sum() # scalar
4848

4949
def curve_length(self, curve: BasicCurve) -> torch.Tensor:
@@ -70,7 +70,7 @@ def curve_length(self, curve: BasicCurve) -> torch.Tensor:
7070
B, N, d = curve.shape
7171
delta = curve[:, 1:] - curve[:, :-1] # Bx(N-1)x(d)
7272
flat_delta = delta.view(-1, d) # (B*(N-1))x(d)
73-
energy = self.inner(curve[:, :-1].view(-1, d), flat_delta, flat_delta) # B*(N-1)
73+
energy = self.inner(curve[:, :-1].reshape(-1, d), flat_delta, flat_delta) # B*(N-1)
7474
length = energy.view(B, N - 1).sqrt().sum(dim=1) # B
7575
return length
7676

@@ -592,44 +592,6 @@ def metric(self, c, return_deriv=False):
592592
else:
593593
return torch.cat(M)
594594

595-
def curve_energy(self, c):
596-
"""
597-
Evaluate the energy of a curve represented as a discrete set of points.
598-
599-
Input:
600-
c: A discrete set of points along a curve. This is represented
601-
as a PxD or BxPxD torch Tensor. The points are assumed to be ordered
602-
along the curve and evaluated at equidistant time points.
603-
604-
Output:
605-
energy: The energy of the input curve.
606-
"""
607-
if len(c.shape) == 2:
608-
c.unsqueeze_(0) # add batch dimension if one isn't present
609-
energy = torch.zeros(1)
610-
for b in range(c.shape[0]):
611-
M = self.metric(c[b, :-1]) # (P-1)xD
612-
delta1 = (c[b, 1:] - c[b, :-1]) ** 2 # (P-1)xD
613-
energy += (M * delta1).sum()
614-
return energy
615-
616-
def curve_length(self, c):
617-
"""
618-
Evaluate the length of a curve represented as a discrete set of points.
619-
620-
Input:
621-
c: A discrete set of points along a curve. This is represented
622-
as a PxD torch Tensor. The points are assumed to be ordered
623-
along the curve and evaluated at equidistant indices.
624-
625-
Output:
626-
length: The length of the input curve.
627-
"""
628-
M = self.metric(c[:-1]) # (P-1)xD
629-
delta1 = (c[1:] - c[:-1]) ** 2 # (P-1)xD
630-
length = (M * delta1).sum(dim=1).sqrt().sum()
631-
return length
632-
633595
def geodesic_system(self, c, dc):
634596
"""
635597
Evaluate the 2nd order system of ordinary differential equations that

0 commit comments

Comments
 (0)