1+ import math
12import os
23from typing import List , Optional
34
1112
1213
1314class Muon (BaseOptimizer ):
14- r"""MomentUm Orthogonalized by Newton-schulz.
15+ r"""Momentum Orthogonalized by Newton-schulz.
1516
1617 Muon internally runs standard SGD-momentum, and then performs an orthogonalization post-processing step, in which
1718 each 2D parameter's update is replaced with the nearest orthogonal matrix. To efficiently orthogonalize each
1819 update, we use a Newton-Schulz iteration, which has the advantage that it can be stably run in bfloat16 on the GPU.
1920
21+ Muon is intended to optimize only the internal ≥2D parameters of a network. Embeddings, classifier heads, and
22+ scalar or vector parameters should be optimized using AdamW.
23+
2024 Some warnings:
2125 - We believe this optimizer is unlikely to work well for training with small batch size.
2226 - We believe it may not work well for fine-tuning pretrained models, but we haven't tested this.
2327
2428 :param params: PARAMETERS. the parameters to be optimized by Muon.
2529 :param lr: float. learning rate.
2630 :param momentum: float. the momentum used by the internal SGD.
31+ :param weight_decay: float. weight decay (L2 penalty).
32+ :param weight_decouple: bool. the optimizer uses decoupled weight decay as in AdamW.
2733 :param betas: The betas for the internal AdamW.
2834 :param nesterov: bool. whether to use nesterov momentum.
29- :param ns_steps: int. the number of Newton-Schulz iterations to run. (6 is probably always enough)
30- :param adamw_params: The parameters to be optimized by AdamW. Any parameters in `muon_params` which are {0, 1}-D or
31- are detected as being the embed or lm_head will be optimized by AdamW as well.
32- :param adamw_lr: The learning rate for the internal AdamW.
33- :param adamw_wd: The weight decay for the internal AdamW.
34- :param adamw_eps: The epsilon for the internal AdamW.
35+ :param ns_steps: int. the number of Newton-Schulz iterations to run. (5 is probably always enough)
36+ :param use_adjusted_lr: bool. whether to use adjusted learning rate, which is from the Moonlight.
37+ reference: https://github.com/MoonshotAI/Moonlight/blob/master/examples/toy_train.py
38+ :param adamw_params: Optional[PARAMETERS] The parameters to be optimized by AdamW. Any parameters in `muon_params`
39+ which are {0, 1}-D or are detected as being the embed or lm_head will be optimized by AdamW as well. It'd be
40+ better to create AdamW optimizer instead of using this.
41+ :param adamw_lr: float. The learning rate for the internal AdamW.
42+ :param adamw_wd: float. The weight decay for the internal AdamW.
43+ :param adamw_eps: float. The epsilon for the internal AdamW.
3544 """
3645
3746 def __init__ (
3847 self ,
3948 params : PARAMETERS ,
4049 lr : float = 2e-2 ,
4150 momentum : float = 0.95 ,
42- betas : BETAS = (0.95 , 0.95 ),
51+ weight_decay : float = 1e-2 ,
52+ weight_decouple : bool = True ,
53+ betas : BETAS = (0.9 , 0.95 ),
4354 nesterov : bool = True ,
44- ns_steps : int = 6 ,
55+ ns_steps : int = 5 ,
56+ use_adjusted_lr : bool = False ,
4557 adamw_params : Optional [PARAMETERS ] = None ,
4658 adamw_lr : float = 3e-4 ,
47- adamw_wd : float = 0 ,
59+ adamw_wd : float = 0.0 ,
4860 adamw_eps : float = 1e-8 ,
4961 ** kwargs ,
5062 ):
5163 self .validate_learning_rate (lr )
5264 self .validate_learning_rate (adamw_lr )
65+ self .validate_non_negative (weight_decay , 'weight_decay' )
5366 self .validate_range (momentum , 'momentum' , 0.0 , 1.0 , range_type = '[)' )
5467 self .validate_positive (ns_steps , 'ns_steps' )
5568 self .validate_betas (betas )
@@ -66,8 +79,11 @@ def __init__(
6679 defaults : DEFAULTS = {
6780 'lr' : lr ,
6881 'momentum' : momentum ,
82+ 'weight_decay' : weight_decay ,
83+ 'weight_decouple' : weight_decouple ,
6984 'nesterov' : nesterov ,
7085 'ns_steps' : ns_steps ,
86+ 'use_adjusted_lr' : use_adjusted_lr ,
7187 'adamw_lr' : adamw_lr ,
7288 'adamw_lr_ratio' : adamw_lr / lr ,
7389 'adamw_betas' : betas ,
@@ -114,6 +130,11 @@ def reset(self):
114130 state ['moment1' ] = torch .zeros_like (p )
115131 state ['moment2' ] = torch .zeros_like (p )
116132
133+ @staticmethod
134+ def adjust_lr_for_muon (lr : float , param_shape ) -> float :
135+ adjusted_ratio : float = 0.2 * math .sqrt (max (param_shape [0 ], param_shape [1 ]))
136+ return lr * adjusted_ratio
137+
117138 @torch .no_grad ()
118139 def step (self , closure : CLOSURE = None ) -> LOSS :
119140 loss : LOSS = None
@@ -137,7 +158,6 @@ def step(self, closure: CLOSURE = None) -> LOSS:
137158 if len (params ) == 0 :
138159 continue
139160
140- lr = group ['lr' ]
141161 momentum = group ['momentum' ]
142162
143163 total_params : int = sum (p .numel () for p in params )
@@ -149,34 +169,42 @@ def step(self, closure: CLOSURE = None) -> LOSS:
149169 curr_idx += p .numel ()
150170 continue
151171
152- g = p .grad
153- if g .ndim > 2 :
154- g = g .view (g .size (0 ), - 1 )
172+ grad = p .grad
173+ if grad .ndim > 2 :
174+ grad = grad .view (grad .size (0 ), - 1 )
155175
156176 state = self .state [p ]
157177 if 'momentum_buffer' not in state :
158- state ['momentum_buffer' ] = torch .zeros_like (g )
178+ state ['momentum_buffer' ] = torch .zeros_like (grad )
159179
160180 buf = state ['momentum_buffer' ]
161- buf .mul_ ( momentum ). add_ ( g )
181+ buf .lerp_ ( grad , weight = 1.0 - momentum )
162182
163- if group ['nesterov' ]:
164- g .add_ (buf , alpha = momentum )
165- else :
166- g = buf
183+ grad = grad .lerp_ (buf , momentum ) if group ['nesterov' ] else buf
167184
168- g = zero_power_via_newton_schulz_5 (g , num_steps = group ['ns_steps' ])
169- g .mul_ (max (1.0 , g .size (0 ) / g .size (1 )) ** 0.5 )
185+ grad = zero_power_via_newton_schulz_5 (grad , num_steps = group ['ns_steps' ]).flatten ()
170186
171- updates_flat [curr_idx :curr_idx + p .numel ()] = g . flatten () # fmt: skip
187+ updates_flat [curr_idx :curr_idx + p .numel ()] = grad # fmt: skip
172188
173189 if self .world_size > 1 : # pragma: no cover
174190 all_reduce (updates_flat , op = ReduceOp .SUM )
175191
176192 curr_idx : int = 0
177193 for p in params :
178- g = updates_flat [curr_idx :curr_idx + p .numel ()].view_as (p ).type_as (p ) # fmt: skip
179- p .add_ (g , alpha = - lr )
194+ g = updates_flat [curr_idx :curr_idx + p .numel ()].view_as (p ) # fmt: skip
195+
196+ self .apply_weight_decay (
197+ p ,
198+ grad = g ,
199+ lr = group ['lr' ],
200+ weight_decay = group ['weight_decay' ],
201+ weight_decouple = group ['weight_decouple' ],
202+ fixed_decay = False ,
203+ )
204+
205+ lr : float = self .adjust_lr_for_muon (group ['lr' ], p .size ()) if group ['use_adjusted_lr' ] else group ['lr' ]
206+
207+ p .add_ (g , alpha = - lr * (max (1.0 , p .size (- 2 ) / p .size (- 1 )) ** 0.5 ))
180208 curr_idx += p .numel ()
181209
182210 params = [p for p in group ['params' ] if p .grad is not None and not self .state [p ]['use_muon' ]]
0 commit comments