Skip to content

Commit 8c7fd8e

Browse files
authored
Merge pull request #133 from kozistr/feature/adamnorm-optimizer
[Feature] Implement AdaNorm optimizer
2 parents be0351d + b14aa03 commit 8c7fd8e

34 files changed

+984
-785
lines changed

README.rst

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -140,6 +140,8 @@ You can check the supported optimizers & lr schedulers.
140140
+--------------+-------------------------------------------------------------------------------------------------+-----------------------------------------------------------------------------------+-----------------------------------------------------------------------------------------------+
141141
| SM3 | *Memory-Efficient Adaptive Optimization* | `github <https://github.com/google-research/google-research/tree/master/sm3>`__ | `https://arxiv.org/abs/1901.11150 <https://arxiv.org/abs/1901.11150>`__ |
142142
+--------------+-------------------------------------------------------------------------------------------------+-----------------------------------------------------------------------------------+-----------------------------------------------------------------------------------------------+
143+
| AdaNorm` | *Adaptive Gradient Norm Correction based Optimizer for CNNs* | `github <https://github.com/shivram1987/AdaNorm>`__ | `https://arxiv.org/abs/2210.06364 <https://arxiv.org/abs/2210.06364>`__ |
144+
+--------------+-------------------------------------------------------------------------------------------------+-----------------------------------------------------------------------------------+-----------------------------------------------------------------------------------------------+
143145

144146
Useful Resources
145147
----------------
@@ -347,6 +349,8 @@ Citations
347349

348350
`SM3 <https://ui.adsabs.harvard.edu/abs/2019arXiv190111150A/exportcitation>`__
349351

352+
`AdaNorm <https://github.com/shivram1987/AdaNorm/tree/main#citation>`__
353+
350354
Citation
351355
--------
352356

docs/optimizer_api.rst

Lines changed: 8 additions & 16 deletions
Original file line numberDiff line numberDiff line change
@@ -65,14 +65,6 @@ diffGrad
6565
.. autoclass:: pytorch_optimizer.DiffGrad
6666
:members:
6767

68-
.. _diffRGrad:
69-
70-
diffRGrad
71-
---------
72-
73-
.. autoclass:: pytorch_optimizer.DiffRGrad
74-
:members:
75-
7668
.. _GC:
7769

7870
GC
@@ -145,14 +137,6 @@ RAdam
145137
.. autoclass:: pytorch_optimizer.RAdam
146138
:members:
147139

148-
.. _RaLamb:
149-
150-
RaLamb
151-
------
152-
153-
.. autoclass:: pytorch_optimizer.RaLamb
154-
:members:
155-
156140
.. _Ranger:
157141

158142
Ranger
@@ -288,3 +272,11 @@ SM3
288272

289273
.. autoclass:: pytorch_optimizer.SM3
290274
:members:
275+
276+
.. _AdaNorm:
277+
278+
AdaNorm
279+
-------
280+
281+
.. autoclass:: pytorch_optimizer.AdaNorm
282+
:members:

docs/util_api.rst

Lines changed: 8 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -169,4 +169,11 @@ merge_small_dims
169169

170170
.. autoclass:: pytorch_optimizer.merge_small_dims
171171
:members:
172-
re
172+
173+
.. _reduce_max_except_dim:
174+
175+
reduce_max_except_dim
176+
---------------------
177+
178+
.. autoclass:: pytorch_optimizer.reduce_max_except_dim
179+
:members:

pytorch_optimizer/__init__.py

Lines changed: 3 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -23,13 +23,13 @@
2323
from pytorch_optimizer.optimizer.adamp import AdamP
2424
from pytorch_optimizer.optimizer.adams import AdamS
2525
from pytorch_optimizer.optimizer.adan import Adan
26+
from pytorch_optimizer.optimizer.adanorm import AdaNorm
2627
from pytorch_optimizer.optimizer.adapnm import AdaPNM
2728
from pytorch_optimizer.optimizer.agc import agc
2829
from pytorch_optimizer.optimizer.alig import AliG
2930
from pytorch_optimizer.optimizer.apollo import Apollo
3031
from pytorch_optimizer.optimizer.dadapt import DAdaptAdaGrad, DAdaptAdam, DAdaptSGD
3132
from pytorch_optimizer.optimizer.diffgrad import DiffGrad
32-
from pytorch_optimizer.optimizer.diffrgrad import DiffRGrad
3333
from pytorch_optimizer.optimizer.fp16 import DynamicLossScaler, SafeFP16Optimizer
3434
from pytorch_optimizer.optimizer.gc import centralize_gradient
3535
from pytorch_optimizer.optimizer.gsam import GSAM
@@ -43,7 +43,6 @@
4343
from pytorch_optimizer.optimizer.pcgrad import PCGrad
4444
from pytorch_optimizer.optimizer.pnm import PNM
4545
from pytorch_optimizer.optimizer.radam import RAdam
46-
from pytorch_optimizer.optimizer.ralamb import RaLamb
4746
from pytorch_optimizer.optimizer.ranger import Ranger
4847
from pytorch_optimizer.optimizer.ranger21 import Ranger21
4948
from pytorch_optimizer.optimizer.sam import SAM
@@ -71,6 +70,7 @@
7170
enable_running_stats,
7271
get_optimizer_parameters,
7372
normalize_gradient,
73+
reduce_max_except_dim,
7474
unit_norm,
7575
)
7676

@@ -82,14 +82,12 @@
8282
Adan,
8383
AdaPNM,
8484
DiffGrad,
85-
DiffRGrad,
8685
Lamb,
8786
LARS,
8887
MADGRAD,
8988
Nero,
9089
PNM,
9190
RAdam,
92-
RaLamb,
9391
Ranger,
9492
Ranger21,
9593
SGDP,
@@ -105,6 +103,7 @@
105103
Lion,
106104
AliG,
107105
SM3,
106+
AdaNorm,
108107
]
109108
OPTIMIZERS: Dict[str, OPTIMIZER] = {str(optimizer.__name__).lower(): optimizer for optimizer in OPTIMIZER_LIST}
110109

pytorch_optimizer/optimizer/adabelief.py

Lines changed: 69 additions & 53 deletions
Original file line numberDiff line numberDiff line change
@@ -21,7 +21,9 @@ class AdaBelief(Optimizer, BaseOptimizer):
2121
:param rectify: bool. perform the rectified update similar to RAdam.
2222
:param degenerated_to_sgd: bool. perform SGD update when variance of gradient is high.
2323
:param amsgrad: bool. whether to use the AMSBound variant.
24-
:param adamd_debias_term: bool. Only correct the denominator to avoid inflating step sizes early in training.
24+
:param r: float. EMA factor. between 0.9 ~ 0.99 is preferred.
25+
:param adanorm: bool. whether to use the AdaNorm variant.
26+
:param adam_debias: bool. Only correct the denominator to avoid inflating step sizes early in training.
2527
:param eps: float. term added to the denominator to improve numerical stability.
2628
"""
2729

@@ -37,31 +39,35 @@ def __init__(
3739
rectify: bool = True,
3840
degenerated_to_sgd: bool = True,
3941
amsgrad: bool = False,
40-
adamd_debias_term: bool = False,
42+
r: float = 0.95,
43+
adanorm: bool = False,
44+
adam_debias: bool = False,
4145
eps: float = 1e-16,
4246
):
4347
self.lr = lr
4448
self.betas = betas
4549
self.weight_decay = weight_decay
4650
self.n_sma_threshold = n_sma_threshold
47-
self.weight_decouple = weight_decouple
48-
self.fixed_decay = fixed_decay
49-
self.rectify = rectify
5051
self.degenerated_to_sgd = degenerated_to_sgd
51-
self.adamd_debias_term = adamd_debias_term
5252
self.eps = eps
5353

5454
self.validate_parameters()
5555

5656
defaults: DEFAULTS = {
5757
'lr': lr,
5858
'betas': betas,
59-
'eps': eps,
6059
'weight_decay': weight_decay,
60+
'weight_decouple': weight_decouple,
61+
'fixed_decay': fixed_decay,
62+
'rectify': rectify,
6163
'amsgrad': amsgrad,
62-
'adamd_debias_term': adamd_debias_term,
63-
'buffer': [[None, None, None] for _ in range(10)],
64+
'adanorm': adanorm,
65+
'adam_debias': adam_debias,
66+
'eps': eps,
6467
}
68+
if adanorm:
69+
defaults.update({'r': r})
70+
6571
super().__init__(params, defaults)
6672

6773
def validate_parameters(self):
@@ -76,12 +82,14 @@ def __str__(self) -> str:
7682
@torch.no_grad()
7783
def reset(self):
7884
for group in self.param_groups:
85+
group['step'] = 0
7986
for p in group['params']:
8087
state = self.state[p]
8188

82-
state['step'] = 0
8389
state['exp_avg'] = torch.zeros_like(p)
8490
state['exp_avg_var'] = torch.zeros_like(p)
91+
if group['adanorm']:
92+
state['exp_grad_norm'] = torch.zeros((1,), dtype=p.dtype, device=p.device)
8593
if group['amsgrad']:
8694
state['max_exp_avg_var'] = torch.zeros_like(p)
8795

@@ -93,11 +101,21 @@ def step(self, closure: CLOSURE = None) -> LOSS:
93101
loss = closure()
94102

95103
for group in self.param_groups:
104+
if 'step' in group:
105+
group['step'] += 1
106+
else:
107+
group['step'] = 1
108+
96109
beta1, beta2 = group['betas']
97-
weight_decay: float = group['weight_decay']
110+
weight_decay = group['weight_decay']
111+
112+
bias_correction1 = 1.0 - beta1 ** group['step']
113+
bias_correction2_sq = math.sqrt(1.0 - beta2 ** group['step'])
98114

99-
if self.rectify:
115+
if group['rectify']:
100116
n_sma_max: float = 2.0 / (1.0 - beta2) - 1.0
117+
beta2_t: float = beta2 ** group['step']
118+
n_sma: float = n_sma_max - 2 * group['step'] * beta2_t / (1.0 - beta2_t)
101119

102120
for p in group['params']:
103121
if p.grad is None:
@@ -109,70 +127,68 @@ def step(self, closure: CLOSURE = None) -> LOSS:
109127

110128
state = self.state[p]
111129
if len(state) == 0:
112-
state['step'] = 0
113130
state['exp_avg'] = torch.zeros_like(p)
114131
state['exp_avg_var'] = torch.zeros_like(p)
132+
if group['adanorm']:
133+
state['exp_grad_norm'] = torch.zeros((1,), dtype=grad.dtype, device=grad.device)
115134
if group['amsgrad']:
116135
state['max_exp_avg_var'] = torch.zeros_like(p)
117136

118-
if self.weight_decouple:
119-
p.mul_(1.0 - (group['lr'] * weight_decay if not self.fixed_decay else weight_decay))
137+
if group['weight_decouple']:
138+
p.mul_(1.0 - group['weight_decay'] * (1.0 if group['fixed_decay'] else group['lr']))
120139
elif weight_decay > 0.0:
121140
grad.add_(p, alpha=weight_decay)
122141

123-
state['step'] += 1
124142
exp_avg, exp_avg_var = state['exp_avg'], state['exp_avg_var']
125143

126-
bias_correction1 = 1.0 - beta1 ** state['step']
127-
bias_correction2_sq = math.sqrt(1.0 - beta2 ** state['step'])
144+
s_grad = grad
145+
if group['adanorm']:
146+
grad_norm = torch.linalg.norm(grad)
128147

129-
exp_avg.mul_(beta1).add_(grad, alpha=1.0 - beta1)
130-
grad_residual = grad - exp_avg
131-
exp_avg_var.mul_(beta2).addcmul_(grad_residual, grad_residual, value=1.0 - beta2).add_(group['eps'])
148+
exp_grad_norm = state['exp_grad_norm']
149+
exp_grad_norm.mul_(group['r']).add_(grad_norm, alpha=1.0 - group['r'])
150+
151+
if exp_grad_norm > grad_norm:
152+
s_grad *= exp_grad_norm / grad_norm
153+
154+
exp_avg.mul_(beta1).add_(s_grad, alpha=1.0 - beta1)
155+
grad_residual = s_grad - exp_avg
156+
exp_avg_var.mul_(beta2).addcmul_(grad_residual, grad_residual, value=1.0 - beta2).add_(self.eps)
132157

133158
if group['amsgrad']:
134159
max_exp_avg_var = state['max_exp_avg_var']
135160
torch.max(max_exp_avg_var, exp_avg_var, out=max_exp_avg_var)
136-
de_nom = max_exp_avg_var.sqrt()
161+
de_nom = max_exp_avg_var.add(self.eps).sqrt()
137162
else:
138-
de_nom = exp_avg_var.sqrt()
139-
de_nom.div_(bias_correction2_sq).add_(group['eps'])
163+
de_nom = exp_avg_var.add(self.eps).sqrt()
164+
165+
de_nom.div_(bias_correction2_sq).add_(self.eps)
140166

141-
if not self.rectify:
142-
step_size: float = group['lr'] if group['adamd_debias_term'] else group['lr'] / bias_correction1
167+
if not group['rectify']:
168+
step_size: float = group['lr'] if group['adam_debias'] else group['lr'] / bias_correction1
143169
p.addcdiv_(exp_avg, de_nom, value=-step_size)
144170
continue
145171

146-
buffered = group['buffer'][state['step'] % 10]
147-
if state['step'] == buffered[0]:
148-
n_sma, step_size = buffered[1], buffered[2]
172+
if n_sma >= self.n_sma_threshold:
173+
step_size = math.sqrt(
174+
(1 - beta2_t)
175+
* (n_sma - 4)
176+
/ (n_sma_max - 4)
177+
* (n_sma - 2)
178+
/ n_sma
179+
* n_sma_max
180+
/ (n_sma_max - 2)
181+
)
182+
elif self.degenerated_to_sgd:
183+
step_size = 1.0
149184
else:
150-
buffered[0] = state['step']
151-
beta2_t = beta2 ** state['step']
152-
n_sma = n_sma_max - 2 * state['step'] * beta2_t / (1 - beta2_t)
153-
buffered[1] = n_sma
154-
155-
if n_sma >= self.n_sma_threshold:
156-
step_size = math.sqrt(
157-
(1 - beta2_t)
158-
* (n_sma - 4)
159-
/ (n_sma_max - 4)
160-
* (n_sma - 2)
161-
/ n_sma
162-
* n_sma_max
163-
/ (n_sma_max - 2)
164-
)
165-
if not group['adamd_debias_term']:
166-
step_size /= bias_correction1
167-
elif self.degenerated_to_sgd:
168-
step_size = 1.0 / bias_correction1
169-
else:
170-
step_size = -1
171-
172-
buffered[2] = step_size
185+
step_size = -1
186+
187+
if not group['adam_debias']:
188+
step_size /= bias_correction1
173189

174190
if n_sma >= self.n_sma_threshold:
175-
de_nom = exp_avg_var.sqrt().add_(group['eps'])
191+
de_nom = exp_avg_var.sqrt().add_(self.eps)
176192
p.addcdiv_(exp_avg, de_nom, value=-step_size * group['lr'])
177193
elif step_size > 0:
178194
p.add_(exp_avg, alpha=-step_size * group['lr'])

0 commit comments

Comments
 (0)