Skip to content

Commit 8c5e9c2

Browse files
committed
fix: key
1 parent 097f977 commit 8c5e9c2

File tree

1 file changed

+9
-6
lines changed

1 file changed

+9
-6
lines changed

pytorch_optimizer/optimizer/spam.py

Lines changed: 9 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -228,7 +228,7 @@ def step(self, closure: CLOSURE = None) -> LOSS:
228228
if 'mask' in state:
229229
grad = grad[state['mask']]
230230

231-
if len(state) == 0:
231+
if 'exp_avg' not in state:
232232
state['exp_avg'] = torch.zeros_like(grad)
233233
state['exp_avg_sq'] = torch.zeros_like(grad)
234234

@@ -258,11 +258,14 @@ def step(self, closure: CLOSURE = None) -> LOSS:
258258
else:
259259
p.addcdiv_(exp_avg, de_nom, value=-step_size * scale_factor)
260260

261-
if group['weight_decay'] > 0:
262-
if 'mask' in state:
263-
p[state['mask']].add_(p[state['mask']], alpha=-group['lr'] * group['weight_decay'])
264-
else:
265-
p.add_(p, alpha=-group['lr'] * group['weight_decay'])
261+
self.apply_weight_decay(
262+
p[state['mask']] if 'mask' in state else p,
263+
grad=None,
264+
lr=group['lr'],
265+
weight_decay=group['weight_decay'],
266+
weight_decouple=True,
267+
fixed_decay=False,
268+
)
266269

267270
self.state['total_step'] += 1
268271
self.state['current_step'] += 1

0 commit comments

Comments
 (0)