@@ -140,10 +140,11 @@ def closure():
140
140
L .backward ()
141
141
return L
142
142
143
- for _ in range (num_steps ):
144
- loss = opt .step (closure = closure )
145
- if torch .max (torch .abs (self .params .grad )) < threshold :
146
- break
143
+ with torch .enable_grad ():
144
+ for _ in range (num_steps ):
145
+ loss = opt .step (closure = closure )
146
+ if torch .max (torch .abs (self .params .grad )) < threshold :
147
+ break
147
148
return loss
148
149
149
150
@@ -252,8 +253,7 @@ def __setitem__(self, indices, curves) -> None:
252
253
# one = torch.ones(B, 1) # Bx1 -- XXX: ditto
253
254
# new_t = torch.cat((zero, cs / cs[:, -1].unsqueeze(1)), dim=1) # BxN
254
255
# S = CubicSpline(zero, one)
255
- # with torch.enable_grad():
256
- # _ = S.fit(new_t, t.unsqueeze(0).expand(B, -1).unsqueeze(2))
256
+ # _ = S.fit(new_t, t.unsqueeze(0).expand(B, -1).unsqueeze(2))
257
257
# new_params = self(S(self.t[:, 0, 0]).squeeze(-1)) # B
258
258
259
259
# from IPython import embed; embed()
@@ -458,8 +458,7 @@ def constant_speed(
458
458
) # Bx(N-1)
459
459
cs = local_len .cumsum (dim = 1 ) # Bx(N-1)
460
460
new_t = torch .cat ((torch .zeros (B , 1 ), cs / cs [:, - 1 ].unsqueeze (1 )), dim = 1 ) # BxN
461
- with torch .enable_grad ():
462
- _ = self .fit (new_t , Ct )
461
+ _ = self .fit (new_t , Ct )
463
462
return new_t , Ct
464
463
465
464
def todiscrete (self , num_nodes = None ):
0 commit comments