Skip to content

Commit c950609

Browse files
authored
Merge pull request #359 from kozistr/feature/optimizers
[Feature] Implement StableSPAM optimizer
2 parents b85065d + 85d3541 commit c950609

File tree

10 files changed

+185
-9
lines changed

10 files changed

+185
-9
lines changed

README.md

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -10,7 +10,7 @@
1010

1111
## The reasons why you use `pytorch-optimizer`.
1212

13-
* Wide range of supported optimizers. Currently, **99 optimizers (+ `bitsandbytes`, `qgalore`, `torchao`)**, **16 lr schedulers**, and **13 loss functions** are supported!
13+
* Wide range of supported optimizers. Currently, **100 optimizers (+ `bitsandbytes`, `qgalore`, `torchao`)**, **16 lr schedulers**, and **13 loss functions** are supported!
1414
* Including many variants such as `ADOPT`, `Cautious`, `AdamD`, `StableAdamW`, and `Gradient Centrailiaztion`
1515
* Easy to use, clean, and tested codes
1616
* Active maintenance
@@ -207,6 +207,8 @@ get_supported_optimizers(['adam*', 'ranger*'])
207207
| GCSAM | *Gradient Centralized Sharpness Aware Minimization* | [github](https://github.com/mhassann22/GCSAM) | <https://arxiv.org/abs/2501.11584> | [cite](https://github.com/mhassann22/GCSAM?tab=readme-ov-file#citation) |
208208
| LookSAM | *Towards Efficient and Scalable Sharpness-Aware Minimization* | [github](https://github.com/rollovd/LookSAM) | <https://arxiv.org/abs/2203.02714> | [cite](https://ui.adsabs.harvard.edu/abs/2022arXiv220302714L/exportcitation) |
209209
| SCION | *Training Deep Learning Models with Norm-Constrained LMOs* | | <https://arxiv.org/abs/2502.07529> | [cite](https://ui.adsabs.harvard.edu/abs/2025arXiv250207529P/exportcitation) |
210+
| COSMOS | *SOAP with Muon* | [github](https://github.com/lliu606/COSMOS) | | |
211+
| StableSPAM | *How to Train in 4-Bit More Stably than 16-Bit Adam | [github](https://github.com/TianjinYellow/StableSPAM) | <https://arxiv.org/abs/2502.17055> | |
210212

211213
## Supported LR Scheduler
212214

docs/changelogs/v3.4.3.md

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,10 @@
11
### Change Log
22

3+
### Feature
4+
5+
* Support `StableSPAM` optimizer. (#358, #359)
6+
* [How to Train in 4-Bit More Stably than 16-Bit Adam](https://arxiv.org/abs/2502.17055?)
7+
38
### Update
49

510
* Update Muon optimizer. (#355, #356)

docs/index.md

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -10,7 +10,7 @@
1010

1111
## The reasons why you use `pytorch-optimizer`.
1212

13-
* Wide range of supported optimizers. Currently, **99 optimizers (+ `bitsandbytes`, `qgalore`, `torchao`)**, **16 lr schedulers**, and **13 loss functions** are supported!
13+
* Wide range of supported optimizers. Currently, **100 optimizers (+ `bitsandbytes`, `qgalore`, `torchao`)**, **16 lr schedulers**, and **13 loss functions** are supported!
1414
* Including many variants such as `ADOPT`, `Cautious`, `AdamD`, `StableAdamW`, and `Gradient Centrailiaztion`
1515
* Easy to use, clean, and tested codes
1616
* Active maintenance
@@ -207,6 +207,8 @@ get_supported_optimizers(['adam*', 'ranger*'])
207207
| GCSAM | *Gradient Centralized Sharpness Aware Minimization* | [github](https://github.com/mhassann22/GCSAM) | <https://arxiv.org/abs/2501.11584> | [cite](https://github.com/mhassann22/GCSAM?tab=readme-ov-file#citation) |
208208
| LookSAM | *Towards Efficient and Scalable Sharpness-Aware Minimization* | [github](https://github.com/rollovd/LookSAM) | <https://arxiv.org/abs/2203.02714> | [cite](https://ui.adsabs.harvard.edu/abs/2022arXiv220302714L/exportcitation) |
209209
| SCION | *Training Deep Learning Models with Norm-Constrained LMOs* | | <https://arxiv.org/abs/2502.07529> | [cite](https://ui.adsabs.harvard.edu/abs/2025arXiv250207529P/exportcitation) |
210+
| COSMOS | *SOAP with Muon* | [github](https://github.com/lliu606/COSMOS) | | |
211+
| StableSPAM | *How to Train in 4-Bit More Stably than 16-Bit Adam | [github](https://github.com/TianjinYellow/StableSPAM) | <https://arxiv.org/abs/2502.17055> | |
210212

211213
## Supported LR Scheduler
212214

docs/optimizer.md

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -392,6 +392,10 @@
392392
:docstring:
393393
:members:
394394

395+
::: pytorch_optimizer.StableSPAM
396+
:docstring:
397+
:members:
398+
395399
::: pytorch_optimizer.SRMM
396400
:docstring:
397401
:members:

pyproject.toml

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -19,9 +19,9 @@ keywords = [
1919
"Muno", "Nero", "NovoGrad", "OrthoGrad", "PAdam", "PCGrad", "PID", "PNM", "Prodigy", "PSGD", "QHAdam", "QHM",
2020
"RAdam", "Ranger", "Ranger21", "RotoGrad", "SAM", "GCSAM", "LookSAM", "ScheduleFreeSGD", "ScheduleFreeAdamW",
2121
"ScheduleFreeRAdam", "SCION", "SGDP", "Shampoo", "ScalableShampoo", "SGDW", "SignSGD", "SM3", "SOAP", "SopihaH",
22-
"SPAM", "SRMM", "StableAdamW", "SWATS", "TAM", "Tiger", "TRAC", "WSAM", "Yogi", "BCE", "BCEFocal", "Focal",
23-
"FocalCosine", "SoftF1", "Dice", "LDAM", "Jaccard", "Bi-Tempered", "Tversky", "FocalTversky", "LovaszHinge",
24-
"bitsandbytes", "WSD", "QGaLore",
22+
"SPAM", "StableSPAM", "SRMM", "StableAdamW", "SWATS", "TAM", "Tiger", "TRAC", "WSAM", "Yogi", "BCE", "BCEFocal",
23+
"Focal", "FocalCosine", "SoftF1", "Dice", "LDAM", "Jaccard", "Bi-Tempered", "Tversky", "FocalTversky",
24+
"LovaszHinge", "bitsandbytes", "WSD", "QGaLore",
2525
]
2626
classifiers = [
2727
"License :: OSI Approved :: Apache Software License",

pytorch_optimizer/__init__.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -144,6 +144,7 @@
144144
SignSGD,
145145
SophiaH,
146146
StableAdamW,
147+
StableSPAM,
147148
Tiger,
148149
Yogi,
149150
agc,

pytorch_optimizer/optimizer/__init__.py

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -88,7 +88,7 @@
8888
from pytorch_optimizer.optimizer.sm3 import SM3
8989
from pytorch_optimizer.optimizer.soap import SOAP
9090
from pytorch_optimizer.optimizer.sophia import SophiaH
91-
from pytorch_optimizer.optimizer.spam import SPAM
91+
from pytorch_optimizer.optimizer.spam import SPAM, StableSPAM
9292
from pytorch_optimizer.optimizer.srmm import SRMM
9393
from pytorch_optimizer.optimizer.swats import SWATS
9494
from pytorch_optimizer.optimizer.tam import TAM, AdaTAM
@@ -302,6 +302,7 @@ def load_optimizer(optimizer: str) -> OPTIMIZER:
302302
Kron,
303303
EXAdam,
304304
SCION,
305+
StableSPAM,
305306
Ranger25,
306307
]
307308
OPTIMIZERS: Dict[str, OPTIMIZER] = {str(optimizer.__name__).lower(): optimizer for optimizer in OPTIMIZER_LIST}

pytorch_optimizer/optimizer/spam.py

Lines changed: 161 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,5 @@
11
import math
2+
from typing import Optional
23

34
import torch
45
from torch.nn import Parameter, ParameterList
@@ -22,7 +23,7 @@ class CosineDecay:
2223
def __init__(self, death_rate: float, t_max: int, eta_min: float = 0.0, last_epoch: int = -1):
2324
self.sgd: Optimizer = SGD(ParameterList([Parameter(torch.zeros(1))]), lr=death_rate)
2425
self.cosine_stepper: LRScheduler = CosineAnnealingLR(self.sgd, t_max + 1, eta_min, last_epoch)
25-
self.T_max = t_max
26+
self.t_max = t_max
2627
self.eta_min = eta_min
2728

2829
def step(self, current_step: int) -> None:
@@ -37,7 +38,7 @@ def get_death_rate(self, current_step: int) -> float:
3738
3839
:param current_step: int. Current step index.
3940
"""
40-
if current_step >= self.T_max:
41+
if current_step >= self.t_max:
4142
return self.eta_min
4243

4344
self.step(current_step)
@@ -267,3 +268,161 @@ def step(self, closure: CLOSURE = None) -> LOSS:
267268
self.warmup = CosineDecay(0.99, self.warmup_epoch)
268269

269270
return loss
271+
272+
273+
class StableSPAM(BaseOptimizer):
274+
r"""How to Train in 4-Bit More Stably than 16-Bit Adam.
275+
276+
:param params: PARAMETERS. iterable of parameters to optimize or dicts defining parameter groups.
277+
:param lr: float. learning rate.
278+
:param betas: BETAS. coefficients used for computing running averages of gradient and the squared hessian trace.
279+
:param gamma1: float.
280+
:param gamma2: float.
281+
:param theta: float.
282+
:param t_max: Optional[int]. total number of steps.
283+
:param eta_min: float. eta_min of CosineDecay.
284+
:param weight_decay: float. weight decay (L2 penalty).
285+
:param update_proj_gap: int. update projection gap.
286+
:param eps: float. term added to the denominator to improve numerical stability.
287+
"""
288+
289+
def __init__(
290+
self,
291+
params: PARAMETERS,
292+
lr: float = 1e-3,
293+
betas: BETAS = (0.9, 0.999),
294+
gamma1: float = 0.7,
295+
gamma2: float = 0.9,
296+
theta: float = 0.999,
297+
t_max: Optional[int] = None,
298+
eta_min: float = 0.5,
299+
weight_decay: float = 0.0,
300+
update_proj_gap: int = 1000,
301+
eps: float = 1e-8,
302+
**kwargs,
303+
):
304+
self.validate_learning_rate(lr)
305+
self.validate_betas(betas)
306+
self.validate_non_negative(weight_decay, 'weight_decay')
307+
self.validate_positive(update_proj_gap, 'update_proj_gap')
308+
self.validate_non_negative(eps, 'eps')
309+
310+
self.gamma1: float = betas[0] if gamma1 == -1.0 else gamma1
311+
self.gamma2: float = gamma2
312+
self.theta: float = theta
313+
self.t_max = t_max
314+
self.update_proj_gap = update_proj_gap
315+
self.warmup = CosineDecay(1.0, t_max, eta_min=eta_min) if t_max is not None else None
316+
317+
self.total_step: int = 0
318+
319+
defaults: DEFAULTS = {'lr': lr, 'betas': betas, 'weight_decay': weight_decay, 'eps': eps, **kwargs}
320+
super().__init__(params, defaults)
321+
322+
def __str__(self) -> str:
323+
return 'StableSPAM'
324+
325+
@torch.no_grad()
326+
def reset(self):
327+
for group in self.param_groups:
328+
group['step'] = 0
329+
for p in group['params']:
330+
state = self.state[p]
331+
332+
state['exp_avg'] = torch.zeros_like(p)
333+
state['exp_avg_sq'] = torch.zeros_like(p)
334+
state['m_norm_t'] = torch.zeros(1, device=p.device, dtype=p.dtype)
335+
state['v_norm_t'] = torch.zeros(1, device=p.device, dtype=p.dtype)
336+
state['m_max_t'] = torch.zeros(1, device=p.device, dtype=p.dtype)
337+
338+
@torch.no_grad()
339+
def step(self, closure: CLOSURE = None) -> LOSS:
340+
loss: LOSS = None
341+
if closure is not None:
342+
with torch.enable_grad():
343+
loss = closure()
344+
345+
self.total_step += 1
346+
scale: float = self.warmup.get_death_rate(self.total_step) if self.warmup is not None else 1.0
347+
348+
for group in self.param_groups:
349+
if 'step' not in group:
350+
group['step'] = 1
351+
else:
352+
group['step'] += 1
353+
354+
beta1, beta2 = group['betas']
355+
beta1 *= scale
356+
357+
bias_correction1: float = self.debias(beta1, group['step'])
358+
bias_correction2: float = self.debias(beta2, group['step'])
359+
bias_correction2_sq: float = math.sqrt(bias_correction2)
360+
361+
step_size: float = group['lr'] / bias_correction1
362+
363+
theta_t: float = 1.0 - self.theta ** group['step']
364+
365+
for p in group['params']:
366+
if p.grad is None:
367+
continue
368+
369+
grad = p.grad
370+
if grad.is_sparse:
371+
raise NoSparseGradientError(str(self))
372+
373+
state = self.state[p]
374+
375+
if 'exp_avg' not in state:
376+
state['exp_avg'] = torch.zeros_like(grad)
377+
state['exp_avg_sq'] = torch.zeros_like(grad)
378+
state['m_norm_t'] = torch.zeros(1, device=grad.device, dtype=grad.dtype)
379+
state['v_norm_t'] = torch.zeros(1, device=grad.device, dtype=grad.dtype)
380+
state['m_max_t'] = torch.zeros(1, device=grad.device, dtype=grad.dtype)
381+
382+
self.apply_weight_decay(
383+
p,
384+
grad=grad,
385+
lr=group['lr'],
386+
weight_decay=group['weight_decay'],
387+
weight_decouple=True,
388+
fixed_decay=False,
389+
)
390+
391+
max_grad = torch.max(grad.abs())
392+
393+
exp_avg, exp_avg_sq, m_max_t = state['exp_avg'], state['exp_avg_sq'], state['m_max_t']
394+
395+
m_max_t.lerp_(max_grad, weight=1.0 - self.theta)
396+
397+
m_max_hat = m_max_t / theta_t
398+
399+
mask = grad.abs() > m_max_hat
400+
if mask.sum() > 0:
401+
grad[mask].div_(max_grad).mul_(m_max_hat)
402+
403+
grad_norm = torch.norm(grad)
404+
405+
m_norm_t, v_norm_t = state['m_norm_t'], state['v_norm_t']
406+
m_norm_t.lerp_(grad_norm, weight=1.0 - self.gamma1 * scale)
407+
v_norm_t.lerp_(grad_norm.pow(2), weight=1.0 - self.gamma2)
408+
409+
m_norm_hat = m_norm_t / (1.0 - (self.gamma1 * scale) ** group['step'])
410+
v_norm_hat = v_norm_t / (1.0 - self.gamma2 ** group['step'])
411+
412+
c_norm_t = m_norm_hat.div_(v_norm_hat.sqrt_().add_(group['eps']))
413+
414+
grad.div_(grad_norm).mul_(c_norm_t)
415+
416+
if self.update_proj_gap > 0 and self.total_step % self.update_proj_gap == 0:
417+
state['exp_avg'] = torch.zeros_like(grad)
418+
state['exp_avg_sq'] = torch.zeros_like(grad)
419+
group['step'] = 1
420+
421+
exp_avg.mul_(beta1).add_(grad, alpha=1.0 - beta1)
422+
exp_avg_sq.mul_(beta2).addcmul_(grad, grad, value=1.0 - beta2)
423+
424+
de_nom = exp_avg_sq.sqrt().div_(bias_correction2_sq).add_(group['eps'])
425+
426+
p.addcdiv_(exp_avg, de_nom, value=-step_size)
427+
428+
return loss

tests/constants.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -88,6 +88,7 @@
8888
SignSGD,
8989
SophiaH,
9090
StableAdamW,
91+
StableSPAM,
9192
Tiger,
9293
Yogi,
9394
)
@@ -557,6 +558,7 @@
557558
(SGDSaI, {'lr': 1e0, 'momentum': 0.0}, 15),
558559
(Grams, {'lr': 1e-1, 'weight_decay': 1e-3}, 5),
559560
(SPAM, {'lr': 1e0, 'weight_decay': 1e-3, 'warmup_epoch': 1, 'grad_accu_steps': 1, 'update_proj_gap': 1}, 5),
561+
(StableSPAM, {'lr': 1e0, 'weight_decay': 1e-3, 'update_proj_gap': 1, 't_max': 5}, 5),
560562
(TAM, {'lr': 1e0, 'weight_decay': 1e-3}, 5),
561563
(AdaTAM, {'lr': 1e-1, 'weight_decay': 1e-3}, 5),
562564
(FOCUS, {'lr': 1e-1, 'weight_decay': 1e-3}, 5),

tests/test_load_modules.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -34,7 +34,7 @@ def test_load_lr_scheduler_invalid(invalid_lr_scheduler_names):
3434

3535

3636
def test_get_supported_optimizers():
37-
assert len(get_supported_optimizers()) == 96
37+
assert len(get_supported_optimizers()) == 97
3838
assert len(get_supported_optimizers('adam*')) == 8
3939
assert len(get_supported_optimizers(['adam*', 'ranger*'])) == 11
4040

0 commit comments

Comments
 (0)