Skip to content

Commit 4e5a011

Browse files
committed
update: weight decouple to True
1 parent c7c3cfb commit 4e5a011

File tree

1 file changed

+1
-7
lines changed

1 file changed

+1
-7
lines changed

pytorch_optimizer/optimizer/adams.py

Lines changed: 1 addition & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -28,15 +28,13 @@ def __init__(
2828
lr: float = 1e-3,
2929
betas: BETAS = (0.1, 0.99),
3030
weight_decay: float = 1e-4,
31-
weight_decouple: bool = False,
3231
amsgrad: bool = False,
3332
adamd_debias_term: bool = False,
3433
eps: float = 1e-8,
3534
):
3635
self.lr = lr
3736
self.betas = betas
3837
self.weight_decay = weight_decay
39-
self.weight_decouple = weight_decouple
4038
self.amsgrad = amsgrad
4139
self.adamd_debias_term = adamd_debias_term
4240
self.eps = eps
@@ -130,14 +128,10 @@ def step(self, closure: CLOSURE = None) -> LOSS:
130128
if p.grad is None:
131129
continue
132130

133-
grad = p.grad
134131
state = self.state[p]
135132

136133
if group['weight_decay'] > 0.0:
137-
if self.weight_decouple:
138-
p.mul_(1.0 - group['lr'] * group['weight_decay'])
139-
else:
140-
grad.add_(p, alpha=group['weight_decay'])
134+
p.mul_(1.0 - group['lr'] * group['weight_decay'] / exp_avg_sq_hat_mean)
141135

142136
exp_avg, exp_avg_sq = state['exp_avg'], state['exp_avg_sq']
143137

0 commit comments

Comments
 (0)