@@ -43,7 +43,7 @@ def curve_energy(self, curve: BasicCurve) -> torch.Tensor:
43
43
d = curve .shape [2 ]
44
44
delta = curve [:, 1 :] - curve [:, :- 1 ] # Bx(N-1)x(d)
45
45
flat_delta = delta .view (- 1 , d ) # (B*(N-1))x(d)
46
- energy = self .inner (curve [:, :- 1 ].view (- 1 , d ), flat_delta , flat_delta ) # B*(N-1)
46
+ energy = self .inner (curve [:, :- 1 ].reshape (- 1 , d ), flat_delta , flat_delta ) # B*(N-1)
47
47
return energy .sum () # scalar
48
48
49
49
def curve_length (self , curve : BasicCurve ) -> torch .Tensor :
@@ -70,7 +70,7 @@ def curve_length(self, curve: BasicCurve) -> torch.Tensor:
70
70
B , N , d = curve .shape
71
71
delta = curve [:, 1 :] - curve [:, :- 1 ] # Bx(N-1)x(d)
72
72
flat_delta = delta .view (- 1 , d ) # (B*(N-1))x(d)
73
- energy = self .inner (curve [:, :- 1 ].view (- 1 , d ), flat_delta , flat_delta ) # B*(N-1)
73
+ energy = self .inner (curve [:, :- 1 ].reshape (- 1 , d ), flat_delta , flat_delta ) # B*(N-1)
74
74
length = energy .view (B , N - 1 ).sqrt ().sum (dim = 1 ) # B
75
75
return length
76
76
@@ -592,44 +592,6 @@ def metric(self, c, return_deriv=False):
592
592
else :
593
593
return torch .cat (M )
594
594
595
- def curve_energy (self , c ):
596
- """
597
- Evaluate the energy of a curve represented as a discrete set of points.
598
-
599
- Input:
600
- c: A discrete set of points along a curve. This is represented
601
- as a PxD or BxPxD torch Tensor. The points are assumed to be ordered
602
- along the curve and evaluated at equidistant time points.
603
-
604
- Output:
605
- energy: The energy of the input curve.
606
- """
607
- if len (c .shape ) == 2 :
608
- c .unsqueeze_ (0 ) # add batch dimension if one isn't present
609
- energy = torch .zeros (1 )
610
- for b in range (c .shape [0 ]):
611
- M = self .metric (c [b , :- 1 ]) # (P-1)xD
612
- delta1 = (c [b , 1 :] - c [b , :- 1 ]) ** 2 # (P-1)xD
613
- energy += (M * delta1 ).sum ()
614
- return energy
615
-
616
- def curve_length (self , c ):
617
- """
618
- Evaluate the length of a curve represented as a discrete set of points.
619
-
620
- Input:
621
- c: A discrete set of points along a curve. This is represented
622
- as a PxD torch Tensor. The points are assumed to be ordered
623
- along the curve and evaluated at equidistant indices.
624
-
625
- Output:
626
- length: The length of the input curve.
627
- """
628
- M = self .metric (c [:- 1 ]) # (P-1)xD
629
- delta1 = (c [1 :] - c [:- 1 ]) ** 2 # (P-1)xD
630
- length = (M * delta1 ).sum (dim = 1 ).sqrt ().sum ()
631
- return length
632
-
633
595
def geodesic_system (self , c , dc ):
634
596
"""
635
597
Evaluate the 2nd order system of ordinary differential equations that
0 commit comments