Skip to content

Commit aca2ef2

Browse files
authored
Merge pull request #366 from kozistr/feature/optimizers
[Feature] AdaGC and SimplifiedAdEMAMix optimizers
2 parents 487200d + 771541a commit aca2ef2

File tree

16 files changed

+537
-223
lines changed

16 files changed

+537
-223
lines changed

README.md

Lines changed: 107 additions & 105 deletions
Large diffs are not rendered by default.

docs/changelogs/v3.4.3.md renamed to docs/changelogs/v3.5.0.md

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -5,6 +5,10 @@
55
* Support `StableSPAM` optimizer. (#358, #359)
66
* [How to Train in 4-Bit More Stably than 16-Bit Adam](https://arxiv.org/abs/2502.17055?)
77
* Support `ScheduleFreeWrapper`. (#334, #360)
8+
* Implement `AdaGC` optimizer. (#364, #366)
9+
* [Improving Training Stability for Large Language Model Pretraining](https://arxiv.org/abs/2502.11034)
10+
* Implement `Simplified-Ademamix` optimizer. (#364, #366)
11+
* [Connections between Schedule-Free Optimizers, AdEMAMix, and Accelerated SGD Variants](https://arxiv.org/abs/2502.02431)
812

913
### Update
1014

docs/index.md

Lines changed: 107 additions & 105 deletions
Large diffs are not rendered by default.

docs/optimizer.md

Lines changed: 8 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -28,6 +28,10 @@
2828
:docstring:
2929
:members:
3030

31+
::: pytorch_optimizer.AdaGC
32+
:docstring:
33+
:members:
34+
3135
::: pytorch_optimizer.AdaHessian
3236
:docstring:
3337
:members:
@@ -92,6 +96,10 @@
9296
:docstring:
9397
:members:
9498

99+
::: pytorch_optimizer.SimplifiedAdEMAMix
100+
:docstring:
101+
:members:
102+
95103
::: pytorch_optimizer.ADOPT
96104
:docstring:
97105
:members:

docs/visualization.md

Lines changed: 16 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -22,6 +22,10 @@
2222

2323
![image](https://raw.githubusercontent.com/kozistr/pytorch_optimizer/main/docs/visualizations/rastrigin_AdaFactor.png)
2424

25+
### AdaGC
26+
27+
![image](https://raw.githubusercontent.com/kozistr/pytorch_optimizer/main/docs/visualizations/rastrigin_AdaGC.png)
28+
2529
### AdaHessian
2630

2731
![image](https://raw.githubusercontent.com/kozistr/pytorch_optimizer/main/docs/visualizations/rastrigin_AdaHessian.png)
@@ -326,6 +330,10 @@
326330

327331
![image](https://raw.githubusercontent.com/kozistr/pytorch_optimizer/main/docs/visualizations/rastrigin_SignSGD.png)
328332

333+
### SimplifiedAdEMAMix
334+
335+
![image](https://raw.githubusercontent.com/kozistr/pytorch_optimizer/main/docs/visualizations/rastrigin_SimplifiedAdEMAMix.png)
336+
329337
### SM3
330338

331339
![image](https://raw.githubusercontent.com/kozistr/pytorch_optimizer/main/docs/visualizations/rastrigin_SM3.png)
@@ -392,6 +400,10 @@
392400

393401
![image](https://raw.githubusercontent.com/kozistr/pytorch_optimizer/main/docs/visualizations/rosenbrock_AdaFactor.png)
394402

403+
### AdaGC
404+
405+
![image](https://raw.githubusercontent.com/kozistr/pytorch_optimizer/main/docs/visualizations/rosenbrock_AdaGC.png)
406+
395407
### AdaHessian
396408

397409
![image](https://raw.githubusercontent.com/kozistr/pytorch_optimizer/main/docs/visualizations/rosenbrock_AdaHessian.png)
@@ -696,6 +708,10 @@
696708

697709
![image](https://raw.githubusercontent.com/kozistr/pytorch_optimizer/main/docs/visualizations/rosenbrock_SignSGD.png)
698710

711+
### SimplifiedAdEMAMix
712+
713+
![image](https://raw.githubusercontent.com/kozistr/pytorch_optimizer/main/docs/visualizations/rosenbrock_SimplifiedAdEMAMix.png)
714+
699715
### SM3
700716

701717
![image](https://raw.githubusercontent.com/kozistr/pytorch_optimizer/main/docs/visualizations/rosenbrock_SM3.png)
634 KB
Loading
633 KB
Loading
141 KB
Loading
151 KB
Loading

pyproject.toml

Lines changed: 11 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -11,17 +11,17 @@ repository = "https://github.com/kozistr/pytorch_optimizer"
1111
documentation = "https://pytorch-optimizers.readthedocs.io/en/latest"
1212
keywords = [
1313
"pytorch", "deep-learning", "optimizer", "lr scheduler", "A2Grad", "ASGD", "AccSGD", "AdaBelief", "AdaBound",
14-
"AdaDelta", "AdaFactor", "AdaMax", "AdamG", "AdaMod", "AdaNorm", "AdaPNM", "AdaSmooth", "AdEMAMix", "ADOPT",
15-
"AdaHessian", "Adai", "Adalite", "AdaLomo", "AdamMini", "AdamP", "AdamS", "Adan", "AggMo", "Aida", "AliG", "Amos",
16-
"Apollo", "APOLLO", "AvaGrad", "bSAM", "CAME", "DAdaptAdaGrad", "DAdaptAdam", "DAdaptAdan", "DAdaptSGD",
17-
"DAdaptLion", "DeMo", "DiffGrad", "EXAdam", "FAdam", "FOCUS", "Fromage", "FTRL", "GaLore", "Grams", "Gravity",
18-
"GrokFast", "GSAM", "Kate", "Lamb", "LaProp", "LARS", "Lion", "LOMO", "Lookahead", "MADGRAD", "MARS", "MSVAG",
19-
"Muno", "Nero", "NovoGrad", "OrthoGrad", "PAdam", "PCGrad", "PID", "PNM", "Prodigy", "PSGD", "QHAdam", "QHM",
20-
"RAdam", "Ranger", "Ranger21", "RotoGrad", "SAM", "GCSAM", "LookSAM", "ScheduleFreeSGD", "ScheduleFreeAdamW",
21-
"ScheduleFreeRAdam", "SCION", "SGDP", "Shampoo", "ScalableShampoo", "SGDW", "SignSGD", "SM3", "SOAP", "SopihaH",
22-
"SPAM", "StableSPAM", "SRMM", "StableAdamW", "SWATS", "TAM", "Tiger", "TRAC", "WSAM", "Yogi", "BCE", "BCEFocal",
23-
"Focal", "FocalCosine", "SoftF1", "Dice", "LDAM", "Jaccard", "Bi-Tempered", "Tversky", "FocalTversky",
24-
"LovaszHinge", "bitsandbytes", "WSD", "QGaLore",
14+
"AdaDelta", "AdaFactor", "AdaGC", "AdaMax", "AdamG", "AdaMod", "AdaNorm", "AdaPNM", "AdaSmooth", "AdEMAMix",
15+
"Simplified-AdEMAMix", "ADOPT", "AdaHessian", "Adai", "Adalite", "AdaLomo", "AdamMini", "AdamP", "AdamS", "Adan",
16+
"AggMo", "Aida", "AliG", "Amos", "Apollo", "APOLLO", "AvaGrad", "bSAM", "CAME", "DAdaptAdaGrad", "DAdaptAdam",
17+
"DAdaptAdan", "DAdaptSGD", "DAdaptLion", "DeMo", "DiffGrad", "EXAdam", "FAdam", "FOCUS", "Fromage", "FTRL",
18+
"GaLore", "Grams", "Gravity", "GrokFast", "GSAM", "Kate", "Lamb", "LaProp", "LARS", "Lion", "LOMO", "Lookahead",
19+
"MADGRAD", "MARS", "MSVAG", "Muno", "Nero", "NovoGrad", "OrthoGrad", "PAdam", "PCGrad", "PID", "PNM", "Prodigy",
20+
"PSGD", "QHAdam", "QHM", "RAdam", "Ranger", "Ranger21", "RotoGrad", "SAM", "GCSAM", "LookSAM", "ScheduleFreeSGD",
21+
"ScheduleFreeAdamW", "ScheduleFreeRAdam", "SCION", "SGDP", "Shampoo", "ScalableShampoo", "SGDW", "SignSGD", "SM3",
22+
"SOAP", "SopihaH", "SPAM", "StableSPAM", "SRMM", "StableAdamW", "SWATS", "TAM", "Tiger", "TRAC", "WSAM", "Yogi",
23+
"BCE", "BCEFocal", "Focal", "FocalCosine", "SoftF1", "Dice", "LDAM", "Jaccard", "Bi-Tempered", "Tversky",
24+
"FocalTversky", "LovaszHinge", "bitsandbytes", "WSD", "QGaLore",
2525
]
2626
classifiers = [
2727
"License :: OSI Approved :: Apache Software License",

0 commit comments

Comments
 (0)