@@ -219,47 +219,46 @@ def __getitem__(self, indices: int) -> "DiscreteCurve":
219
219
def __setitem__ (self , indices , curves ) -> None :
220
220
self .params [indices ] = curves .params .squeeze ()
221
221
222
- # def constant_speed(
223
- # self, metric=None, t: Optional[torch.Tensor] = None
224
- # ) -> Tuple[torch.Tensor, torch.Tensor]:
225
- # """
226
- # Reparametrize the curve to have constant speed.
227
-
228
- # Optional input:
229
- # metric: the Manifold under which the curve should have constant speed.
230
- # If None then the Euclidean metric is applied.
231
- # Default: None.
232
-
233
- # Note: It is not possible to back-propagate through this function.
234
- # """
235
- # from stochman import CubicSpline
236
-
237
- # with torch.no_grad():
238
- # if t is None:
239
- # t = torch.linspace(0, 1, 100) # N
240
- # Ct = self(t) # NxD or BxNxD
241
- # if Ct.dim() == 2:
242
- # Ct.unsqueeze_(0) # BxNxD
243
- # B, N, D = Ct.shape
244
- # delta = Ct[:, 1:] - Ct[:, :-1] # Bx(N-1)xD
245
- # if metric is None:
246
- # local_len = delta.norm(dim=2) # Bx(N-1)
247
- # else:
248
- # local_len = (
249
- # metric.inner(Ct[:, :-1].reshape(-1, D), delta.view(-1, D), delta.view(-1, D))
250
- # .view(B, N - 1)
251
- # .sqrt()
252
- # ) # Bx(N-1)
253
- # cs = local_len.cumsum(dim=1) # Bx(N-1)
254
- # zero = torch.zeros(B, 1) # Bx1 -- XXX: missing dtype and device
255
- # one = torch.ones(B, 1) # Bx1 -- XXX: ditto
256
- # new_t = torch.cat((zero, cs / cs[:, -1].unsqueeze(1)), dim=1) # BxN
257
- # S = CubicSpline(zero, one)
258
- # _ = S.fit(new_t, t.unsqueeze(0).expand(B, -1).unsqueeze(2))
259
- # new_params = self(S(self.t[:, 0, 0]).squeeze(-1)) # B
260
-
261
- # from IPython import embed; embed()
262
- # return new_t, Ct
222
+ def constant_speed (
223
+ self , metric = None , t : Optional [torch .Tensor ] = None
224
+ ) -> Tuple [torch .Tensor , torch .Tensor ]:
225
+ """
226
+ Reparametrize the curve to have constant speed.
227
+
228
+ Optional input:
229
+ metric: the Manifold under which the curve should have constant speed.
230
+ If None then the Euclidean metric is applied.
231
+ Default: None.
232
+
233
+ Note: It is not possible to back-propagate through this function.
234
+ """
235
+ from stochman import CubicSpline
236
+
237
+ with torch .no_grad ():
238
+ if t is None :
239
+ t = torch .linspace (0 , 1 , 100 ) # N
240
+ Ct = self (t ) # NxD or BxNxD
241
+ if Ct .ndim == 2 :
242
+ Ct .unsqueeze_ (0 ) # BxNxD
243
+ B , N , D = Ct .shape
244
+ delta = Ct [:, 1 :] - Ct [:, :- 1 ] # Bx(N-1)xD
245
+ if metric is None :
246
+ local_len = delta .norm (dim = 2 ) # Bx(N-1)
247
+ else :
248
+ local_len = (
249
+ metric .inner (Ct [:, :- 1 ].reshape (- 1 , D ), delta .view (- 1 , D ), delta .view (- 1 , D ))
250
+ .view (B , N - 1 )
251
+ .sqrt ()
252
+ ) # Bx(N-1)
253
+ cs = local_len .cumsum (dim = 1 ) # Bx(N-1)
254
+ zero = torch .zeros (B , 1 , dtype = cs .dtype , device = cs .device ) # Bx1
255
+ one = torch .ones (B , 1 , dtype = cs .dtype , device = cs .device ) # Bx1
256
+ new_t = torch .cat ((zero , cs / cs [:, - 1 ].unsqueeze (1 )), dim = 1 ) # BxN
257
+ S = CubicSpline (zero , one )
258
+ _ = S .fit (new_t , t .unsqueeze (0 ).expand (B , - 1 ).unsqueeze (2 ))
259
+ new_params = self (S (self .t [:, :, 0 ]).squeeze (- 1 )) # Bx(num_nodes-2)xD
260
+ self .params = nn .Parameter (new_params )
261
+ return new_t , Ct
263
262
264
263
def tospline (self ):
265
264
from stochman import CubicSpline
0 commit comments