Skip to content

Commit 030d303

Browse files
authored
[Feature] Implement AdamWSN optimizer (#389)
* docs: v3.6.1 changelog * feature: AdamWSN optimizer * update: AdamWSN optimizer * update: recipe * docs: README * docs: AdamWSN optimizer * update: AdamWSN optimizer * update: test_csd * update: no cover * style: skip * fix: size to ndim * update: AdamWSN optimizer * fix: test_get_supported_optimizers * docs: README
1 parent 02fc0af commit 030d303

File tree

11 files changed

+435
-223
lines changed

11 files changed

+435
-223
lines changed

README.md

Lines changed: 111 additions & 110 deletions
Large diffs are not rendered by default.

docs/changelogs/v3.6.1.md

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -3,6 +3,8 @@
33
## Feature
44

55
* Implement more cooldown types for WSD learning rate scheduler. (#382, #386)
6+
* Implement `AdamWSN` optimizer. (#387, #389)
7+
* [Lean and Mean Adaptive Optimization via Subset-Norm and Subspace-Momentum with Convergence Guarantees](https://arxiv.org/abs/2411.07120)
68

79
### Fix
810

docs/index.md

Lines changed: 111 additions & 110 deletions
Large diffs are not rendered by default.

docs/optimizer.md

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -72,6 +72,10 @@
7272
:docstring:
7373
:members:
7474

75+
::: pytorch_optimizer.AdamWSN
76+
:docstring:
77+
:members:
78+
7579
::: pytorch_optimizer.Adan
7680
:docstring:
7781
:members:

pytorch_optimizer/__init__.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -89,6 +89,7 @@
8989
AdamP,
9090
AdamS,
9191
AdamW,
92+
AdamWSN,
9293
Adan,
9394
AdaNorm,
9495
AdaPNM,

pytorch_optimizer/optimizer/__init__.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -93,6 +93,7 @@
9393
from pytorch_optimizer.optimizer.sgd import ASGD, SGDW, VSGD, AccSGD, SGDSaI, SignSGD
9494
from pytorch_optimizer.optimizer.shampoo import ScalableShampoo, Shampoo
9595
from pytorch_optimizer.optimizer.sm3 import SM3
96+
from pytorch_optimizer.optimizer.snsm import AdamWSN
9697
from pytorch_optimizer.optimizer.soap import SOAP
9798
from pytorch_optimizer.optimizer.sophia import SophiaH
9899
from pytorch_optimizer.optimizer.spam import SPAM, StableSPAM
@@ -219,6 +220,7 @@ def load_optimizer(optimizer: str) -> OPTIMIZER:
219220
SGD,
220221
AdaBelief,
221222
AdaBound,
223+
AdamWSN,
222224
PID,
223225
AdamP,
224226
Adai,
Lines changed: 182 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,182 @@
1+
import math
2+
3+
import torch
4+
5+
from pytorch_optimizer.base.exception import NoComplexParameterError, NoSparseGradientError
6+
from pytorch_optimizer.base.optimizer import BaseOptimizer
7+
from pytorch_optimizer.base.type import BETAS, CLOSURE, DEFAULTS, GROUP, LOSS, PARAMETERS
8+
9+
10+
def closest_smaller_divisor_of_n_to_k(n: int, k: int) -> int:
11+
r"""Get closest smaller divisor of n to k."""
12+
if n % k == 0:
13+
return k
14+
15+
if n <= 1 or k <= 1:
16+
raise ValueError
17+
18+
for i in range(k, 0, -1):
19+
if n % i == 0:
20+
return i
21+
return -1 # pragma: no cover
22+
23+
24+
class AdamWSN(BaseOptimizer):
25+
r"""Lean and Mean Adaptive Optimization via Subset-Norm and Subspace-Momentum with Convergence Guarantees.
26+
27+
.. code-block:: python
28+
29+
sn_params = [module.weight for module in model.modules() if isinstance(module, nn.Linear)]
30+
sn_param_ids = [id(p) for p in sn_params]
31+
regular_params = [p for p in model.parameters() if id(p) not in sn_param_ids]
32+
param_groups = [{'params': regular_params, 'sn': False}, {'params': sn_params, 'sn': True}]
33+
optimizer = AdamWSN(param_groups, lr=args.lr, weight_decay=args.weight_decay, subset_size=args.subset_size)
34+
35+
:param params: PARAMETERS. iterable of parameters to optimize or dicts defining parameter groups.
36+
:param lr: float. learning rate.
37+
:param betas: BETAS. coefficients used for computing running averages of gradient and the squared hessian trace.
38+
:param weight_decay: float. weight decay (L2 penalty).
39+
:param weight_decouple: bool. the optimizer uses decoupled weight decay as in AdamW.
40+
:param fixed_decay: bool. fix weight decay.
41+
:param subset_size: int. If you do not know what subset_size to set, a good rule of thumb is to set it as d/2 where
42+
d is the hidden dimension of your transformer model. For example, the hidden dimension is 4096 for Llama 7B and
43+
so a good subset_size could be 2048. You can leave the subset_size argument to its default value of -1 to use
44+
the recommended subset size as stated above.
45+
:param eps: float. term added to the denominator to improve numerical stability.
46+
:param maximize: bool. maximize the objective with respect to the params, instead of minimizing.
47+
"""
48+
49+
def __init__(
50+
self,
51+
params: PARAMETERS,
52+
lr: float = 1e-3,
53+
betas: BETAS = (0.9, 0.999),
54+
weight_decay: float = 0.0,
55+
weight_decouple: bool = True,
56+
fixed_decay: bool = False,
57+
subset_size: int = -1,
58+
eps: float = 1e-8,
59+
maximize: bool = False,
60+
**kwargs,
61+
):
62+
self.validate_learning_rate(lr)
63+
self.validate_betas(betas)
64+
self.validate_non_negative(weight_decay, 'weight_decay')
65+
self.validate_non_negative(eps, 'eps')
66+
67+
self.maximize = maximize
68+
69+
defaults: DEFAULTS = {
70+
'lr': lr,
71+
'betas': betas,
72+
'weight_decay': weight_decay,
73+
'weight_decouple': weight_decouple,
74+
'fixed_decay': fixed_decay,
75+
'subset_size': subset_size,
76+
'eps': eps,
77+
**kwargs,
78+
}
79+
80+
super().__init__(params, defaults)
81+
82+
def __str__(self) -> str:
83+
return 'AdamWSN'
84+
85+
def init_group(self, group: GROUP, **kwargs) -> None:
86+
for p in group['params']:
87+
if p.grad is None:
88+
continue
89+
90+
grad = p.grad
91+
if grad.is_sparse:
92+
raise NoSparseGradientError(str(self))
93+
94+
if torch.is_complex(p):
95+
raise NoComplexParameterError(str(self))
96+
97+
state = self.state[p]
98+
99+
if len(state) == 0:
100+
state['exp_avg'] = torch.zeros_like(grad)
101+
102+
if group.get('sn'):
103+
size: int = grad.numel()
104+
105+
if 'subset_size' not in state:
106+
state['subset_size'] = closest_smaller_divisor_of_n_to_k(
107+
size,
108+
(
109+
group['subset_size']
110+
if group['subset_size'] > 0
111+
else int(math.sqrt(size) / abs(int(group['subset_size'])))
112+
),
113+
)
114+
115+
reshaped_grad = grad.view(size // state['subset_size'], state['subset_size'])
116+
second_moment_update = torch.sum(reshaped_grad ** 2, dim=1, keepdim=True) # fmt: skip
117+
state['exp_avg_sq'] = torch.zeros_like(second_moment_update)
118+
else:
119+
state['exp_avg_sq'] = torch.zeros_like(grad)
120+
121+
@torch.no_grad()
122+
def step(self, closure: CLOSURE = None) -> LOSS:
123+
loss: LOSS = None
124+
if closure is not None:
125+
with torch.enable_grad():
126+
loss = closure()
127+
128+
for group in self.param_groups:
129+
if 'step' not in group:
130+
self.init_group(group)
131+
group['step'] = 1
132+
else:
133+
group['step'] += 1
134+
135+
beta1, beta2 = group['betas']
136+
137+
bias_correction1: float = self.debias(beta1, group['step'])
138+
bias_correction2_sq: float = math.sqrt(self.debias(beta2, group['step']))
139+
140+
step_size: float = group['lr'] * bias_correction2_sq / bias_correction1
141+
142+
for p in group['params']:
143+
if p.grad is None:
144+
continue
145+
146+
grad = p.grad
147+
size = grad.numel()
148+
149+
self.maximize_gradient(grad, maximize=self.maximize)
150+
151+
state = self.state[p]
152+
153+
exp_avg, exp_avg_sq = state['exp_avg'], state['exp_avg_sq']
154+
155+
if group.get('sn'):
156+
reshaped_grad = grad.view(size // state['subset_size'], state['subset_size'])
157+
second_moment_update = torch.sum(reshaped_grad ** 2, dim=1, keepdim=True) # fmt: skip
158+
else:
159+
second_moment_update = grad.pow(2)
160+
161+
exp_avg.mul_(beta1).add_(grad, alpha=1.0 - beta1)
162+
exp_avg_sq.mul_(beta2).add_(second_moment_update, alpha=1.0 - beta2)
163+
164+
de_nom = exp_avg_sq.sqrt().add_(group['eps'])
165+
166+
if group.get('sn'):
167+
numerator = exp_avg.view(size // state['subset_size'], state['subset_size'])
168+
norm_grad = (numerator / de_nom).reshape(p.shape)
169+
p.add_(norm_grad, alpha=-step_size)
170+
else:
171+
p.addcdiv_(exp_avg, de_nom, value=-step_size)
172+
173+
self.apply_weight_decay(
174+
p=p,
175+
grad=grad,
176+
lr=group['lr'],
177+
weight_decay=group['weight_decay'],
178+
weight_decouple=group['weight_decouple'],
179+
fixed_decay=group['fixed_decay'],
180+
)
181+
182+
return loss

tests/constants.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -41,6 +41,7 @@
4141
AdaMod,
4242
AdamP,
4343
AdamS,
44+
AdamWSN,
4445
Adan,
4546
AdaNorm,
4647
AdaPNM,
@@ -647,6 +648,7 @@
647648
(RACS, {'lr': 1e0}, 5),
648649
(Alice, {'lr': 1e0, 'rank': 2, 'leading_basis': 1, 'update_interval': 2}, 5),
649650
(VSGD, {'lr': 1e0}, 5),
651+
(AdamWSN, {'lr': 1e0}, 5),
650652
(Ranger25, {'lr': 1e-1}, 3),
651653
(Ranger25, {'lr': 1e-1, 't_alpha_beta3': 5}, 3),
652654
(Ranger25, {'lr': 5e-2, 'stable_adamw': False, 'orthograd': False, 'eps': None, 'lookahead_merge_time': 2}, 3),

tests/test_load_modules.py

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -34,9 +34,9 @@ def test_load_lr_scheduler_invalid(invalid_lr_scheduler_names):
3434

3535

3636
def test_get_supported_optimizers():
37-
assert len(get_supported_optimizers()) == 104
38-
assert len(get_supported_optimizers('adam*')) == 8
39-
assert len(get_supported_optimizers(['adam*', 'ranger*'])) == 11
37+
assert len(get_supported_optimizers()) == 105
38+
assert len(get_supported_optimizers('adam*')) == 9
39+
assert len(get_supported_optimizers(['adam*', 'ranger*'])) == 12
4040

4141

4242
def test_get_supported_lr_schedulers():

tests/test_optimizers.py

Lines changed: 8 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -50,6 +50,10 @@ def _closure() -> float:
5050
adamw_params = [p for i, p in enumerate(parameters) if i >= 2]
5151
parameters = [p for i, p in enumerate(parameters) if i < 2]
5252
config.update({'adamw_params': adamw_params})
53+
if optimizer_name == 'AdamWSN':
54+
sn_params = [p for p in parameters if p.ndim == 2]
55+
regular_params = [p for p in parameters if p.ndim != 2]
56+
parameters = [{'params': sn_params, 'sn': True}, {'params': regular_params, 'sn': False}]
5357

5458
optimizer = optimizer_class(parameters, **config)
5559

@@ -97,6 +101,10 @@ def _closure() -> float:
97101
adamw_params = [p for i, p in enumerate(parameters) if i >= 2]
98102
parameters = [p for i, p in enumerate(parameters) if i < 2]
99103
config.update({'adamw_params': adamw_params})
104+
if optimizer_name == 'AdamWSN':
105+
sn_params = [p for p in parameters if p.ndim == 2]
106+
regular_params = [p for p in parameters if p.ndim != 2]
107+
parameters = [{'params': sn_params, 'sn': True}, {'params': regular_params, 'sn': False}]
100108

101109
optimizer = optimizer_class(parameters, **config)
102110

0 commit comments

Comments
 (0)