Skip to content

Commit 43a35a9

Browse files
committed
fix: momentum_buffer
1 parent b44e563 commit 43a35a9

File tree

1 file changed

+1
-1
lines changed

1 file changed

+1
-1
lines changed

pytorch_optimizer/shampoo.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -108,7 +108,7 @@ def step(self, closure: CLOSURE = None) -> LOSS:
108108
state[f'inv_pre_cond_{dim_id}'] = grad.new(dim, dim).zero_()
109109

110110
if momentum > 0.0:
111-
grad.mul_(1.0 - momentum).add_(momentum, alpha=state['momentum_buffer'])
111+
grad.mul_(1.0 - momentum).add_(state['momentum_buffer'], alpha=momentum)
112112

113113
weight_decay = group['weight_decay']
114114
if weight_decay > 0.0:

0 commit comments

Comments
 (0)