Skip to content

Commit 71ec26f

Browse files
committed
Make distance calculations (hopefully) more stable by moving to a constant speed parametrization, and correction with the known curve length
1 parent 9ae5c8e commit 71ec26f

File tree

1 file changed

+25
-20
lines changed

1 file changed

+25
-20
lines changed

stochman/utilities/distance.py

Lines changed: 25 additions & 20 deletions
Original file line numberDiff line numberDiff line change
@@ -7,31 +7,36 @@ class __Dist2__(torch.autograd.Function):
77
def forward(ctx, M, p0: torch.Tensor, p1: torch.Tensor):
88
with torch.no_grad():
99
with torch.enable_grad():
10-
# TODO: Only perform the computations needed for backpropagation
11-
# (check if p0.requires_grad and p1.requires_grad)
1210
C, success = M.connecting_geodesic(p0, p1)
13-
14-
lm0 = C.deriv(torch.zeros(1, device=p0.device)).squeeze(1) # log(p0, p1); Bx(d)
15-
lm1 = -C.deriv(torch.ones(1, device=p0.device)).squeeze(1) # log(p1, p0); Bx(d)
16-
M0 = M.metric(p0) # Bx(d)x(d) or Bx(d)
17-
M1 = M.metric(p1) # Bx(d)x(d) or Bx(d)
18-
if M0.ndim == 3: # metric is square
19-
Mlm0 = M0.bmm(lm0.unsqueeze(-1)).squeeze(-1) # Bx(d)
20-
Mlm1 = M1.bmm(lm1.unsqueeze(-1)).squeeze(-1) # Bx(d)
21-
else:
22-
Mlm0 = M0 * lm0 # Bx(d)
23-
Mlm1 = M1 * lm1 # Bx(d)
24-
25-
ctx.save_for_backward(Mlm0, Mlm1)
26-
retval = (lm0 * Mlm0).sum(dim=-1) # B
27-
return retval
11+
C.constant_speed(M)
12+
dist = M.curve_length(C(torch.linspace(0, 1, 100))) # B
13+
dist2 = dist**2
14+
15+
lm0 = C.deriv(torch.zeros(1, device=p0.device)).squeeze(1) # log(p0, p1); Bx(d)
16+
lm1 = -C.deriv(torch.ones(1, device=p0.device)).squeeze(1) # log(p1, p0); Bx(d)
17+
G0 = M.metric(p0) # Bx(d)x(d) or Bx(d)
18+
G1 = M.metric(p1) # Bx(d)x(d) or Bx(d)
19+
if G0.ndim == 3: # metric is square
20+
Glm0 = G0.bmm(lm0.unsqueeze(-1)).squeeze(-1) # Bx(d)
21+
Glm1 = G1.bmm(lm1.unsqueeze(-1)).squeeze(-1) # Bx(d)
22+
else:
23+
Glm0 = G0 * lm0 # Bx(d)
24+
Glm1 = G1 * lm1 # Bx(d)
25+
26+
length_from_log = (lm0 * Glm0).sum(dim=-1) # B
27+
alpha = (dist2 / length_from_log).sqrt().unsqueeze(1) # Bx1
28+
Glm0 *= alpha
29+
Glm1 *= alpha
30+
31+
ctx.save_for_backward(Glm0, Glm1)
32+
return dist2
2833

2934
@staticmethod
3035
def backward(ctx, grad_output):
31-
Mlm0, Mlm1 = ctx.saved_tensors
36+
Glm0, Glm1 = ctx.saved_tensors
3237
return (None,
33-
2.0 * grad_output.view(-1, 1) * Mlm0,
34-
2.0 * grad_output.view(-1, 1) * Mlm1)
38+
2.0 * grad_output.view(-1, 1) * Glm0,
39+
2.0 * grad_output.view(-1, 1) * Glm1)
3540

3641

3742
def squared_manifold_distance(manifold, p0: torch.Tensor, p1: torch.Tensor):

0 commit comments

Comments
 (0)