We read every piece of feedback, and take your input very seriously.
To see all available qualifiers, see our documentation.
There was an error while loading. Please reload this page.
1 parent b44e563 commit 43a35a9Copy full SHA for 43a35a9
pytorch_optimizer/shampoo.py
@@ -108,7 +108,7 @@ def step(self, closure: CLOSURE = None) -> LOSS:
108
state[f'inv_pre_cond_{dim_id}'] = grad.new(dim, dim).zero_()
109
110
if momentum > 0.0:
111
- grad.mul_(1.0 - momentum).add_(momentum, alpha=state['momentum_buffer'])
+ grad.mul_(1.0 - momentum).add_(state['momentum_buffer'], alpha=momentum)
112
113
weight_decay = group['weight_decay']
114
if weight_decay > 0.0:
0 commit comments