Skip to content

Commit 5db0994

Browse files
authored
Merge pull request #252 from kozistr/feature/stableadamw-optimizer
[Feature] Implement StableAdamW optimizer
2 parents b316ef9 + 39a38f3 commit 5db0994

File tree

11 files changed

+174
-5
lines changed

11 files changed

+174
-5
lines changed

README.md

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -10,7 +10,7 @@
1010

1111
**pytorch-optimizer** is optimizer & lr scheduler collections in PyTorch.
1212
I just re-implemented (speed & memory tweaks, plug-ins) the algorithm while based on the original paper. Also, It includes useful and practical optimization ideas.
13-
Currently, **70 optimizers (+ `bitsandbytes`)**, **16 lr schedulers**, and **13 loss functions** are supported!
13+
Currently, **71 optimizers (+ `bitsandbytes`)**, **16 lr schedulers**, and **13 loss functions** are supported!
1414

1515
Highly inspired by [pytorch-optimizer](https://github.com/jettify/pytorch-optimizer).
1616

@@ -167,6 +167,7 @@ supported_optimizers = get_supported_optimizers()
167167
| FAdam | *Adam is a natural gradient optimizer using diagonal empirical Fisher information* | [github](https://github.com/lessw2020/fadam_pytorch) | <https://arxiv.org/abs/2405.12807> | [cite](https://ui.adsabs.harvard.edu/abs/2024arXiv240512807H/exportcitation) |
168168
| Grokfast | *Accelerated Grokking by Amplifying Slow Gradients* | [github](https://github.com/ironjr/grokfast) | <https://arxiv.org/abs/2405.20233> | [cite](https://github.com/ironjr/grokfast?tab=readme-ov-file#citation) |
169169
| Kate | *Remove that Square Root: A New Efficient Scale-Invariant Version of AdaGrad* | [github](https://github.com/nazya/KATE) | <https://arxiv.org/abs/2403.02648> | [cite](https://github.com/nazya/KATE?tab=readme-ov-file#remove-that-square-root-a-new-efficient-scale-invariant-version-of-adagrad) |
170+
| StableAdamW | *Stable and low-precision training for large-scale vision-language models* | | <https://arxiv.org/abs/2304.13013> | [cite](https://ui.adsabs.harvard.edu/abs/2023arXiv230413013W/exportcitation) |
170171

171172
## Supported LR Scheduler
172173

docs/changelogs/v3.0.2.md

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -7,6 +7,8 @@
77
* Add more Pytorch built-in lr schedulers. (#248)
88
* Implement `Kate` optimizer. (#249, #251)
99
* [Remove that Square Root: A New Efficient Scale-Invariant Version of AdaGrad](https://arxiv.org/abs/2403.02648)
10+
* Implement `StableAdamW` optimizer. (#250, #252)
11+
* [Stable and low-precision training for large-scale vision-language models](https://arxiv.org/abs/2304.13013)
1012

1113
### Refactor
1214

docs/index.md

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -10,7 +10,7 @@
1010

1111
**pytorch-optimizer** is optimizer & lr scheduler collections in PyTorch.
1212
I just re-implemented (speed & memory tweaks, plug-ins) the algorithm while based on the original paper. Also, It includes useful and practical optimization ideas.
13-
Currently, **70 optimizers (+ `bitsandbytes`)**, **16 lr schedulers**, and **13 loss functions** are supported!
13+
Currently, **71 optimizers (+ `bitsandbytes`)**, **16 lr schedulers**, and **13 loss functions** are supported!
1414

1515
Highly inspired by [pytorch-optimizer](https://github.com/jettify/pytorch-optimizer).
1616

@@ -167,6 +167,7 @@ supported_optimizers = get_supported_optimizers()
167167
| FAdam | *Adam is a natural gradient optimizer using diagonal empirical Fisher information* | [github](https://github.com/lessw2020/fadam_pytorch) | <https://arxiv.org/abs/2405.12807> | [cite](https://ui.adsabs.harvard.edu/abs/2024arXiv240512807H/exportcitation) |
168168
| Grokfast | *Accelerated Grokking by Amplifying Slow Gradients* | [github](https://github.com/ironjr/grokfast) | <https://arxiv.org/abs/2405.20233> | [cite](https://github.com/ironjr/grokfast?tab=readme-ov-file#citation) |
169169
| Kate | *Remove that Square Root: A New Efficient Scale-Invariant Version of AdaGrad* | [github](https://github.com/nazya/KATE) | <https://arxiv.org/abs/2403.02648> | [cite](https://github.com/nazya/KATE?tab=readme-ov-file#remove-that-square-root-a-new-efficient-scale-invariant-version-of-adagrad) |
170+
| StableAdamW | *Stable and low-precision training for large-scale vision-language models* | | <https://arxiv.org/abs/2304.13013> | [cite](https://ui.adsabs.harvard.edu/abs/2023arXiv230413013W/exportcitation) |
170171

171172
## Supported LR Scheduler
172173

docs/optimizer.md

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -268,6 +268,10 @@
268268
:docstring:
269269
:members:
270270

271+
::: pytorch_optimizer.StableAdamW
272+
:docstring:
273+
:members:
274+
271275
::: pytorch_optimizer.AccSGD
272276
:docstring:
273277
:members:

pyproject.toml

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -17,8 +17,8 @@ keywords = [
1717
"GrokFast", "GSAM", "Kate", "Lamb", "LARS", "Lion", "LOMO", "Lookahead", "MADGRAD", "MSVAG", "Nero", "NovoGrad",
1818
"PAdam", "PCGrad", "PID", "PNM", "Prodigy", "QHAdam", "QHM", "RAdam", "Ranger", "Ranger21", "RotoGrad", "SAM",
1919
"ScheduleFreeSGD", "ScheduleFreeAdamW", "SGDP", "Shampoo", "ScalableShampoo", "SGDW", "SignSGD", "SM3", "SopihaH",
20-
"SRMM", "SWATS", "Tiger", "WSAM", "Yogi", "BCE", "BCEFocal", "Focal", "FocalCosine", "SoftF1", "Dice", "LDAM",
21-
"Jaccard", "Bi-Tempered", "Tversky", "FocalTversky", "LovaszHinge", "bitsandbytes", "WSD",
20+
"SRMM", "StableAdamW", "SWATS", "Tiger", "WSAM", "Yogi", "BCE", "BCEFocal", "Focal", "FocalCosine", "SoftF1",
21+
"Dice", "LDAM", "Jaccard", "Bi-Tempered", "Tversky", "FocalTversky", "LovaszHinge", "bitsandbytes", "WSD",
2222
]
2323
classifiers = [
2424
"License :: OSI Approved :: Apache Software License",

pytorch_optimizer/__init__.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -44,6 +44,7 @@
4444
from pytorch_optimizer.optimizer.adamod import AdaMod
4545
from pytorch_optimizer.optimizer.adamp import AdamP
4646
from pytorch_optimizer.optimizer.adams import AdamS
47+
from pytorch_optimizer.optimizer.adamw import StableAdamW
4748
from pytorch_optimizer.optimizer.adan import Adan
4849
from pytorch_optimizer.optimizer.adanorm import AdaNorm
4950
from pytorch_optimizer.optimizer.adapnm import AdaPNM
@@ -201,6 +202,7 @@
201202
FAdam,
202203
GrokFastAdamW,
203204
Kate,
205+
StableAdamW,
204206
]
205207
OPTIMIZERS: Dict[str, OPTIMIZER] = {str(optimizer.__name__).lower(): optimizer for optimizer in OPTIMIZER_LIST}
206208

Lines changed: 135 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,135 @@
1+
import math
2+
3+
import torch
4+
from torch.optim.optimizer import Optimizer
5+
6+
from pytorch_optimizer.base.exception import NoSparseGradientError
7+
from pytorch_optimizer.base.optimizer import BaseOptimizer
8+
from pytorch_optimizer.base.types import BETAS, CLOSURE, DEFAULTS, LOSS, PARAMETERS
9+
from pytorch_optimizer.optimizer.utils import debias_beta
10+
11+
12+
class StableAdamW(Optimizer, BaseOptimizer):
13+
r"""Stable and low-precision training for large-scale vision-language models.
14+
15+
:param params: PARAMETERS. iterable of parameters to optimize or dicts defining parameter groups.
16+
:param lr: float. learning rate.
17+
:param betas: BETAS. coefficients used for computing running averages of gradient and the squared hessian trace.
18+
:param kahan_sum: bool. Enables Kahan summation for more accurate parameter updates when training in low precision
19+
(float16 or bfloat16).
20+
:param weight_decay: float. weight decay (L2 penalty).
21+
:param weight_decouple: bool. decoupled weight decay.
22+
:param eps: float. term added to the denominator to improve numerical stability.
23+
"""
24+
25+
def __init__(
26+
self,
27+
params: PARAMETERS,
28+
lr: float = 1e-3,
29+
betas: BETAS = (0.9, 0.99),
30+
kahan_sum: bool = True,
31+
weight_decay: float = 1e-2,
32+
weight_decouple: bool = True,
33+
eps: float = 1e-8,
34+
):
35+
self.validate_learning_rate(lr)
36+
self.validate_betas(betas)
37+
self.validate_non_negative(weight_decay, 'weight_decay')
38+
self.validate_non_negative(eps, 'eps')
39+
40+
defaults: DEFAULTS = {
41+
'lr': lr,
42+
'betas': betas,
43+
'kahan_sum': kahan_sum,
44+
'weight_decay': weight_decay,
45+
'weight_decouple': weight_decouple,
46+
'eps': eps,
47+
}
48+
49+
super().__init__(params, defaults)
50+
51+
def __str__(self) -> str:
52+
return 'StableAdamW'
53+
54+
@torch.no_grad()
55+
def reset(self):
56+
for group in self.param_groups:
57+
group['step'] = 0
58+
for p in group['params']:
59+
state = self.state[p]
60+
61+
state['exp_avg'] = torch.zeros_like(p)
62+
state['exp_avg_sq'] = torch.zeros_like(p)
63+
64+
state['kahan_comp'] = (
65+
torch.zeros_like(p) if group['kahan_sum'] and p.dtype in {torch.float16, torch.bfloat16} else None
66+
)
67+
68+
@torch.no_grad()
69+
def step(self, closure: CLOSURE = None) -> LOSS:
70+
loss: LOSS = None
71+
if closure is not None:
72+
with torch.enable_grad():
73+
loss = closure()
74+
75+
for group in self.param_groups:
76+
if 'step' in group:
77+
group['step'] += 1
78+
else:
79+
group['step'] = 1
80+
81+
beta1, beta2 = group['betas']
82+
83+
beta1_comp: float = 1.0 - debias_beta(beta1, group['step'])
84+
beta2_hat: float = debias_beta(beta2, group['step'])
85+
86+
eps_p2: float = math.pow(group['eps'], 2)
87+
88+
for p in group['params']:
89+
if p.grad is None:
90+
continue
91+
92+
grad = p.grad
93+
if grad.is_sparse:
94+
raise NoSparseGradientError(str(self))
95+
96+
state = self.state[p]
97+
if len(state) == 0:
98+
state['exp_avg'] = torch.zeros_like(p)
99+
state['exp_avg_sq'] = torch.zeros_like(p)
100+
101+
state['kahan_comp'] = (
102+
torch.zeros_like(p)
103+
if (group['kahan_sum'] and p.dtype in {torch.float16, torch.bfloat16})
104+
else None
105+
)
106+
107+
exp_avg, exp_avg_sq = state['exp_avg'], state['exp_avg_sq']
108+
exp_avg.lerp_(grad, weight=beta1_comp)
109+
exp_avg_sq.mul_(beta2_hat).addcmul_(grad, grad, value=1.0 - beta2_hat)
110+
111+
rms = grad.pow(2).div_(exp_avg_sq.clip(min=eps_p2)).mean().sqrt_()
112+
113+
lr = group['lr'] / rms.clip(min=1.0)
114+
115+
self.apply_weight_decay(
116+
p,
117+
p.grad,
118+
lr=lr,
119+
weight_decay=group['weight_decay'],
120+
weight_decouple=group['weight_decouple'],
121+
fixed_decay=False,
122+
)
123+
124+
if group['kahan_sum'] and p.dtype in {torch.float16, torch.bfloat16}:
125+
kahan_comp = state['kahan_comp']
126+
kahan_comp.addcdiv_(exp_avg, exp_avg_sq.sqrt().add_(group['eps']), value=-lr)
127+
128+
grad.copy_(p.detach())
129+
p.add_(kahan_comp)
130+
131+
kahan_comp.add_(grad.sub_(p))
132+
else:
133+
p.addcdiv_(exp_avg, exp_avg_sq.sqrt().add_(group['eps']), value=-lr)
134+
135+
return loss

pytorch_optimizer/optimizer/utils.py

Lines changed: 11 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -12,6 +12,17 @@
1212
from pytorch_optimizer.base.types import PARAMETERS
1313

1414

15+
def debias_beta(beta: float, step: int) -> float:
16+
r"""Apply the Adam-style debias correction into beta.
17+
18+
Simplified version of `\^{beta} = beta * (1.0 - beta ** (step - 1)) / (1.0 - beta ** step)`
19+
20+
:param beta: float. beta.
21+
:param step: int. number of step.
22+
"""
23+
return (beta ** step - beta) / (beta ** step - 1.0) # fmt: skip
24+
25+
1526
def is_valid_parameters(parameters: PARAMETERS) -> bool:
1627
r"""Check where the parameters are valid."""
1728
return isinstance(parameters, (list, tuple)) and len(parameters) > 0 and isinstance(parameters[0], dict)

tests/constants.py

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -67,6 +67,7 @@
6767
Shampoo,
6868
SignSGD,
6969
SophiaH,
70+
StableAdamW,
7071
Tiger,
7172
Yogi,
7273
)
@@ -132,6 +133,7 @@
132133
'schedulefreeadamw',
133134
'fadam',
134135
'grokfastadamw',
136+
'stableadamw',
135137
]
136138

137139
VALID_LR_SCHEDULER_NAMES: List[str] = [
@@ -463,6 +465,7 @@
463465
(FAdam, {'lr': 1e0, 'weight_decay': 1e-3}, 5),
464466
(GrokFastAdamW, {'lr': 1e0, 'weight_decay': 1e-3}, 10),
465467
(Kate, {'lr': 5e-2}, 10),
468+
(StableAdamW, {'lr': 1e0}, 5),
466469
]
467470
ADANORM_SUPPORTED_OPTIMIZERS: List[Tuple[Any, Dict[str, Union[float, bool, int]], int]] = [
468471
(AdaBelief, {'lr': 5e-1, 'weight_decay': 1e-3, 'adanorm': True}, 10),

tests/test_load_modules.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -38,7 +38,7 @@ def test_load_lr_scheduler_invalid(invalid_lr_scheduler_names):
3838

3939

4040
def test_get_supported_optimizers():
41-
assert len(get_supported_optimizers()) == 69
41+
assert len(get_supported_optimizers()) == 70
4242

4343

4444
def test_get_supported_lr_schedulers():

0 commit comments

Comments
 (0)