Skip to content

Commit e94290f

Browse files
committed
fix: reset
1 parent 4f4d359 commit e94290f

File tree

1 file changed

+2
-2
lines changed

1 file changed

+2
-2
lines changed

pytorch_optimizer/optimizer/trac.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -138,7 +138,7 @@ def state(self):
138138
def reset(self):
139139
device = self.param_groups[0]['params'][0].device
140140

141-
self.state = {
141+
self.state['trac'] = {
142142
'betas': torch.tensor(self.betas, device=device),
143143
's': torch.zeros(len(self.betas), device=device),
144144
'variance': torch.zeros(len(self.betas), device=device),
@@ -148,7 +148,7 @@ def reset(self):
148148

149149
for group in self.param_groups:
150150
for p in group['params']:
151-
self.state[p] = p.clone()
151+
self.state['trac'][p] = p.clone()
152152

153153
@torch.no_grad()
154154
def zero_grad(self) -> None:

0 commit comments

Comments
 (0)