We read every piece of feedback, and take your input very seriously.
To see all available qualifiers, see our documentation.
There was an error while loading. Please reload this page.
1 parent 4f4d359 commit e94290fCopy full SHA for e94290f
pytorch_optimizer/optimizer/trac.py
@@ -138,7 +138,7 @@ def state(self):
138
def reset(self):
139
device = self.param_groups[0]['params'][0].device
140
141
- self.state = {
+ self.state['trac'] = {
142
'betas': torch.tensor(self.betas, device=device),
143
's': torch.zeros(len(self.betas), device=device),
144
'variance': torch.zeros(len(self.betas), device=device),
@@ -148,7 +148,7 @@ def reset(self):
148
149
for group in self.param_groups:
150
for p in group['params']:
151
- self.state[p] = p.clone()
+ self.state['trac'][p] = p.clone()
152
153
@torch.no_grad()
154
def zero_grad(self) -> None:
0 commit comments