Skip to content

Commit 255bfc6

Browse files
authored
Add Mop (#82)
* add mop Signed-off-by: Hao Wu <[email protected]>
1 parent 93d9eb3 commit 255bfc6

File tree

4 files changed

+126
-1
lines changed

4 files changed

+126
-1
lines changed

docs/apidocs/orthogonalized-optimizers.md

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -27,6 +27,12 @@ emerging_optimizers.orthogonalized_optimizers
2727
.. autoclass:: Scion
2828
:members:
2929
30+
:hidden:`Mop`
31+
~~~~~~~~~~~~~~~
32+
33+
.. autoclass:: MOP
34+
:members:
35+
3036
3137
:hidden:`Newton-Schulz`
3238
~~~~~~~~~~~~~~~~~~~~~~~~

emerging_optimizers/orthogonalized_optimizers/__init__.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -13,6 +13,7 @@
1313
# See the License for the specific language governing permissions and
1414
# limitations under the License.
1515
from emerging_optimizers.orthogonalized_optimizers.adaptive_muon import *
16+
from emerging_optimizers.orthogonalized_optimizers.mop import *
1617
from emerging_optimizers.orthogonalized_optimizers.muon import *
1718
from emerging_optimizers.orthogonalized_optimizers.orthogonalized_optimizer import *
1819
from emerging_optimizers.orthogonalized_optimizers.scion import *
Lines changed: 98 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,98 @@
1+
# SPDX-FileCopyrightText: Copyright (c) 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
2+
# SPDX-License-Identifier: Apache-2.0
3+
#
4+
# Licensed under the Apache License, Version 2.0 (the "License");
5+
# you may not use this file except in compliance with the License.
6+
# You may obtain a copy of the License at
7+
#
8+
# http://www.apache.org/licenses/LICENSE-2.0
9+
#
10+
# Unless required by applicable law or agreed to in writing, software
11+
# distributed under the License is distributed on an "AS IS" BASIS,
12+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13+
# See the License for the specific language governing permissions and
14+
# limitations under the License.
15+
16+
17+
from typing import Optional
18+
19+
import torch
20+
from torch.optim.optimizer import ParamsT
21+
22+
from emerging_optimizers.mixin import WeightDecayT
23+
from emerging_optimizers.orthogonalized_optimizers import muon
24+
from emerging_optimizers.orthogonalized_optimizers.orthogonalized_optimizer import OrthogonalizedOptimizer, _args_doc
25+
26+
27+
__all__ = ["MOP"]
28+
29+
30+
class MOP(OrthogonalizedOptimizer):
31+
"""MOP: Momentum Orthogonalized by Polar decomposition
32+
33+
warning:
34+
This optimizer is experimental and not yet thoroughly tested.
35+
36+
37+
Args:
38+
{_args_doc}
39+
scale_mode: The type of scale factor to use for the update. Defaults to "spectral" style scaling.
40+
extra_scale_factor: The additional scale factor to use for the update.
41+
"""
42+
43+
def __init__(
44+
self,
45+
params: ParamsT,
46+
lr: float = 3e-4,
47+
momentum_beta: float = 0.95,
48+
weight_decay: float = 0.01,
49+
*,
50+
use_nesterov: bool = False,
51+
weight_decay_method: WeightDecayT = "decoupled",
52+
fp32_matmul_prec: str = "highest",
53+
scale_mode: str = "spectral",
54+
extra_scale_factor: float = 1.0,
55+
) -> None:
56+
def scaled_orthogonalize_fn(grad: torch.Tensor) -> torch.Tensor:
57+
orth_grad, _ = polar_via_svd(grad, False)
58+
59+
scale_factor = muon.get_muon_scale_factor(grad.size(-2), grad.size(-1), mode=scale_mode)
60+
return orth_grad * scale_factor * extra_scale_factor
61+
62+
super().__init__(
63+
params,
64+
lr,
65+
momentum_beta,
66+
use_nesterov=use_nesterov,
67+
weight_decay=weight_decay,
68+
weight_decay_method=weight_decay_method,
69+
fp32_matmul_prec=fp32_matmul_prec,
70+
scaled_orthogonalize_fn=scaled_orthogonalize_fn,
71+
)
72+
73+
74+
MOP.__doc__ = MOP.__doc__.format(_args_doc=_args_doc) # type: ignore[union-attr]
75+
76+
77+
def polar_via_svd(A: torch.Tensor, return_p: bool = False) -> tuple[torch.Tensor, Optional[torch.Tensor]]:
78+
"""Compute polar decomposition via SVD
79+
80+
Args:
81+
A: The input tensor to compute the polar decomposition of.
82+
return_p: Whether to return the positive-semidefinite part of the polar decomposition. p is not needed
83+
by the MOP optimizer, so by default it is not calculated to save computation. The option is provided to
84+
return full polar decomposition to match the function name.
85+
86+
Returns:
87+
A tuple containing:
88+
- The unitary part of the polar decomposition.
89+
- The positive-semidefinite part of the polar decomposition, if return_p is True.
90+
"""
91+
U_svd, S, Vh = torch.linalg.svd(A, full_matrices=False)
92+
U_polar = U_svd @ Vh
93+
94+
if not return_p:
95+
return U_polar, None
96+
else:
97+
p = Vh.mH @ torch.diag(S) @ Vh
98+
return U_polar, p

tests/test_orthogonalized_optimizer.py

Lines changed: 21 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -17,7 +17,7 @@
1717
import torch.nn as nn
1818
from absl.testing import absltest, parameterized
1919

20-
from emerging_optimizers.orthogonalized_optimizers import muon, scion
20+
from emerging_optimizers.orthogonalized_optimizers import mop, muon, scion
2121
from emerging_optimizers.orthogonalized_optimizers.orthogonalized_optimizer import OrthogonalizedOptimizer
2222

2323

@@ -249,5 +249,25 @@ def test_smoke(self, shape) -> None:
249249
scion_opt.step()
250250

251251

252+
class MopTest(parameterized.TestCase):
253+
@parameterized.product(
254+
shape=[(5, 7), (33, 65), (127, 257)],
255+
weight_decay_method=["decoupled", "independent"],
256+
use_nesterov=[True, False],
257+
extra_scale_factor=[1.0, 2.0],
258+
)
259+
def test_smoke(self, shape, weight_decay_method, use_nesterov, extra_scale_factor) -> None:
260+
test_param = nn.Parameter(torch.randint(-5, 5, shape, dtype=torch.float32, device="cuda"))
261+
test_param.grad = torch.randint_like(test_param, -5, 5)
262+
263+
mop_opt = mop.MOP(
264+
[test_param],
265+
weight_decay_method=weight_decay_method,
266+
use_nesterov=use_nesterov,
267+
extra_scale_factor=extra_scale_factor,
268+
)
269+
mop_opt.step()
270+
271+
252272
if __name__ == "__main__":
253273
absltest.main()

0 commit comments

Comments
 (0)