Skip to content

Commit 47f7261

Browse files
authored
Merge pull request #69 from kozistr/feature/adan
[Feature] Implement Adan optimizer
2 parents f24b3fe + 740945b commit 47f7261

File tree

8 files changed

+158
-8
lines changed

8 files changed

+158
-8
lines changed

README.rst

Lines changed: 22 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -87,6 +87,8 @@ Supported Optimizers
8787
+--------------+----------------------------------------------------------------------------------------+-----------------------------------------------------------------------------------+-----------------------------------------------------------------------------------------------+
8888
| Nero | *Learning by Turning: Neural Architecture Aware Optimisation* | `github <https://github.com/jxbz/nero>`__ | `https://arxiv.org/abs/2102.07227 <https://arxiv.org/abs/2102.07227>`__ |
8989
+--------------+----------------------------------------------------------------------------------------+-----------------------------------------------------------------------------------+-----------------------------------------------------------------------------------------------+
90+
| Adan | * Adaptive Nesterov Momentum Algorithm for Faster Optimizing Deep Models* | `github <https://github.com/sail-sg/Adan>`__ | `https://arxiv.org/abs/2208.06677 <https://arxiv.org/abs/2208.06677>`__ |
91+
+--------------+----------------------------------------------------------------------------------------+-----------------------------------------------------------------------------------+-----------------------------------------------------------------------------------------------+
9092

9193
Useful Resources
9294
----------------
@@ -504,6 +506,26 @@ Nero: Learning by Turning: Neural Architecture Aware Optimisation
504506
eprint={arXiv:2102.07227}
505507
}
506508

509+
Adan: Adaptive Nesterov Momentum Algorithm for Faster Optimizing Deep Models
510+
511+
::
512+
513+
@ARTICLE{2022arXiv220806677X,
514+
author = {{Xie}, Xingyu and {Zhou}, Pan and {Li}, Huan and {Lin}, Zhouchen and {Yan}, Shuicheng},
515+
title = "{Adan: Adaptive Nesterov Momentum Algorithm for Faster Optimizing Deep Models}",
516+
journal = {arXiv e-prints},
517+
keywords = {Computer Science - Machine Learning, Mathematics - Optimization and Control},
518+
year = 2022,
519+
month = aug,
520+
eid = {arXiv:2208.06677},
521+
pages = {arXiv:2208.06677},
522+
archivePrefix = {arXiv},
523+
eprint = {2208.06677},
524+
primaryClass = {cs.LG},
525+
adsurl = {https://ui.adsabs.harvard.edu/abs/2022arXiv220806677X},
526+
adsnote = {Provided by the SAO/NASA Astrophysics Data System}
527+
}
528+
507529
Author
508530
------
509531

pyproject.toml

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,6 @@
11
[tool.poetry]
22
name = "pytorch_optimizer"
3-
version = "1.1.4"
3+
version = "1.2.0"
44
description = "Bunch of optimizer implementations in PyTorch with clean-code, strict types. Also, including useful optimization ideas."
55
license = "Apache-2.0"
66
authors = ["kozistr <[email protected]>"]

pytorch_optimizer/__init__.py

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -1,11 +1,10 @@
11
# pylint: disable=unused-import
22
from typing import Callable, Dict, List
33

4-
from torch.optim import Optimizer
5-
64
from pytorch_optimizer.adabelief import AdaBelief
75
from pytorch_optimizer.adabound import AdaBound
86
from pytorch_optimizer.adamp import AdamP
7+
from pytorch_optimizer.adan import Adan
98
from pytorch_optimizer.adapnm import AdaPNM
109
from pytorch_optimizer.agc import agc
1110
from pytorch_optimizer.chebyshev_schedule import get_chebyshev_schedule
@@ -39,6 +38,7 @@
3938
AdaBelief,
4039
AdaBound,
4140
AdamP,
41+
Adan,
4242
AdaPNM,
4343
DiffGrad,
4444
DiffRGrad,
@@ -54,7 +54,7 @@
5454
SGDP,
5555
Shampoo,
5656
]
57-
OPTIMIZERS: Dict[str, Optimizer] = {str(optimizer.__name__).lower(): optimizer for optimizer in OPTIMIZER_LIST}
57+
OPTIMIZERS: Dict[str, Callable] = {str(optimizer.__name__).lower(): optimizer for optimizer in OPTIMIZER_LIST}
5858

5959

6060
def load_optimizer(optimizer: str) -> Callable:

pytorch_optimizer/adan.py

Lines changed: 121 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,121 @@
1+
import torch
2+
from torch.optim.optimizer import Optimizer
3+
4+
from pytorch_optimizer.base_optimizer import BaseOptimizer
5+
from pytorch_optimizer.gc import centralize_gradient
6+
from pytorch_optimizer.types import BETAS, CLOSURE, DEFAULTS, LOSS, PARAMETERS
7+
8+
9+
class Adan(Optimizer, BaseOptimizer):
10+
"""
11+
Reference : x
12+
Example :
13+
from pytorch_optimizer import Adan
14+
...
15+
model = YourModel()
16+
optimizer = Adan(model.parameters())
17+
...
18+
for input, output in data:
19+
optimizer.zero_grad()
20+
loss = loss_function(output, model(input))
21+
loss.backward()
22+
optimizer.step()
23+
"""
24+
25+
def __init__(
26+
self,
27+
params: PARAMETERS,
28+
lr: float = 1e-3,
29+
betas: BETAS = (0.98, 0.92, 0.99),
30+
weight_decay: float = 0.02,
31+
use_gc: bool = False,
32+
eps: float = 1e-16,
33+
):
34+
"""Adan optimizer
35+
:param params: PARAMETERS. iterable of parameters to optimize or dicts defining parameter groups
36+
:param lr: float. learning rate
37+
:param betas: BETAS. coefficients used for computing running averages of gradient and the squared hessian trace
38+
:param weight_decay: float. weight decay (L2 penalty)
39+
:param use_gc: bool. use gradient centralization
40+
:param eps: float. term added to the denominator to improve numerical stability
41+
"""
42+
self.lr = lr
43+
self.betas = betas
44+
self.weight_decay = weight_decay
45+
self.use_gc = use_gc
46+
self.eps = eps
47+
48+
self.validate_parameters()
49+
50+
defaults: DEFAULTS = dict(
51+
lr=lr,
52+
betas=betas,
53+
eps=eps,
54+
weight_decay=weight_decay,
55+
)
56+
super().__init__(params, defaults)
57+
58+
def validate_parameters(self):
59+
self.validate_learning_rate(self.lr)
60+
self.validate_betas(self.betas)
61+
self.validate_weight_decay(self.weight_decay)
62+
self.validate_epsilon(self.eps)
63+
64+
@torch.no_grad()
65+
def reset(self):
66+
for group in self.param_groups:
67+
for p in group['params']:
68+
state = self.state[p]
69+
70+
state['step'] = 0
71+
state['exp_avg'] = torch.zeros_like(p)
72+
state['exp_avg_var'] = torch.zeros_like(p)
73+
state['exp_avg_nest'] = torch.zeros_like(p)
74+
state['previous_grad'] = torch.zeros_like(p)
75+
76+
@torch.no_grad()
77+
def step(self, closure: CLOSURE = None) -> LOSS:
78+
loss: LOSS = None
79+
if closure is not None:
80+
with torch.enable_grad():
81+
loss = closure()
82+
83+
for group in self.param_groups:
84+
for p in group['params']:
85+
if p.grad is None:
86+
continue
87+
88+
grad = p.grad
89+
if grad.is_sparse:
90+
raise RuntimeError('Adan does not support sparse gradients')
91+
92+
state = self.state[p]
93+
if len(state) == 0:
94+
state['step'] = 0
95+
state['exp_avg'] = torch.zeros_like(p)
96+
state['exp_avg_var'] = torch.zeros_like(p)
97+
state['exp_avg_nest'] = torch.zeros_like(p)
98+
state['previous_grad'] = torch.zeros_like(p)
99+
100+
exp_avg, exp_avg_var, exp_avg_nest = state['exp_avg'], state['exp_avg_var'], state['exp_avg_nest']
101+
prev_grad = state['previous_grad']
102+
103+
state['step'] += 1
104+
beta1, beta2, beta3 = group['betas']
105+
106+
if self.use_gc:
107+
grad = centralize_gradient(grad, gc_conv_only=False)
108+
109+
grad_diff = grad - prev_grad
110+
state['previous_grad'] = grad.clone()
111+
112+
exp_avg.mul_(beta1).add_(grad, alpha=1.0 - beta1)
113+
exp_avg_var.mul_(beta2).add_(grad_diff, alpha=1.0 - beta2)
114+
exp_avg_nest.mul_(beta3).add_((grad + beta2 * grad_diff) ** 2, alpha=1.0 - beta3)
115+
116+
step_size = group['lr'] / exp_avg_nest.add_(self.eps).sqrt_()
117+
118+
p.sub_(step_size * (exp_avg + beta2 * exp_avg_var))
119+
p.div_(1.0 + group['weight_decay'])
120+
121+
return loss

tests/test_load_optimizers.py

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -6,6 +6,7 @@
66

77
VALID_OPTIMIZER_NAMES: List[str] = [
88
'adamp',
9+
'adan',
910
'sgdp',
1011
'madgrad',
1112
'ranger',
@@ -46,4 +47,4 @@ def test_load_optimizers_invalid(invalid_optimizer_names):
4647

4748

4849
def test_get_supported_optimizers():
49-
assert len(get_supported_optimizers()) == 17
50+
assert len(get_supported_optimizers()) == 18

tests/test_optimizer_parameters.py

Lines changed: 5 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -8,6 +8,7 @@
88

99
OPTIMIZER_NAMES: List[str] = [
1010
'adamp',
11+
'adan',
1112
'sgdp',
1213
'madgrad',
1314
'ranger',
@@ -37,6 +38,7 @@
3738
'ralamb',
3839
'pnm',
3940
'adapnm',
41+
'adan',
4042
]
4143

4244

@@ -122,16 +124,16 @@ def test_betas(optimizer_name):
122124
with pytest.raises(ValueError):
123125
if optimizer_name == 'ranger21':
124126
optimizer(None, num_iterations=100, betas=(-0.1, 0.1))
125-
else:
127+
elif optimizer not in ('adapnm', 'adan'):
126128
optimizer(None, betas=(-0.1, 0.1))
127129

128130
with pytest.raises(ValueError):
129131
if optimizer_name == 'ranger21':
130132
optimizer(None, num_iterations=100, betas=(0.1, -0.1))
131-
else:
133+
elif optimizer not in ('adapnm', 'adan'):
132134
optimizer(None, betas=(0.1, -0.1))
133135

134-
if optimizer_name == 'adapnm':
136+
if optimizer_name in ('adapnm', 'adan'):
135137
with pytest.raises(ValueError):
136138
optimizer(None, betas=(0.1, 0.1, -0.1))
137139

tests/test_optimizers.py

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -14,6 +14,7 @@
1414
AdaBelief,
1515
AdaBound,
1616
AdamP,
17+
Adan,
1718
AdaPNM,
1819
DiffGrad,
1920
DiffRGrad,
@@ -79,6 +80,8 @@
7980
(AdaPNM, {'lr': 3e-1, 'weight_decay': 1e-3, 'amsgrad': False}, 500),
8081
(Nero, {'lr': 5e-1}, 200),
8182
(Nero, {'lr': 5e-1, 'constraints': False}, 200),
83+
(Adan, {'lr': 2e-1}, 200),
84+
(Adan, {'lr': 1e-0, 'weight_decay': 1e-3, 'use_gc': True}, 500),
8285
]
8386

8487
ADAMD_SUPPORTED_OPTIMIZERS: List[Tuple[Any, Dict[str, Union[float, bool, int]], int]] = [

tests/test_sparse_gradient.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -25,6 +25,7 @@
2525
'lars',
2626
'shampoo',
2727
'nero',
28+
'adan',
2829
]
2930

3031

0 commit comments

Comments
 (0)