Skip to content

Commit 6a91d6a

Browse files
committed
flake8 adaptation
1 parent 22cf106 commit 6a91d6a

File tree

1 file changed

+6
-6
lines changed

1 file changed

+6
-6
lines changed

stochman/curves.py

Lines changed: 6 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -168,7 +168,7 @@ def _init_params(self, params, *args, **kwargs) -> None:
168168
)
169169
if params is None:
170170
params = self.t * self.end.unsqueeze(1) + \
171-
(1 - self.t) * self.begin.unsqueeze(1) # Bx(_num_nodes)xD
171+
(1 - self.t) * self.begin.unsqueeze(1) # Bx(_num_nodes)xD
172172
if self._requires_grad:
173173
self.register_parameter("params", nn.Parameter(params))
174174
else:
@@ -185,7 +185,7 @@ def forward(self, t: torch.Tensor) -> torch.Tensor:
185185
torch.ones(B, 1, D, dtype=self.t.dtype, device=self.device),
186186
),
187187
dim=1
188-
) # Bx(num_nodes)xD
188+
) # Bx(num_nodes)xD
189189
a = (end_nodes - start_nodes) / (t0[:, 1:] - t0[:, :-1]) # Bx(num_edges)xD
190190
b = start_nodes - a * t0[:, :-1] # Bx(num_edges)xD
191191

@@ -256,7 +256,7 @@ def constant_speed(
256256
new_t = torch.cat((zero, cs / cs[:, -1].unsqueeze(1)), dim=1) # BxN
257257
S = CubicSpline(zero, one)
258258
_ = S.fit(new_t, t.unsqueeze(0).expand(B, -1).unsqueeze(2))
259-
new_params = self(S(self.t[:, :, 0]).squeeze(-1)) # Bx(num_nodes-2)xD
259+
new_params = self(S(self.t[:, :, 0]).squeeze(-1)) # Bx(num_nodes-2)xD
260260
self.params = nn.Parameter(new_params)
261261
return new_t, Ct
262262

@@ -467,9 +467,9 @@ def todiscrete(self, num_nodes=None):
467467

468468
if num_nodes is None:
469469
num_nodes = self._num_nodes
470-
t = torch.linspace(0, 1, num_nodes)[1:-1] # (num_nodes-2)
471-
Ct = self(t) # Bx(num_nodes-2)xD
472-
470+
t = torch.linspace(0, 1, num_nodes)[1:-1] # (num_nodes-2)
471+
Ct = self(t) # Bx(num_nodes-2)xD
472+
473473
return DiscreteCurve(
474474
begin=self.begin,
475475
end=self.end,

0 commit comments

Comments
 (0)