@@ -168,7 +168,7 @@ def _init_params(self, params, *args, **kwargs) -> None:
168
168
)
169
169
if params is None :
170
170
params = self .t * self .end .unsqueeze (1 ) + \
171
- (1 - self .t ) * self .begin .unsqueeze (1 ) # Bx(_num_nodes)xD
171
+ (1 - self .t ) * self .begin .unsqueeze (1 ) # Bx(_num_nodes)xD
172
172
if self ._requires_grad :
173
173
self .register_parameter ("params" , nn .Parameter (params ))
174
174
else :
@@ -185,7 +185,7 @@ def forward(self, t: torch.Tensor) -> torch.Tensor:
185
185
torch .ones (B , 1 , D , dtype = self .t .dtype , device = self .device ),
186
186
),
187
187
dim = 1
188
- ) # Bx(num_nodes)xD
188
+ ) # Bx(num_nodes)xD
189
189
a = (end_nodes - start_nodes ) / (t0 [:, 1 :] - t0 [:, :- 1 ]) # Bx(num_edges)xD
190
190
b = start_nodes - a * t0 [:, :- 1 ] # Bx(num_edges)xD
191
191
@@ -199,6 +199,8 @@ def forward(self, t: torch.Tensor) -> torch.Tensor:
199
199
torch .floor (tt * num_edges ).clamp (min = 0 , max = num_edges - 1 ).long () # Bx|t|
200
200
).unsqueeze (2 ).repeat (1 , 1 , D ) # Bx|t|xD, this assumes that nodes are equi-distant
201
201
result = torch .gather (a , 1 , idx ) * tt .unsqueeze (2 ) + torch .gather (b , 1 , idx ) # Bx|t|xD
202
+ if B == 1 :
203
+ result = result .squeeze (0 ) # |t|xD
202
204
return result
203
205
204
206
def __getitem__ (self , indices : int ) -> "DiscreteCurve" :
@@ -217,47 +219,46 @@ def __getitem__(self, indices: int) -> "DiscreteCurve":
217
219
def __setitem__ (self , indices , curves ) -> None :
218
220
self .params [indices ] = curves .params .squeeze ()
219
221
220
- # def constant_speed(
221
- # self, metric=None, t: Optional[torch.Tensor] = None
222
- # ) -> Tuple[torch.Tensor, torch.Tensor]:
223
- # """
224
- # Reparametrize the curve to have constant speed.
225
-
226
- # Optional input:
227
- # metric: the Manifold under which the curve should have constant speed.
228
- # If None then the Euclidean metric is applied.
229
- # Default: None.
230
-
231
- # Note: It is not possible to back-propagate through this function.
232
- # """
233
- # from stochman import CubicSpline
234
-
235
- # with torch.no_grad():
236
- # if t is None:
237
- # t = torch.linspace(0, 1, 100) # N
238
- # Ct = self(t) # NxD or BxNxD
239
- # if Ct.dim() == 2:
240
- # Ct.unsqueeze_(0) # BxNxD
241
- # B, N, D = Ct.shape
242
- # delta = Ct[:, 1:] - Ct[:, :-1] # Bx(N-1)xD
243
- # if metric is None:
244
- # local_len = delta.norm(dim=2) # Bx(N-1)
245
- # else:
246
- # local_len = (
247
- # metric.inner(Ct[:, :-1].reshape(-1, D), delta.view(-1, D), delta.view(-1, D))
248
- # .view(B, N - 1)
249
- # .sqrt()
250
- # ) # Bx(N-1)
251
- # cs = local_len.cumsum(dim=1) # Bx(N-1)
252
- # zero = torch.zeros(B, 1) # Bx1 -- XXX: missing dtype and device
253
- # one = torch.ones(B, 1) # Bx1 -- XXX: ditto
254
- # new_t = torch.cat((zero, cs / cs[:, -1].unsqueeze(1)), dim=1) # BxN
255
- # S = CubicSpline(zero, one)
256
- # _ = S.fit(new_t, t.unsqueeze(0).expand(B, -1).unsqueeze(2))
257
- # new_params = self(S(self.t[:, 0, 0]).squeeze(-1)) # B
258
-
259
- # from IPython import embed; embed()
260
- # return new_t, Ct
222
+ def constant_speed (
223
+ self , metric = None , t : Optional [torch .Tensor ] = None
224
+ ) -> Tuple [torch .Tensor , torch .Tensor ]:
225
+ """
226
+ Reparametrize the curve to have constant speed.
227
+
228
+ Optional input:
229
+ metric: the Manifold under which the curve should have constant speed.
230
+ If None then the Euclidean metric is applied.
231
+ Default: None.
232
+
233
+ Note: It is not possible to back-propagate through this function.
234
+ """
235
+ from stochman import CubicSpline
236
+
237
+ with torch .no_grad ():
238
+ if t is None :
239
+ t = torch .linspace (0 , 1 , 100 ) # N
240
+ Ct = self (t ) # NxD or BxNxD
241
+ if Ct .ndim == 2 :
242
+ Ct .unsqueeze_ (0 ) # BxNxD
243
+ B , N , D = Ct .shape
244
+ delta = Ct [:, 1 :] - Ct [:, :- 1 ] # Bx(N-1)xD
245
+ if metric is None :
246
+ local_len = delta .norm (dim = 2 ) # Bx(N-1)
247
+ else :
248
+ local_len = (
249
+ metric .inner (Ct [:, :- 1 ].reshape (- 1 , D ), delta .view (- 1 , D ), delta .view (- 1 , D ))
250
+ .view (B , N - 1 )
251
+ .sqrt ()
252
+ ) # Bx(N-1)
253
+ cs = local_len .cumsum (dim = 1 ) # Bx(N-1)
254
+ zero = torch .zeros (B , 1 , dtype = cs .dtype , device = cs .device ) # Bx1
255
+ one = torch .ones (B , 1 , dtype = cs .dtype , device = cs .device ) # Bx1
256
+ new_t = torch .cat ((zero , cs / cs [:, - 1 ].unsqueeze (1 )), dim = 1 ) # BxN
257
+ S = CubicSpline (zero , one )
258
+ _ = S .fit (new_t , t .unsqueeze (0 ).expand (B , - 1 ).unsqueeze (2 ))
259
+ new_params = self (S (self .t [:, :, 0 ]).squeeze (- 1 )) # Bx(num_nodes-2)xD
260
+ self .params = nn .Parameter (new_params )
261
+ return new_t , Ct
261
262
262
263
def tospline (self ):
263
264
from stochman import CubicSpline
@@ -466,9 +467,9 @@ def todiscrete(self, num_nodes=None):
466
467
467
468
if num_nodes is None :
468
469
num_nodes = self ._num_nodes
469
- t = torch .linspace (0 , 1 , num_nodes )[1 :- 1 ] # (num_nodes-2)
470
- Ct = self (t ) # Bx(num_nodes-2)xD
471
-
470
+ t = torch .linspace (0 , 1 , num_nodes )[1 :- 1 ] # (num_nodes-2)
471
+ Ct = self (t ) # Bx(num_nodes-2)xD
472
+
472
473
return DiscreteCurve (
473
474
begin = self .begin ,
474
475
end = self .end ,
0 commit comments