Skip to content

Commit 3e68698

Browse files
committed
refactor: weight_decay
1 parent dd11575 commit 3e68698

File tree

1 file changed

+3
-4
lines changed

1 file changed

+3
-4
lines changed

pytorch_optimizer/radam.py

Lines changed: 3 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -152,14 +152,13 @@ def step(self, closure: CLOSURE = None) -> LOSS:
152152
step_size = -1
153153
buffered[2] = step_size
154154

155+
if (n_sma >= self.n_sma_threshold or step_size > 0) and group['weight_decay'] != 0:
156+
p_fp32.add_(p_fp32, alpha=-group['weight_decay'] * group['lr'])
157+
155158
if n_sma >= self.n_sma_threshold:
156-
if group['weight_decay'] != 0:
157-
p_fp32.add_(p_fp32, alpha=-group['weight_decay'] * group['lr'])
158159
de_nom = exp_avg_sq.sqrt().add_(group['eps'])
159160
p_fp32.addcdiv_(exp_avg, de_nom, value=-step_size * group['lr'])
160161
elif step_size > 0:
161-
if group['weight_decay'] != 0:
162-
p_fp32.add_(p_fp32, alpha=-group['weight_decay'] * group['lr'])
163162
p_fp32.add_(exp_avg, alpha=-step_size * group['lr'])
164163

165164
if p.dtype in (torch.float16, torch.bfloat16):

0 commit comments

Comments
 (0)