Skip to content

Commit 48030b5

Browse files
authored
Merge pull request #229 from kozistr/feature/adalite-optimizer
[Feature] Implement Adalite optimizer
2 parents b1b5ed4 + 6abce12 commit 48030b5

File tree

10 files changed

+196
-9
lines changed

10 files changed

+196
-9
lines changed

README.md

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -10,7 +10,7 @@
1010

1111
**pytorch-optimizer** is optimizer & lr scheduler collections in PyTorch.
1212
I just re-implemented (speed & memory tweaks, plug-ins) the algorithm while based on the original paper. Also, It includes useful and practical optimization ideas.
13-
Currently, **63 optimizers (+ `bitsandbytes`)**, **11 lr schedulers**, and **13 loss functions** are supported!
13+
Currently, **64 optimizers (+ `bitsandbytes`)**, **11 lr schedulers**, and **13 loss functions** are supported!
1414

1515
Highly inspired by [pytorch-optimizer](https://github.com/jettify/pytorch-optimizer).
1616

@@ -161,6 +161,7 @@ supported_optimizers = get_supported_optimizers()
161161
| WSAM | *Sharpness-Aware Minimization Revisited: Weighted Sharpness as a Regularization Term* | [github](https://github.com/intelligent-machine-learning/dlrover/blob/master/atorch/atorch/optimizers/wsam.py) | <https://arxiv.org/abs/2305.15817> | [cite](https://github.com/intelligent-machine-learning/dlrover) |
162162
| Aida | *A DNN Optimizer that Improves over AdaBelief by Suppression of the Adaptive Stepsize Range* | [github](https://github.com/guoqiang-zhang-x/Aida-Optimizer) | <https://arxiv.org/abs/2203.13273> | [cite](https://github.com/guoqiang-zhang-x/Aida-Optimizer?tab=readme-ov-file#1-brief-description-of-aida) |
163163
| GaLore | *Memory-Efficient LLM Training by Gradient Low-Rank Projection* | [github](https://github.com/jiaweizzhao/GaLore) | <https://arxiv.org/abs/2403.03507> | [cite](https://github.com/jiaweizzhao/GaLore/tree/master?tab=readme-ov-file#citation) |
164+
| Adalite | *Adalite optimizer* | [github](https://github.com/VatsaDev/adalite) | <https://github.com/VatsaDev/adalite> | [cite](https://github.com/VatsaDev/adalite) |
164165

165166
## Supported LR Scheduler
166167

docs/changelogs/v3.0.0.md

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -12,6 +12,7 @@ Major version is updated! (`v2.12.0` -> `v3.0.0`) (#164)
1212
* [Sharpness-Aware Minimization Revisited: Weighted Sharpness as a Regularization Term](https://arxiv.org/abs/2305.15817)
1313
* Implement `GaLore` optimizer. (#224, #228)
1414
* [Memory-Efficient LLM Training by Gradient Low-Rank Projection](https://arxiv.org/abs/2403.03507)
15+
* Implement `Adalite` optimizer. (#225, #229)
1516

1617
### Fix
1718

docs/optimizer.md

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -28,6 +28,10 @@
2828
:docstring:
2929
:members:
3030

31+
::: pytorch_optimizer.Adalite
32+
:docstring:
33+
:members:
34+
3135
::: pytorch_optimizer.AdaMax
3236
:docstring:
3337
:members:

pyproject.toml

Lines changed: 8 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -11,13 +11,14 @@ repository = "https://github.com/kozistr/pytorch_optimizer"
1111
documentation = "https://pytorch-optimizers.readthedocs.io/en/latest"
1212
keywords = [
1313
"pytorch", "deep-learning", "optimizer", "lr scheduler", "A2Grad", "ASGD", "AccSGD", "AdaBelief", "AdaBound",
14-
"AdaDelta", "AdaFactor", "AdaMax", "AdaMod", "AdaNorm", "AdaPNM", "AdaSmooth", "AdaHessian", "Adai", "AdamP",
15-
"AdamS", "Adan", "AggMo", "Aida", "AliG", "Amos", "Apollo", "AvaGrad", "CAME", "DAdaptAdaGrad", "DAdaptAdam",
16-
"DAdaptAdan", "DAdaptSGD", "DAdaptLion", "DiffGrad", "Fromage", "GaLore", "Gravity", "GSAM", "LARS", "Lamb", "Lion",
17-
"LOMO", "Lookahead", "MADGRAD", "MSVAG", "Nero", "NovoGrad", "PAdam", "PCGrad", "PID", "PNM", "Prodigy", "QHAdam",
18-
"QHM", "RAdam", "Ranger", "Ranger21", "RotoGrad", "SAM", "SGDP", "Shampoo", "ScalableShampoo", "SGDW", "SignSGD",
19-
"SM3", "SopihaH", "SRMM", "SWATS", "Tiger", "WSAM", "Yogi", "BCE", "BCEFocal", "Focal", "FocalCosine", "SoftF1",
20-
"Dice", "LDAM", "Jaccard", "Bi-Tempered", "Tversky", "FocalTversky", "LovaszHinge", "bitsandbytes",
14+
"AdaDelta", "AdaFactor", "AdaMax", "AdaMod", "AdaNorm", "AdaPNM", "AdaSmooth", "AdaHessian", "Adai", "Adalite",
15+
"AdamP", "AdamS", "Adan", "AggMo", "Aida", "AliG", "Amos", "Apollo", "AvaGrad", "CAME", "DAdaptAdaGrad",
16+
"DAdaptAdam", "DAdaptAdan", "DAdaptSGD", "DAdaptLion", "DiffGrad", "Fromage", "GaLore", "Gravity", "GSAM", "LARS",
17+
"Lamb", "Lion", "LOMO", "Lookahead", "MADGRAD", "MSVAG", "Nero", "NovoGrad", "PAdam", "PCGrad", "PID", "PNM",
18+
"Prodigy", "QHAdam", "QHM", "RAdam", "Ranger", "Ranger21", "RotoGrad", "SAM", "SGDP", "Shampoo", "ScalableShampoo",
19+
"SGDW", "SignSGD", "SM3", "SopihaH", "SRMM", "SWATS", "Tiger", "WSAM", "Yogi", "BCE", "BCEFocal", "Focal",
20+
"FocalCosine", "SoftF1", "Dice", "LDAM", "Jaccard", "Bi-Tempered", "Tversky", "FocalTversky", "LovaszHinge",
21+
"bitsandbytes",
2122
]
2223
classifiers = [
2324
"License :: OSI Approved :: Apache Software License",

pytorch_optimizer/__init__.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -34,6 +34,7 @@
3434
from pytorch_optimizer.optimizer.adafactor import AdaFactor
3535
from pytorch_optimizer.optimizer.adahessian import AdaHessian
3636
from pytorch_optimizer.optimizer.adai import Adai
37+
from pytorch_optimizer.optimizer.adalite import Adalite
3738
from pytorch_optimizer.optimizer.adamax import AdaMax
3839
from pytorch_optimizer.optimizer.adamod import AdaMod
3940
from pytorch_optimizer.optimizer.adamp import AdamP
@@ -184,6 +185,7 @@
184185
DAdaptLion,
185186
Aida,
186187
GaLore,
188+
Adalite,
187189
]
188190
OPTIMIZERS: Dict[str, OPTIMIZER] = {str(optimizer.__name__).lower(): optimizer for optimizer in OPTIMIZER_LIST}
189191

Lines changed: 168 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,168 @@
1+
import torch
2+
from torch.nn.functional import softmax
3+
from torch.optim.optimizer import Optimizer
4+
5+
from pytorch_optimizer.base.exception import NoSparseGradientError
6+
from pytorch_optimizer.base.optimizer import BaseOptimizer
7+
from pytorch_optimizer.base.types import BETAS, CLOSURE, DEFAULTS, LOSS, PARAMETERS
8+
9+
10+
class Adalite(Optimizer, BaseOptimizer):
11+
r"""Adalite optimizer.
12+
13+
:param params: PARAMETERS. iterable of parameters to optimize or dicts defining parameter groups.
14+
:param lr: float. learning rate.
15+
:param betas: BETAS. coefficients used for computing running averages of gradient and the squared hessian trace.
16+
:param weight_decay: float. weight decay (L2 penalty).
17+
:param weight_decouple: bool. the optimizer uses decoupled weight decay as in AdamW.
18+
:param fixed_decay: bool. fix weight decay.
19+
:param g_norm_min: float.
20+
:param ratio_min: float.
21+
:param tau: float.
22+
:param eps1: float. term added to the denominator to improve numerical stability.
23+
:param eps2: float. term added to the denominator to improve numerical stability.
24+
"""
25+
26+
def __init__(
27+
self,
28+
params: PARAMETERS,
29+
lr: float = 1e-3,
30+
betas: BETAS = (0.9, 0.999),
31+
weight_decay: float = 1e-2,
32+
weight_decouple: bool = False,
33+
fixed_decay: bool = False,
34+
g_norm_min: float = 1e-10,
35+
ratio_min: float = 1e-4,
36+
tau: float = 1.0,
37+
eps1: float = 1e-6,
38+
eps2: float = 1e-10,
39+
):
40+
self.validate_learning_rate(lr)
41+
self.validate_betas(betas)
42+
self.validate_non_negative(weight_decay, 'weight_decay')
43+
self.validate_non_negative(eps1, 'eps1')
44+
self.validate_non_negative(eps2, 'eps1')
45+
46+
defaults: DEFAULTS = {
47+
'lr': lr,
48+
'betas': betas,
49+
'weight_decay': weight_decay,
50+
'weight_decouple': weight_decouple,
51+
'fixed_decay': fixed_decay,
52+
'g_norm_min': g_norm_min,
53+
'ratio_min': ratio_min,
54+
'tau': tau,
55+
'eps1': eps1,
56+
'eps2': eps2,
57+
}
58+
super().__init__(params, defaults)
59+
60+
def __str__(self) -> str:
61+
return 'Adalite'
62+
63+
@torch.no_grad()
64+
def reset(self):
65+
for group in self.param_groups:
66+
group['step'] = 0
67+
for p in group['params']:
68+
state = self.state[p]
69+
70+
if len(p.shape) < 2:
71+
state['m_avg'] = torch.zeros_like(p)
72+
state['v_avg'] = torch.zeros_like(p)
73+
else:
74+
state['v_avg_0'] = torch.zeros_like(p.mean(dim=1))
75+
state['v_avg_1'] = torch.zeros_like(p.mean(dim=0))
76+
77+
state['m_avg_c'] = torch.zeros_like(p.mean(dim=1)[:, None])
78+
state['m_avg_r'] = torch.zeros_like(p.mean(dim=0)[None, :])
79+
state['m_avg_u'] = torch.zeros_like(p.mean().unsqueeze(0).unsqueeze(0))
80+
81+
@torch.no_grad()
82+
def step(self, closure: CLOSURE = None) -> LOSS:
83+
loss: LOSS = None
84+
if closure is not None:
85+
with torch.enable_grad():
86+
loss = closure()
87+
88+
for group in self.param_groups:
89+
if 'step' in group:
90+
group['step'] += 1
91+
else:
92+
group['step'] = 1
93+
94+
beta1, beta2 = group['betas']
95+
96+
for p in group['params']:
97+
if p.grad is None:
98+
continue
99+
100+
grad = p.grad
101+
if grad.is_sparse:
102+
raise NoSparseGradientError(str(self))
103+
104+
state = self.state[p]
105+
106+
if len(state) == 0:
107+
if len(p.shape) < 2:
108+
state['m_avg'] = torch.zeros_like(p)
109+
state['v_avg'] = torch.zeros_like(p)
110+
else:
111+
state['v_avg_0'] = torch.zeros_like(p.mean(dim=1))
112+
state['v_avg_1'] = torch.zeros_like(p.mean(dim=0))
113+
114+
state['m_avg_c'] = torch.zeros_like(p.mean(dim=1)[:, None])
115+
state['m_avg_r'] = torch.zeros_like(p.mean(dim=0)[None, :])
116+
state['m_avg_u'] = torch.zeros_like(p.mean().unsqueeze(0).unsqueeze(0))
117+
118+
if sum(grad.shape) > 1:
119+
trust_ratio = (p.norm() / grad.norm().clip(min=group['g_norm_min'])).clip(min=group['ratio_min'])
120+
grad.mul_(trust_ratio)
121+
122+
if len(grad.shape) < 2:
123+
m = state['m_avg']
124+
v = state['v_avg']
125+
else:
126+
r, c = state['v_avg_0'][:, None], state['v_avg_1'][None, :]
127+
v = (r * c) / r.sum().clamp(min=group['eps2'])
128+
m = state['m_avg_c'] @ state['m_avg_u'] @ state['m_avg_r']
129+
130+
m.lerp_(grad, 1.0 - beta1)
131+
v.lerp_((grad - m).square(), 1.0 - beta2)
132+
133+
v_avg = v / (1.0 - beta2 ** group['step'])
134+
135+
if len(grad.shape) == 2:
136+
imp_c = softmax(v.mean(dim=1), dim=0)[:, None]
137+
imp_r = softmax(v.mean(dim=0), dim=0)[None, :]
138+
m.lerp_(grad, 1.0 - imp_c * imp_r)
139+
140+
u = m.lerp(grad, 1.0 - beta1)
141+
142+
if len(grad.shape) < 2:
143+
state['m_avg'] = m
144+
state['v_avg'] = v
145+
else:
146+
state['v_avg_0'] = v.sum(dim=1)
147+
state['v_avg_1'] = v.sum(dim=0) / v.sum().clamp(min=group['eps2'])
148+
149+
imp_c = softmax(v.mean(dim=1) / group['tau'], dim=-1)[:, None]
150+
imp_r = softmax(v.mean(dim=0) / group['tau'], dim=-1)[None, :]
151+
152+
c = ((m * imp_r).sum(dim=1))[:, None]
153+
r = ((m * imp_c).sum(dim=0))[None, :]
154+
155+
s = (c.T @ m @ r.T) / (c.T @ c @ r @ r.T).clamp(min=group['eps2'])
156+
157+
state['m_avg_c'] = c
158+
state['m_avg_r'] = r
159+
state['m_avg_u'] = s
160+
161+
u.div_((v_avg + group['eps1']).sqrt())
162+
163+
u = u.reshape(p.shape)
164+
u.add_(p, alpha=group['weight_decay'])
165+
166+
p.add_(u, alpha=-group['lr'])
167+
168+
return loss

tests/constants.py

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -23,6 +23,7 @@
2323
AdaFactor,
2424
AdaHessian,
2525
Adai,
26+
Adalite,
2627
AdaMax,
2728
AdaMod,
2829
AdamP,
@@ -120,6 +121,8 @@
120121
'padam',
121122
'came',
122123
'aida',
124+
'galore',
125+
'adalite',
123126
]
124127

125128
VALID_LR_SCHEDULER_NAMES: List[str] = [
@@ -434,6 +437,7 @@
434437
{'lr': 5e-1, 'weight_decay': 1e-3, 'rank': 2, 'scale': 1.0, 'update_proj_gap': 2, 'projection_type': 'full'},
435438
5,
436439
),
440+
(Adalite, {'lr': 1e0, 'weight_decay': 1e-3}, 5),
437441
]
438442
ADANORM_SUPPORTED_OPTIMIZERS: List[Tuple[Any, Dict[str, Union[float, bool, int]], int]] = [
439443
(AdaBelief, {'lr': 5e-1, 'weight_decay': 1e-3, 'adanorm': True}, 10),

tests/test_general_optimizer_parameters.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -46,6 +46,7 @@ def test_epsilon(optimizer_name):
4646
'lomo',
4747
'tiger',
4848
'came',
49+
'adalite',
4950
):
5051
pytest.skip(f'skip {optimizer_name} optimizer')
5152

tests/test_load_modules.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -38,7 +38,7 @@ def test_load_lr_scheduler_invalid(invalid_lr_scheduler_names):
3838

3939

4040
def test_get_supported_optimizers():
41-
assert len(get_supported_optimizers()) == 62
41+
assert len(get_supported_optimizers()) == 63
4242

4343

4444
def test_get_supported_lr_schedulers():

tests/test_optimizers.py

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -464,6 +464,11 @@ def test_prodigy_reset():
464464
assert str(optimizer) == 'Prodigy'
465465

466466

467+
def test_adalite_reset():
468+
optimizer = load_optimizer('adalite')([simple_zero_rank_parameter(True)])
469+
optimizer.reset()
470+
471+
467472
@pytest.mark.parametrize('pre_conditioner_type', [0, 1, 2])
468473
def test_scalable_shampoo_pre_conditioner_with_svd(pre_conditioner_type, environment):
469474
(x_data, y_data), _, loss_fn = environment

0 commit comments

Comments
 (0)