Skip to content

Commit 2a4423d

Browse files
authored
[Feature] Implement AdamC optimizer (#390)
* feature: implement AdamC optimizer * update: AdamC optimizer * docs: AdamC optimizer * docs: AdamC optimizer * update: recipe * docs: v3.6.1 changelog * docs: README * update: test_get_supported_optimizers * update: recipe
1 parent 030d303 commit 2a4423d

File tree

10 files changed

+180
-33
lines changed

10 files changed

+180
-33
lines changed

README.md

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -10,7 +10,7 @@
1010

1111
## The reasons why you use `pytorch-optimizer`.
1212

13-
* Wide range of supported optimizers. Currently, **107 optimizers (+ `bitsandbytes`, `qgalore`, `torchao`)**, **16 lr schedulers**, and **13 loss functions** are supported!
13+
* Wide range of supported optimizers. Currently, **108 optimizers (+ `bitsandbytes`, `qgalore`, `torchao`)**, **16 lr schedulers**, and **13 loss functions** are supported!
1414
* Including many variants such as `ADOPT`, `Cautious`, `AdamD`, `StableAdamW`, and `Gradient Centrailiaztion`
1515
* Easy to use, clean, and tested codes
1616
* Active maintenance
@@ -215,6 +215,7 @@ get_supported_optimizers(['adam*', 'ranger*'])
215215
| RACS & Alice | *Towards Efficient Optimizer Design for LLM via Structured Fisher Approximation with a Low-Rank Extension* | | <https://arxiv.org/pdf/2502.07752> | [cite](https://ui.adsabs.harvard.edu/abs/2025arXiv250207752G/exportcitation) |
216216
| VSGD | *Variational Stochastic Gradient Descent for Deep Neural Networks* | [github](https://github.com/generativeai-tue/vsgd) | <https://openreview.net/forum?id=xu4ATNjcdy> | [cite](https://github.com/generativeai-tue/vsgd/tree/main?tab=readme-ov-file#cite) |
217217
| SNSM | *Subset-Norm and Subspace-Momentum: Faster Memory-Efficient Adaptive Optimization with Convergence Guarantees* | [github](https://github.com/timmytonga/sn-sm) | <https://arxiv.org/abs/2411.07120> | [cite](https://ui.adsabs.harvard.edu/abs/2024arXiv241107120N/exportcitation) |
218+
| AdamC | Why Gradients Rapidly Increase Near the End of Training* | | <https://arxiv.org/abs/2506.02285> | [cite](https://ui.adsabs.harvard.edu/abs/2025arXiv250602285D/exportcitation) |
218219

219220
## Supported LR Scheduler
220221

docs/changelogs/v3.6.1.md

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -5,6 +5,8 @@
55
* Implement more cooldown types for WSD learning rate scheduler. (#382, #386)
66
* Implement `AdamWSN` optimizer. (#387, #389)
77
* [Lean and Mean Adaptive Optimization via Subset-Norm and Subspace-Momentum with Convergence Guarantees](https://arxiv.org/abs/2411.07120)
8+
* Implement `AdamC` optimizer. (#388, #390)
9+
* [Why Gradients Rapidly Increase Near the End of Training](https://arxiv.org/abs/2506.02285)
810

911
### Fix
1012

docs/index.md

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -10,7 +10,7 @@
1010

1111
## The reasons why you use `pytorch-optimizer`.
1212

13-
* Wide range of supported optimizers. Currently, **107 optimizers (+ `bitsandbytes`, `qgalore`, `torchao`)**, **16 lr schedulers**, and **13 loss functions** are supported!
13+
* Wide range of supported optimizers. Currently, **108 optimizers (+ `bitsandbytes`, `qgalore`, `torchao`)**, **16 lr schedulers**, and **13 loss functions** are supported!
1414
* Including many variants such as `ADOPT`, `Cautious`, `AdamD`, `StableAdamW`, and `Gradient Centrailiaztion`
1515
* Easy to use, clean, and tested codes
1616
* Active maintenance
@@ -215,6 +215,7 @@ get_supported_optimizers(['adam*', 'ranger*'])
215215
| RACS & Alice | *Towards Efficient Optimizer Design for LLM via Structured Fisher Approximation with a Low-Rank Extension* | | <https://arxiv.org/pdf/2502.07752> | [cite](https://ui.adsabs.harvard.edu/abs/2025arXiv250207752G/exportcitation) |
216216
| VSGD | *Variational Stochastic Gradient Descent for Deep Neural Networks* | [github](https://github.com/generativeai-tue/vsgd) | <https://openreview.net/forum?id=xu4ATNjcdy> | [cite](https://github.com/generativeai-tue/vsgd/tree/main?tab=readme-ov-file#cite) |
217217
| SNSM | *Subset-Norm and Subspace-Momentum: Faster Memory-Efficient Adaptive Optimization with Convergence Guarantees* | [github](https://github.com/timmytonga/sn-sm) | <https://arxiv.org/abs/2411.07120> | [cite](https://ui.adsabs.harvard.edu/abs/2024arXiv241107120N/exportcitation) |
218+
| AdamC | Why Gradients Rapidly Increase Near the End of Training* | | <https://arxiv.org/abs/2506.02285> | [cite](https://ui.adsabs.harvard.edu/abs/2025arXiv250602285D/exportcitation) |
218219

219220
## Supported LR Scheduler
220221

docs/optimizer.md

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -56,6 +56,10 @@
5656
:docstring:
5757
:members:
5858

59+
::: pytorch_optimizer.AdamC
60+
:docstring:
61+
:members:
62+
5963
::: pytorch_optimizer.AdamG
6064
:docstring:
6165
:members:

pytorch_optimizer/__init__.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -83,6 +83,7 @@
8383
Adalite,
8484
AdaLOMO,
8585
AdaMax,
86+
AdamC,
8687
AdamG,
8788
AdamMini,
8889
AdaMod,

pytorch_optimizer/optimizer/__init__.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -19,6 +19,7 @@
1919
from pytorch_optimizer.optimizer.adalite import Adalite
2020
from pytorch_optimizer.optimizer.adam_mini import AdamMini
2121
from pytorch_optimizer.optimizer.adamax import AdaMax
22+
from pytorch_optimizer.optimizer.adamc import AdamC
2223
from pytorch_optimizer.optimizer.adamg import AdamG
2324
from pytorch_optimizer.optimizer.adamod import AdaMod
2425
from pytorch_optimizer.optimizer.adamp import SGDP, AdamP
@@ -221,6 +222,7 @@ def load_optimizer(optimizer: str) -> OPTIMIZER:
221222
AdaBelief,
222223
AdaBound,
223224
AdamWSN,
225+
AdamC,
224226
PID,
225227
AdamP,
226228
Adai,
Lines changed: 139 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,139 @@
1+
import math
2+
3+
import torch
4+
5+
from pytorch_optimizer.base.exception import NoComplexParameterError, NoSparseGradientError
6+
from pytorch_optimizer.base.optimizer import BaseOptimizer
7+
from pytorch_optimizer.base.type import BETAS, CLOSURE, DEFAULTS, GROUP, LOSS, PARAMETERS
8+
9+
10+
class AdamC(BaseOptimizer):
11+
r"""Why Gradients Rapidly Increase Near the End of Training.
12+
13+
Set `normalized=True` for LayerNorm and BatchNorm layers.
14+
15+
:param params: PARAMETERS. iterable of parameters to optimize or dicts defining parameter groups.
16+
:param lr: float. learning rate.
17+
:param betas: BETAS. coefficients used for computing running averages of gradient and the squared hessian trace.
18+
:param weight_decay: float. weight decay (L2 penalty).
19+
:param weight_decouple: bool. the optimizer uses decoupled weight decay as in AdamW.
20+
:param fixed_decay: bool. fix weight decay.
21+
:param ams_bound: bool. whether to use the AMSBound variant.
22+
:param eps: float. term added to the denominator to improve numerical stability.
23+
:param maximize: bool. maximize the objective with respect to the params, instead of minimizing.
24+
"""
25+
26+
def __init__(
27+
self,
28+
params: PARAMETERS,
29+
lr: float = 1e-3,
30+
betas: BETAS = (0.9, 0.999),
31+
weight_decay: float = 0.0,
32+
weight_decouple: bool = True,
33+
fixed_decay: bool = False,
34+
ams_bound: bool = False,
35+
eps: float = 1e-8,
36+
maximize: bool = False,
37+
**kwargs,
38+
):
39+
self.validate_learning_rate(lr)
40+
self.validate_betas(betas)
41+
self.validate_non_negative(weight_decay, 'weight_decay')
42+
self.validate_non_negative(eps, 'eps')
43+
44+
self.maximize = maximize
45+
self.max_lr: float = lr
46+
47+
defaults: DEFAULTS = {
48+
'lr': lr,
49+
'betas': betas,
50+
'weight_decay': weight_decay,
51+
'weight_decouple': weight_decouple,
52+
'fixed_decay': fixed_decay,
53+
'ams_bound': ams_bound,
54+
'eps': eps,
55+
**kwargs,
56+
}
57+
58+
super().__init__(params, defaults)
59+
60+
def __str__(self) -> str:
61+
return 'AdamC'
62+
63+
def init_group(self, group: GROUP, **kwargs) -> None:
64+
for p in group['params']:
65+
if p.grad is None:
66+
continue
67+
68+
grad = p.grad
69+
if grad.is_sparse:
70+
raise NoSparseGradientError(str(self))
71+
72+
if torch.is_complex(p):
73+
raise NoComplexParameterError(str(self))
74+
75+
state = self.state[p]
76+
77+
if len(state) == 0:
78+
state['exp_avg'] = torch.zeros_like(p)
79+
state['exp_avg_sq'] = torch.zeros_like(p)
80+
81+
if group['ams_bound']:
82+
state['max_exp_avg_sq'] = torch.zeros_like(p)
83+
84+
@torch.no_grad()
85+
def step(self, closure: CLOSURE = None) -> LOSS:
86+
loss: LOSS = None
87+
if closure is not None:
88+
with torch.enable_grad():
89+
loss = closure()
90+
91+
for group in self.param_groups:
92+
if 'step' not in group:
93+
self.init_group(group)
94+
group['step'] = 1
95+
else:
96+
group['step'] += 1
97+
98+
beta1, beta2 = group['betas']
99+
100+
bias_correction1: float = self.debias(beta1, group['step'])
101+
bias_correction2_sq: float = math.sqrt(self.debias(beta2, group['step']))
102+
103+
wd_step_size: float = group['lr'] if not group.get('normalized') else (group['lr'] ** 2) / self.max_lr
104+
105+
for p in group['params']:
106+
if p.grad is None:
107+
continue
108+
109+
grad = p.grad
110+
111+
self.maximize_gradient(grad, maximize=self.maximize)
112+
113+
state = self.state[p]
114+
115+
exp_avg, exp_avg_sq = state['exp_avg'], state['exp_avg_sq']
116+
117+
self.apply_weight_decay(
118+
p=p,
119+
grad=grad,
120+
lr=wd_step_size,
121+
weight_decay=group['weight_decay'],
122+
weight_decouple=group['weight_decouple'],
123+
fixed_decay=group['fixed_decay'],
124+
)
125+
126+
exp_avg.mul_(beta1).add_(grad, alpha=1.0 - beta1)
127+
exp_avg_sq.mul_(beta2).addcmul_(grad, grad, value=1.0 - beta2)
128+
129+
de_nom = self.apply_ams_bound(
130+
ams_bound=group['ams_bound'],
131+
exp_avg_sq=exp_avg_sq,
132+
max_exp_avg_sq=state.get('max_exp_avg_sq', None),
133+
eps=group['eps'],
134+
)
135+
de_nom.div_(bias_correction2_sq)
136+
137+
p.addcdiv_(exp_avg / bias_correction1, de_nom, value=-group['lr'])
138+
139+
return loss

tests/constants.py

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -37,6 +37,7 @@
3737
Adai,
3838
Adalite,
3939
AdaMax,
40+
AdamC,
4041
AdamG,
4142
AdaMod,
4243
AdamP,
@@ -649,6 +650,8 @@
649650
(Alice, {'lr': 1e0, 'rank': 2, 'leading_basis': 1, 'update_interval': 2}, 5),
650651
(VSGD, {'lr': 1e0}, 5),
651652
(AdamWSN, {'lr': 1e0}, 5),
653+
(AdamC, {'lr': 1e0}, 5),
654+
(AdamC, {'lr': 1e0, 'ams_bound': True}, 5),
652655
(Ranger25, {'lr': 1e-1}, 3),
653656
(Ranger25, {'lr': 1e-1, 't_alpha_beta3': 5}, 3),
654657
(Ranger25, {'lr': 5e-2, 'stable_adamw': False, 'orthograd': False, 'eps': None, 'lookahead_merge_time': 2}, 3),

tests/test_load_modules.py

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -34,9 +34,9 @@ def test_load_lr_scheduler_invalid(invalid_lr_scheduler_names):
3434

3535

3636
def test_get_supported_optimizers():
37-
assert len(get_supported_optimizers()) == 105
38-
assert len(get_supported_optimizers('adam*')) == 9
39-
assert len(get_supported_optimizers(['adam*', 'ranger*'])) == 12
37+
assert len(get_supported_optimizers()) == 106
38+
assert len(get_supported_optimizers('adam*')) == 10
39+
assert len(get_supported_optimizers(['adam*', 'ranger*'])) == 13
4040

4141

4242
def test_get_supported_lr_schedulers():

tests/test_optimizers.py

Lines changed: 22 additions & 28 deletions
Original file line numberDiff line numberDiff line change
@@ -26,6 +26,25 @@
2626
)
2727

2828

29+
def build_optimizer_parameter(parameters, optimizer_name, config):
30+
if optimizer_name == 'AliG':
31+
config.update({'projection_fn': lambda: l2_projection(parameters, max_norm=1)})
32+
if optimizer_name == 'Muon':
33+
adamw_params = [p for i, p in enumerate(parameters) if i >= 2]
34+
parameters = [p for i, p in enumerate(parameters) if i < 2]
35+
config.update({'adamw_params': adamw_params})
36+
if optimizer_name == 'AdamWSN':
37+
sn_params = [p for p in parameters if p.ndim == 2]
38+
regular_params = [p for p in parameters if p.ndim != 2]
39+
parameters = [{'params': sn_params, 'sn': True}, {'params': regular_params, 'sn': False}]
40+
if optimizer_name == 'AdamC':
41+
norm_params = [p for i, p in enumerate(parameters) if i == 1]
42+
regular_params = [p for i, p in enumerate(parameters) if i != 1]
43+
parameters = [{'params': norm_params, 'normalized': True}, {'params': regular_params}]
44+
45+
return parameters, config
46+
47+
2948
@pytest.mark.parametrize('optimizer_fp32_config', OPTIMIZERS, ids=ids)
3049
def test_f32_optimizers(optimizer_fp32_config, environment):
3150
def closure(x):
@@ -42,18 +61,7 @@ def _closure() -> float:
4261
x_data, y_data = environment
4362
model, loss_fn = build_model()
4463

45-
parameters = list(model.parameters())
46-
47-
if optimizer_name == 'AliG':
48-
config.update({'projection_fn': lambda: l2_projection(parameters, max_norm=1)})
49-
if optimizer_name == 'Muon':
50-
adamw_params = [p for i, p in enumerate(parameters) if i >= 2]
51-
parameters = [p for i, p in enumerate(parameters) if i < 2]
52-
config.update({'adamw_params': adamw_params})
53-
if optimizer_name == 'AdamWSN':
54-
sn_params = [p for p in parameters if p.ndim == 2]
55-
regular_params = [p for p in parameters if p.ndim != 2]
56-
parameters = [{'params': sn_params, 'sn': True}, {'params': regular_params, 'sn': False}]
64+
parameters, config = build_optimizer_parameter(list(model.parameters()), optimizer_name, config)
5765

5866
optimizer = optimizer_class(parameters, **config)
5967

@@ -93,18 +101,7 @@ def _closure() -> float:
93101
model, loss_fn = build_model()
94102
model = model.bfloat16()
95103

96-
parameters = list(model.parameters())
97-
98-
if optimizer_name == 'AliG':
99-
config.update({'projection_fn': lambda: l2_projection(parameters, max_norm=1)})
100-
elif optimizer_name == 'Muon':
101-
adamw_params = [p for i, p in enumerate(parameters) if i >= 2]
102-
parameters = [p for i, p in enumerate(parameters) if i < 2]
103-
config.update({'adamw_params': adamw_params})
104-
if optimizer_name == 'AdamWSN':
105-
sn_params = [p for p in parameters if p.ndim == 2]
106-
regular_params = [p for p in parameters if p.ndim != 2]
107-
parameters = [{'params': sn_params, 'sn': True}, {'params': regular_params, 'sn': False}]
104+
parameters, config = build_optimizer_parameter(list(model.parameters()), optimizer_name, config)
108105

109106
optimizer = optimizer_class(parameters, **config)
110107

@@ -150,10 +147,7 @@ def _closure() -> float:
150147

151148
x_data = x_data.to(torch.complex64)
152149

153-
parameters = list(model.parameters())
154-
155-
if optimizer_name == 'alig':
156-
config.update({'projection_fn': lambda: l2_projection(parameters, max_norm=1)})
150+
parameters, config = build_optimizer_parameter(list(model.parameters()), optimizer_name, config)
157151

158152
optimizer = optimizer_class(parameters, **config)
159153

0 commit comments

Comments
 (0)