Skip to content

Commit 07dd145

Browse files
authored
Merge pull request #89 from kozistr/feature/adai-optimizer
[Feature] Implement `Adai` optimizer
2 parents 1839482 + 0572462 commit 07dd145

File tree

10 files changed

+210
-22
lines changed

10 files changed

+210
-22
lines changed

README.rst

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -110,6 +110,8 @@ Supported Optimizers
110110
+--------------+----------------------------------------------------------------------------------------+-----------------------------------------------------------------------------------+-----------------------------------------------------------------------------------------------+
111111
| Adan | *Adaptive Nesterov Momentum Algorithm for Faster Optimizing Deep Models* | `github <https://github.com/sail-sg/Adan>`__ | `https://arxiv.org/abs/2208.06677 <https://arxiv.org/abs/2208.06677>`__ |
112112
+--------------+----------------------------------------------------------------------------------------+-----------------------------------------------------------------------------------+-----------------------------------------------------------------------------------------------+
113+
| Adai | *Disentangling the Effects of Adaptive Learning Rate and Momentum* | `github <https://github.com/zeke-xie/adaptive-inertia-adai>`__ | `https://arxiv.org/abs/2006.15815 <https://arxiv.org/abs/2006.15815>`__ |
114+
+--------------+----------------------------------------------------------------------------------------+-----------------------------------------------------------------------------------+-----------------------------------------------------------------------------------------------+
113115

114116
Useful Resources
115117
----------------
@@ -299,6 +301,8 @@ Citations
299301

300302
`Adan <https://ui.adsabs.harvard.edu/abs/2022arXiv220806677X/exportcitation>`__
301303

304+
`Adai <https://github.com/zeke-xie/adaptive-inertia-adai#citing>`__
305+
302306
Author
303307
------
304308

pyproject.toml

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,6 @@
11
[tool.poetry]
22
name = "pytorch_optimizer"
3-
version = "2.0.1"
3+
version = "2.1.0"
44
description = "Bunch of optimizer implementations in PyTorch with clean-code, strict types. Also, including useful optimization ideas."
55
license = "Apache-2.0"
66
authors = ["kozistr <[email protected]>"]

pytorch_optimizer/__init__.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -6,6 +6,7 @@
66
from pytorch_optimizer.lr_scheduler.cosine_anealing import CosineAnnealingWarmupRestarts
77
from pytorch_optimizer.optimizer.adabelief import AdaBelief
88
from pytorch_optimizer.optimizer.adabound import AdaBound
9+
from pytorch_optimizer.optimizer.adai import Adai
910
from pytorch_optimizer.optimizer.adamp import AdamP
1011
from pytorch_optimizer.optimizer.adan import Adan
1112
from pytorch_optimizer.optimizer.adapnm import AdaPNM
@@ -40,6 +41,7 @@
4041
AdaBelief,
4142
AdaBound,
4243
AdamP,
44+
Adai,
4345
Adan,
4446
AdaPNM,
4547
DiffGrad,
Lines changed: 160 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,160 @@
1+
import math
2+
3+
import torch
4+
from torch.optim.optimizer import Optimizer
5+
6+
from pytorch_optimizer.base.base_optimizer import BaseOptimizer
7+
from pytorch_optimizer.base.types import BETAS, CLOSURE, DEFAULTS, LOSS, PARAMETERS
8+
9+
10+
class Adai(Optimizer, BaseOptimizer):
11+
"""
12+
Reference : https://github.com/zeke-xie/adaptive-inertia-adai
13+
Example :
14+
from pytorch_optimizer import Adai
15+
...
16+
model = YourModel()
17+
optimizer = Adai(model.parameters())
18+
...
19+
for input, output in data:
20+
optimizer.zero_grad()
21+
loss = loss_function(output, model(input))
22+
loss.backward()
23+
optimizer.step()
24+
"""
25+
26+
def __init__(
27+
self,
28+
params: PARAMETERS,
29+
lr: float = 1e-3,
30+
betas: BETAS = (0.1, 0.99),
31+
weight_decay: float = 0.0,
32+
weight_decouple: bool = False,
33+
dampening: float = 1.0,
34+
eps: float = 1e-3,
35+
):
36+
"""Adai
37+
:param params: PARAMETERS. iterable of parameters to optimize or dicts defining parameter groups
38+
:param lr: float. learning rate
39+
:param betas: BETAS. coefficients used for computing running averages of gradient and the squared hessian trace
40+
:param weight_decay: float. weight decay (L2 penalty)
41+
:param weight_decouple: bool. the optimizer uses decoupled weight decay as in AdamW
42+
:param dampening: float. dampening for momentum. where dampening < 1,
43+
it will show some adaptive-moment behavior
44+
:param eps: float. term added to the denominator to improve numerical stability
45+
"""
46+
self.lr = lr
47+
self.betas = betas
48+
self.weight_decay = weight_decay
49+
self.weight_decouple = weight_decouple
50+
self.dampening = dampening
51+
self.eps = eps
52+
53+
self.validate_parameters()
54+
55+
defaults: DEFAULTS = dict(
56+
lr=lr,
57+
betas=betas,
58+
weight_decay=weight_decay,
59+
dampening=dampening,
60+
eps=eps,
61+
)
62+
super().__init__(params, defaults)
63+
64+
def validate_parameters(self):
65+
self.validate_learning_rate(self.lr)
66+
self.validate_betas(self.betas)
67+
self.validate_weight_decay(self.weight_decay)
68+
self.validate_epsilon(self.eps)
69+
70+
@torch.no_grad()
71+
def reset(self):
72+
for group in self.param_groups:
73+
for p in group['params']:
74+
state = self.state[p]
75+
76+
state['step'] = 0
77+
state['exp_avg'] = torch.zeros_like(p)
78+
state['exp_avg_sq'] = torch.zeros_like(p)
79+
state['beta1_prod'] = torch.ones_like(p)
80+
81+
@torch.no_grad()
82+
def step(self, closure: CLOSURE = None) -> LOSS:
83+
loss: LOSS = None
84+
if closure is not None:
85+
with torch.enable_grad():
86+
loss = closure()
87+
88+
param_size: int = 0
89+
exp_avg_sq_hat_sum: float = 0.0
90+
91+
for group in self.param_groups:
92+
for p in group['params']:
93+
if p.grad is None:
94+
continue
95+
96+
grad = p.grad
97+
if grad.is_sparse:
98+
raise RuntimeError('Adai does not support sparse gradients')
99+
100+
param_size += p.numel()
101+
102+
state = self.state[p]
103+
104+
if len(state) == 0:
105+
state['step'] = 0
106+
state['exp_avg'] = torch.zeros_like(p)
107+
state['exp_avg_sq'] = torch.zeros_like(p)
108+
state['beta1_prod'] = torch.ones_like(p)
109+
110+
state['step'] += 1
111+
112+
exp_avg_sq = state['exp_avg_sq']
113+
_, beta2 = group['betas']
114+
115+
bias_correction2 = 1.0 - beta2 ** state['step']
116+
117+
if group['weight_decay'] != 0:
118+
if self.weight_decouple:
119+
p.mul_(1.0 - group['lr'] * group['weight_decay'])
120+
else:
121+
grad.add_(p, alpha=group['weight_decay'])
122+
123+
exp_avg_sq.mul_(beta2).addcmul_(grad, grad, value=1.0 - beta2)
124+
125+
exp_avg_sq_hat_sum += exp_avg_sq.sum() / bias_correction2
126+
127+
if param_size == 0:
128+
raise ValueError('[-] param_size is 0')
129+
130+
exp_avg_sq_hat_mean = exp_avg_sq_hat_sum / param_size
131+
132+
for group in self.param_groups:
133+
for p in group['params']:
134+
if p.grad is None:
135+
continue
136+
137+
grad = p.grad
138+
state = self.state[p]
139+
140+
exp_avg, exp_avg_sq = state['exp_avg'], state['exp_avg_sq']
141+
beta1_prod = state['beta1_prod']
142+
beta0, beta2 = group['betas']
143+
144+
bias_correction2 = 1 - beta2 ** state['step']
145+
146+
exp_avg_sq_hat = exp_avg_sq / bias_correction2
147+
beta1 = (
148+
1.0 - (exp_avg_sq_hat / exp_avg_sq_hat_mean).pow(1.0 / (3 - 2 * group['dampening'])).mul(beta0)
149+
).clamp(0.0, 1 - group['eps'])
150+
beta3 = (1.0 - beta1).pow(group['dampening'])
151+
152+
beta1_prod.mul_(beta1)
153+
bias_correction1 = 1.0 - beta1_prod
154+
155+
exp_avg.mul_(beta1).addcmul_(beta3, grad)
156+
exp_avg_hat = exp_avg / bias_correction1 * math.pow(beta0, 1.0 - group['dampening'])
157+
158+
p.add_(exp_avg_hat, alpha=-group['lr'])
159+
160+
return loss

pytorch_optimizer/optimizer/ranger21.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -236,7 +236,7 @@ def step(self, closure: CLOSURE = None) -> LOSS:
236236

237237
# stable weight decay
238238
if param_size == 0:
239-
raise ValueError('[-] size of parameter is 0')
239+
raise ValueError('[-] param_size is 0')
240240

241241
variance_normalized = math.sqrt(variance_ma_sum / param_size)
242242
if math.isnan(variance_normalized):

tests/test_load_optimizers.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -23,8 +23,8 @@
2323
'pnm',
2424
'adapnm',
2525
'nero',
26+
'adai',
2627
]
27-
2828
INVALID_OPTIMIZER_NAMES: List[str] = [
2929
'asam',
3030
'sam',
@@ -47,4 +47,4 @@ def test_load_optimizers_invalid(invalid_optimizer_names):
4747

4848

4949
def test_get_supported_optimizers():
50-
assert len(get_supported_optimizers()) == 18
50+
assert len(get_supported_optimizers()) == 19

tests/test_optimizer_parameters.py

Lines changed: 6 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -5,7 +5,7 @@
55
from torch import nn
66
from torch.nn import functional as F
77

8-
from pytorch_optimizer import SAM, AdamP, Lookahead, PCGrad, Ranger21, SafeFP16Optimizer, load_optimizer
8+
from pytorch_optimizer import SAM, Adai, AdamP, Lookahead, PCGrad, Ranger21, SafeFP16Optimizer, load_optimizer
99
from tests.utils import Example
1010

1111
OPTIMIZER_NAMES: List[str] = [
@@ -25,8 +25,8 @@
2525
'lars',
2626
'pnm',
2727
'adapnm',
28+
'adai',
2829
]
29-
3030
BETA_OPTIMIZER_NAMES: List[str] = [
3131
'adabelief',
3232
'adabound',
@@ -41,6 +41,7 @@
4141
'pnm',
4242
'adapnm',
4343
'adan',
44+
'adai',
4445
]
4546

4647

@@ -213,12 +214,13 @@ def test_ranger21_warm_methods():
213214
assert Ranger21.build_warm_down_iterations(1000) == 280
214215

215216

216-
def test_ranger21_size_of_parameter():
217+
@pytest.mark.parametrize('optimizer', [Ranger21, Adai])
218+
def test_size_of_parameter(optimizer):
217219
model: nn.Module = nn.Linear(1, 1, bias=False)
218220
model.requires_grad_(False)
219221

220222
with pytest.raises(ValueError):
221-
Ranger21(model.parameters(), 100).step()
223+
optimizer(model.parameters(), 100).step()
222224

223225

224226
def test_ranger21_closure():

tests/test_optimizers.py

Lines changed: 29 additions & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -13,6 +13,7 @@
1313
SGDP,
1414
AdaBelief,
1515
AdaBound,
16+
Adai,
1617
AdamP,
1718
Adan,
1819
AdaPNM,
@@ -36,6 +37,7 @@
3637
dummy_closure,
3738
ids,
3839
make_dataset,
40+
names,
3941
tensor_to_numpy,
4042
)
4143

@@ -50,6 +52,10 @@
5052
(AdaBound, {'lr': 5e-1, 'gamma': 0.1, 'weight_decay': 1e-3, 'fixed_decay': True}, 100),
5153
(AdaBound, {'lr': 5e-1, 'gamma': 0.1, 'weight_decay': 1e-3, 'weight_decouple': False}, 100),
5254
(AdaBound, {'lr': 5e-1, 'gamma': 0.1, 'weight_decay': 1e-3, 'amsbound': True}, 100),
55+
(Adai, {'lr': 1e-1, 'weight_decay': 0.0}, 200),
56+
(Adai, {'lr': 1e-1, 'weight_decay': 0.0, 'dampening': 0.9}, 200),
57+
(Adai, {'lr': 1e-1, 'weight_decay': 1e-4, 'weight_decouple': False}, 200),
58+
(Adai, {'lr': 1e-1, 'weight_decay': 1e-4, 'weight_decouple': True}, 200),
5359
(AdamP, {'lr': 5e-1, 'weight_decay': 1e-3}, 100),
5460
(AdamP, {'lr': 5e-1, 'weight_decay': 1e-3, 'use_gc': True}, 100),
5561
(AdamP, {'lr': 5e-1, 'weight_decay': 1e-3, 'nesterov': True}, 100),
@@ -84,7 +90,6 @@
8490
(Adan, {'lr': 1e-0, 'weight_decay': 1e-3, 'use_gc': True}, 100),
8591
(Adan, {'lr': 1e-0, 'weight_decay': 1e-3, 'use_gc': True, 'weight_decouple': True}, 100),
8692
]
87-
8893
ADAMD_SUPPORTED_OPTIMIZERS: List[Tuple[Any, Dict[str, Union[float, bool, int]], int]] = [
8994
(build_lookahead, {'lr': 5e-1, 'weight_decay': 1e-3, 'adamd_debias_term': True}, 100),
9095
(AdaBelief, {'lr': 5e-1, 'weight_decay': 1e-3, 'adamd_debias_term': True}, 100),
@@ -167,6 +172,7 @@ def test_safe_f16_optimizers(optimizer_fp16_config):
167172
or (optimizer_name == 'Nero')
168173
or (optimizer_name == 'Adan' and 'weight_decay' not in config)
169174
or (optimizer_name == 'RAdam')
175+
or (optimizer_name == 'Adai')
170176
):
171177
pytest.skip(f'skip {optimizer_name}')
172178

@@ -195,8 +201,10 @@ def test_sam_optimizers(adaptive, optimizer_sam_config):
195201
(x_data, y_data), model, loss_fn = build_environment()
196202

197203
optimizer_class, config, iterations = optimizer_sam_config
198-
if optimizer_class.__name__ == 'Shampoo':
199-
pytest.skip(f'skip {optimizer_class.__name__}')
204+
205+
optimizer_name: str = optimizer_class.__name__
206+
if (optimizer_name == 'Shampoo') or (optimizer_name == 'Adai'):
207+
pytest.skip(f'skip {optimizer_name}')
200208

201209
optimizer = SAM(model.parameters(), optimizer_class, **config, adaptive=adaptive)
202210

@@ -221,8 +229,10 @@ def test_sam_optimizers_with_closure(adaptive, optimizer_sam_config):
221229
(x_data, y_data), model, loss_fn = build_environment()
222230

223231
optimizer_class, config, iterations = optimizer_sam_config
224-
if optimizer_class.__name__ == 'Shampoo':
225-
pytest.skip(f'skip {optimizer_class.__name__}')
232+
233+
optimizer_name: str = optimizer_class.__name__
234+
if (optimizer_name == 'Shampoo') or (optimizer_name == 'Adai'):
235+
pytest.skip(f'skip {optimizer_name}')
226236

227237
optimizer = SAM(model.parameters(), optimizer_class, **config, adaptive=adaptive)
228238

@@ -335,26 +345,31 @@ def test_no_gradients(optimizer_config):
335345
assert tensor_to_numpy(init_loss) >= tensor_to_numpy(loss)
336346

337347

338-
@pytest.mark.parametrize('optimizer_config', OPTIMIZERS, ids=ids)
339-
def test_closure(optimizer_config):
348+
@pytest.mark.parametrize('optimizer', set(config[0] for config in OPTIMIZERS), ids=names)
349+
def test_closure(optimizer):
340350
_, model, _ = build_environment()
341351

342-
optimizer_class, config, _ = optimizer_config
343-
if optimizer_class.__name__ == 'Ranger21':
344-
pytest.skip(f'skip {optimizer_class.__name__}')
345-
346-
optimizer = optimizer_class(model.parameters(), **config)
352+
if optimizer.__name__ == 'Ranger21':
353+
optimizer = optimizer(model.parameters(), num_iterations=1)
354+
else:
355+
optimizer = optimizer(model.parameters())
347356

348357
optimizer.zero_grad()
349-
optimizer.step(closure=dummy_closure)
358+
359+
try:
360+
optimizer.step(closure=dummy_closure)
361+
except ValueError: # in case of Ranger21, Adai optimizers
362+
pass
350363

351364

352365
@pytest.mark.parametrize('optimizer_config', OPTIMIZERS, ids=ids)
353366
def test_reset(optimizer_config):
354367
_, model, _ = build_environment()
355368

356369
optimizer_class, config, _ = optimizer_config
357-
optimizer = optimizer_class(model.parameters(), **config)
370+
if optimizer_class.__name__ == 'Ranger21':
371+
config.update({'num_iterations': 1})
358372

373+
optimizer = optimizer_class(model.parameters(), **config)
359374
optimizer.zero_grad()
360375
optimizer.reset()

tests/test_sparse_gradient.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -25,6 +25,7 @@
2525
'shampoo',
2626
'nero',
2727
'adan',
28+
'adai',
2829
]
2930

3031

tests/utils.py

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -78,6 +78,10 @@ def ids(v) -> str:
7878
return f'{v[0].__name__}_{v[1:]}'
7979

8080

81+
def names(v) -> str:
82+
return v.__name__
83+
84+
8185
def build_environment(use_gpu: bool = False) -> Tuple[Tuple[torch.Tensor, torch.Tensor], nn.Module, nn.Module]:
8286
torch.manual_seed(42)
8387

0 commit comments

Comments
 (0)