Skip to content

Commit 75c7945

Browse files
authored
[Feature] Implement DistributedMuon optimizer (#418)
* build(deps): docs requirements * feature: implement DistributedMuon optimizer * docs: v3.8.0 changelog * feature: implement StochasticAccumulator * update: DistributedMuon optimizer * docs: DistributedMuon optimizer * docs: v3.8.0 changelog * update: test recipe * refactor: remove unused var * docs: v3.8.0 changelog * update: test_stochastic_accumulation_hook * update: recipes * refactor: exception message format * update: mse to bce
1 parent 6071805 commit 75c7945

File tree

12 files changed

+302
-19
lines changed

12 files changed

+302
-19
lines changed

docs/changelogs/v3.8.0.md

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -7,6 +7,9 @@
77
* You can use this variant by setting `decoupling_c` parameter in the `ScheduleFreeAdamW` optimizer.
88
* Add more built-in optimizers, `NAdam`, `RMSProp`, and `LBFGS` optimizers. (#415)
99
* Support `cautious` variant for `Muon` optimizer. (#417)
10+
* Separate distributed functionality from `Muon` to `DistribtuedMuon` optimizer. (#418)
11+
* Implement `StochasticAccumulator`, which is a gradient hook. (#418)
12+
* [stochastic optimizer](https://github.com/lodestone-rock/torchastic/)
1013

1114
### Update
1215

docs/optimizer.md

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -304,6 +304,10 @@
304304
:docstring:
305305
:members:
306306

307+
::: pytorch_optimizer.DistributedMuon
308+
:docstring:
309+
:members:
310+
307311
::: pytorch_optimizer.AdaMuon
308312
:docstring:
309313
:members:

pytorch_optimizer/__init__.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -116,6 +116,7 @@
116116
DAdaptSGD,
117117
DeMo,
118118
DiffGrad,
119+
DistributedMuon,
119120
DynamicLossScaler,
120121
EmoFact,
121122
EmoLynx,

pytorch_optimizer/base/optimizer.py

Lines changed: 12 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -64,7 +64,7 @@ def set_hessian(param_groups: PARAMETERS, state: STATE, hessian: List[torch.Tens
6464
for p in group['params']:
6565
if p.size() != hessian[i].size():
6666
raise ValueError(
67-
f'[-] the shape of parameter and hessian does not match. {p.size()} vs {hessian[i].size()}'
67+
f'the shape of parameter and hessian does not match. {p.size()} vs {hessian[i].size()}'
6868
)
6969

7070
state[p]['hessian'] = hessian[i]
@@ -312,35 +312,35 @@ def get_stable_adamw_rms(grad: torch.Tensor, exp_avg_sq: torch.Tensor, eps: floa
312312
@staticmethod
313313
def validate_range(x: float, name: str, low: float, high: float, range_type: str = '[)') -> None:
314314
if range_type == '[)' and not low <= x < high:
315-
raise ValueError(f'[-] {name} must be in the range [{low}, {high})')
315+
raise ValueError(f'{name} must be in the range [{low}, {high})')
316316
if range_type == '[]' and not low <= x <= high:
317-
raise ValueError(f'[-] {name} must be in the range [{low}, {high}]')
317+
raise ValueError(f'{name} must be in the range [{low}, {high}]')
318318
if range_type == '(]' and not low < x <= high:
319-
raise ValueError(f'[-] {name} must be in the range ({low}, {high}]')
319+
raise ValueError(f'{name} must be in the range ({low}, {high}]')
320320
if range_type == '()' and not low < x < high:
321-
raise ValueError(f'[-] {name} must be in the range ({low}, {high})')
321+
raise ValueError(f'{name} must be in the range ({low}, {high})')
322322

323323
@staticmethod
324324
def validate_non_negative(x: Optional[float], name: str) -> None:
325325
if x is not None and x < 0.0:
326-
raise ValueError(f'[-] {name} must be non-negative')
326+
raise ValueError(f'{name} must be non-negative')
327327

328328
@staticmethod
329329
def validate_non_positive(x: Optional[float], name: str) -> None:
330330
if x is not None and x > 0.0:
331-
raise ValueError(f'[-] {name} must be non-positive')
331+
raise ValueError(f'{name} must be non-positive')
332332

333333
@staticmethod
334334
def validate_positive(x: Union[float, int], name: str) -> None:
335335
if x <= 0:
336-
raise ValueError(f'[-] {name} must be positive')
336+
raise ValueError(f'{name} must be positive')
337337

338338
@staticmethod
339339
def validate_boundary(constant: float, boundary: float, bound_type: str = 'upper') -> None:
340340
if bound_type == 'upper' and constant > boundary:
341-
raise ValueError(f'[-] constant {constant} must be in a range of (-inf, {boundary}]')
341+
raise ValueError(f'constant {constant} must be in a range of (-inf, {boundary}]')
342342
if bound_type == 'lower' and constant < boundary:
343-
raise ValueError(f'[-] constant {constant} must be in a range of [{boundary}, inf)')
343+
raise ValueError(f'constant {constant} must be in a range of [{boundary}, inf)')
344344

345345
@staticmethod
346346
def validate_step(step: int, step_type: str) -> None:
@@ -351,7 +351,7 @@ def validate_step(step: int, step_type: str) -> None:
351351
def validate_options(x: str, name: str, options: List[str]) -> None:
352352
if x not in options:
353353
opts: str = ' or '.join([f"'{option}'" for option in options]).strip()
354-
raise ValueError(f'[-] {name} {x} must be one of ({opts})')
354+
raise ValueError(f'{name} {x} must be one of ({opts})')
355355

356356
@staticmethod
357357
def validate_learning_rate(learning_rate: Optional[float]) -> None:
@@ -361,7 +361,7 @@ def validate_learning_rate(learning_rate: Optional[float]) -> None:
361361
@staticmethod
362362
def validate_mod(x: int, y: int) -> None:
363363
if x % y != 0:
364-
raise ValueError(f'[-] {x} must be divisible by {y}')
364+
raise ValueError(f'{x} must be divisible by {y}')
365365

366366
def validate_betas(self, betas: BETAS, beta_range_type: str = '[)', beta3_range_type: str = '[]') -> None:
367367
if betas[0] is not None:

pytorch_optimizer/optimizer/__init__.py

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -67,7 +67,7 @@
6767
from pytorch_optimizer.optimizer.madgrad import MADGRAD
6868
from pytorch_optimizer.optimizer.mars import MARS
6969
from pytorch_optimizer.optimizer.msvag import MSVAG
70-
from pytorch_optimizer.optimizer.muon import AdaMuon, Muon, prepare_muon_parameters
70+
from pytorch_optimizer.optimizer.muon import AdaMuon, DistributedMuon, Muon, prepare_muon_parameters
7171
from pytorch_optimizer.optimizer.nero import Nero
7272
from pytorch_optimizer.optimizer.novograd import NovoGrad
7373
from pytorch_optimizer.optimizer.orthograd import OrthoGrad
@@ -164,6 +164,7 @@
164164
DAdaptSGD,
165165
DeMo,
166166
DiffGrad,
167+
DistributedMuon,
167168
EXAdam,
168169
EmoFact,
169170
EmoLynx,

pytorch_optimizer/optimizer/muon.py

Lines changed: 204 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -3,6 +3,7 @@
33

44
import torch
55
from torch import nn
6+
from torch.distributed import all_gather, get_rank, get_world_size
67
from torch.optim import Optimizer
78

89
from pytorch_optimizer.base.exception import NoComplexParameterError, NoSparseGradientError
@@ -216,6 +217,209 @@ def step(self, closure: CLOSURE = None) -> LOSS:
216217
return loss
217218

218219

220+
class DistributedMuon(BaseOptimizer): # pragma: no cover
221+
r"""Distributed Momentum Orthogonalized by Newton-schulz.
222+
223+
Muon internally runs standard SGD-momentum, and then performs an orthogonalization post-processing step, in which
224+
each 2D parameter's update is replaced with the nearest orthogonal matrix. To efficiently orthogonalize each
225+
update, we use a Newton-Schulz iteration, which has the advantage that it can be stably run in bfloat16 on the GPU.
226+
227+
Muon is intended to optimize only the internal ≥2D parameters of a network. Embeddings, classifier heads, and
228+
scalar or vector parameters should be optimized using AdamW.
229+
230+
Some warnings:
231+
- We believe this optimizer is unlikely to work well for training with small batch size.
232+
- We believe it may not work well for fine-tuning pretrained models, but we haven't tested this.
233+
234+
Example:
235+
-------
236+
from pytorch_optimizer import DistributedMuon
237+
238+
hidden_weights = [p for p in model.body.parameters() if p.ndim >= 2]
239+
hidden_gains_biases = [p for p in model.body.parameters() if p.ndim < 2]
240+
non_hidden_params = [*model.head.parameters(), *model.embed.parameters()]
241+
242+
param_groups = [
243+
dict(params=hidden_weights, lr=0.02, weight_decay=0.01, use_muon=True),
244+
dict(
245+
params=hidden_gains_biases + non_hidden_params,
246+
lr=3e-4,
247+
betas=(0.9, 0.95),
248+
weight_decay=0.01,
249+
use_muon=False,
250+
),
251+
]
252+
253+
optimizer = Muon(param_groups)
254+
255+
:param params: PARAMETERS. the parameters to be optimized by Muon.
256+
:param lr: float. learning rate.
257+
:param momentum: float. the momentum used by the internal SGD.
258+
:param weight_decay: float. weight decay (L2 penalty).
259+
:param weight_decouple: bool. the optimizer uses decoupled weight decay as in AdamW.
260+
:param nesterov: bool. whether to use nesterov momentum.
261+
:param ns_steps: int. the number of Newton-Schulz iterations to run. (5 is probably always enough)
262+
:param use_adjusted_lr: bool. whether to use adjusted learning rate, which is from the Moonlight.
263+
reference: https://github.com/MoonshotAI/Moonlight/blob/master/examples/toy_train.py
264+
:param adamw_lr: float. The learning rate for the internal AdamW.
265+
:param adamw_betas: The betas for the internal AdamW.
266+
:param adamw_wd: float. The weight decay for the internal AdamW.
267+
:param adamw_eps: float. The epsilon for the internal AdamW.
268+
:param maximize: bool. maximize the objective with respect to the params, instead of minimizing.
269+
"""
270+
271+
def __init__(
272+
self,
273+
params: PARAMETERS,
274+
lr: float = 2e-2,
275+
momentum: float = 0.95,
276+
weight_decay: float = 0.0,
277+
weight_decouple: bool = True,
278+
nesterov: bool = True,
279+
ns_steps: int = 5,
280+
use_adjusted_lr: bool = False,
281+
adamw_lr: float = 3e-4,
282+
adamw_betas: BETAS = (0.9, 0.95),
283+
adamw_wd: float = 0.0,
284+
adamw_eps: float = 1e-10,
285+
maximize: bool = False,
286+
**kwargs,
287+
):
288+
self.validate_learning_rate(lr)
289+
self.validate_learning_rate(adamw_lr)
290+
self.validate_non_negative(weight_decay, 'weight_decay')
291+
self.validate_range(momentum, 'momentum', 0.0, 1.0, range_type='[)')
292+
self.validate_positive(ns_steps, 'ns_steps')
293+
self.validate_betas(adamw_betas)
294+
self.validate_non_negative(adamw_wd, 'adamw_wd')
295+
self.validate_non_negative(adamw_eps, 'adamw_eps')
296+
297+
self.maximize = maximize
298+
299+
self.world_size: int = get_world_size()
300+
self.rank: int = get_rank()
301+
302+
for group in params:
303+
if 'use_muon' not in group:
304+
raise ValueError('`use_muon` must be set.')
305+
306+
if group['use_muon']:
307+
group['lr'] = group.get('lr', lr)
308+
group['momentum'] = group.get('momentum', momentum)
309+
group['nesterov'] = group.get('nesterov', nesterov)
310+
group['weight_decay'] = group.get('weight_decay', weight_decay)
311+
group['ns_steps'] = group.get('ns_steps', ns_steps)
312+
group['use_adjusted_lr'] = group.get('use_adjusted_lr', use_adjusted_lr)
313+
else:
314+
group['lr'] = group.get('lr', adamw_lr)
315+
group['betas'] = group.get('betas', adamw_betas)
316+
group['eps'] = group.get('eps', adamw_eps)
317+
group['weight_decay'] = group.get('weight_decay', adamw_wd)
318+
319+
group['weight_decouple'] = group.get('weight_decouple', weight_decouple)
320+
321+
super().__init__(params, kwargs)
322+
323+
def __str__(self) -> str:
324+
return 'DistributedMuon'
325+
326+
def init_group(self, group: GROUP, **kwargs) -> None:
327+
for p in group['params']:
328+
if p.grad is None:
329+
p.grad = torch.zeros_like(p)
330+
331+
grad = p.grad
332+
if grad.is_sparse:
333+
raise NoSparseGradientError(str(self))
334+
335+
if torch.is_complex(p):
336+
raise NoComplexParameterError(str(self))
337+
338+
state = self.state[p]
339+
340+
if len(state) == 0 and not group['use_muon']:
341+
state['exp_avg'] = torch.zeros_like(p)
342+
state['exp_avg_sq'] = torch.zeros_like(p)
343+
344+
@torch.no_grad()
345+
def step(self, closure: CLOSURE = None) -> LOSS:
346+
loss: LOSS = None
347+
if closure is not None:
348+
with torch.enable_grad():
349+
loss = closure()
350+
351+
for group in self.param_groups:
352+
if 'step' not in group:
353+
self.init_group(group)
354+
group['step'] = 1
355+
else:
356+
group['step'] += 1
357+
358+
if group['use_muon']:
359+
params = group['params']
360+
padded_params = params + [torch.empty_like(params[-1])] * (
361+
self.world_size - len(params) % self.world_size
362+
)
363+
364+
for i in range(len(params))[:: self.world_size]:
365+
if i + self.rank < len(params):
366+
p = params[i + self.rank]
367+
368+
grad = p.grad
369+
370+
self.maximize_gradient(grad, maximize=self.maximize)
371+
372+
state = self.state[p]
373+
if len(state) == 0:
374+
state['momentum_buffer'] = torch.zeros_like(p)
375+
376+
self.apply_weight_decay(
377+
p,
378+
grad=grad,
379+
lr=group['lr'],
380+
weight_decay=group['weight_decay'],
381+
weight_decouple=group['weight_decouple'],
382+
fixed_decay=False,
383+
)
384+
385+
buf = state['momentum_buffer']
386+
buf.lerp_(grad, weight=1.0 - group['momentum'])
387+
388+
update = grad.lerp_(buf, weight=group['momentum']) if group['nesterov'] else buf
389+
if update.ndim > 2:
390+
update = update.view(len(update), -1)
391+
392+
update = zero_power_via_newton_schulz_5(update, num_steps=group['ns_steps'])
393+
394+
if group.get('cautious'):
395+
self.apply_cautious(update, grad)
396+
397+
lr: float = get_adjusted_lr(group['lr'], p.size(), use_adjusted_lr=group['use_adjusted_lr'])
398+
399+
p.add_(update.reshape(p.shape), alpha=-lr)
400+
401+
all_gather(padded_params[i:i + self.world_size], padded_params[i:i + self.rank]) # fmt: skip
402+
else:
403+
for p in group['params']:
404+
grad = p.grad
405+
406+
exp_avg, exp_avg_sq = state['exp_avg'], state['exp_avg_sq']
407+
408+
beta1, beta2 = group['betas']
409+
410+
bias_correction1: float = self.debias(beta1, group['step'])
411+
bias_correction2_sq: float = math.sqrt(self.debias(beta2, group['step']))
412+
413+
exp_avg.lerp_(grad, weight=1.0 - beta1)
414+
exp_avg_sq.lerp_(grad.square(), weight=1.0 - beta2)
415+
416+
de_nom = exp_avg_sq.sqrt().add_(group['eps']).div_(bias_correction2_sq)
417+
418+
p.addcdiv_(exp_avg / bias_correction1, de_nom, value=-group['lr'])
419+
420+
return loss
421+
422+
219423
class AdaMuon(BaseOptimizer):
220424
r"""Adaptive Muon optimizer.
221425

pytorch_optimizer/optimizer/utils.py

Lines changed: 51 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -157,6 +157,57 @@ def load_state_dict(self, state_dict):
157157
optim.load_state_dict(optim_state_dict)
158158

159159

160+
class StochasticAccumulator:
161+
r"""Stochastic accumulator.
162+
163+
Example:
164+
-------
165+
model = YourModel()
166+
167+
# apply stochastic grad accumulator hooks
168+
StochasticAccumulator.assign_hooks(model)
169+
170+
while True:
171+
loss = model.loss(*your_model_input)
172+
for _ in range(grad_accum_length):
173+
loss.backward()
174+
175+
StochasticAccumulator.reassign_grad_buffer(model)
176+
177+
optimizer.step()
178+
optimizer.zero_grad()
179+
"""
180+
181+
@staticmethod
182+
def stochastic_grad_accum(p: torch.Tensor) -> None:
183+
if hasattr(p, 'acc_grad'):
184+
acc_grad_fp32 = p.acc_grad.clone().to(torch.float32)
185+
acc_grad_fp32.add_(p.grad.to(torch.float32))
186+
187+
copy_stochastic(p.acc_grad, acc_grad_fp32)
188+
189+
del acc_grad_fp32
190+
else:
191+
p.acc_grad = p.grad.clone().to(torch.bfloat16)
192+
193+
del p.grad
194+
195+
@staticmethod
196+
def reassign_grad_buffer(model: nn.Module) -> None:
197+
for _, p in model.named_parameters():
198+
if p.requires_grad and hasattr(p, 'acc_grad'):
199+
p.grad = p.acc_grad
200+
del p.acc_grad
201+
202+
@staticmethod
203+
def assign_hooks(model: nn.Module) -> List:
204+
return [
205+
p.register_post_accumulate_grad_hook(StochasticAccumulator.stochastic_grad_accum)
206+
for _, p in model.named_parameters()
207+
if p.requires_grad
208+
]
209+
210+
160211
def is_valid_parameters(parameters: PARAMETERS) -> bool:
161212
r"""Check where the parameters are valid."""
162213
return isinstance(parameters, (list, tuple)) and len(parameters) > 0 and isinstance(parameters[0], dict)

0 commit comments

Comments
 (0)