Skip to content

Commit 651355f

Browse files
committed
Use constant_speed to compute curve length as this is cheaper
1 parent bcb5679 commit 651355f

File tree

1 file changed

+1
-2
lines changed

1 file changed

+1
-2
lines changed

stochman/utilities/distance.py

Lines changed: 1 addition & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -8,8 +8,7 @@ def forward(ctx, M, p0: torch.Tensor, p1: torch.Tensor):
88
with torch.no_grad():
99
with torch.enable_grad():
1010
C, success = M.connecting_geodesic(p0, p1)
11-
C.constant_speed(M)
12-
dist = M.curve_length(C(torch.linspace(0, 1, 100))) # B
11+
_, _, dist = C.constant_speed(M) # B
1312
dist2 = dist**2
1413

1514
lm0 = C.deriv(torch.zeros(1, device=p0.device)).squeeze(1) # log(p0, p1); Bx(d)

0 commit comments

Comments
 (0)