Skip to content

Commit 87ab0e6

Browse files
authored
Merge pull request #376 from kozistr/feature/optimizers
[Feature] Implement Fira, RACS and Alice optimizers
2 parents 84b926c + 7dd7911 commit 87ab0e6

File tree

19 files changed

+804
-283
lines changed

19 files changed

+804
-283
lines changed

Makefile

Lines changed: 7 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,4 @@
1-
.PHONY: init format test check requirements docs
1+
.PHONY: init format test check requirements visualize docs
22

33
init:
44
python -m pip install -q -U poetry isort black ruff pytest pytest-cov
@@ -8,16 +8,19 @@ format:
88
isort --profile black -l 119 pytorch_optimizer examples tests hubconf.py
99
black -S -l 119 pytorch_optimizer examples tests hubconf.py
1010

11-
test:
12-
python -m pytest -p no:pastebin -p no:nose -p no:doctest -sv -vv --cov=pytorch_optimizer --cov-report=xml ./tests
13-
1411
check:
1512
black -S -l 119 --check pytorch_optimizer examples tests hubconf.py
1613
ruff check pytorch_optimizer examples tests hubconf.py
1714

15+
test:
16+
python -m pytest -p no:pastebin -p no:nose -p no:doctest -sv -vv --cov=pytorch_optimizer --cov-report=xml ./tests
17+
1818
requirements:
1919
poetry export -f requirements.txt --output requirements.txt --without-hashes
2020
poetry export -f requirements.txt --output requirements-dev.txt --without-hashes --with dev
2121

22+
visualize:
23+
python -m examples.visualize_optimizers
24+
2225
docs:
2326
mkdocs serve

README.md

Lines changed: 109 additions & 107 deletions
Large diffs are not rendered by default.

docs/changelogs/v3.5.2.md

Lines changed: 12 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,12 @@
1+
## Change Log
2+
3+
### Feature
4+
5+
* Implement `Fira` optimizer. (#376)
6+
* [Can We Achieve Full-rank Training of LLMs Under Low-rank Constraint?](https://arxiv.org/abs/2410.01623)
7+
* Implement `RACS` and `Alice optimizer. (#376)
8+
* [Towards Efficient Optimizer Design for LLM via Structured Fisher Approximation with a Low-Rank Extension](https://arxiv.org/abs/2502.07752)
9+
10+
### Fix
11+
12+
* Fix shape mismatch issues in the Galore projection for `reverse_std`, `right` and `full` projection types. (#376)

docs/index.md

Lines changed: 109 additions & 107 deletions
Large diffs are not rendered by default.

docs/optimizer.md

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -188,6 +188,10 @@
188188
:docstring:
189189
:members:
190190

191+
::: pytorch_optimizer.Fira
192+
:docstring:
193+
:members:
194+
191195
::: pytorch_optimizer.FOCUS
192196
:docstring:
193197
:members:

docs/visualization.md

Lines changed: 8 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -174,6 +174,10 @@
174174

175175
![image](https://raw.githubusercontent.com/kozistr/pytorch_optimizer/main/docs/visualizations/rastrigin_FAdam.png)
176176

177+
### Fira
178+
179+
![image](https://raw.githubusercontent.com/kozistr/pytorch_optimizer/main/docs/visualizations/rastrigin_Fira.png)
180+
177181
### FOCUS
178182

179183
![image](https://raw.githubusercontent.com/kozistr/pytorch_optimizer/main/docs/visualizations/rastrigin_FOCUS.png)
@@ -556,6 +560,10 @@
556560

557561
![image](https://raw.githubusercontent.com/kozistr/pytorch_optimizer/main/docs/visualizations/rosenbrock_FAdam.png)
558562

563+
### Fira
564+
565+
![image](https://raw.githubusercontent.com/kozistr/pytorch_optimizer/main/docs/visualizations/rosenbrock_Fira.png)
566+
559567
### FOCUS
560568

561569
![image](https://raw.githubusercontent.com/kozistr/pytorch_optimizer/main/docs/visualizations/rosenbrock_FOCUS.png)
633 KB
Loading
140 KB
Loading

pyproject.toml

Lines changed: 12 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -10,18 +10,18 @@ homepage = "https://github.com/kozistr/pytorch_optimizer"
1010
repository = "https://github.com/kozistr/pytorch_optimizer"
1111
documentation = "https://pytorch-optimizers.readthedocs.io/en/latest"
1212
keywords = [
13-
"pytorch", "deep-learning", "optimizer", "lr scheduler", "A2Grad", "ASGD", "AccSGD", "AdaBelief", "AdaBound",
14-
"AdaDelta", "AdaFactor", "AdaGC", "AdaMax", "AdamG", "AdaMod", "AdaNorm", "AdaPNM", "AdaSmooth", "AdEMAMix",
15-
"Simplified-AdEMAMix", "ADOPT", "AdaHessian", "Adai", "Adalite", "AdaLomo", "AdamMini", "AdamP", "AdamS", "Adan",
16-
"AggMo", "Aida", "AliG", "Amos", "Apollo", "APOLLO", "AvaGrad", "bSAM", "CAME", "DAdaptAdaGrad", "DAdaptAdam",
17-
"DAdaptAdan", "DAdaptSGD", "DAdaptLion", "DeMo", "DiffGrad", "EXAdam", "FAdam", "FOCUS", "Fromage", "FTRL",
18-
"GaLore", "Grams", "Gravity", "GrokFast", "GSAM", "Kate", "Lamb", "LaProp", "LARS", "Lion", "LOMO", "Lookahead",
19-
"MADGRAD", "MARS", "MSVAG", "Muno", "Nero", "NovoGrad", "OrthoGrad", "PAdam", "PCGrad", "PID", "PNM", "Prodigy",
20-
"PSGD", "QHAdam", "QHM", "RAdam", "Ranger", "Ranger21", "RotoGrad", "SAM", "GCSAM", "LookSAM", "ScheduleFreeSGD",
21-
"ScheduleFreeAdamW", "ScheduleFreeRAdam", "SCION", "SGDP", "Shampoo", "ScalableShampoo", "SGDW", "SignSGD", "SM3",
22-
"SOAP", "SopihaH", "SPAM", "StableSPAM", "SRMM", "StableAdamW", "SWATS", "TAM", "Tiger", "TRAC", "WSAM", "Yogi",
23-
"BCE", "BCEFocal", "Focal", "FocalCosine", "SoftF1", "Dice", "LDAM", "Jaccard", "Bi-Tempered", "Tversky",
24-
"FocalTversky", "LovaszHinge", "bitsandbytes", "WSD", "QGaLore",
13+
"pytorch", "deep-learning", "optimizer", "lr scheduler", "A2Grad", "Alice", "ASGD", "AccSGD", "AdaBelief",
14+
"AdaBound", "AdaDelta", "AdaFactor", "AdaGC", "AdaMax", "AdamG", "AdaMod", "AdaNorm", "AdaPNM", "AdaSmooth",
15+
"AdEMAMix", "Simplified-AdEMAMix", "ADOPT", "AdaHessian", "Adai", "Adalite", "AdaLomo", "AdamMini", "AdamP",
16+
"AdamS", "Adan", "AggMo", "Aida", "AliG", "Amos", "Apollo", "APOLLO", "AvaGrad", "bSAM", "CAME", "DAdaptAdaGrad",
17+
"DAdaptAdam", "DAdaptAdan", "DAdaptSGD", "DAdaptLion", "DeMo", "DiffGrad", "EXAdam", "FAdam", "Fira", "FOCUS",
18+
"Fromage", "FTRL", "GaLore", "Grams", "Gravity", "GrokFast", "GSAM", "Kate", "Lamb", "LaProp", "LARS", "Lion",
19+
"LOMO", "Lookahead", "MADGRAD", "MARS", "MSVAG", "Muno", "Nero", "NovoGrad", "OrthoGrad", "PAdam", "PCGrad", "PID",
20+
"PNM", "Prodigy", "PSGD", "QHAdam", "QHM", "RACS", "RAdam", "Ranger", "Ranger21", "RotoGrad", "SAM", "GCSAM",
21+
"LookSAM", "ScheduleFreeSGD", "ScheduleFreeAdamW", "ScheduleFreeRAdam", "SCION", "SGDP", "Shampoo",
22+
"ScalableShampoo", "SGDW", "SignSGD", "SM3", "SOAP", "SopihaH", "SPAM", "StableSPAM", "SRMM", "StableAdamW",
23+
"SWATS", "TAM", "Tiger", "TRAC", "WSAM", "Yogi", "BCE", "BCEFocal", "Focal", "FocalCosine", "SoftF1", "Dice",
24+
"LDAM", "Jaccard", "Bi-Tempered", "Tversky", "FocalTversky", "LovaszHinge", "bitsandbytes", "WSD", "QGaLore",
2525
]
2626
classifiers = [
2727
"License :: OSI Approved :: Apache Software License",

pytorch_optimizer/__init__.py

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -57,6 +57,7 @@
5757
PID,
5858
PNM,
5959
QHM,
60+
RACS,
6061
SAM,
6162
SCION,
6263
SGDP,
@@ -96,6 +97,7 @@
9697
AdEMAMix,
9798
AggMo,
9899
Aida,
100+
Alice,
99101
AliG,
100102
Amos,
101103
ApolloDQN,
@@ -110,6 +112,7 @@
110112
DynamicLossScaler,
111113
EXAdam,
112114
FAdam,
115+
Fira,
113116
Fromage,
114117
GaLore,
115118
Grams,

0 commit comments

Comments
 (0)