Skip to content

Commit f51dead

Browse files
authored
Merge pull request #77 from kozistr/feature/adan-optimizer
[Feature, Fix] Adan optimizer
2 parents 34fd10b + 504f13c commit f51dead

File tree

3 files changed

+35
-16
lines changed

3 files changed

+35
-16
lines changed

pyproject.toml

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,6 @@
11
[tool.poetry]
22
name = "pytorch_optimizer"
3-
version = "1.3.1"
3+
version = "1.3.2"
44
description = "Bunch of optimizer implementations in PyTorch with clean-code, strict types. Also, including useful optimization ideas."
55
license = "Apache-2.0"
66
authors = ["kozistr <[email protected]>"]

pytorch_optimizer/adan.py

Lines changed: 30 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -1,3 +1,5 @@
1+
import math
2+
13
import torch
24
from torch.optim.optimizer import Optimizer
35

@@ -8,7 +10,7 @@
810

911
class Adan(Optimizer, BaseOptimizer):
1012
"""
11-
Reference : x
13+
Reference : https://github.com/sail-sg/Adan/blob/main/adan.py
1214
Example :
1315
from pytorch_optimizer import Adan
1416
...
@@ -27,21 +29,24 @@ def __init__(
2729
params: PARAMETERS,
2830
lr: float = 1e-3,
2931
betas: BETAS = (0.98, 0.92, 0.99),
30-
weight_decay: float = 0.02,
32+
weight_decay: float = 0.0,
33+
weight_decouple: bool = False,
3134
use_gc: bool = False,
32-
eps: float = 1e-16,
35+
eps: float = 1e-8,
3336
):
3437
"""Adan optimizer
3538
:param params: PARAMETERS. iterable of parameters to optimize or dicts defining parameter groups
3639
:param lr: float. learning rate
3740
:param betas: BETAS. coefficients used for computing running averages of gradient and the squared hessian trace
3841
:param weight_decay: float. weight decay (L2 penalty)
42+
:param weight_decouple: bool. decoupled weight decay
3943
:param use_gc: bool. use gradient centralization
4044
:param eps: float. term added to the denominator to improve numerical stability
4145
"""
4246
self.lr = lr
4347
self.betas = betas
4448
self.weight_decay = weight_decay
49+
self.weight_decouple = weight_decouple
4550
self.use_gc = use_gc
4651
self.eps = eps
4752

@@ -52,6 +57,7 @@ def __init__(
5257
betas=betas,
5358
eps=eps,
5459
weight_decay=weight_decay,
60+
weight_decouple=weight_decouple,
5561
)
5662
super().__init__(params, defaults)
5763

@@ -69,7 +75,7 @@ def reset(self):
6975

7076
state['step'] = 0
7177
state['exp_avg'] = torch.zeros_like(p)
72-
state['exp_avg_var'] = torch.zeros_like(p)
78+
state['exp_avg_diff'] = torch.zeros_like(p)
7379
state['exp_avg_nest'] = torch.zeros_like(p)
7480
state['previous_grad'] = torch.zeros_like(p)
7581

@@ -93,29 +99,40 @@ def step(self, closure: CLOSURE = None) -> LOSS:
9399
if len(state) == 0:
94100
state['step'] = 0
95101
state['exp_avg'] = torch.zeros_like(p)
96-
state['exp_avg_var'] = torch.zeros_like(p)
102+
state['exp_avg_diff'] = torch.zeros_like(p)
97103
state['exp_avg_nest'] = torch.zeros_like(p)
98104
state['previous_grad'] = torch.zeros_like(p)
99105

100-
exp_avg, exp_avg_var, exp_avg_nest = state['exp_avg'], state['exp_avg_var'], state['exp_avg_nest']
106+
exp_avg, exp_avg_diff, exp_avg_nest = state['exp_avg'], state['exp_avg_diff'], state['exp_avg_nest']
101107
prev_grad = state['previous_grad']
102108

103109
state['step'] += 1
104110
beta1, beta2, beta3 = group['betas']
105111

112+
bias_correction1 = 1.0 - beta1 ** state['step']
113+
bias_correction2 = 1.0 - beta2 ** state['step']
114+
bias_correction3 = 1.0 - beta3 ** state['step']
115+
106116
if self.use_gc:
107117
grad = centralize_gradient(grad, gc_conv_only=False)
108118

109119
grad_diff = grad - prev_grad
110120
state['previous_grad'] = grad.clone()
111121

112-
exp_avg.mul_(beta1).add_(grad, alpha=1.0 - beta1)
113-
exp_avg_var.mul_(beta2).add_(grad_diff, alpha=1.0 - beta2)
114-
exp_avg_nest.mul_(beta3).add_((grad + beta2 * grad_diff) ** 2, alpha=1.0 - beta3)
122+
update = grad + beta2 * grad_diff
115123

116-
step_size = group['lr'] / exp_avg_nest.add_(self.eps).sqrt_()
117-
118-
p.sub_(step_size * (exp_avg + beta2 * exp_avg_var))
119-
p.div_(1.0 + group['weight_decay'])
124+
exp_avg.mul_(beta1).add_(grad, alpha=1.0 - beta1)
125+
exp_avg_diff.mul_(beta2).add_(grad_diff, alpha=1.0 - beta2)
126+
exp_avg_nest.mul_(beta3).addcmul_(update, update, value=1.0 - beta3)
127+
128+
de_nom = (exp_avg_nest.sqrt_() / math.sqrt(bias_correction3)).add_(self.eps)
129+
perturb = (exp_avg / bias_correction1 + beta2 * exp_avg_diff / bias_correction2).div_(de_nom)
130+
131+
if group['weight_decouple']:
132+
p.mul_(1.0 - group['lr'] * group['weight_decay'])
133+
p.add_(perturb, alpha=-group['lr'])
134+
else:
135+
p.add_(perturb, alpha=-group['lr'])
136+
p.div_(1.0 + group['lr'] * group['weight_decay'])
120137

121138
return loss

tests/test_optimizers.py

Lines changed: 4 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -80,8 +80,9 @@
8080
(AdaPNM, {'lr': 3e-1, 'weight_decay': 1e-3, 'amsgrad': False}, 500),
8181
(Nero, {'lr': 5e-1}, 200),
8282
(Nero, {'lr': 5e-1, 'constraints': False}, 200),
83-
(Adan, {'lr': 2e-1}, 200),
84-
(Adan, {'lr': 1e-0, 'weight_decay': 1e-3, 'use_gc': True}, 500),
83+
(Adan, {'lr': 5e-1}, 300),
84+
(Adan, {'lr': 1e-0, 'weight_decay': 1e-3, 'use_gc': True}, 300),
85+
(Adan, {'lr': 1e-0, 'weight_decay': 1e-3, 'use_gc': True, 'weight_decouple': True}, 300),
8586
]
8687

8788
ADAMD_SUPPORTED_OPTIMIZERS: List[Tuple[Any, Dict[str, Union[float, bool, int]], int]] = [
@@ -163,6 +164,7 @@ def test_safe_f16_optimizers(optimizer_fp16_config):
163164
or (optimizer_name == 'RaLamb' and 'pre_norm' in config)
164165
or (optimizer_name == 'PNM')
165166
or (optimizer_name == 'Nero')
167+
or (optimizer_name == 'Adan' and 'weight_decay' not in config)
166168
):
167169
return True
168170

0 commit comments

Comments
 (0)