We read every piece of feedback, and take your input very seriously.
To see all available qualifiers, see our documentation.
There was an error while loading. Please reload this page.
1 parent dbcf3b4 commit 0dd95a1Copy full SHA for 0dd95a1
pytorch_optimizer/shampoo.py
@@ -70,10 +70,8 @@ def reset(self):
70
state = self.state[p]
71
72
state['step'] = 0
73
- if self.momentum > 0.0:
74
- state['momentum_buffer'] = p.grad.clone()
75
76
- # precondition matrices
+ # pre-condition matrices
77
for dim_id, dim in enumerate(p.grad.size()):
78
state[f'pre_cond_{dim_id}'] = group['eps'] * torch.eye(dim, out=p.grad.new(dim, dim))
79
state[f'inv_pre_cond_{dim_id}'] = p.grad.new(dim, dim).zero_()
0 commit comments