Skip to content

Commit a9fb8a2

Browse files
authored
Merge pull request #324 from kozistr/feature/ranger25
[Refactor] flexible and consistent `optimizer` parameters for `Lookahead`, `TRAC`, and `OrthoGrad` optimizers
2 parents 5baa713 + 87e1a60 commit a9fb8a2

28 files changed

+496
-130
lines changed

README.md

Lines changed: 3 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -10,8 +10,8 @@
1010

1111
## The reasons why you use `pytorch-optimizer`.
1212

13-
* Wide range of supported optimizers. Currently, **89 optimizers (+ `bitsandbytes`, `qgalore`, `torchao`)**, **16 lr schedulers**, and **13 loss functions** are supported!
14-
* Including many variants such as `Cautious`, `AdamD`, `Gradient Centrailiaztion`
13+
* Wide range of supported optimizers. Currently, **90 optimizers (+ `bitsandbytes`, `qgalore`, `torchao`)**, **16 lr schedulers**, and **13 loss functions** are supported!
14+
* Including many variants such as `ADOPT`, `Cautious`, `AdamD`, `StableAdamW`, and `Gradient Centrailiaztion`
1515
* Easy to use, clean, and tested codes
1616
* Active maintenance
1717
* Somewhat a bit more optimized compared to the original implementation
@@ -198,6 +198,7 @@ get_supported_optimizers(['adam*', 'ranger*'])
198198
| Grams | *Gradient Descent with Adaptive Momentum Scaling* | | <https://arxiv.org/abs/2412.17107> | [cite](https://ui.adsabs.harvard.edu/abs/2024arXiv241217107C/exportcitation) |
199199
| OrthoGrad | *Grokking at the Edge of Numerical Stability* | [github](https://github.com/LucasPrietoAl/grokking-at-the-edge-of-numerical-stability) | <https://arxiv.org/abs/2501.04697> | [cite](https://github.com/LucasPrietoAl/grokking-at-the-edge-of-numerical-stability?tab=readme-ov-file#citation) |
200200
| Adam-ATAN2 | *Scaling Exponents Across Parameterizations and Optimizers* | | <https://arxiv.org/abs/2407.05872> | [cite](https://ui.adsabs.harvard.edu/abs/2024arXiv240705872E/exportcitation) |
201+
| SPAM | *Spike-Aware Adam with Momentum Reset for Stable LLM Training* | [github](https://github.com/TianjinYellow/SPAM-Optimizer) | <https://arxiv.org/abs/2501.06842> | [cite](https://ui.adsabs.harvard.edu/abs/2025arXiv250106842H/exportcitation) |
201202

202203
## Supported LR Scheduler
203204

docs/changelogs/v3.3.4.md

Lines changed: 12 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,12 @@
1+
### Change Log
2+
3+
### Feature
4+
5+
* Support `OrthoGrad` feature for `create_optimizer()`. (#324)
6+
* Enhanced flexibility for the `optimizer` parameter in `Lookahead`, `TRAC`, and `OrthoGrad` optimizers. (#324)
7+
* Now supports both torch.optim.Optimizer instances and classes
8+
* You can now use `Lookahead` optimizer in two ways.
9+
* `Lookahead(AdamW(model.parameters(), lr=1e-3), k=5, alpha=0.5)`
10+
* `Lookahead(AdamW, k=5, alpha=0.5, params=model.parameters())`
11+
* Implement `SPAM` optimizer. (#324)
12+
* [Spike-Aware Adam with Momentum Reset for Stable LLM Training](https://arxiv.org/abs/2501.06842)

docs/index.md

Lines changed: 3 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -10,8 +10,8 @@
1010

1111
## The reasons why you use `pytorch-optimizer`.
1212

13-
* Wide range of supported optimizers. Currently, **89 optimizers (+ `bitsandbytes`, `qgalore`, `torchao`)**, **16 lr schedulers**, and **13 loss functions** are supported!
14-
* Including many variants such as `Cautious`, `AdamD`, `Gradient Centrailiaztion`
13+
* Wide range of supported optimizers. Currently, **90 optimizers (+ `bitsandbytes`, `qgalore`, `torchao`)**, **16 lr schedulers**, and **13 loss functions** are supported!
14+
* Including many variants such as `ADOPT`, `Cautious`, `AdamD`, `StableAdamW`, and `Gradient Centrailiaztion`
1515
* Easy to use, clean, and tested codes
1616
* Active maintenance
1717
* Somewhat a bit more optimized compared to the original implementation
@@ -198,6 +198,7 @@ get_supported_optimizers(['adam*', 'ranger*'])
198198
| Grams | *Gradient Descent with Adaptive Momentum Scaling* | | <https://arxiv.org/abs/2412.17107> | [cite](https://ui.adsabs.harvard.edu/abs/2024arXiv241217107C/exportcitation) |
199199
| OrthoGrad | *Grokking at the Edge of Numerical Stability* | [github](https://github.com/LucasPrietoAl/grokking-at-the-edge-of-numerical-stability) | <https://arxiv.org/abs/2501.04697> | [cite](https://github.com/LucasPrietoAl/grokking-at-the-edge-of-numerical-stability?tab=readme-ov-file#citation) |
200200
| Adam-ATAN2 | *Scaling Exponents Across Parameterizations and Optimizers* | | <https://arxiv.org/abs/2407.05872> | [cite](https://ui.adsabs.harvard.edu/abs/2024arXiv240705872E/exportcitation) |
201+
| SPAM | *Spike-Aware Adam with Momentum Reset for Stable LLM Training* | [github](https://github.com/TianjinYellow/SPAM-Optimizer) | <https://arxiv.org/abs/2501.06842> | [cite](https://ui.adsabs.harvard.edu/abs/2025arXiv250106842H/exportcitation) |
201202

202203
## Supported LR Scheduler
203204

docs/optimizer.md

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -368,6 +368,10 @@
368368
:docstring:
369369
:members:
370370

371+
::: pytorch_optimizer.SPAM
372+
:docstring:
373+
:members:
374+
371375
::: pytorch_optimizer.SRMM
372376
:docstring:
373377
:members:

pyproject.toml

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -18,9 +18,9 @@ keywords = [
1818
"Kate", "Lamb", "LaProp", "LARS", "Lion", "LOMO", "Lookahead", "MADGRAD", "MARS", "MSVAG", "Muno", "Nero",
1919
"NovoGrad", "OrthoGrad", "PAdam", "PCGrad", "PID", "PNM", "Prodigy", "QHAdam", "QHM", "RAdam", "Ranger",
2020
"Ranger21", "RotoGrad", "SAM", "ScheduleFreeSGD", "ScheduleFreeAdamW", "ScheduleFreeRAdam", "SGDP", "Shampoo",
21-
"ScalableShampoo", "SGDW", "SignSGD", "SM3", "SOAP", "SopihaH", "SRMM", "StableAdamW", "SWATS", "Tiger", "TRAC",
22-
"WSAM", "Yogi", "BCE", "BCEFocal", "Focal", "FocalCosine", "SoftF1", "Dice", "LDAM", "Jaccard", "Bi-Tempered",
23-
"Tversky", "FocalTversky", "LovaszHinge", "bitsandbytes", "WSD", "QGaLore",
21+
"ScalableShampoo", "SGDW", "SignSGD", "SM3", "SOAP", "SopihaH", "SPAM", "SRMM", "StableAdamW", "SWATS", "Tiger",
22+
"TRAC", "WSAM", "Yogi", "BCE", "BCEFocal", "Focal", "FocalCosine", "SoftF1", "Dice", "LDAM", "Jaccard",
23+
"Bi-Tempered", "Tversky", "FocalTversky", "LovaszHinge", "bitsandbytes", "WSD", "QGaLore",
2424
]
2525
classifiers = [
2626
"License :: OSI Approved :: Apache Software License",

pytorch_optimizer/__init__.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -61,6 +61,7 @@
6161
SGDW,
6262
SM3,
6363
SOAP,
64+
SPAM,
6465
SRMM,
6566
SWATS,
6667
TRAC,

pytorch_optimizer/base/scheduler.py

Lines changed: 4 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -1,16 +1,17 @@
11
from abc import ABC, abstractmethod
22
from typing import List
33

4+
from torch.optim import Optimizer
5+
46
from pytorch_optimizer.base.exception import NegativeLRError, NegativeStepError
5-
from pytorch_optimizer.base.types import OPTIMIZER
67

78

89
class BaseLinearWarmupScheduler(ABC):
910
r"""BaseLinearWarmupScheduler class.
1011
1112
The LR Scheduler class based on this class has linear warmup strategy.
1213
13-
:param optimizer: Optimizer. OPTIMIZER. It will set learning rate to all trainable parameters in optimizer.
14+
:param optimizer: Optimizer. It will set learning rate to all trainable parameters in optimizer.
1415
:param t_max: int. total steps to train.
1516
:param max_lr: float. maximum lr.
1617
:param min_lr: float. minimum lr.
@@ -20,7 +21,7 @@ class BaseLinearWarmupScheduler(ABC):
2021

2122
def __init__(
2223
self,
23-
optimizer: OPTIMIZER,
24+
optimizer: Optimizer,
2425
t_max: int,
2526
max_lr: float,
2627
min_lr: float = 0.0,

pytorch_optimizer/base/types.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -11,6 +11,7 @@
1111
PARAMETERS = Optional[Union[Iterable[Dict], Iterable[torch.Tensor]]]
1212
STATE = Dict
1313
OPTIMIZER = Type[Optimizer]
14+
OPTIMIZER_INSTANCE_OR_CLASS = Union[OPTIMIZER, Optimizer]
1415
SCHEDULER = Type[LRScheduler]
1516

1617
HUTCHINSON_G = Literal['gaussian', 'rademacher']

pytorch_optimizer/lr_scheduler/cosine_anealing.py

Lines changed: 3 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -1,10 +1,9 @@
11
import math
22
from typing import List, Optional
33

4+
from torch.optim import Optimizer
45
from torch.optim.lr_scheduler import LRScheduler
56

6-
from pytorch_optimizer.base.types import OPTIMIZER
7-
87

98
class CosineAnnealingWarmupRestarts(LRScheduler):
109
r"""CosineAnnealingWarmupRestarts.
@@ -21,7 +20,7 @@ class CosineAnnealingWarmupRestarts(LRScheduler):
2120

2221
def __init__(
2322
self,
24-
optimizer: OPTIMIZER,
23+
optimizer: Optimizer,
2524
first_cycle_steps: int,
2625
cycle_mult: float = 1.0,
2726
max_lr: float = 1e-4,
@@ -53,7 +52,7 @@ def __init__(
5352

5453
self.init_lr()
5554

56-
def init_lr(self):
55+
def init_lr(self) -> None:
5756
self.base_lrs = []
5857
for param_group in self.optimizer.param_groups:
5958
param_group['lr'] = self.min_lr

pytorch_optimizer/lr_scheduler/rex.py

Lines changed: 3 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -1,9 +1,8 @@
11
from typing import List, Optional
22

3+
from torch.optim import Optimizer
34
from torch.optim.lr_scheduler import LRScheduler
45

5-
from pytorch_optimizer.base.types import OPTIMIZER
6-
76

87
class REXScheduler(LRScheduler):
98
r"""Revisiting Budgeted Training with an Improved Schedule.
@@ -16,7 +15,7 @@ class REXScheduler(LRScheduler):
1615

1716
def __init__(
1817
self,
19-
optimizer: OPTIMIZER,
18+
optimizer: Optimizer,
2019
total_steps: int,
2120
max_lr: float = 1.0,
2221
min_lr: float = 0.0,
@@ -35,7 +34,7 @@ def __init__(
3534

3635
self.init_lr()
3736

38-
def init_lr(self):
37+
def init_lr(self) -> None:
3938
self.base_lrs = []
4039
for param_group in self.optimizer.param_groups:
4140
param_group['lr'] = self.min_lr

0 commit comments

Comments
 (0)