@@ -78,7 +78,7 @@ def plot(self, t0: float = 0.0, t1: float = 1.0, N: int = 100, *plot_args, **plo
78
78
return figs
79
79
if points .shape [- 1 ] == 2 :
80
80
for b in range (points .shape [0 ]):
81
- fig = plt .plot (points [b , :, 0 ], points [b , :, 1 ], "-" , * plot_args , ** plot_kwargs )
81
+ fig = plt .plot (points [b , :, 0 ], points [b , :, 1 ], * plot_args , ** plot_kwargs )
82
82
figs .append (fig )
83
83
return figs
84
84
@@ -140,10 +140,11 @@ def closure():
140
140
L .backward ()
141
141
return L
142
142
143
- for _ in range (num_steps ):
144
- loss = opt .step (closure = closure )
145
- if torch .max (torch .abs (self .params .grad )) < threshold :
146
- break
143
+ with torch .enable_grad ():
144
+ for _ in range (num_steps ):
145
+ loss = opt .step (closure = closure )
146
+ if torch .max (torch .abs (self .params .grad )) < threshold :
147
+ break
147
148
return loss
148
149
149
150
@@ -162,43 +163,48 @@ def _init_params(self, params, *args, **kwargs) -> None:
162
163
self .register_buffer (
163
164
"t" ,
164
165
torch .linspace (0 , 1 , self ._num_nodes , dtype = self .begin .dtype )[1 :- 1 ]
165
- .reshape ( - 1 , 1 , 1 )
166
- .repeat ( 1 , * self .begin .shape ), # (_num_nodes-2)xBxD
166
+ .view ( 1 , - 1 , 1 )
167
+ .expand ( self . begin . shape [ 0 ], - 1 , self .begin .shape [ 1 ] ), # Bx (_num_nodes-2)xD
167
168
)
168
169
if params is None :
169
- params = self .t * self .end .unsqueeze (0 ) + (1 - self .t ) * self .begin .unsqueeze (
170
- 0
171
- ) # (_num_nodes)xBxD
170
+ params = self .t * self .end .unsqueeze (1 ) + \
171
+ (1 - self .t ) * self .begin .unsqueeze (1 ) # Bx(_num_nodes)xD
172
172
if self ._requires_grad :
173
173
self .register_parameter ("params" , nn .Parameter (params ))
174
174
else :
175
175
self .register_buffer ("params" , params )
176
176
177
177
def forward (self , t : torch .Tensor ) -> torch .Tensor :
178
- start_nodes = torch .cat ((self .begin .unsqueeze (0 ), self .params )) # (num_edges)xBxD
179
- end_nodes = torch .cat ((self .params , self .end .unsqueeze (0 )) ) # (num_edges)xBxD
180
- num_edges , B , D = start_nodes .shape
178
+ start_nodes = torch .cat ((self .begin .unsqueeze (1 ), self .params ), dim = 1 ) # Bx (num_edges)xD
179
+ end_nodes = torch .cat ((self .params , self .end .unsqueeze (1 )), dim = 1 ) # Bx (num_edges)xD
180
+ B , num_edges , D = start_nodes .shape
181
181
t0 = torch .cat (
182
182
(
183
- torch .zeros (1 , B , D , dtype = self .t .dtype , device = self .device ),
183
+ torch .zeros (B , 1 , D , dtype = self .t .dtype , device = self .device ),
184
184
self .t ,
185
- torch .ones (1 , B , D , dtype = self .t .dtype , device = self .device ),
186
- )
187
- )
188
- a = (end_nodes - start_nodes ) / (t0 [1 :] - t0 [:- 1 ]) # (num_edges)xBxD
189
- b = start_nodes - a * t0 [:- 1 ] # (num_edges)xBxD
190
-
185
+ torch .ones (B , 1 , D , dtype = self .t .dtype , device = self .device ),
186
+ ),
187
+ dim = 1
188
+ ) # Bx(num_nodes)xD
189
+ a = (end_nodes - start_nodes ) / (t0 [:, 1 :] - t0 [:, :- 1 ]) # Bx(num_edges)xD
190
+ b = start_nodes - a * t0 [:, :- 1 ] # Bx(num_edges)xD
191
+
192
+ if t .ndim == 1 :
193
+ tt = t .view ((1 , - 1 )).expand (B , - 1 ) # Bx|t|
194
+ elif t .ndim == 2 :
195
+ tt = t # Bx|t|
196
+ else :
197
+ raise Exception ('t must have at most 2 dimensions' )
191
198
idx = (
192
- torch .floor (t .flatten () * num_edges ).clamp (min = 0 , max = num_edges - 1 ).long ()
193
- ) # use this if nodes are equi-distant
194
- tt = t .view ((- 1 , 1 , 1 )).expand (- 1 , B , D )
195
- result = a [idx ] * tt + b [idx ] # (num_edges)xBxD
196
- return result .permute (1 , 0 , 2 ).squeeze (0 ) # Bx(num_edges)xD
199
+ torch .floor (tt * num_edges ).clamp (min = 0 , max = num_edges - 1 ).long () # Bx|t|
200
+ ).unsqueeze (2 ).repeat (1 , 1 , D ) # Bx|t|xD, this assumes that nodes are equi-distant
201
+ result = torch .gather (a , 1 , idx ) * tt .unsqueeze (2 ) + torch .gather (b , 1 , idx ) # Bx|t|xD
202
+ return result
197
203
198
204
def __getitem__ (self , indices : int ) -> "DiscreteCurve" :
199
- params = self .params [:, indices ]
205
+ params = self .params [indices ]
200
206
if params .dim () == 2 :
201
- params = params .unsqueeze (1 )
207
+ params = params .unsqueeze (0 )
202
208
C = DiscreteCurve (
203
209
begin = self .begin [indices ],
204
210
end = self .end [indices ],
@@ -209,7 +215,49 @@ def __getitem__(self, indices: int) -> "DiscreteCurve":
209
215
return C
210
216
211
217
def __setitem__ (self , indices , curves ) -> None :
212
- self .params [:, indices ] = curves .params .squeeze ()
218
+ self .params [indices ] = curves .params .squeeze ()
219
+
220
+ # def constant_speed(
221
+ # self, metric=None, t: Optional[torch.Tensor] = None
222
+ # ) -> Tuple[torch.Tensor, torch.Tensor]:
223
+ # """
224
+ # Reparametrize the curve to have constant speed.
225
+
226
+ # Optional input:
227
+ # metric: the Manifold under which the curve should have constant speed.
228
+ # If None then the Euclidean metric is applied.
229
+ # Default: None.
230
+
231
+ # Note: It is not possible to back-propagate through this function.
232
+ # """
233
+ # from stochman import CubicSpline
234
+
235
+ # with torch.no_grad():
236
+ # if t is None:
237
+ # t = torch.linspace(0, 1, 100) # N
238
+ # Ct = self(t) # NxD or BxNxD
239
+ # if Ct.dim() == 2:
240
+ # Ct.unsqueeze_(0) # BxNxD
241
+ # B, N, D = Ct.shape
242
+ # delta = Ct[:, 1:] - Ct[:, :-1] # Bx(N-1)xD
243
+ # if metric is None:
244
+ # local_len = delta.norm(dim=2) # Bx(N-1)
245
+ # else:
246
+ # local_len = (
247
+ # metric.inner(Ct[:, :-1].reshape(-1, D), delta.view(-1, D), delta.view(-1, D))
248
+ # .view(B, N - 1)
249
+ # .sqrt()
250
+ # ) # Bx(N-1)
251
+ # cs = local_len.cumsum(dim=1) # Bx(N-1)
252
+ # zero = torch.zeros(B, 1) # Bx1 -- XXX: missing dtype and device
253
+ # one = torch.ones(B, 1) # Bx1 -- XXX: ditto
254
+ # new_t = torch.cat((zero, cs / cs[:, -1].unsqueeze(1)), dim=1) # BxN
255
+ # S = CubicSpline(zero, one)
256
+ # _ = S.fit(new_t, t.unsqueeze(0).expand(B, -1).unsqueeze(2))
257
+ # new_params = self(S(self.t[:, 0, 0]).squeeze(-1)) # B
258
+
259
+ # from IPython import embed; embed()
260
+ # return new_t, Ct
213
261
214
262
def tospline (self ):
215
263
from stochman import CubicSpline
@@ -220,7 +268,7 @@ def tospline(self):
220
268
num_nodes = self ._num_nodes ,
221
269
requires_grad = self ._requires_grad ,
222
270
)
223
- c .fit (self .t [:, 0 , 0 ], self .params . squeeze ( 1 ) )
271
+ _ = c .fit (self .t [0 , : , 0 ], self .params )
224
272
return c
225
273
226
274
@@ -329,8 +377,6 @@ def forward(self, t: torch.Tensor) -> torch.Tensor:
329
377
if no_batch :
330
378
t = t .expand (coeffs .shape [0 ], - 1 ) # Bx|t|
331
379
retval = self ._eval_polynomials (t , coeffs ) # Bx|t|xD
332
- # tt = t.view((-1, 1)).unsqueeze(0).expand(retval.shape[0], -1, -1) # Bx|t|x1
333
- # retval += (1-tt).bmm(self.begin.unsqueeze(1)) + tt.bmm(self.end.unsqueeze(1)) # Bx|t|xD
334
380
retval += self ._eval_straight_line (t )
335
381
if no_batch and retval .shape [0 ] == 1 :
336
382
retval .squeeze_ (0 ) # |t|xD
@@ -393,7 +439,6 @@ def constant_speed(
393
439
Default: None.
394
440
395
441
Note: It is not possible to back-propagate through this function.
396
- Note: This function does currently not support batching.
397
442
"""
398
443
with torch .no_grad ():
399
444
if t is None :
@@ -413,16 +458,21 @@ def constant_speed(
413
458
) # Bx(N-1)
414
459
cs = local_len .cumsum (dim = 1 ) # Bx(N-1)
415
460
new_t = torch .cat ((torch .zeros (B , 1 ), cs / cs [:, - 1 ].unsqueeze (1 )), dim = 1 ) # BxN
416
- with torch .enable_grad ():
417
- _ = self .fit (new_t , Ct )
461
+ _ = self .fit (new_t , Ct )
418
462
return new_t , Ct
419
463
420
- def todiscrete (self ):
464
+ def todiscrete (self , num_nodes = None ):
421
465
from stochman import DiscreteCurve
422
466
467
+ if num_nodes is None :
468
+ num_nodes = self ._num_nodes
469
+ t = torch .linspace (0 , 1 , num_nodes )[1 :- 1 ] # (num_nodes-2)
470
+ Ct = self (t ) # Bx(num_nodes-2)xD
471
+
423
472
return DiscreteCurve (
424
473
begin = self .begin ,
425
474
end = self .end ,
426
- num_nodes = self . _num_nodes ,
475
+ num_nodes = num_nodes ,
427
476
requires_grad = self ._requires_grad ,
477
+ params = Ct ,
428
478
)
0 commit comments