Skip to content

Commit 6736f5d

Browse files
committed
CubicSpline.fit: enable_grad always
1 parent 27c26ae commit 6736f5d

File tree

1 file changed

+7
-8
lines changed

1 file changed

+7
-8
lines changed

stochman/curves.py

Lines changed: 7 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -140,10 +140,11 @@ def closure():
140140
L.backward()
141141
return L
142142

143-
for _ in range(num_steps):
144-
loss = opt.step(closure=closure)
145-
if torch.max(torch.abs(self.params.grad)) < threshold:
146-
break
143+
with torch.enable_grad():
144+
for _ in range(num_steps):
145+
loss = opt.step(closure=closure)
146+
if torch.max(torch.abs(self.params.grad)) < threshold:
147+
break
147148
return loss
148149

149150

@@ -252,8 +253,7 @@ def __setitem__(self, indices, curves) -> None:
252253
# one = torch.ones(B, 1) # Bx1 -- XXX: ditto
253254
# new_t = torch.cat((zero, cs / cs[:, -1].unsqueeze(1)), dim=1) # BxN
254255
# S = CubicSpline(zero, one)
255-
# with torch.enable_grad():
256-
# _ = S.fit(new_t, t.unsqueeze(0).expand(B, -1).unsqueeze(2))
256+
# _ = S.fit(new_t, t.unsqueeze(0).expand(B, -1).unsqueeze(2))
257257
# new_params = self(S(self.t[:, 0, 0]).squeeze(-1)) # B
258258

259259
# from IPython import embed; embed()
@@ -458,8 +458,7 @@ def constant_speed(
458458
) # Bx(N-1)
459459
cs = local_len.cumsum(dim=1) # Bx(N-1)
460460
new_t = torch.cat((torch.zeros(B, 1), cs / cs[:, -1].unsqueeze(1)), dim=1) # BxN
461-
with torch.enable_grad():
462-
_ = self.fit(new_t, Ct)
461+
_ = self.fit(new_t, Ct)
463462
return new_t, Ct
464463

465464
def todiscrete(self, num_nodes=None):

0 commit comments

Comments
 (0)