Skip to content

Commit fc473bb

Browse files
authored
Merge pull request #233 from kozistr/feature/lots-of-stuffs
[Feature] Lots of stuffs
2 parents 48030b5 + 8c8b821 commit fc473bb

19 files changed

+953
-305
lines changed

README.md

Lines changed: 72 additions & 70 deletions
Large diffs are not rendered by default.

docs/changelogs/v3.0.0.md

Lines changed: 8 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -13,6 +13,12 @@ Major version is updated! (`v2.12.0` -> `v3.0.0`) (#164)
1313
* Implement `GaLore` optimizer. (#224, #228)
1414
* [Memory-Efficient LLM Training by Gradient Low-Rank Projection](https://arxiv.org/abs/2403.03507)
1515
* Implement `Adalite` optimizer. (#225, #229)
16+
* Implement `bSAM` optimizer. (#212, #233)
17+
* [SAM as an Optimal Relaxation of Bayes](https://arxiv.org/abs/2210.01620)
18+
* Implement `Schedule-Free` optimizer. (#230, #233)
19+
* [Schedule-Free optimizers](https://github.com/facebookresearch/schedule_free)
20+
* Implement `EMCMC`. (#231, #233)
21+
* [Entropy-MCMC: Sampling from flat basins with ease](https://www.semanticscholar.org/paper/Entropy-MCMC%3A-Sampling-from-Flat-Basins-with-Ease-Li-Zhang/fd95de3f24fc4f955a6fe5719d38d1d06136e0cd)
1622

1723
### Fix
1824

@@ -35,4 +41,5 @@ thanks to @sdbds, @i404788
3541

3642
## Diff
3743

38-
[2.12.0...3.0.0](https://github.com/kozistr/pytorch_optimizer/compare/v2.12.0...v3.0.0)
44+
* from the previous major version : [2.0.0...3.0.0](https://github.com/kozistr/pytorch_optimizer/compare/v2.0.0...v3.0.0)
45+
* from the previous version: [2.12.0...3.0.0](https://github.com/kozistr/pytorch_optimizer/compare/v2.12.0...v3.0.0)

docs/index.md

Lines changed: 80 additions & 71 deletions
Large diffs are not rendered by default.

docs/optimizer.md

Lines changed: 12 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -96,6 +96,10 @@
9696
:docstring:
9797
:members:
9898

99+
::: pytorch_optimizer.BSAM
100+
:docstring:
101+
:members:
102+
99103
::: pytorch_optimizer.CAME
100104
:docstring:
101105
:members:
@@ -236,6 +240,14 @@
236240
:docstring:
237241
:members:
238242

243+
::: pytorch_optimizer.ScheduleFreeSGD
244+
:docstring:
245+
:members:
246+
247+
::: pytorch_optimizer.ScheduleFreeAdamW
248+
:docstring:
249+
:members:
250+
239251
::: pytorch_optimizer.AccSGD
240252
:docstring:
241253
:members:

docs/util.md

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -84,6 +84,10 @@
8484
:docstring:
8585
:members:
8686

87+
::: pytorch_optimizer.optimizer.utils.reg_noise
88+
:docstring:
89+
:members:
90+
8791
## Newton methods
8892

8993
::: pytorch_optimizer.optimizer.shampoo_utils.power_iteration

poetry.lock

Lines changed: 174 additions & 127 deletions
Some generated files are not rendered by default. Learn more about customizing how changed files appear on GitHub.

pyproject.toml

Lines changed: 7 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,6 @@
11
[tool.poetry]
22
name = "pytorch_optimizer"
3-
version = "2.12.0"
3+
version = "3.0.0"
44
description = "optimizer & lr scheduler & objective function collections in PyTorch"
55
license = "Apache-2.0"
66
authors = ["kozistr <[email protected]>"]
@@ -12,13 +12,13 @@ documentation = "https://pytorch-optimizers.readthedocs.io/en/latest"
1212
keywords = [
1313
"pytorch", "deep-learning", "optimizer", "lr scheduler", "A2Grad", "ASGD", "AccSGD", "AdaBelief", "AdaBound",
1414
"AdaDelta", "AdaFactor", "AdaMax", "AdaMod", "AdaNorm", "AdaPNM", "AdaSmooth", "AdaHessian", "Adai", "Adalite",
15-
"AdamP", "AdamS", "Adan", "AggMo", "Aida", "AliG", "Amos", "Apollo", "AvaGrad", "CAME", "DAdaptAdaGrad",
15+
"AdamP", "AdamS", "Adan", "AggMo", "Aida", "AliG", "Amos", "Apollo", "AvaGrad", "bSAM", "CAME", "DAdaptAdaGrad",
1616
"DAdaptAdam", "DAdaptAdan", "DAdaptSGD", "DAdaptLion", "DiffGrad", "Fromage", "GaLore", "Gravity", "GSAM", "LARS",
1717
"Lamb", "Lion", "LOMO", "Lookahead", "MADGRAD", "MSVAG", "Nero", "NovoGrad", "PAdam", "PCGrad", "PID", "PNM",
18-
"Prodigy", "QHAdam", "QHM", "RAdam", "Ranger", "Ranger21", "RotoGrad", "SAM", "SGDP", "Shampoo", "ScalableShampoo",
19-
"SGDW", "SignSGD", "SM3", "SopihaH", "SRMM", "SWATS", "Tiger", "WSAM", "Yogi", "BCE", "BCEFocal", "Focal",
20-
"FocalCosine", "SoftF1", "Dice", "LDAM", "Jaccard", "Bi-Tempered", "Tversky", "FocalTversky", "LovaszHinge",
21-
"bitsandbytes",
18+
"Prodigy", "QHAdam", "QHM", "RAdam", "Ranger", "Ranger21", "RotoGrad", "SAM", "ScheduleFreeSGD",
19+
"ScheduleFreeAdamW", "SGDP", "Shampoo", "ScalableShampoo", "SGDW", "SignSGD", "SM3", "SopihaH", "SRMM", "SWATS",
20+
"Tiger", "WSAM", "Yogi", "BCE", "BCEFocal", "Focal", "FocalCosine", "SoftF1", "Dice", "LDAM", "Jaccard",
21+
"Bi-Tempered", "Tversky", "FocalTversky", "LovaszHinge", "bitsandbytes",
2222
]
2323
classifiers = [
2424
"License :: OSI Approved :: Apache Software License",
@@ -50,7 +50,7 @@ bitsandbytes = { version = "^0.43", optional = true }
5050

5151
[tool.poetry.dev-dependencies]
5252
isort = { version = "^5", python = ">=3.8" }
53-
black = { version = "^24", python = ">=3.8"}
53+
black = { version = "^24", python = ">=3.8" }
5454
ruff = "*"
5555
pytest = "*"
5656
pytest-cov = "*"

pytorch_optimizer/__init__.py

Lines changed: 5 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -79,7 +79,8 @@
7979
from pytorch_optimizer.optimizer.ranger import Ranger
8080
from pytorch_optimizer.optimizer.ranger21 import Ranger21
8181
from pytorch_optimizer.optimizer.rotograd import RotoGrad
82-
from pytorch_optimizer.optimizer.sam import GSAM, SAM, WSAM
82+
from pytorch_optimizer.optimizer.sam import BSAM, GSAM, SAM, WSAM
83+
from pytorch_optimizer.optimizer.schedulefree import ScheduleFreeAdamW, ScheduleFreeSGD
8384
from pytorch_optimizer.optimizer.sgd import ASGD, SGDW, AccSGD, SignSGD
8485
from pytorch_optimizer.optimizer.sgdp import SGDP
8586
from pytorch_optimizer.optimizer.shampoo import ScalableShampoo, Shampoo
@@ -186,6 +187,9 @@
186187
Aida,
187188
GaLore,
188189
Adalite,
190+
BSAM,
191+
ScheduleFreeSGD,
192+
ScheduleFreeAdamW,
189193
]
190194
OPTIMIZERS: Dict[str, OPTIMIZER] = {str(optimizer.__name__).lower(): optimizer for optimizer in OPTIMIZER_LIST}
191195

0 commit comments

Comments
 (0)