Skip to content

Commit b85065d

Browse files
authored
Merge pull request #356 from kozistr/update/muon-optimizer
[Update] Muon optimizer
2 parents 3d20627 + 2cf72e5 commit b85065d

File tree

9 files changed

+65
-43
lines changed

9 files changed

+65
-43
lines changed

docs/changelogs/v3.4.3.md

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,11 @@
11
### Change Log
22

3+
### Update
4+
5+
* Update Muon optimizer. (#355, #356)
6+
* support decoupled weight decay.
7+
* adjust default hyperparameters same with the original implementation.
8+
* support adjusted lr from the Moonlight. you can use it by setting `use_adjusted_lr=True`.
39

410
### Fix
511

docs/visualization.md

Lines changed: 0 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -234,10 +234,6 @@
234234

235235
![image](https://raw.githubusercontent.com/kozistr/pytorch_optimizer/main/docs/visualizations/rastrigin_MSVAG.png)
236236

237-
### Muon
238-
239-
![image](https://raw.githubusercontent.com/kozistr/pytorch_optimizer/main/docs/visualizations/rastrigin_Muon.png)
240-
241237
### Nero
242238

243239
![image](https://raw.githubusercontent.com/kozistr/pytorch_optimizer/main/docs/visualizations/rastrigin_Nero.png)
@@ -604,10 +600,6 @@
604600

605601
![image](https://raw.githubusercontent.com/kozistr/pytorch_optimizer/main/docs/visualizations/rosenbrock_MSVAG.png)
606602

607-
### Muon
608-
609-
![image](https://raw.githubusercontent.com/kozistr/pytorch_optimizer/main/docs/visualizations/rosenbrock_Muon.png)
610-
611603
### Nero
612604

613605
![image](https://raw.githubusercontent.com/kozistr/pytorch_optimizer/main/docs/visualizations/rosenbrock_Nero.png)
-632 KB
Binary file not shown.
-132 KB
Binary file not shown.

examples/visualize_optimizers.py

Lines changed: 1 addition & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -16,7 +16,7 @@
1616

1717
filterwarnings('ignore', category=UserWarning)
1818

19-
OPTIMIZERS_IGNORE = ('lomo', 'adalomo', 'demo', 'a2grad')
19+
OPTIMIZERS_IGNORE = ('lomo', 'adalomo', 'demo', 'a2grad', 'muon')
2020
OPTIMIZERS_MODEL_INPUT_NEEDED = ('lomo', 'adalomo', 'adammini')
2121
OPTIMIZERS_GRAPH_NEEDED = ('adahessian', 'sophiah')
2222
OPTIMIZERS_CLOSURE_NEEDED = ('alig', 'bsam')
@@ -93,10 +93,6 @@
9393
'lr': hp.uniform('lr', 0, 0.8),
9494
'momentum': hp.quniform('momentum', 0, 0.99, 0.01),
9595
},
96-
'muon': {
97-
'lr': hp.uniform('lr', 0, 0.8),
98-
'momentum': hp.quniform('momentum', 0, 0.99, 0.01),
99-
},
10096
}
10197

10298

pytorch_optimizer/optimizer/muon.py

Lines changed: 53 additions & 25 deletions
Original file line numberDiff line numberDiff line change
@@ -1,3 +1,4 @@
1+
import math
12
import os
23
from typing import List, Optional
34

@@ -11,45 +12,57 @@
1112

1213

1314
class Muon(BaseOptimizer):
14-
r"""MomentUm Orthogonalized by Newton-schulz.
15+
r"""Momentum Orthogonalized by Newton-schulz.
1516
1617
Muon internally runs standard SGD-momentum, and then performs an orthogonalization post-processing step, in which
1718
each 2D parameter's update is replaced with the nearest orthogonal matrix. To efficiently orthogonalize each
1819
update, we use a Newton-Schulz iteration, which has the advantage that it can be stably run in bfloat16 on the GPU.
1920
21+
Muon is intended to optimize only the internal ≥2D parameters of a network. Embeddings, classifier heads, and
22+
scalar or vector parameters should be optimized using AdamW.
23+
2024
Some warnings:
2125
- We believe this optimizer is unlikely to work well for training with small batch size.
2226
- We believe it may not work well for fine-tuning pretrained models, but we haven't tested this.
2327
2428
:param params: PARAMETERS. the parameters to be optimized by Muon.
2529
:param lr: float. learning rate.
2630
:param momentum: float. the momentum used by the internal SGD.
31+
:param weight_decay: float. weight decay (L2 penalty).
32+
:param weight_decouple: bool. the optimizer uses decoupled weight decay as in AdamW.
2733
:param betas: The betas for the internal AdamW.
2834
:param nesterov: bool. whether to use nesterov momentum.
29-
:param ns_steps: int. the number of Newton-Schulz iterations to run. (6 is probably always enough)
30-
:param adamw_params: The parameters to be optimized by AdamW. Any parameters in `muon_params` which are {0, 1}-D or
31-
are detected as being the embed or lm_head will be optimized by AdamW as well.
32-
:param adamw_lr: The learning rate for the internal AdamW.
33-
:param adamw_wd: The weight decay for the internal AdamW.
34-
:param adamw_eps: The epsilon for the internal AdamW.
35+
:param ns_steps: int. the number of Newton-Schulz iterations to run. (5 is probably always enough)
36+
:param use_adjusted_lr: bool. whether to use adjusted learning rate, which is from the Moonlight.
37+
reference: https://github.com/MoonshotAI/Moonlight/blob/master/examples/toy_train.py
38+
:param adamw_params: Optional[PARAMETERS] The parameters to be optimized by AdamW. Any parameters in `muon_params`
39+
which are {0, 1}-D or are detected as being the embed or lm_head will be optimized by AdamW as well. It'd be
40+
better to create AdamW optimizer instead of using this.
41+
:param adamw_lr: float. The learning rate for the internal AdamW.
42+
:param adamw_wd: float. The weight decay for the internal AdamW.
43+
:param adamw_eps: float. The epsilon for the internal AdamW.
3544
"""
3645

3746
def __init__(
3847
self,
3948
params: PARAMETERS,
4049
lr: float = 2e-2,
4150
momentum: float = 0.95,
42-
betas: BETAS = (0.95, 0.95),
51+
weight_decay: float = 1e-2,
52+
weight_decouple: bool = True,
53+
betas: BETAS = (0.9, 0.95),
4354
nesterov: bool = True,
44-
ns_steps: int = 6,
55+
ns_steps: int = 5,
56+
use_adjusted_lr: bool = False,
4557
adamw_params: Optional[PARAMETERS] = None,
4658
adamw_lr: float = 3e-4,
47-
adamw_wd: float = 0,
59+
adamw_wd: float = 0.0,
4860
adamw_eps: float = 1e-8,
4961
**kwargs,
5062
):
5163
self.validate_learning_rate(lr)
5264
self.validate_learning_rate(adamw_lr)
65+
self.validate_non_negative(weight_decay, 'weight_decay')
5366
self.validate_range(momentum, 'momentum', 0.0, 1.0, range_type='[)')
5467
self.validate_positive(ns_steps, 'ns_steps')
5568
self.validate_betas(betas)
@@ -66,8 +79,11 @@ def __init__(
6679
defaults: DEFAULTS = {
6780
'lr': lr,
6881
'momentum': momentum,
82+
'weight_decay': weight_decay,
83+
'weight_decouple': weight_decouple,
6984
'nesterov': nesterov,
7085
'ns_steps': ns_steps,
86+
'use_adjusted_lr': use_adjusted_lr,
7187
'adamw_lr': adamw_lr,
7288
'adamw_lr_ratio': adamw_lr / lr,
7389
'adamw_betas': betas,
@@ -114,6 +130,11 @@ def reset(self):
114130
state['moment1'] = torch.zeros_like(p)
115131
state['moment2'] = torch.zeros_like(p)
116132

133+
@staticmethod
134+
def adjust_lr_for_muon(lr: float, param_shape) -> float:
135+
adjusted_ratio: float = 0.2 * math.sqrt(max(param_shape[0], param_shape[1]))
136+
return lr * adjusted_ratio
137+
117138
@torch.no_grad()
118139
def step(self, closure: CLOSURE = None) -> LOSS:
119140
loss: LOSS = None
@@ -137,7 +158,6 @@ def step(self, closure: CLOSURE = None) -> LOSS:
137158
if len(params) == 0:
138159
continue
139160

140-
lr = group['lr']
141161
momentum = group['momentum']
142162

143163
total_params: int = sum(p.numel() for p in params)
@@ -149,34 +169,42 @@ def step(self, closure: CLOSURE = None) -> LOSS:
149169
curr_idx += p.numel()
150170
continue
151171

152-
g = p.grad
153-
if g.ndim > 2:
154-
g = g.view(g.size(0), -1)
172+
grad = p.grad
173+
if grad.ndim > 2:
174+
grad = grad.view(grad.size(0), -1)
155175

156176
state = self.state[p]
157177
if 'momentum_buffer' not in state:
158-
state['momentum_buffer'] = torch.zeros_like(g)
178+
state['momentum_buffer'] = torch.zeros_like(grad)
159179

160180
buf = state['momentum_buffer']
161-
buf.mul_(momentum).add_(g)
181+
buf.lerp_(grad, weight=1.0 - momentum)
162182

163-
if group['nesterov']:
164-
g.add_(buf, alpha=momentum)
165-
else:
166-
g = buf
183+
grad = grad.lerp_(buf, momentum) if group['nesterov'] else buf
167184

168-
g = zero_power_via_newton_schulz_5(g, num_steps=group['ns_steps'])
169-
g.mul_(max(1.0, g.size(0) / g.size(1)) ** 0.5)
185+
grad = zero_power_via_newton_schulz_5(grad, num_steps=group['ns_steps']).flatten()
170186

171-
updates_flat[curr_idx:curr_idx + p.numel()] = g.flatten() # fmt: skip
187+
updates_flat[curr_idx:curr_idx + p.numel()] = grad # fmt: skip
172188

173189
if self.world_size > 1: # pragma: no cover
174190
all_reduce(updates_flat, op=ReduceOp.SUM)
175191

176192
curr_idx: int = 0
177193
for p in params:
178-
g = updates_flat[curr_idx:curr_idx + p.numel()].view_as(p).type_as(p) # fmt: skip
179-
p.add_(g, alpha=-lr)
194+
g = updates_flat[curr_idx:curr_idx + p.numel()].view_as(p) # fmt: skip
195+
196+
self.apply_weight_decay(
197+
p,
198+
grad=g,
199+
lr=group['lr'],
200+
weight_decay=group['weight_decay'],
201+
weight_decouple=group['weight_decouple'],
202+
fixed_decay=False,
203+
)
204+
205+
lr: float = self.adjust_lr_for_muon(group['lr'], p.size()) if group['use_adjusted_lr'] else group['lr']
206+
207+
p.add_(g, alpha=-lr * (max(1.0, p.size(-2) / p.size(-1)) ** 0.5))
180208
curr_idx += p.numel()
181209

182210
params = [p for p in group['params'] if p.grad is not None and not self.state[p]['use_muon']]

pytorch_optimizer/optimizer/shampoo_utils.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -528,7 +528,7 @@ def merge_small_dims(shape_to_merge: Union[List[int], torch.Size], max_dim: int)
528528

529529

530530
def zero_power_via_newton_schulz_5(
531-
g: torch.Tensor, num_steps: int = 10, eps: float = 1e-7, weights: Tuple[int, int, int] = (3.4445, -4.7750, 2.0315)
531+
g: torch.Tensor, num_steps: int = 5, eps: float = 1e-7, weights: Tuple[int, int, int] = (3.4445, -4.7750, 2.0315)
532532
) -> torch.Tensor:
533533
r"""Compute the zeroth power / orthogonalization of G.
534534

tests/constants.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -518,8 +518,8 @@
518518
),
519519
(ADOPT, {'lr': 1e0}, 5),
520520
(FTRL, {'lr': 1e0, 'beta': 0.0, 'lambda_1': 0.0, 'lambda_2': 0.0}, 5),
521-
(Muon, {'lr': 1e0, 'ns_steps': 6, 'adam_lr': 1e0, 'adamw_wd': 1e-2}, 5),
522-
(Muon, {'lr': 1e0, 'ns_steps': 6, 'adam_lr': 1e0, 'adamw_wd': 1e-2, 'nesterov': False}, 5),
521+
(Muon, {'lr': 5e0, 'use_adjusted_lr': True, 'adam_lr': 1e0, 'adamw_wd': 1e-2}, 5),
522+
(Muon, {'lr': 1e0, 'adam_lr': 1e0, 'adamw_wd': 1e-2, 'nesterov': False}, 5),
523523
(LaProp, {'lr': 1e0, 'weight_decay': 1e-3}, 5),
524524
(LaProp, {'lr': 1e0, 'centered': True, 'weight_decay': 1e-3}, 11),
525525
(LaProp, {'lr': 1e0, 'ams_bound': True, 'weight_decay': 1e-3}, 5),

tests/test_optimizers.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -924,15 +924,15 @@ def test_muon_rank(rank):
924924
model = nn.Sequential(
925925
nn.Conv1d(1, 1, 1),
926926
nn.Conv1d(1, 1, 1),
927-
nn.Conv1d(1, 1, 1),
927+
nn.Conv2d(1, 1, (2, 2)),
928928
)
929929

930930
optimizer = Muon(model.parameters())
931931
optimizer.zero_grad()
932932

933933
model[0].weight.grad = torch.randn(1, 1, 1)
934934
model[1].weight.grad = torch.randn(1, 1, 1)
935-
model[2].weight.grad = torch.randn(1, 1, 1)
935+
model[2].weight.grad = torch.randn(1, 1, 2, 2)
936936

937937
optimizer.step()
938938

0 commit comments

Comments
 (0)