Skip to content

Commit de06f63

Browse files
authored
Merge pull request #102 from kozistr/feature/stable-weight-decay
[Feature] Stable Weight Decay
2 parents 75a023a + 115906a commit de06f63

File tree

7 files changed

+182
-8
lines changed

7 files changed

+182
-8
lines changed

docs/optimizer_api.rst

Lines changed: 8 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -224,3 +224,11 @@ DAdaptSGD
224224

225225
.. autoclass:: pytorch_optimizer.DAdaptSGD
226226
:members:
227+
228+
.. _AdamS:
229+
230+
AdamS
231+
-----
232+
233+
.. autoclass:: pytorch_optimizer.AdamS
234+
:members:

pytorch_optimizer/__init__.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -18,6 +18,7 @@
1818
from pytorch_optimizer.optimizer.adabound import AdaBound
1919
from pytorch_optimizer.optimizer.adai import Adai
2020
from pytorch_optimizer.optimizer.adamp import AdamP
21+
from pytorch_optimizer.optimizer.adams import AdamS
2122
from pytorch_optimizer.optimizer.adan import Adan
2223
from pytorch_optimizer.optimizer.adapnm import AdaPNM
2324
from pytorch_optimizer.optimizer.agc import agc
@@ -88,6 +89,7 @@
8889
DAdaptAdaGrad,
8990
DAdaptAdam,
9091
DAdaptSGD,
92+
AdamS,
9193
]
9294
OPTIMIZERS: Dict[str, OPTIMIZER] = {str(optimizer.__name__).lower(): optimizer for optimizer in OPTIMIZER_LIST}
9395

pytorch_optimizer/optimizer/adai.py

Lines changed: 14 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -17,8 +17,8 @@ class Adai(Optimizer, BaseOptimizer):
1717
:param betas: BETAS. coefficients used for computing running averages of gradient and the squared hessian trace.
1818
:param weight_decay: float. weight decay (L2 penalty).
1919
:param weight_decouple: bool. the optimizer uses decoupled weight decay as in AdamW.
20-
:param dampening: float. dampening for momentum. where dampening < 1,
21-
it will show some adaptive-moment behavior.
20+
:param use_stable_weight_decay: bool. perform stable weight decay.
21+
:param dampening: float. dampening for momentum. where dampening < 1, it will show some adaptive-moment behavior.
2222
:param use_gc: bool. use gradient centralization.
2323
:param eps: float. term added to the denominator to improve numerical stability.
2424
"""
@@ -30,6 +30,7 @@ def __init__(
3030
betas: BETAS = (0.1, 0.99),
3131
weight_decay: float = 0.0,
3232
weight_decouple: bool = False,
33+
use_stable_weight_decay: bool = False,
3334
dampening: float = 1.0,
3435
use_gc: bool = False,
3536
eps: float = 1e-3,
@@ -38,6 +39,7 @@ def __init__(
3839
self.betas = betas
3940
self.weight_decay = weight_decay
4041
self.weight_decouple = weight_decouple
42+
self.use_stable_weight_decay = use_stable_weight_decay
4143
self.dampening = dampening
4244
self.use_gc = use_gc
4345
self.eps = eps
@@ -111,7 +113,7 @@ def step(self, closure: CLOSURE = None) -> LOSS:
111113

112114
bias_correction2 = 1.0 - beta2 ** state['step']
113115

114-
if group['weight_decay'] > 0.0:
116+
if not self.use_stable_weight_decay and group['weight_decay'] > 0.0:
115117
if self.weight_decouple:
116118
p.mul_(1.0 - group['lr'] * group['weight_decay'])
117119
else:
@@ -137,8 +139,13 @@ def step(self, closure: CLOSURE = None) -> LOSS:
137139
grad = p.grad
138140
state = self.state[p]
139141

142+
if self.use_stable_weight_decay and group['weight_decay'] > 0.0:
143+
if self.weight_decouple:
144+
p.mul_(1.0 - group['lr'] * group['weight_decay'])
145+
else:
146+
grad.add_(p, alpha=group['weight_decay'])
147+
140148
exp_avg, exp_avg_sq = state['exp_avg'], state['exp_avg_sq']
141-
beta1_prod = state['beta1_prod']
142149

143150
bias_correction2 = 1.0 - beta2 ** state['step']
144151

@@ -148,11 +155,13 @@ def step(self, closure: CLOSURE = None) -> LOSS:
148155
).clamp(0.0, 1.0 - group['eps'])
149156
beta3 = (1.0 - beta1).pow(group['dampening'])
150157

158+
beta1_prod = state['beta1_prod']
151159
beta1_prod.mul_(beta1)
160+
152161
bias_correction1 = 1.0 - beta1_prod
153162

154163
exp_avg.mul_(beta1).addcmul_(beta3, grad)
155-
exp_avg_hat = exp_avg / bias_correction1 * beta0_dp
164+
exp_avg_hat = exp_avg.div(bias_correction1).mul(beta0_dp)
156165

157166
p.add_(exp_avg_hat, alpha=-group['lr'])
158167

Lines changed: 147 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,147 @@
1+
import math
2+
3+
import torch
4+
from torch.optim.optimizer import Optimizer
5+
6+
from pytorch_optimizer.base.exception import NoSparseGradientError, ZeroParameterSizeError
7+
from pytorch_optimizer.base.optimizer import BaseOptimizer
8+
from pytorch_optimizer.base.types import BETAS, CLOSURE, DEFAULTS, LOSS, PARAMETERS
9+
10+
11+
class AdamS(Optimizer, BaseOptimizer):
12+
r"""Adam with stable weight decay.
13+
14+
:param params: PARAMETERS. iterable of parameters to optimize or dicts defining parameter groups.
15+
:param lr: float. learning rate.
16+
:param betas: BETAS. coefficients used for computing running averages of gradient and the squared hessian trace.
17+
:param weight_decay: float. weight decay (L2 penalty).
18+
:param amsgrad: bool. whether to use the AMSGrad variant of this algorithm from the paper.
19+
:param adamd_debias_term: bool. Only correct the denominator to avoid inflating step sizes early in training.
20+
:param eps: float. term added to the denominator to improve numerical stability.
21+
"""
22+
23+
def __init__(
24+
self,
25+
params: PARAMETERS,
26+
lr: float = 1e-3,
27+
betas: BETAS = (0.9, 0.999),
28+
weight_decay: float = 1e-4,
29+
amsgrad: bool = False,
30+
adamd_debias_term: bool = False,
31+
eps: float = 1e-8,
32+
):
33+
self.lr = lr
34+
self.betas = betas
35+
self.weight_decay = weight_decay
36+
self.amsgrad = amsgrad
37+
self.adamd_debias_term = adamd_debias_term
38+
self.eps = eps
39+
40+
self.validate_parameters()
41+
42+
defaults: DEFAULTS = {
43+
'lr': lr,
44+
'betas': betas,
45+
'weight_decay': weight_decay,
46+
'eps': eps,
47+
}
48+
super().__init__(params, defaults)
49+
50+
def validate_parameters(self):
51+
self.validate_learning_rate(self.lr)
52+
self.validate_betas(self.betas)
53+
self.validate_weight_decay(self.weight_decay)
54+
self.validate_epsilon(self.eps)
55+
56+
@property
57+
def __str__(self) -> str:
58+
return 'AdamS'
59+
60+
@torch.no_grad()
61+
def reset(self):
62+
for group in self.param_groups:
63+
for p in group['params']:
64+
state = self.state[p]
65+
66+
state['step'] = 0
67+
state['exp_avg'] = torch.zeros_like(p)
68+
state['exp_avg_sq'] = torch.zeros_like(p)
69+
70+
@torch.no_grad()
71+
def step(self, closure: CLOSURE = None) -> LOSS:
72+
loss: LOSS = None
73+
if closure is not None:
74+
with torch.enable_grad():
75+
loss = closure()
76+
77+
param_size: int = 0
78+
exp_avg_sq_hat_sum: float = 0.0
79+
80+
for group in self.param_groups:
81+
beta1, beta2 = group['betas']
82+
for p in group['params']:
83+
if p.grad is None:
84+
continue
85+
86+
grad = p.grad
87+
if grad.is_sparse:
88+
raise NoSparseGradientError(self.__str__)
89+
90+
param_size += p.numel()
91+
92+
state = self.state[p]
93+
94+
if len(state) == 0:
95+
state['step'] = 0
96+
state['exp_avg'] = torch.zeros_like(p)
97+
state['exp_avg_sq'] = torch.zeros_like(p)
98+
if self.amsgrad:
99+
state['max_exp_avg_sq'] = torch.zeros_like(p)
100+
101+
state['step'] += 1
102+
exp_avg, exp_avg_sq = state['exp_avg'], state['exp_avg_sq']
103+
104+
bias_correction2 = 1.0 - beta2 ** state['step']
105+
106+
exp_avg.mul_(beta1).add_(grad, alpha=1.0 - beta1)
107+
exp_avg_sq.mul_(beta2).addcmul_(grad, grad, value=1.0 - beta2)
108+
109+
if self.amsgrad:
110+
max_exp_avg_sq = state['max_exp_avg_sq']
111+
torch.max(max_exp_avg_sq, exp_avg_sq, out=max_exp_avg_sq)
112+
exp_avg_sq_hat = max_exp_avg_sq
113+
else:
114+
exp_avg_sq_hat = exp_avg_sq
115+
116+
exp_avg_sq_hat_sum += exp_avg_sq_hat.sum() / bias_correction2
117+
118+
if param_size == 0:
119+
raise ZeroParameterSizeError()
120+
121+
exp_avg_sq_hat_mean = math.sqrt(exp_avg_sq_hat_sum / param_size) + self.eps
122+
123+
for group in self.param_groups:
124+
beta1, beta2 = group['betas']
125+
for p in group['params']:
126+
if p.grad is None:
127+
continue
128+
129+
state = self.state[p]
130+
131+
if group['weight_decay'] > 0.0:
132+
p.mul_(1.0 - group['lr'] * group['weight_decay'] / exp_avg_sq_hat_mean)
133+
134+
exp_avg, exp_avg_sq = state['exp_avg'], state['exp_avg_sq']
135+
136+
bias_correction1 = 1.0 - beta1 ** state['step']
137+
bias_correction2 = 1.0 - beta2 ** state['step']
138+
139+
exp_avg_sq_hat = state['max_exp_avg_sq'] if self.amsgrad else exp_avg_sq
140+
exp_avg_sq_hat.div_(bias_correction2)
141+
142+
de_nom = exp_avg_sq_hat.sqrt().add(group['eps'])
143+
144+
step_size = group['lr'] if self.adamd_debias_term else group['lr'] / bias_correction1
145+
p.addcdiv_(exp_avg, de_nom, value=-step_size)
146+
147+
return loss

tests/constants.py

Lines changed: 9 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -9,6 +9,7 @@
99
AdaBound,
1010
Adai,
1111
AdamP,
12+
AdamS,
1213
Adan,
1314
AdaPNM,
1415
DAdaptAdaGrad,
@@ -55,6 +56,7 @@
5556
'pnm',
5657
'dadaptadam',
5758
'dadaptsgd',
59+
'adams',
5860
]
5961
VALID_OPTIMIZER_NAMES: List[str] = [
6062
'adamp',
@@ -79,6 +81,7 @@
7981
'dadaptadagrad',
8082
'dadaptadam',
8183
'dadaptsgd',
84+
'adams',
8285
]
8386
INVALID_OPTIMIZER_NAMES: List[str] = [
8487
'asam',
@@ -105,6 +108,7 @@
105108
'adai',
106109
'shampoo',
107110
'dadaptadam',
111+
'adams',
108112
]
109113

110114
VALID_LR_SCHEDULER_NAMES: List[str] = [
@@ -135,6 +139,8 @@
135139
(Adai, {'lr': 1e-1, 'weight_decay': 0.0, 'dampening': 0.9}, 150),
136140
(Adai, {'lr': 1e-1, 'weight_decay': 1e-4, 'weight_decouple': False}, 150),
137141
(Adai, {'lr': 1e-1, 'weight_decay': 1e-4, 'weight_decouple': True}, 150),
142+
(Adai, {'lr': 1e-1, 'weight_decay': 1e-4, 'weight_decouple': False, 'use_stable_weight_decay': True}, 150),
143+
(Adai, {'lr': 1e-1, 'weight_decay': 1e-4, 'weight_decouple': True, 'use_stable_weight_decay': True}, 150),
138144
(AdamP, {'lr': 5e-1, 'weight_decay': 1e-3}, 10),
139145
(AdamP, {'lr': 5e-1, 'weight_decay': 1e-3, 'use_gc': True}, 10),
140146
(AdamP, {'lr': 5e-1, 'weight_decay': 1e-3, 'nesterov': True}, 10),
@@ -188,12 +194,13 @@
188194
(DAdaptAdam, {'lr': 1.0, 'weight_decay': 1e-2, 'weight_decouple': True}, 50),
189195
(DAdaptSGD, {'lr': 1.0, 'weight_decay': 1e-2}, 50),
190196
(DAdaptSGD, {'lr': 1.0, 'momentum': 0.9, 'weight_decay': 1e-3}, 50),
197+
(AdamS, {'lr': 1.0, 'weight_decay': 1e-3}, 50),
198+
(AdamS, {'lr': 1.0, 'weight_decay': 1e-3, 'amsgrad': True}, 50),
191199
]
192200
ADAMD_SUPPORTED_OPTIMIZERS: List[Tuple[Any, Dict[str, Union[float, bool, int]], int]] = [
193201
(build_lookahead, {'lr': 5e-1, 'weight_decay': 1e-3, 'adamd_debias_term': True}, 10),
194202
(AdaBelief, {'lr': 5e-1, 'weight_decay': 1e-3, 'adamd_debias_term': True}, 50),
195203
(AdaBound, {'lr': 5e-1, 'gamma': 0.1, 'weight_decay': 1e-3, 'adamd_debias_term': True}, 100),
196-
(AdaBound, {'lr': 1e-2, 'gamma': 0.1, 'weight_decay': 1e-3, 'amsbound': True, 'adamd_debias_term': True}, 100),
197204
(AdamP, {'lr': 5e-1, 'weight_decay': 1e-3, 'adamd_debias_term': True}, 10),
198205
(DiffGrad, {'lr': 5e-1, 'weight_decay': 1e-3, 'adamd_debias_term': True}, 10),
199206
(DiffRGrad, {'lr': 1e-1, 'weight_decay': 1e-3, 'adamd_debias_term': True}, 100),
@@ -202,4 +209,5 @@
202209
(Ranger, {'lr': 5e-1, 'weight_decay': 1e-3, 'adamd_debias_term': True}, 100),
203210
(Ranger21, {'lr': 5e-1, 'weight_decay': 1e-3, 'adamd_debias_term': True, 'num_iterations': 200}, 200),
204211
(AdaPNM, {'lr': 3e-1, 'weight_decay': 1e-3, 'adamd_debias_term': True}, 50),
212+
(AdamS, {'lr': 2e1, 'weight_decay': 1e-3, 'adamd_debias_term': True}, 50),
205213
]

tests/test_load_optimizers.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -16,4 +16,4 @@ def test_load_optimizers_invalid(invalid_optimizer_names):
1616

1717

1818
def test_get_supported_optimizers():
19-
assert len(get_supported_optimizers()) == 22
19+
assert len(get_supported_optimizers()) == 23

tests/test_optimizers.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -227,7 +227,7 @@ def test_closure(optimizer):
227227
optimizer = optimizer([param], num_iterations=1) if optimizer_name == 'Ranger21' else optimizer([param])
228228
optimizer.zero_grad()
229229

230-
if optimizer_name in ('Ranger21', 'Adai'):
230+
if optimizer_name in ('Ranger21', 'Adai', 'AdamS'):
231231
with pytest.raises(ZeroParameterSizeError):
232232
optimizer.step(closure=dummy_closure)
233233
else:

0 commit comments

Comments
 (0)