Skip to content

Commit 0eaad00

Browse files
committed
support polar grad scale in mop
Signed-off-by: Hao Wu <[email protected]>
1 parent 255bfc6 commit 0eaad00

File tree

4 files changed

+30
-16
lines changed

4 files changed

+30
-16
lines changed

emerging_optimizers/orthogonalized_optimizers/adaptive_muon.py

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -26,10 +26,10 @@
2626

2727
from emerging_optimizers import mixin as opt_mixin
2828
from emerging_optimizers import utils
29-
from emerging_optimizers.orthogonalized_optimizers.muon import Muon
29+
from emerging_optimizers.orthogonalized_optimizers import muon
3030

3131

32-
class AdaptiveMuon(Muon):
32+
class AdaptiveMuon(muon.Muon):
3333
"""Adaptive Muon optimizer with adaptive second moment (AdaMuon/NorMuon variants).
3434
3535
This class extends Muon by adding AdamW-style or NorMuon-style second moment
@@ -68,7 +68,7 @@ def __init__(
6868
fp32_matmul_prec: str,
6969
coefficient_type: str = "quintic",
7070
num_ns_steps: int = 5,
71-
scale_mode: str = "spectral",
71+
scale_mode: muon.MuonScaleT = "spectral",
7272
extra_scale_factor: float = 1.0,
7373
use_syrk: bool = False,
7474
moment2_method: Literal["adamuon", "normuon"] = "adamuon",

emerging_optimizers/orthogonalized_optimizers/mop.py

Lines changed: 16 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -14,7 +14,7 @@
1414
# limitations under the License.
1515

1616

17-
from typing import Optional
17+
from typing import Literal, Optional
1818

1919
import torch
2020
from torch.optim.optimizer import ParamsT
@@ -36,7 +36,7 @@ class MOP(OrthogonalizedOptimizer):
3636
3737
Args:
3838
{_args_doc}
39-
scale_mode: The type of scale factor to use for the update. Defaults to "spectral" style scaling.
39+
scale_mode: The type of scale factor to use for the update. Defaults to nuclear_norm style scaling.
4040
extra_scale_factor: The additional scale factor to use for the update.
4141
"""
4242

@@ -50,21 +50,25 @@ def __init__(
5050
use_nesterov: bool = False,
5151
weight_decay_method: WeightDecayT = "decoupled",
5252
fp32_matmul_prec: str = "highest",
53-
scale_mode: str = "spectral",
53+
scale_mode: muon.MuonScaleT | Literal["nuclear_norm"] = "nuclear_norm",
5454
extra_scale_factor: float = 1.0,
5555
) -> None:
5656
def scaled_orthogonalize_fn(grad: torch.Tensor) -> torch.Tensor:
57-
orth_grad, _ = polar_via_svd(grad, False)
57+
orth_grad, _, S = polar_via_svd(grad, False)
5858

59-
scale_factor = muon.get_muon_scale_factor(grad.size(-2), grad.size(-1), mode=scale_mode)
59+
if scale_mode != "nuclear_norm":
60+
scale_factor = muon.get_muon_scale_factor(grad.size(-2), grad.size(-1), mode=scale_mode)
61+
else:
62+
# nuclear norm scaling suggested by PolarGrad paper (https://arxiv.org/pdf/2505.21799)
63+
scale_factor = S.sum().sqrt()
6064
return orth_grad * scale_factor * extra_scale_factor
6165

6266
super().__init__(
6367
params,
6468
lr,
6569
momentum_beta,
70+
weight_decay,
6671
use_nesterov=use_nesterov,
67-
weight_decay=weight_decay,
6872
weight_decay_method=weight_decay_method,
6973
fp32_matmul_prec=fp32_matmul_prec,
7074
scaled_orthogonalize_fn=scaled_orthogonalize_fn,
@@ -74,7 +78,9 @@ def scaled_orthogonalize_fn(grad: torch.Tensor) -> torch.Tensor:
7478
MOP.__doc__ = MOP.__doc__.format(_args_doc=_args_doc) # type: ignore[union-attr]
7579

7680

77-
def polar_via_svd(A: torch.Tensor, return_p: bool = False) -> tuple[torch.Tensor, Optional[torch.Tensor]]:
81+
def polar_via_svd(
82+
A: torch.Tensor, return_p: bool = False
83+
) -> tuple[torch.Tensor, Optional[torch.Tensor], torch.Tensor]:
7884
"""Compute polar decomposition via SVD
7985
8086
Args:
@@ -87,12 +93,13 @@ def polar_via_svd(A: torch.Tensor, return_p: bool = False) -> tuple[torch.Tensor
8793
A tuple containing:
8894
- The unitary part of the polar decomposition.
8995
- The positive-semidefinite part of the polar decomposition, if return_p is True.
96+
- The singular values of the input tensor.
9097
"""
9198
U_svd, S, Vh = torch.linalg.svd(A, full_matrices=False)
9299
U_polar = U_svd @ Vh
93100

94101
if not return_p:
95-
return U_polar, None
102+
return U_polar, None, S
96103
else:
97104
p = Vh.mH @ torch.diag(S) @ Vh
98-
return U_polar, p
105+
return U_polar, p, S

emerging_optimizers/orthogonalized_optimizers/muon.py

Lines changed: 7 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -13,6 +13,8 @@
1313
# See the License for the specific language governing permissions and
1414
# limitations under the License.
1515

16+
from typing import Literal
17+
1618
import torch
1719
from absl import logging
1820
from torch.optim.optimizer import ParamsT
@@ -23,6 +25,9 @@
2325
from emerging_optimizers.orthogonalized_optimizers.orthogonalized_optimizer import OrthogonalizedOptimizer, _args_doc
2426

2527

28+
MuonScaleT = Literal["shape_scaling", "spectral", "unit_rms_norm"]
29+
30+
2631
class Muon(OrthogonalizedOptimizer):
2732
"""Muon: MomentUm Orthogonalized by Newton-schulz
2833
@@ -72,7 +77,7 @@ def __init__(
7277
fp32_matmul_prec: str = "medium",
7378
coefficient_type: str = "quintic",
7479
num_ns_steps: int = 5,
75-
scale_mode: str = "spectral",
80+
scale_mode: MuonScaleT = "spectral",
7681
extra_scale_factor: float = 1.0,
7782
use_syrk: bool = False,
7883
) -> None:
@@ -122,7 +127,7 @@ def scaled_orthogonalize_fn(grad: torch.Tensor) -> torch.Tensor:
122127
Muon.__doc__ = Muon.__doc__.format(_args_doc=_args_doc) # type: ignore[union-attr]
123128

124129

125-
def get_muon_scale_factor(size_out: int, size_in: int, mode: str = "spectral") -> float:
130+
def get_muon_scale_factor(size_out: int, size_in: int, mode: MuonScaleT = "spectral") -> float:
126131
"""Get the scale for the update.
127132
128133
Default mode is "spectral", which is the mode that allows for learning rate transferability from AdamW.

tests/test_orthogonalized_optimizer.py

Lines changed: 4 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -254,16 +254,18 @@ class MopTest(parameterized.TestCase):
254254
shape=[(5, 7), (33, 65), (127, 257)],
255255
weight_decay_method=["decoupled", "independent"],
256256
use_nesterov=[True, False],
257-
extra_scale_factor=[1.0, 2.0],
257+
scale_mode=["spectral", "nuclear_norm"],
258+
extra_scale_factor=[1.0, 0.2],
258259
)
259-
def test_smoke(self, shape, weight_decay_method, use_nesterov, extra_scale_factor) -> None:
260+
def test_smoke(self, shape, weight_decay_method, use_nesterov, scale_mode, extra_scale_factor) -> None:
260261
test_param = nn.Parameter(torch.randint(-5, 5, shape, dtype=torch.float32, device="cuda"))
261262
test_param.grad = torch.randint_like(test_param, -5, 5)
262263

263264
mop_opt = mop.MOP(
264265
[test_param],
265266
weight_decay_method=weight_decay_method,
266267
use_nesterov=use_nesterov,
268+
scale_mode=scale_mode,
267269
extra_scale_factor=extra_scale_factor,
268270
)
269271
mop_opt.step()

0 commit comments

Comments
 (0)