Skip to content

Commit 0dd95a1

Browse files
committed
update: reset
1 parent dbcf3b4 commit 0dd95a1

File tree

1 file changed

+1
-3
lines changed

1 file changed

+1
-3
lines changed

pytorch_optimizer/shampoo.py

Lines changed: 1 addition & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -70,10 +70,8 @@ def reset(self):
7070
state = self.state[p]
7171

7272
state['step'] = 0
73-
if self.momentum > 0.0:
74-
state['momentum_buffer'] = p.grad.clone()
7573

76-
# precondition matrices
74+
# pre-condition matrices
7775
for dim_id, dim in enumerate(p.grad.size()):
7876
state[f'pre_cond_{dim_id}'] = group['eps'] * torch.eye(dim, out=p.grad.new(dim, dim))
7977
state[f'inv_pre_cond_{dim_id}'] = p.grad.new(dim, dim).zero_()

0 commit comments

Comments
 (0)