Skip to content

Commit a980dc0

Browse files
authored
Merge pull request #302 from kozistr/feature/muon-optimizer
[Feature] Implement Muon optimizer
2 parents c341872 + ecaf786 commit a980dc0

File tree

15 files changed

+331
-13
lines changed

15 files changed

+331
-13
lines changed

README.md

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -10,7 +10,7 @@
1010

1111
**pytorch-optimizer** is optimizer & lr scheduler collections in PyTorch.
1212
I just re-implemented (speed & memory tweaks, plug-ins) the algorithm while based on the original paper. Also, It includes useful and practical optimization ideas.
13-
Currently, **80 optimizers (+ `bitsandbytes`, `qgalore`, `torchao`)**, **16 lr schedulers**, and **13 loss functions** are supported!
13+
Currently, **81 optimizers (+ `bitsandbytes`, `qgalore`, `torchao`)**, **16 lr schedulers**, and **13 loss functions** are supported!
1414

1515
Highly inspired by [pytorch-optimizer](https://github.com/jettify/pytorch-optimizer).
1616

@@ -186,6 +186,7 @@ get_supported_optimizers(['adam*', 'ranger*'])
186186
| Cautious | *Improving Training with One Line of Code* | [github](https://github.com/kyleliang919/C-Optim) | <https://arxiv.org/pdf/2411.16085v1> | [cite](https://github.com/kyleliang919/C-Optim?tab=readme-ov-file#citation) |
187187
| DeMo | *Decoupled Momentum Optimization* | [github](https://github.com/bloc97/DeMo) | <https://arxiv.org/abs/2411.19870> | [cite](https://ui.adsabs.harvard.edu/abs/2024arXiv241119870P/exportcitation) |
188188
| MicroAdam | *Accurate Adaptive Optimization with Low Space Overhead and Provable Convergence* | [github](https://github.com/IST-DASLab/MicroAdam) | <https://arxiv.org/abs/2405.15593> | [cite](https://github.com/IST-DASLab/MicroAdam?tab=readme-ov-file#citing) |
189+
| Muon | *MomentUm Orthogonalized by Newton-schulz* | [github](https://github.com/KellerJordan/Muon) | <https://x.com/kellerjordan0/status/1842300916864844014> | [cite](https://github.com/KellerJordan/Muon) |
189190

190191
## Supported LR Scheduler
191192

docs/changelogs/v3.3.1.md

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -4,3 +4,5 @@
44

55
* Implement `DeMo` optimizer. (#300, #301)
66
* [Decoupled Momentum Optimization](https://arxiv.org/abs/2411.19870)
7+
* Implement `Muon` optimizer. (#302)
8+
* [MomentUm Orthogonalized by Newton-schulz](https://github.com/KellerJordan/Muon)

docs/index.md

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -10,7 +10,7 @@
1010

1111
**pytorch-optimizer** is optimizer & lr scheduler collections in PyTorch.
1212
I just re-implemented (speed & memory tweaks, plug-ins) the algorithm while based on the original paper. Also, It includes useful and practical optimization ideas.
13-
Currently, **80 optimizers (+ `bitsandbytes`, `qgalore`, `torchao`)**, **16 lr schedulers**, and **13 loss functions** are supported!
13+
Currently, **81 optimizers (+ `bitsandbytes`, `qgalore`, `torchao`)**, **16 lr schedulers**, and **13 loss functions** are supported!
1414

1515
Highly inspired by [pytorch-optimizer](https://github.com/jettify/pytorch-optimizer).
1616

@@ -186,6 +186,7 @@ get_supported_optimizers(['adam*', 'ranger*'])
186186
| Cautious | *Improving Training with One Line of Code* | [github](https://github.com/kyleliang919/C-Optim) | <https://arxiv.org/pdf/2411.16085v1> | [cite](https://github.com/kyleliang919/C-Optim?tab=readme-ov-file#citation) |
187187
| DeMo | *Decoupled Momentum Optimization* | [github](https://github.com/bloc97/DeMo) | <https://arxiv.org/abs/2411.19870> | [cite](https://ui.adsabs.harvard.edu/abs/2024arXiv241119870P/exportcitation) |
188188
| MicroAdam | *Accurate Adaptive Optimization with Low Space Overhead and Provable Convergence* | [github](https://github.com/IST-DASLab/MicroAdam) | <https://arxiv.org/abs/2405.15593> | [cite](https://github.com/IST-DASLab/MicroAdam?tab=readme-ov-file#citing) |
189+
| Muon | *MomentUm Orthogonalized by Newton-schulz* | [github](https://github.com/KellerJordan/Muon) | <https://x.com/kellerjordan0/status/1842300916864844014> | [cite](https://github.com/KellerJordan/Muon) |
189190

190191
## Supported LR Scheduler
191192

docs/optimizer.md

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -228,6 +228,10 @@
228228
:docstring:
229229
:members:
230230

231+
::: pytorch_optimizer.Muon
232+
:docstring:
233+
:members:
234+
231235
::: pytorch_optimizer.Nero
232236
:docstring:
233237
:members:

pyproject.toml

Lines changed: 6 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -15,11 +15,12 @@ keywords = [
1515
"AdaHessian", "Adai", "Adalite", "AdaLomo", "AdamMini", "AdamP", "AdamS", "Adan", "AggMo", "Aida", "AliG", "Amos",
1616
"Apollo", "AvaGrad", "bSAM", "CAME", "DAdaptAdaGrad", "DAdaptAdam", "DAdaptAdan", "DAdaptSGD", "DAdaptLion",
1717
"DeMo", "DiffGrad", "FAdam", "Fromage", "FTRL", "GaLore", "Gravity", "GrokFast", "GSAM", "Kate", "Lamb", "LARS",
18-
"Lion", "LOMO", "Lookahead", "MADGRAD", "MSVAG", "Nero", "NovoGrad", "PAdam", "PCGrad", "PID", "PNM", "Prodigy",
19-
"QHAdam", "QHM", "RAdam", "Ranger", "Ranger21", "RotoGrad", "SAM", "ScheduleFreeSGD", "ScheduleFreeAdamW", "SGDP",
20-
"Shampoo", "ScalableShampoo", "SGDW", "SignSGD", "SM3", "SOAP", "SopihaH", "SRMM", "StableAdamW", "SWATS", "Tiger",
21-
"TRAC", "WSAM", "Yogi", "BCE", "BCEFocal", "Focal", "FocalCosine", "SoftF1", "Dice", "LDAM", "Jaccard",
22-
"Bi-Tempered", "Tversky", "FocalTversky", "LovaszHinge", "bitsandbytes", "WSD", "QGaLore",
18+
"Lion", "LOMO", "Lookahead", "MADGRAD", "MSVAG", "Muno", "Nero", "NovoGrad", "PAdam", "PCGrad", "PID", "PNM",
19+
"Prodigy", "QHAdam", "QHM", "RAdam", "Ranger", "Ranger21", "RotoGrad", "SAM", "ScheduleFreeSGD",
20+
"ScheduleFreeAdamW", "SGDP", "Shampoo", "ScalableShampoo", "SGDW", "SignSGD", "SM3", "SOAP", "SopihaH", "SRMM",
21+
"StableAdamW", "SWATS", "Tiger", "TRAC", "WSAM", "Yogi", "BCE", "BCEFocal", "Focal", "FocalCosine", "SoftF1",
22+
"Dice", "LDAM", "Jaccard", "Bi-Tempered", "Tversky", "FocalTversky", "LovaszHinge", "bitsandbytes", "WSD",
23+
"QGaLore",
2324
]
2425
classifiers = [
2526
"License :: OSI Approved :: Apache Software License",

pytorch_optimizer/__init__.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -109,6 +109,7 @@
109109
Lamb,
110110
Lion,
111111
Lookahead,
112+
Muon,
112113
Nero,
113114
NovoGrad,
114115
PAdam,

pytorch_optimizer/optimizer/__init__.py

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -56,6 +56,7 @@
5656
from pytorch_optimizer.optimizer.lookahead import Lookahead
5757
from pytorch_optimizer.optimizer.madgrad import MADGRAD
5858
from pytorch_optimizer.optimizer.msvag import MSVAG
59+
from pytorch_optimizer.optimizer.muon import Muon
5960
from pytorch_optimizer.optimizer.nero import Nero
6061
from pytorch_optimizer.optimizer.novograd import NovoGrad
6162
from pytorch_optimizer.optimizer.padam import PAdam
@@ -272,6 +273,8 @@ def load_optimizer(optimizer: str) -> OPTIMIZER:
272273
SOAP,
273274
ADOPT,
274275
FTRL,
276+
DeMo,
277+
Muon,
275278
]
276279
OPTIMIZERS: Dict[str, OPTIMIZER] = {str(optimizer.__name__).lower(): optimizer for optimizer in OPTIMIZER_LIST}
277280

pytorch_optimizer/optimizer/demo.py

Lines changed: 4 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -302,11 +302,14 @@ def __init__(
302302
process_group: Optional[ProcessGroup] = None,
303303
**kwargs,
304304
):
305+
self.validate_learning_rate(lr)
305306
self.validate_non_negative(weight_decay, 'weight_decay')
306307
self.validate_range(compression_decay, 'compression_decay', 0.0, 1.0, range_type='[)')
307308
self.validate_positive(compression_top_k, 'compression_top_k')
308309
self.validate_positive(compression_chunk, 'compression_chunk')
309310

311+
self.weight_decay = weight_decay
312+
310313
self.compression_decay = compression_decay
311314
self.compression_top_k = compression_top_k
312315
self.compression_chunk = compression_chunk
@@ -406,7 +409,7 @@ def step(self, closure: CLOSURE = None) -> LOSS:
406409
p,
407410
grad,
408411
lr=lr,
409-
weight_decay=group['weight_decay'],
412+
weight_decay=self.weight_decay,
410413
weight_decouple=True,
411414
fixed_decay=False,
412415
)
Lines changed: 253 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,253 @@
1+
import os
2+
from typing import List, Optional, Tuple
3+
4+
import torch
5+
from torch.distributed import ReduceOp, all_reduce
6+
7+
from pytorch_optimizer.base.exception import NoSparseGradientError
8+
from pytorch_optimizer.base.optimizer import BaseOptimizer
9+
from pytorch_optimizer.base.types import BETAS, CLOSURE, DEFAULTS, LOSS, PARAMETERS
10+
11+
12+
def zero_power_via_newton_schulz_5(
13+
g: torch.Tensor, num_steps: int = 10, eps: float = 1e-7, weights: Tuple[int, int, int] = (3.4445, -4.7750, 2.0315)
14+
) -> torch.Tensor:
15+
r"""Compute the zeroth power / orthogonalization of G.
16+
17+
Newton-Schulz iteration to compute the zeroth power / orthogonalization of G. We opt to use a quintic iteration
18+
whose coefficients are selected to maximize the slope at zero. For the purpose of minimizing steps, it turns out
19+
to be empirically effective to keep increasing the slope at zero even beyond the point where the iteration no
20+
longer converges all the way to one everywhere on the interval. This iteration therefore does not produce UV^T but
21+
rather something like US'V^T where S' is diagonal with S_{ii}' ~ Uniform(0.5, 1.5), which turns out not to hurt
22+
model performance at all relative to UV^T, where USV^T = G is the SVD.
23+
24+
:param g: torch.Tensor. matrix.
25+
:param num_steps: int. number of iterations.
26+
:param eps: float. add this times I to G, to make is positive definite. For scaling, we multiply it by the largest
27+
eigenvalue of G.
28+
:param weights: Tuple[int, int, int]. weights.
29+
"""
30+
if len(g.shape) != 2:
31+
raise ValueError('shape of g must be 2-dimensional')
32+
33+
x = g.bfloat16()
34+
x.div_(x.norm().add_(eps))
35+
36+
if g.size(0) > g.size(1):
37+
x = x.T
38+
39+
for _ in range(num_steps):
40+
a = x @ x.T
41+
b = weights[1] * a + weights[2] * a @ a
42+
x = weights[0] * x + b @ x
43+
44+
if g.size(0) > g.size(1):
45+
x = x.T
46+
47+
return x
48+
49+
50+
class Muon(BaseOptimizer):
51+
r"""MomentUm Orthogonalized by Newton-schulz.
52+
53+
Muon internally runs standard SGD-momentum, and then performs an orthogonalization post-processing step, in which
54+
each 2D parameter's update is replaced with the nearest orthogonal matrix. To efficiently orthogonalize each
55+
update, we use a Newton-Schulz iteration, which has the advantage that it can be stably run in bfloat16 on the GPU.
56+
57+
Some warnings:
58+
- We believe this optimizer is unlikely to work well for training with small batch size.
59+
- We believe it may not work well for fine-tuning pretrained models, but we haven't tested this.
60+
61+
:param params: PARAMETERS. the parameters to be optimized by Muon.
62+
:param lr: float. learning rate.
63+
:param momentum: float. the momentum used by the internal SGD.
64+
:param betas: The betas for the internal AdamW.
65+
:param nesterov: bool. whether to use nesterov momentum.
66+
:param ns_steps: int. the number of Newton-Schulz iterations to run. (6 is probably always enough)
67+
:param adamw_params: The parameters to be optimized by AdamW. Any parameters in `muon_params` which are {0, 1}-D or
68+
are detected as being the embed or lm_head will be optimized by AdamW as well.
69+
:param adamw_lr: The learning rate for the internal AdamW.
70+
:param adamw_wd: The weight decay for the internal AdamW.
71+
:param adamw_eps: The epsilon for the internal AdamW.
72+
"""
73+
74+
def __init__(
75+
self,
76+
params: PARAMETERS,
77+
lr: float = 2e-2,
78+
momentum: float = 0.95,
79+
betas: BETAS = (0.95, 0.95),
80+
nesterov: bool = True,
81+
ns_steps: int = 6,
82+
adamw_params: Optional[PARAMETERS] = None,
83+
adamw_lr: float = 3e-4,
84+
adamw_wd: float = 0,
85+
adamw_eps: float = 1e-8,
86+
**kwargs,
87+
):
88+
self.validate_learning_rate(lr)
89+
self.validate_learning_rate(adamw_lr)
90+
self.validate_range(momentum, 'momentum', 0.0, 1.0, range_type='[)')
91+
self.validate_positive(ns_steps, 'ns_steps')
92+
self.validate_betas(betas)
93+
self.validate_non_negative(adamw_wd, 'adamw_wd')
94+
self.validate_non_negative(adamw_eps, 'adamw_eps')
95+
96+
params = self.get_parameters(params)
97+
adamw_params = self.get_parameters(adamw_params) if adamw_params is not None else []
98+
params.extend(adamw_params)
99+
100+
self.world_size: int = int(os.environ.get('WORLD_SIZE', 1))
101+
self.rank: int = int(os.environ.get('RANK', 0))
102+
103+
defaults: DEFAULTS = {
104+
'lr': lr,
105+
'momentum': momentum,
106+
'nesterov': nesterov,
107+
'ns_steps': ns_steps,
108+
'adamw_lr': adamw_lr,
109+
'adamw_lr_ratio': adamw_lr / lr,
110+
'adamw_betas': betas,
111+
'adamw_wd': adamw_wd,
112+
'adamw_eps': adamw_eps,
113+
}
114+
super().__init__(params, defaults)
115+
116+
self.set_muon_state(params, adamw_params)
117+
118+
def __str__(self) -> str:
119+
return 'Muon'
120+
121+
@staticmethod
122+
def get_parameters(params: PARAMETERS) -> List[torch.Tensor]:
123+
if isinstance(params, list) and isinstance(params[0], torch.Tensor):
124+
return params
125+
126+
new_params = []
127+
for group in params:
128+
if isinstance(group, dict) and 'params' in group:
129+
new_params.extend(list(group['params']))
130+
else:
131+
new_params.append(group)
132+
133+
return new_params
134+
135+
def set_muon_state(self, params: PARAMETERS, adamw_params: PARAMETERS, threshold: int = 8192) -> None:
136+
r"""Set use_muon flag."""
137+
for p in params:
138+
self.state[p]['use_muon'] = p.ndim >= 2 and p.size(0) < threshold
139+
140+
for p in adamw_params:
141+
self.state[p]['use_muon'] = False
142+
143+
@torch.no_grad()
144+
def reset(self):
145+
for group in self.param_groups:
146+
group['step'] = 0
147+
for p in group['params']:
148+
state = self.state[p]
149+
150+
state['momentum_buffer'] = torch.zeros_like(p)
151+
state['moment1'] = torch.zeros_like(p)
152+
state['moment2'] = torch.zeros_like(p)
153+
154+
@torch.no_grad()
155+
def step(self, closure: CLOSURE = None) -> LOSS:
156+
loss: LOSS = None
157+
if closure is not None:
158+
with torch.enable_grad():
159+
loss = closure()
160+
161+
for group in self.param_groups:
162+
if 'step' in group:
163+
group['step'] += 1
164+
else:
165+
group['step'] = 1
166+
167+
params = []
168+
for p in group['params']:
169+
if p.grad is not None and self.state[p]['use_muon']:
170+
if p.grad.is_sparse:
171+
raise NoSparseGradientError(str(self))
172+
params.append(p)
173+
174+
if len(params) == 0:
175+
continue
176+
177+
lr = group['lr']
178+
momentum = group['momentum']
179+
180+
total_params: int = sum(p.numel() for p in params)
181+
updates_flat = torch.zeros(total_params, device=params[0].device, dtype=torch.bfloat16)
182+
curr_idx: int = 0
183+
184+
for i, p in enumerate(params):
185+
if i % self.world_size != self.rank:
186+
curr_idx += p.numel()
187+
continue
188+
189+
g = p.grad
190+
if g.ndim > 2:
191+
g = g.view(g.size(0), -1)
192+
193+
state = self.state[p]
194+
if 'momentum_buffer' not in state:
195+
state['momentum_buffer'] = torch.zeros_like(g)
196+
197+
buf = state['momentum_buffer']
198+
buf.mul_(momentum).add_(g)
199+
200+
if group['nesterov']:
201+
g.add_(buf, alpha=momentum)
202+
else:
203+
g = buf
204+
205+
g = zero_power_via_newton_schulz_5(g, num_steps=group['ns_steps'])
206+
g.mul_(max(1.0, g.size(0) / g.size(1)) ** 0.5)
207+
208+
updates_flat[curr_idx:curr_idx + p.numel()] = g.flatten() # fmt: skip
209+
210+
if self.world_size > 1: # pragma: no cover
211+
all_reduce(updates_flat, op=ReduceOp.SUM)
212+
213+
curr_idx: int = 0
214+
for p in params:
215+
g = updates_flat[curr_idx:curr_idx + p.numel()].view_as(p).type_as(p) # fmt: skip
216+
p.add_(g, alpha=-lr)
217+
curr_idx += p.numel()
218+
219+
params = [p for p in group['params'] if p.grad is not None and not self.state[p]['use_muon']]
220+
221+
lr: float = group['adamw_lr_ratio'] * group['lr']
222+
beta1, beta2 = group['adamw_betas']
223+
224+
bias_correction1: float = self.debias(beta1, group['step'])
225+
bias_correction2: float = self.debias(beta2, group['step'])
226+
scale: float = bias_correction1 / bias_correction2 ** 0.5 # fmt: skip
227+
step_size: float = lr / scale
228+
229+
for p in params:
230+
grad = p.grad
231+
state = self.state[p]
232+
if 'moment1' not in state:
233+
state['moment1'] = torch.zeros_like(grad)
234+
state['moment2'] = torch.zeros_like(grad)
235+
236+
buf1, buf2 = state['moment1'], state['moment2']
237+
buf1.lerp_(grad, weight=1.0 - beta1)
238+
buf2.lerp_(grad.square(), weight=1.0 - beta2)
239+
240+
update = buf1 / buf2.sqrt().add_(group['adamw_eps'])
241+
242+
self.apply_weight_decay(
243+
p,
244+
grad,
245+
lr=lr,
246+
weight_decay=group['adamw_wd'],
247+
weight_decouple=True,
248+
fixed_decay=False,
249+
)
250+
251+
p.add_(update, alpha=-step_size)
252+
253+
return loss

tests/constants.py

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -58,6 +58,7 @@
5858
Kate,
5959
Lamb,
6060
Lion,
61+
Muon,
6162
Nero,
6263
NovoGrad,
6364
PAdam,
@@ -144,6 +145,7 @@
144145
'adamg',
145146
'ademamix',
146147
'soap',
148+
'muon',
147149
]
148150

149151
VALID_LR_SCHEDULER_NAMES: List[str] = [
@@ -495,6 +497,8 @@
495497
),
496498
(ADOPT, {'lr': 1e0}, 5),
497499
(FTRL, {'lr': 1e0, 'beta': 0.0, 'lambda_1': 0.0, 'lambda_2': 0.0}, 5),
500+
(Muon, {'lr': 1e0, 'ns_steps': 6, 'adam_lr': 1e0, 'adamw_wd': 1e-2}, 5),
501+
(Muon, {'lr': 1e0, 'ns_steps': 6, 'adam_lr': 1e0, 'adamw_wd': 1e-2, 'nesterov': False}, 5),
498502
]
499503
ADANORM_SUPPORTED_OPTIMIZERS: List[Tuple[Any, Dict[str, Union[float, bool, int]], int]] = [
500504
(AdaBelief, {'lr': 5e-1, 'weight_decay': 1e-3, 'adanorm': True}, 10),

0 commit comments

Comments
 (0)