1414# limitations under the License.
1515
1616
17- from typing import Optional
17+ from typing import Literal , Optional
1818
1919import torch
2020from torch .optim .optimizer import ParamsT
@@ -36,7 +36,7 @@ class MOP(OrthogonalizedOptimizer):
3636
3737 Args:
3838 {_args_doc}
39- scale_mode: The type of scale factor to use for the update. Defaults to "spectral" style scaling.
39+ scale_mode: The type of scale factor to use for the update. Defaults to nuclear_norm style scaling.
4040 extra_scale_factor: The additional scale factor to use for the update.
4141 """
4242
@@ -50,21 +50,25 @@ def __init__(
5050 use_nesterov : bool = False ,
5151 weight_decay_method : WeightDecayT = "decoupled" ,
5252 fp32_matmul_prec : str = "highest" ,
53- scale_mode : str = "spectral " ,
53+ scale_mode : muon . MuonScaleT | Literal [ "nuclear_norm" ] = "nuclear_norm " ,
5454 extra_scale_factor : float = 1.0 ,
5555 ) -> None :
5656 def scaled_orthogonalize_fn (grad : torch .Tensor ) -> torch .Tensor :
57- orth_grad , _ = polar_via_svd (grad , False )
57+ orth_grad , _ , S = polar_via_svd (grad , False )
5858
59- scale_factor = muon .get_muon_scale_factor (grad .size (- 2 ), grad .size (- 1 ), mode = scale_mode )
59+ if scale_mode != "nuclear_norm" :
60+ scale_factor = muon .get_muon_scale_factor (grad .size (- 2 ), grad .size (- 1 ), mode = scale_mode )
61+ else :
62+ # nuclear norm scaling suggested by PolarGrad paper (https://arxiv.org/pdf/2505.21799)
63+ scale_factor = S .sum ().sqrt ()
6064 return orth_grad * scale_factor * extra_scale_factor
6165
6266 super ().__init__ (
6367 params ,
6468 lr ,
6569 momentum_beta ,
70+ weight_decay ,
6671 use_nesterov = use_nesterov ,
67- weight_decay = weight_decay ,
6872 weight_decay_method = weight_decay_method ,
6973 fp32_matmul_prec = fp32_matmul_prec ,
7074 scaled_orthogonalize_fn = scaled_orthogonalize_fn ,
@@ -74,7 +78,9 @@ def scaled_orthogonalize_fn(grad: torch.Tensor) -> torch.Tensor:
7478MOP .__doc__ = MOP .__doc__ .format (_args_doc = _args_doc ) # type: ignore[union-attr]
7579
7680
77- def polar_via_svd (A : torch .Tensor , return_p : bool = False ) -> tuple [torch .Tensor , Optional [torch .Tensor ]]:
81+ def polar_via_svd (
82+ A : torch .Tensor , return_p : bool = False
83+ ) -> tuple [torch .Tensor , Optional [torch .Tensor ], torch .Tensor ]:
7884 """Compute polar decomposition via SVD
7985
8086 Args:
@@ -87,12 +93,13 @@ def polar_via_svd(A: torch.Tensor, return_p: bool = False) -> tuple[torch.Tensor
8793 A tuple containing:
8894 - The unitary part of the polar decomposition.
8995 - The positive-semidefinite part of the polar decomposition, if return_p is True.
96+ - The singular values of the input tensor.
9097 """
9198 U_svd , S , Vh = torch .linalg .svd (A , full_matrices = False )
9299 U_polar = U_svd @ Vh
93100
94101 if not return_p :
95- return U_polar , None
102+ return U_polar , None , S
96103 else :
97104 p = Vh .mH @ torch .diag (S ) @ Vh
98- return U_polar , p
105+ return U_polar , p , S
0 commit comments