3434Dtype Behavior
3535--------------
3636- Newton-Schulz iterations: always bfloat16 (matches official Muon)
37+ - NS output (bfloat16) directly applied to parameters (PyTorch handles mixed precision)
3738- Adam state (exp_avg, exp_avg_sq): always float32 for numerical stability
38- - Muon gradients: cast to parameter dtype before momentum update
39+ - Muon momentum buffer: follows gradient dtype (grad -> buffer -> update)
3940- Adam gradients: cast to float32 for update computation
4041
4142References
7576
7677# Newton-Schulz iteration count
7778NS_STEPS : int = 5
78- # Numerical stability epsilon for norm clamping
79- NS_EPS : float = 1e-7
80- # Adam epsilon for numerical stability
81- ADAM_EPS : float = 1e-7
79+ # Numerical stability epsilon for norm clamping and Adam
80+ EPS : float = 1e-7
8281# Quintic Newton-Schulz polynomial coefficients
8382NS_COEFF_A : float = 3.4445
8483NS_COEFF_B : float = - 4.7750
@@ -118,7 +117,7 @@ def _zeropower_via_newtonschulz5_2d(
118117 X = X .transpose (- 2 , - 1 )
119118
120119 # === Step 2. Normalize Frobenius norm to at most 1 ===
121- X = X / X .norm (dim = (- 2 , - 1 ), keepdim = True ).clamp (min = NS_EPS )
120+ X = X / X .norm (dim = (- 2 , - 1 ), keepdim = True ).clamp (min = EPS )
122121
123122 # === Step 3. Newton-Schulz iterations with fused GEMM ===
124123 for _ in range (NS_STEPS ):
@@ -152,7 +151,7 @@ def _zeropower_via_newtonschulz5_3d(
152151 X = X .transpose (- 2 , - 1 )
153152
154153 # === Step 2. Normalize Frobenius norm to at most 1 ===
155- X = X / X .norm (dim = (- 2 , - 1 ), keepdim = True ).clamp (min = NS_EPS )
154+ X = X / X .norm (dim = (- 2 , - 1 ), keepdim = True ).clamp (min = EPS )
156155
157156 # === Step 3. Newton-Schulz iterations with batched fused GEMM ===
158157 for _ in range (NS_STEPS ):
@@ -270,7 +269,7 @@ class HybridMuonOptimizer(Optimizer):
270269 momentum : float
271270 Momentum coefficient for Muon with default 0.95.
272271 weight_decay : float
273- Weight decay coefficient (applied only to >=2D params ) with default 0.001.
272+ Weight decay coefficient (applied only to Muon-routed parameters ) with default 0.001.
274273 adam_betas : tuple[float, float]
275274 Adam beta coefficients with default (0.9, 0.95).
276275 lr_adjust : float
@@ -287,6 +286,11 @@ class HybridMuonOptimizer(Optimizer):
287286 2. For 2D Adam fallback: learning rate multiplier,
288287 adam_lr_matrix = adam_lr * min(lr_adjust_coeff, 0.1).
289288 The min(., 0.1) cap ensures conservative updates for small matrices.
289+ muon_2d_only : bool
290+ If True, only 2D parameters use Muon (matching PyTorch's torch.optim.Muon).
291+ Parameters with ndim > 2 use Adam without weight decay.
292+ If False, all >=2D parameters use Muon (default behavior).
293+ Default is True.
290294 min_2d_dim : int
291295 Minimum min(m, n) threshold for Muon on 2D matrices.
292296 Matrices with min(m, n) >= min_2d_dim use Muon;
@@ -313,6 +317,7 @@ def __init__(
313317 adam_betas : tuple [float , float ] = (0.9 , 0.95 ),
314318 lr_adjust : float = 10.0 ,
315319 lr_adjust_coeff : float = 0.2 ,
320+ muon_2d_only : bool = True ,
316321 min_2d_dim : int = 1 ,
317322 ) -> None :
318323 if min_2d_dim < 1 :
@@ -325,6 +330,7 @@ def __init__(
325330 "adam_betas" : adam_betas ,
326331 "lr_adjust" : lr_adjust ,
327332 "lr_adjust_coeff" : lr_adjust_coeff ,
333+ "muon_2d_only" : muon_2d_only ,
328334 "min_2d_dim" : min_2d_dim ,
329335 }
330336 super ().__init__ (params , defaults )
@@ -337,9 +343,11 @@ def _build_param_routing(self) -> None:
337343 Classify parameters into Muon and Adam routes (static routing).
338344
339345 Routing logic:
340- - >=2D parameters with min(m, n) >= min_2d_dim → Muon path
341- - 2D parameters with min(m, n) < min_2d_dim → Adam fallback path
342346 - 1D parameters → Adam path
347+ - >2D parameters (when muon_2d_only=True) → Adam path
348+ - 2D parameters with min(m, n) < min_2d_dim → Adam fallback path
349+ - 2D parameters with min(m, n) >= min_2d_dim → Muon path
350+ - >=2D parameters (when muon_2d_only=False) → Muon path
343351 """
344352 if self ._routing_built :
345353 return
@@ -349,14 +357,23 @@ def _build_param_routing(self) -> None:
349357 muon_params : list [dict [str , Any ]] = []
350358 adam_1d : list [dict [str , Any ]] = []
351359 adam_matrix : list [dict [str , Any ]] = []
360+ adam_nd : list [dict [str , Any ]] = []
352361
353362 min_2d_dim = group ["min_2d_dim" ]
363+ muon_2d_only = group ["muon_2d_only" ]
354364
355365 for p in group ["params" ]:
366+ # === Step 1. 1D parameters → Adam ===
356367 if p .ndim < 2 :
357368 adam_1d .append ({"param" : p })
358369 continue
359370
371+ # === Step 2. >2D parameters (when muon_2d_only=True) → Adam ===
372+ if muon_2d_only and p .ndim > 2 :
373+ adam_nd .append ({"param" : p })
374+ continue
375+
376+ # === Step 3. 2D small matrices → Adam fallback ===
360377 if (p .ndim == 2 ) and should_fallback_to_adam_for_matrix (
361378 p , min_2d_dim = min_2d_dim
362379 ):
@@ -368,6 +385,7 @@ def _build_param_routing(self) -> None:
368385 )
369386 continue
370387
388+ # === Step 4. >=2D (or 2D only when muon_2d_only=True) → Muon ===
371389 muon_params .append (
372390 {
373391 "param" : p ,
@@ -381,6 +399,7 @@ def _build_param_routing(self) -> None:
381399 "muon_params" : muon_params ,
382400 "adam_1d" : adam_1d ,
383401 "adam_matrix" : adam_matrix ,
402+ "adam_nd" : adam_nd ,
384403 }
385404 )
386405
@@ -470,12 +489,67 @@ def step(
470489 bias_corr2 = 1 - state ["beta2_pow" ]
471490 step_size = adam_lr / bias_corr1
472491 # delta = -step_size * m_hat / (sqrt(v_hat) + eps)
473- denom = (adam_exp_avg_sqs [i ] / bias_corr2 ).sqrt ().add_ (ADAM_EPS )
492+ denom = (adam_exp_avg_sqs [i ] / bias_corr2 ).sqrt ().add_ (EPS )
474493 delta_fp32 = - step_size * (adam_exp_avgs [i ] / denom )
475494 p .add_ (delta_fp32 .to (p .dtype ))
476495
477- # === Step 2. Adam update for small 2D matrices (fallback path ) ===
496+ # === Step 2. Adam update for > 2D parameters (when muon_2d_only=True ) ===
478497 # === Step 2.1. Collect gradients and initialize state ===
498+ adam_nd_params : list [torch .Tensor ] = []
499+ adam_nd_grads_fp32 : list [torch .Tensor ] = []
500+ adam_nd_exp_avgs : list [torch .Tensor ] = []
501+ adam_nd_exp_avg_sqs : list [torch .Tensor ] = []
502+ adam_nd_states : list [dict [str , Any ]] = []
503+
504+ for entry in route .get ("adam_nd" , []):
505+ p = entry ["param" ]
506+ grad = p .grad
507+ if grad is None :
508+ continue
509+
510+ grad_fp32 = grad .float ()
511+
512+ state = self .state [p ]
513+ if "exp_avg" not in state :
514+ state ["exp_avg" ] = torch .zeros_like (p , dtype = torch .float32 )
515+ state ["exp_avg_sq" ] = torch .zeros_like (p , dtype = torch .float32 )
516+ state ["beta1_pow" ] = 1.0
517+ state ["beta2_pow" ] = 1.0
518+
519+ state ["beta1_pow" ] *= adam_betas [0 ]
520+ state ["beta2_pow" ] *= adam_betas [1 ]
521+
522+ adam_nd_params .append (p )
523+ adam_nd_grads_fp32 .append (grad_fp32 )
524+ adam_nd_exp_avgs .append (state ["exp_avg" ])
525+ adam_nd_exp_avg_sqs .append (state ["exp_avg_sq" ])
526+ adam_nd_states .append (state )
527+
528+ if adam_nd_params :
529+ # === Step 2.2. Update exp_avg / exp_avg_sq ===
530+ adam_lr = lr if lr_adjust <= 0 else lr / lr_adjust
531+
532+ # exp_avg = beta1 * exp_avg + (1 - beta1) * grad
533+ # exp_avg_sq = beta2 * exp_avg_sq + (1 - beta2) * grad^2
534+ torch ._foreach_lerp_ (
535+ adam_nd_exp_avgs , adam_nd_grads_fp32 , 1 - adam_betas [0 ]
536+ )
537+ grad_sq = torch ._foreach_mul (adam_nd_grads_fp32 , adam_nd_grads_fp32 )
538+ torch ._foreach_lerp_ (adam_nd_exp_avg_sqs , grad_sq , 1 - adam_betas [1 ])
539+
540+ # === Step 2.3. Bias correction and parameter update ===
541+ for i , p in enumerate (adam_nd_params ):
542+ state = adam_nd_states [i ]
543+ bias_corr1 = 1 - state ["beta1_pow" ]
544+ bias_corr2 = 1 - state ["beta2_pow" ]
545+ step_size = adam_lr / bias_corr1
546+ # delta = -step_size * m_hat / (sqrt(v_hat) + eps)
547+ denom = (adam_nd_exp_avg_sqs [i ] / bias_corr2 ).sqrt ().add_ (EPS )
548+ delta_fp32 = - step_size * (adam_nd_exp_avgs [i ] / denom )
549+ p .add_ (delta_fp32 .to (p .dtype ))
550+
551+ # === Step 3. Adam update for small 2D matrices (fallback path) ===
552+ # === Step 3.1. Collect gradients and initialize state ===
479553 adam_matrix_params : list [torch .Tensor ] = []
480554 adam_matrix_grads_fp32 : list [torch .Tensor ] = []
481555 adam_matrix_exp_avgs : list [torch .Tensor ] = []
@@ -509,7 +583,7 @@ def step(
509583 adam_matrix_abs_floor .append (entry ["abs_floor" ])
510584
511585 if adam_matrix_params :
512- # === Step 2 .2. Update exp_avg / exp_avg_sq with scaled lr ===
586+ # === Step 3 .2. Update exp_avg / exp_avg_sq with scaled lr ===
513587 adam_lr = lr if lr_adjust <= 0 else lr / lr_adjust
514588 adam_lr_matrix = adam_lr * min (lr_adjust_coeff , 0.1 )
515589
@@ -525,19 +599,17 @@ def step(
525599 adam_matrix_exp_avg_sqs , grad_sq_m , 1 - adam_betas [1 ]
526600 )
527601
528- # === Step 2 .3. Compute unclipped deltas ===
602+ # === Step 3 .3. Compute unclipped deltas ===
529603 raw_deltas : list [torch .Tensor ] = []
530604 for i in range (len (adam_matrix_params )):
531605 state = adam_matrix_states [i ]
532606 bias_corr1 = 1 - state ["beta1_pow" ]
533607 bias_corr2 = 1 - state ["beta2_pow" ]
534608 step_size = adam_lr_matrix / bias_corr1
535- denom = (
536- (adam_matrix_exp_avg_sqs [i ] / bias_corr2 ).sqrt ().add_ (ADAM_EPS )
537- )
609+ denom = (adam_matrix_exp_avg_sqs [i ] / bias_corr2 ).sqrt ().add_ (EPS )
538610 raw_deltas .append (- step_size * (adam_matrix_exp_avgs [i ] / denom ))
539611
540- # === Step 2 .4. Clip updates by relative norm and apply ===
612+ # === Step 3 .4. Clip updates by relative norm and apply ===
541613 max_rel_change = 0.05
542614 p_norms = torch .stack (torch ._foreach_norm (adam_matrix_params ))
543615 delta_norms = torch .stack (torch ._foreach_norm (raw_deltas ))
@@ -553,8 +625,8 @@ def step(
553625 ):
554626 p .add_ (delta .mul_ (scales_tensor [i ]).to (p .dtype ))
555627
556- # === Step 3 . Muon update for >=2D parameters (weight matrices) ===
557- # === Step 3 .1. Collect gradients and initialize momentum ===
628+ # === Step 4 . Muon update for >=2D parameters (weight matrices) ===
629+ # === Step 4 .1. Collect gradients and initialize momentum ===
558630 muon_params_for_decay : list [torch .Tensor ] = []
559631 muon_grads : list [torch .Tensor ] = []
560632 muon_momentum_buffers : list [torch .Tensor ] = []
@@ -579,22 +651,22 @@ def step(
579651 muon_momentum_buffers .append (buf )
580652 active_entries .append ((entry , grad ))
581653
582- # === Step 3 .2. Apply weight decay (Muon path only) ===
654+ # === Step 4 .2. Apply weight decay (Muon path only) ===
583655 if weight_decay > 0 and muon_params_for_decay :
584656 torch ._foreach_mul_ (muon_params_for_decay , 1.0 - lr * weight_decay )
585657
586658 if not active_entries :
587659 continue
588660
589- # === Step 3 .3. Momentum update (Nesterov) ===
661+ # === Step 4 .3. Momentum update (Nesterov) ===
590662 # m_t = beta * m_{t-1} + (1 - beta) * g_t
591663 torch ._foreach_lerp_ (muon_momentum_buffers , muon_grads , 1 - momentum )
592664 # update = beta * m_t + (1 - beta) * g_t
593665 muon_updates = torch ._foreach_lerp (
594666 muon_grads , muon_momentum_buffers , momentum
595667 )
596668
597- # === Step 3 .4. Bucket by shape/device/dtype for batched NS ===
669+ # === Step 4 .4. Bucket by shape/device/dtype for batched NS ===
598670 buckets : dict [
599671 tuple [int , int , torch .device , torch .dtype ],
600672 list [tuple [dict [str , Any ], torch .Tensor ]],
@@ -608,8 +680,8 @@ def step(
608680 buckets [bucket_key ] = []
609681 buckets [bucket_key ].append ((entry , muon_updates [idx ]))
610682
611- # === Step 3 .5. Newton-Schulz orthogonalization and update ===
612- for (rows , cols , _device , dtype ), bucket_entries in buckets .items ():
683+ # === Step 4 .5. Newton-Schulz orthogonalization and update ===
684+ for (rows , cols , _device , _ ), bucket_entries in buckets .items ():
613685 # scale = coeff * sqrt(max(m, n)) [match-RMS mode]
614686 # scale = sqrt(max(1, m/n)) [rectangular mode]
615687 if lr_adjust <= 0 :
@@ -626,8 +698,6 @@ def step(
626698 orth = _zeropower_via_newtonschulz5_2d (update_matrix )
627699 orth .mul_ (scale )
628700 delta = orth .reshape (entry ["param" ].shape )
629- if delta .dtype != dtype :
630- delta = delta .to (dtype )
631701 entry ["param" ].add_ (delta , alpha = - lr )
632702 continue
633703
@@ -648,8 +718,6 @@ def step(
648718 stacked = torch .stack (matrices , dim = 0 )
649719 orth = _zeropower_via_newtonschulz5_3d (stacked )
650720 orth .mul_ (scale )
651- if orth .dtype != dtype :
652- orth = orth .to (dtype )
653721
654722 for i , _ in enumerate (bucket_entries ):
655723 params [i ].add_ (orth [i ].reshape (orig_shapes [i ]), alpha = - lr )
0 commit comments