@@ -168,7 +168,7 @@ def _init_params(self, params, *args, **kwargs) -> None:
168
168
)
169
169
if params is None :
170
170
params = self .t * self .end .unsqueeze (1 ) + \
171
- (1 - self .t ) * self .begin .unsqueeze (1 ) # Bx(_num_nodes)xD
171
+ (1 - self .t ) * self .begin .unsqueeze (1 ) # Bx(_num_nodes)xD
172
172
if self ._requires_grad :
173
173
self .register_parameter ("params" , nn .Parameter (params ))
174
174
else :
@@ -185,7 +185,7 @@ def forward(self, t: torch.Tensor) -> torch.Tensor:
185
185
torch .ones (B , 1 , D , dtype = self .t .dtype , device = self .device ),
186
186
),
187
187
dim = 1
188
- ) # Bx(num_nodes)xD
188
+ ) # Bx(num_nodes)xD
189
189
a = (end_nodes - start_nodes ) / (t0 [:, 1 :] - t0 [:, :- 1 ]) # Bx(num_edges)xD
190
190
b = start_nodes - a * t0 [:, :- 1 ] # Bx(num_edges)xD
191
191
@@ -256,7 +256,7 @@ def constant_speed(
256
256
new_t = torch .cat ((zero , cs / cs [:, - 1 ].unsqueeze (1 )), dim = 1 ) # BxN
257
257
S = CubicSpline (zero , one )
258
258
_ = S .fit (new_t , t .unsqueeze (0 ).expand (B , - 1 ).unsqueeze (2 ))
259
- new_params = self (S (self .t [:, :, 0 ]).squeeze (- 1 )) # Bx(num_nodes-2)xD
259
+ new_params = self (S (self .t [:, :, 0 ]).squeeze (- 1 )) # Bx(num_nodes-2)xD
260
260
self .params = nn .Parameter (new_params )
261
261
return new_t , Ct
262
262
@@ -467,9 +467,9 @@ def todiscrete(self, num_nodes=None):
467
467
468
468
if num_nodes is None :
469
469
num_nodes = self ._num_nodes
470
- t = torch .linspace (0 , 1 , num_nodes )[1 :- 1 ] # (num_nodes-2)
471
- Ct = self (t ) # Bx(num_nodes-2)xD
472
-
470
+ t = torch .linspace (0 , 1 , num_nodes )[1 :- 1 ] # (num_nodes-2)
471
+ Ct = self (t ) # Bx(num_nodes-2)xD
472
+
473
473
return DiscreteCurve (
474
474
begin = self .begin ,
475
475
end = self .end ,
0 commit comments