Skip to content

Commit 83a2f5e

Browse files
authored
Merge pull request #253 from kozistr/feature/adam-mini-optimizer
[Feature] Implement AdamMini optimizer
2 parents 5db0994 + a970453 commit 83a2f5e

File tree

13 files changed

+396
-13
lines changed

13 files changed

+396
-13
lines changed

README.md

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -10,7 +10,7 @@
1010

1111
**pytorch-optimizer** is optimizer & lr scheduler collections in PyTorch.
1212
I just re-implemented (speed & memory tweaks, plug-ins) the algorithm while based on the original paper. Also, It includes useful and practical optimization ideas.
13-
Currently, **71 optimizers (+ `bitsandbytes`)**, **16 lr schedulers**, and **13 loss functions** are supported!
13+
Currently, **72 optimizers (+ `bitsandbytes`)**, **16 lr schedulers**, and **13 loss functions** are supported!
1414

1515
Highly inspired by [pytorch-optimizer](https://github.com/jettify/pytorch-optimizer).
1616

@@ -168,6 +168,7 @@ supported_optimizers = get_supported_optimizers()
168168
| Grokfast | *Accelerated Grokking by Amplifying Slow Gradients* | [github](https://github.com/ironjr/grokfast) | <https://arxiv.org/abs/2405.20233> | [cite](https://github.com/ironjr/grokfast?tab=readme-ov-file#citation) |
169169
| Kate | *Remove that Square Root: A New Efficient Scale-Invariant Version of AdaGrad* | [github](https://github.com/nazya/KATE) | <https://arxiv.org/abs/2403.02648> | [cite](https://github.com/nazya/KATE?tab=readme-ov-file#remove-that-square-root-a-new-efficient-scale-invariant-version-of-adagrad) |
170170
| StableAdamW | *Stable and low-precision training for large-scale vision-language models* | | <https://arxiv.org/abs/2304.13013> | [cite](https://ui.adsabs.harvard.edu/abs/2023arXiv230413013W/exportcitation) |
171+
| AdamMini | *Use Fewer Learning Rates To Gain More* | [github](https://github.com/zyushun/Adam-mini) | <https://arxiv.org/abs/2406.16793> | [cite](https://github.com/zyushun/Adam-mini?tab=readme-ov-file#citation) |
171172

172173
## Supported LR Scheduler
173174

docs/changelogs/v3.0.2.md

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -9,6 +9,8 @@
99
* [Remove that Square Root: A New Efficient Scale-Invariant Version of AdaGrad](https://arxiv.org/abs/2403.02648)
1010
* Implement `StableAdamW` optimizer. (#250, #252)
1111
* [Stable and low-precision training for large-scale vision-language models](https://arxiv.org/abs/2304.13013)
12+
* Implement `AdamMini` optimizer. (#246, #253)
13+
* [Use Fewer Learning Rates To Gain More](https://arxiv.org/abs/2406.16793)
1214

1315
### Refactor
1416

docs/index.md

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -10,7 +10,7 @@
1010

1111
**pytorch-optimizer** is optimizer & lr scheduler collections in PyTorch.
1212
I just re-implemented (speed & memory tweaks, plug-ins) the algorithm while based on the original paper. Also, It includes useful and practical optimization ideas.
13-
Currently, **71 optimizers (+ `bitsandbytes`)**, **16 lr schedulers**, and **13 loss functions** are supported!
13+
Currently, **72 optimizers (+ `bitsandbytes`)**, **16 lr schedulers**, and **13 loss functions** are supported!
1414

1515
Highly inspired by [pytorch-optimizer](https://github.com/jettify/pytorch-optimizer).
1616

@@ -168,6 +168,7 @@ supported_optimizers = get_supported_optimizers()
168168
| Grokfast | *Accelerated Grokking by Amplifying Slow Gradients* | [github](https://github.com/ironjr/grokfast) | <https://arxiv.org/abs/2405.20233> | [cite](https://github.com/ironjr/grokfast?tab=readme-ov-file#citation) |
169169
| Kate | *Remove that Square Root: A New Efficient Scale-Invariant Version of AdaGrad* | [github](https://github.com/nazya/KATE) | <https://arxiv.org/abs/2403.02648> | [cite](https://github.com/nazya/KATE?tab=readme-ov-file#remove-that-square-root-a-new-efficient-scale-invariant-version-of-adagrad) |
170170
| StableAdamW | *Stable and low-precision training for large-scale vision-language models* | | <https://arxiv.org/abs/2304.13013> | [cite](https://ui.adsabs.harvard.edu/abs/2023arXiv230413013W/exportcitation) |
171+
| AdamMini | *Use Fewer Learning Rates To Gain More* | [github](https://github.com/zyushun/Adam-mini) | <https://arxiv.org/abs/2406.16793> | [cite](https://github.com/zyushun/Adam-mini?tab=readme-ov-file#citation) |
171172

172173
## Supported LR Scheduler
173174

docs/optimizer.md

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -32,6 +32,10 @@
3232
:docstring:
3333
:members:
3434

35+
::: pytorch_optimizer.AdamMini
36+
:docstring:
37+
:members:
38+
3539
::: pytorch_optimizer.AdaMax
3640
:docstring:
3741
:members:

pyproject.toml

Lines changed: 8 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -12,13 +12,13 @@ documentation = "https://pytorch-optimizers.readthedocs.io/en/latest"
1212
keywords = [
1313
"pytorch", "deep-learning", "optimizer", "lr scheduler", "A2Grad", "ASGD", "AccSGD", "AdaBelief", "AdaBound",
1414
"AdaDelta", "AdaFactor", "AdaMax", "AdaMod", "AdaNorm", "AdaPNM", "AdaSmooth", "AdaHessian", "Adai", "Adalite",
15-
"AdamP", "AdamS", "Adan", "AggMo", "Aida", "AliG", "Amos", "Apollo", "AvaGrad", "bSAM", "CAME", "DAdaptAdaGrad",
16-
"DAdaptAdam", "DAdaptAdan", "DAdaptSGD", "DAdaptLion", "DiffGrad", "FAdam", "Fromage", "GaLore", "Gravity",
17-
"GrokFast", "GSAM", "Kate", "Lamb", "LARS", "Lion", "LOMO", "Lookahead", "MADGRAD", "MSVAG", "Nero", "NovoGrad",
18-
"PAdam", "PCGrad", "PID", "PNM", "Prodigy", "QHAdam", "QHM", "RAdam", "Ranger", "Ranger21", "RotoGrad", "SAM",
19-
"ScheduleFreeSGD", "ScheduleFreeAdamW", "SGDP", "Shampoo", "ScalableShampoo", "SGDW", "SignSGD", "SM3", "SopihaH",
20-
"SRMM", "StableAdamW", "SWATS", "Tiger", "WSAM", "Yogi", "BCE", "BCEFocal", "Focal", "FocalCosine", "SoftF1",
21-
"Dice", "LDAM", "Jaccard", "Bi-Tempered", "Tversky", "FocalTversky", "LovaszHinge", "bitsandbytes", "WSD",
15+
"AdamMini", "AdamP", "AdamS", "Adan", "AggMo", "Aida", "AliG", "Amos", "Apollo", "AvaGrad", "bSAM", "CAME",
16+
"DAdaptAdaGrad", "DAdaptAdam", "DAdaptAdan", "DAdaptSGD", "DAdaptLion", "DiffGrad", "FAdam", "Fromage", "GaLore",
17+
"Gravity", "GrokFast", "GSAM", "Kate", "Lamb", "LARS", "Lion", "LOMO", "Lookahead", "MADGRAD", "MSVAG", "Nero",
18+
"NovoGrad", "PAdam", "PCGrad", "PID", "PNM", "Prodigy", "QHAdam", "QHM", "RAdam", "Ranger", "Ranger21", "RotoGrad",
19+
"SAM", "ScheduleFreeSGD", "ScheduleFreeAdamW", "SGDP", "Shampoo", "ScalableShampoo", "SGDW", "SignSGD", "SM3",
20+
"SopihaH", "SRMM", "StableAdamW", "SWATS", "Tiger", "WSAM", "Yogi", "BCE", "BCEFocal", "Focal", "FocalCosine",
21+
"SoftF1", "Dice", "LDAM", "Jaccard", "Bi-Tempered", "Tversky", "FocalTversky", "LovaszHinge", "bitsandbytes", "WSD",
2222
]
2323
classifiers = [
2424
"License :: OSI Approved :: Apache Software License",
@@ -122,6 +122,7 @@ testpaths = "tests"
122122
[tool.coverage.run]
123123
omit = [
124124
"./pytorch_optimizer/optimizer/rotograd.py",
125+
"./pytorch_optimizer/optimizer/adam_mini.py",
125126
]
126127

127128
[build-system]

pytorch_optimizer/__init__.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -40,6 +40,7 @@
4040
from pytorch_optimizer.optimizer.adahessian import AdaHessian
4141
from pytorch_optimizer.optimizer.adai import Adai
4242
from pytorch_optimizer.optimizer.adalite import Adalite
43+
from pytorch_optimizer.optimizer.adam_mini import AdamMini
4344
from pytorch_optimizer.optimizer.adamax import AdaMax
4445
from pytorch_optimizer.optimizer.adamod import AdaMod
4546
from pytorch_optimizer.optimizer.adamp import AdamP
@@ -203,6 +204,7 @@
203204
GrokFastAdamW,
204205
Kate,
205206
StableAdamW,
207+
AdamMini,
206208
]
207209
OPTIMIZERS: Dict[str, OPTIMIZER] = {str(optimizer.__name__).lower(): optimizer for optimizer in OPTIMIZER_LIST}
208210

pytorch_optimizer/base/optimizer.py

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -258,6 +258,11 @@ def validate_learning_rate(learning_rate: Optional[float]) -> None:
258258
if learning_rate is not None and learning_rate < 0.0:
259259
raise NegativeLRError(learning_rate)
260260

261+
@staticmethod
262+
def validate_mod(x: int, y: int) -> None:
263+
if x % y != 0:
264+
raise ValueError(f'[-] {x} must be divisible by {y}')
265+
261266
def validate_betas(self, betas: BETAS) -> None:
262267
if betas[0] is not None:
263268
self.validate_range(betas[0], 'beta1', 0.0, 1.0, range_type='[]')

0 commit comments

Comments
 (0)