@@ -7,31 +7,36 @@ class __Dist2__(torch.autograd.Function):
7
7
def forward (ctx , M , p0 : torch .Tensor , p1 : torch .Tensor ):
8
8
with torch .no_grad ():
9
9
with torch .enable_grad ():
10
- # TODO: Only perform the computations needed for backpropagation
11
- # (check if p0.requires_grad and p1.requires_grad)
12
10
C , success = M .connecting_geodesic (p0 , p1 )
13
-
14
- lm0 = C .deriv (torch .zeros (1 , device = p0 .device )).squeeze (1 ) # log(p0, p1); Bx(d)
15
- lm1 = - C .deriv (torch .ones (1 , device = p0 .device )).squeeze (1 ) # log(p1, p0); Bx(d)
16
- M0 = M .metric (p0 ) # Bx(d)x(d) or Bx(d)
17
- M1 = M .metric (p1 ) # Bx(d)x(d) or Bx(d)
18
- if M0 .ndim == 3 : # metric is square
19
- Mlm0 = M0 .bmm (lm0 .unsqueeze (- 1 )).squeeze (- 1 ) # Bx(d)
20
- Mlm1 = M1 .bmm (lm1 .unsqueeze (- 1 )).squeeze (- 1 ) # Bx(d)
21
- else :
22
- Mlm0 = M0 * lm0 # Bx(d)
23
- Mlm1 = M1 * lm1 # Bx(d)
24
-
25
- ctx .save_for_backward (Mlm0 , Mlm1 )
26
- retval = (lm0 * Mlm0 ).sum (dim = - 1 ) # B
27
- return retval
11
+ C .constant_speed (M )
12
+ dist = M .curve_length (C (torch .linspace (0 , 1 , 100 ))) # B
13
+ dist2 = dist ** 2
14
+
15
+ lm0 = C .deriv (torch .zeros (1 , device = p0 .device )).squeeze (1 ) # log(p0, p1); Bx(d)
16
+ lm1 = - C .deriv (torch .ones (1 , device = p0 .device )).squeeze (1 ) # log(p1, p0); Bx(d)
17
+ G0 = M .metric (p0 ) # Bx(d)x(d) or Bx(d)
18
+ G1 = M .metric (p1 ) # Bx(d)x(d) or Bx(d)
19
+ if G0 .ndim == 3 : # metric is square
20
+ Glm0 = G0 .bmm (lm0 .unsqueeze (- 1 )).squeeze (- 1 ) # Bx(d)
21
+ Glm1 = G1 .bmm (lm1 .unsqueeze (- 1 )).squeeze (- 1 ) # Bx(d)
22
+ else :
23
+ Glm0 = G0 * lm0 # Bx(d)
24
+ Glm1 = G1 * lm1 # Bx(d)
25
+
26
+ length_from_log = (lm0 * Glm0 ).sum (dim = - 1 ) # B
27
+ alpha = (dist2 / length_from_log ).sqrt ().unsqueeze (1 ) # Bx1
28
+ Glm0 *= alpha
29
+ Glm1 *= alpha
30
+
31
+ ctx .save_for_backward (Glm0 , Glm1 )
32
+ return dist2
28
33
29
34
@staticmethod
30
35
def backward (ctx , grad_output ):
31
- Mlm0 , Mlm1 = ctx .saved_tensors
36
+ Glm0 , Glm1 = ctx .saved_tensors
32
37
return (None ,
33
- 2.0 * grad_output .view (- 1 , 1 ) * Mlm0 ,
34
- 2.0 * grad_output .view (- 1 , 1 ) * Mlm1 )
38
+ 2.0 * grad_output .view (- 1 , 1 ) * Glm0 ,
39
+ 2.0 * grad_output .view (- 1 , 1 ) * Glm1 )
35
40
36
41
37
42
def squared_manifold_distance (manifold , p0 : torch .Tensor , p1 : torch .Tensor ):
0 commit comments