Skip to content

Commit 7d604f1

Browse files
Add adaptive learning rate for Muon: NorMuon and AdaMuon (#76)
* support adaptive learning rate for Muon: normuon and adamuon Signed-off-by: mikail <[email protected]>
1 parent fe29e56 commit 7d604f1

File tree

5 files changed

+363
-0
lines changed

5 files changed

+363
-0
lines changed

emerging_optimizers/orthogonalized_optimizers/__init__.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -12,6 +12,7 @@
1212
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
1313
# See the License for the specific language governing permissions and
1414
# limitations under the License.
15+
from emerging_optimizers.orthogonalized_optimizers.adaptive_muon import *
1516
from emerging_optimizers.orthogonalized_optimizers.muon import *
1617
from emerging_optimizers.orthogonalized_optimizers.orthogonalized_optimizer import *
1718
from emerging_optimizers.orthogonalized_optimizers.scion import *
Lines changed: 239 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,239 @@
1+
# SPDX-FileCopyrightText: Copyright (c) 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
2+
# SPDX-License-Identifier: Apache-2.0
3+
#
4+
# Licensed under the Apache License, Version 2.0 (the "License");
5+
# you may not use this file except in compliance with the License.
6+
# You may obtain a copy of the License at
7+
#
8+
# http://www.apache.org/licenses/LICENSE-2.0
9+
#
10+
# Unless required by applicable law or agreed to in writing, software
11+
# distributed under the License is distributed on an "AS IS" BASIS,
12+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13+
# See the License for the specific language governing permissions and
14+
# limitations under the License.
15+
from typing import Callable, Literal
16+
17+
18+
# TODO(@boxiangw): remove this once bump to python 3.12
19+
try:
20+
from typing import override
21+
except ImportError:
22+
from typing_extensions import override
23+
24+
import torch
25+
from torch.optim.optimizer import ParamsT
26+
27+
from emerging_optimizers import mixin as opt_mixin
28+
from emerging_optimizers import utils
29+
from emerging_optimizers.orthogonalized_optimizers.muon import Muon
30+
31+
32+
class AdaptiveMuon(Muon):
33+
"""Adaptive Muon optimizer with adaptive second moment (AdaMuon/NorMuon variants).
34+
35+
This class extends Muon by adding AdamW-style or NorMuon-style second moment
36+
accumulation after orthogonalization. This idea was first explored in D.E. Carlson,
37+
E. Collins, Ya-Ping Hsieh, L. Carin, and V. Cevher. *Preconditioned spectral
38+
descent for deep learning.* In Advances in neural information processing systems 28 (2015).
39+
The step() method is overridden to include second moment normalization logic.
40+
41+
Args:
42+
params: Iterable of parameters to optimize or dicts defining parameter groups.
43+
lr: Learning rate.
44+
momentum_beta: The exponential decay rate for momentum.
45+
weight_decay: Weight decay coefficient.
46+
use_nesterov: Whether to use Nesterov momentum.
47+
weight_decay_method: The weight decay method to use.
48+
fp32_matmul_prec: Precision for FP32 matrix multiplication.
49+
coefficient_type: The type of coefficient set to use for the Newton-Schulz iteration.
50+
num_ns_steps: The number of iteration steps to use in the Newton-Schulz iteration.
51+
scale_mode: The type of scale factor to use for the update.
52+
extra_scale_factor: The additional scale factor to use for the update.
53+
use_syrk: Whether to use the Triton kernel for the Newton-Schulz iteration.
54+
moment2_method: Method for second moment accumulation ("adamuon" or "normuon").
55+
beta2: The exponential decay rate for second moment.
56+
eps: Small constant for numerical stability.
57+
"""
58+
59+
def __init__(
60+
self,
61+
params: ParamsT,
62+
lr: float,
63+
momentum_beta: float,
64+
weight_decay: float,
65+
*,
66+
use_nesterov: bool,
67+
weight_decay_method: opt_mixin.WeightDecayT = "decoupled",
68+
fp32_matmul_prec: str,
69+
coefficient_type: str = "quintic",
70+
num_ns_steps: int = 5,
71+
scale_mode: str = "spectral",
72+
extra_scale_factor: float = 1.0,
73+
use_syrk: bool = False,
74+
moment2_method: Literal["adamuon", "normuon"] = "adamuon",
75+
beta2: float = 0.95,
76+
eps: float = 1e-8,
77+
):
78+
super().__init__(
79+
params,
80+
lr=lr,
81+
momentum_beta=momentum_beta,
82+
weight_decay=weight_decay,
83+
use_nesterov=use_nesterov,
84+
weight_decay_method=weight_decay_method,
85+
fp32_matmul_prec=fp32_matmul_prec,
86+
coefficient_type=coefficient_type,
87+
num_ns_steps=num_ns_steps,
88+
scale_mode=scale_mode,
89+
extra_scale_factor=extra_scale_factor,
90+
use_syrk=use_syrk,
91+
)
92+
self.moment2_method = moment2_method
93+
94+
for group in self.param_groups:
95+
group.setdefault("beta2", beta2)
96+
group.setdefault("eps", eps)
97+
98+
def _initialize_moment2(
99+
self,
100+
state: dict[str, torch.Tensor],
101+
grad: torch.Tensor,
102+
) -> None:
103+
"""Initialize the second moment buffer if it doesn't exist.
104+
105+
The shape of the buffer depends on the moment2_method:
106+
- "adamuon": Full elementwise buffer with same shape as grad
107+
- "normuon": Reduced shape buffer (averaged along -1 if shape[-2] >= shape[-1], else -2)
108+
109+
Args:
110+
state: The optimizer state dict for a parameter.
111+
grad: The gradient tensor (used for shape/dtype).
112+
"""
113+
if "moment2_buffer" not in state:
114+
if self.moment2_method == "adamuon":
115+
# Full elementwise second moment
116+
moment2 = torch.zeros_like(grad)
117+
elif self.moment2_method == "normuon":
118+
# Row/column-wise second moment - reduced along one dimension
119+
# Determine which dimension to reduce based on parameter shape
120+
avg_dim = -1 if grad.shape[-2] >= grad.shape[-1] else -2
121+
# Specify the shape with reduced dimension
122+
moment2_shape = list(grad.shape)
123+
moment2_shape[avg_dim] = 1
124+
moment2 = torch.zeros(moment2_shape, dtype=grad.dtype, device=grad.device)
125+
else:
126+
raise TypeError(f"Invalid second moment method: {self.moment2_method}")
127+
128+
state["moment2_buffer"] = moment2
129+
130+
def _apply_moment2_normalization(
131+
self,
132+
orth_grad: torch.Tensor,
133+
moment2: torch.Tensor,
134+
beta2: float,
135+
eps: float,
136+
) -> torch.Tensor:
137+
"""Apply AdamW-style second moment accumulation and normalization.
138+
139+
This method supports two variants:
140+
- "adamuon": Full elementwise second moment (like AdamW, https://arxiv.org/abs/2507.11005)
141+
- "normuon": Row or column-wise second moment (https://arxiv.org/abs/2510.05491)
142+
143+
For both methods:
144+
1. Updates the second moment as an EMA of squared gradients
145+
2. Returns the adaptively scaled gradient
146+
147+
Args:
148+
orth_grad: The orthogonalized gradient tensor.
149+
moment2: The second moment buffer from state.
150+
beta2: The exponential decay rate for second moment.
151+
eps: Small constant for numerical stability.
152+
153+
Returns:
154+
The adaptively scaled weight update tensor.
155+
"""
156+
if self.moment2_method == "adamuon":
157+
# AdamMuon: Full elementwise second moment like AdamW
158+
# Update second moment with EMA of squared orthogonalized gradient
159+
moment2.lerp_(orth_grad.square(), 1 - beta2)
160+
161+
# AdamW-style division: grad / (sqrt(moment2) + eps)
162+
denom = moment2.sqrt() + eps
163+
return orth_grad / denom
164+
165+
elif self.moment2_method == "normuon":
166+
# NorMuon: Row or column-wise second moment
167+
# Compute mean of squared gradients along one dimension based on shape
168+
# Average along the longer dimension to preserve structure along shorter dim
169+
avg_dim = -1 if orth_grad.shape[-2] >= orth_grad.shape[-1] else -2
170+
v_mean = orth_grad.square().mean(dim=avg_dim, keepdim=True)
171+
172+
# Update second moment with EMA
173+
moment2.lerp_(v_mean, 1 - beta2)
174+
175+
# NorMuon uses reciprocal square root with clamping
176+
step_size = moment2.clamp_min(eps).rsqrt_()
177+
return orth_grad * step_size
178+
179+
else:
180+
raise TypeError(f"Invalid second moment method: {self.moment2_method}")
181+
182+
@torch.no_grad() # type: ignore[misc]
183+
@override
184+
def step(self, closure: Callable[[], float] | None = None) -> float | None:
185+
"""Single optimization step.
186+
187+
Args:
188+
closure: A closure that reevaluates the model and returns the loss.
189+
"""
190+
if closure is None:
191+
loss = None
192+
else:
193+
loss = closure()
194+
195+
for group in self.param_groups:
196+
for p in group["params"]:
197+
if p.dim() != 2:
198+
raise ValueError("AdaptiveMuon only supports 2D parameters")
199+
grad = p.grad
200+
if grad is None:
201+
continue
202+
state = self.state[p]
203+
204+
if "momentum_buffer" not in state:
205+
state["momentum_buffer"] = torch.zeros_like(grad)
206+
self._initialize_moment2(state, grad)
207+
208+
exp_avg = state["momentum_buffer"]
209+
210+
self._apply_weight_decay_inplace(
211+
p,
212+
grad,
213+
group["lr"],
214+
group["weight_decay"],
215+
)
216+
217+
# update momentum buffer with EMA of gradient
218+
exp_avg.lerp_(grad, 1 - group["momentum_beta"])
219+
220+
if self.use_nesterov:
221+
grad = grad.lerp(exp_avg, group["momentum_beta"])
222+
else:
223+
grad = exp_avg
224+
225+
with utils.fp32_matmul_precision(self.fp32_matmul_prec):
226+
orth_grad = self.scaled_orthogonalize_fn(grad)
227+
228+
update = self._apply_moment2_normalization(
229+
orth_grad=orth_grad,
230+
moment2=state["moment2_buffer"],
231+
beta2=group["beta2"],
232+
eps=group["eps"],
233+
)
234+
235+
# perform weight update
236+
# scale is applied to have update RMS == 1
237+
p.add_(update, alpha=-group["lr"])
238+
239+
return loss

tests/ci/L0_Tests_GPU.sh

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -17,6 +17,7 @@ export TORCH_ALLOW_TF32_CUBLAS_OVERRIDE=0
1717

1818
error=0
1919
coverage run -p --source=emerging_optimizers tests/test_muon_utils.py -v -2 || error=1
20+
coverage run -p --source=emerging_optimizers tests/test_adaptive_muon.py -v -2 || error=1
2021
coverage run -p --source=emerging_optimizers tests/test_orthogonalized_optimizer.py -v -2 || error=1
2122
coverage run -p --source=emerging_optimizers tests/test_soap_utils.py -v -2 || error=1
2223
coverage run -p --source=emerging_optimizers tests/test_soap.py -v -2 || error=1

tests/ci/L1_Tests_GPU.sh

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -16,6 +16,7 @@ export TORCH_ALLOW_TF32_CUBLAS_OVERRIDE=0
1616

1717
error=0
1818
python tests/test_muon_utils.py || error=1
19+
python tests/test_adaptive_muon.py || error=1
1920
python tests/test_orthogonalized_optimizer.py || error=1
2021
python tests/test_soap_utils.py || error=1
2122
python tests/test_soap.py || error=1

tests/test_adaptive_muon.py

Lines changed: 121 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,121 @@
1+
# SPDX-FileCopyrightText: Copyright (c) 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
2+
# SPDX-License-Identifier: Apache-2.0
3+
#
4+
# Licensed under the Apache License, Version 2.0 (the "License");
5+
# you may not use this file except in compliance with the License.
6+
# You may obtain a copy of the License at
7+
#
8+
# http://www.apache.org/licenses/LICENSE-2.0
9+
#
10+
# Unless required by applicable law or agreed to in writing, software
11+
# distributed under the License is distributed on an "AS IS" BASIS,
12+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13+
# See the License for the specific language governing permissions and
14+
# limitations under the License.
15+
import torch
16+
import torch.nn as nn
17+
from absl import flags
18+
from absl.testing import absltest, parameterized
19+
20+
from emerging_optimizers.orthogonalized_optimizers.adaptive_muon import (
21+
AdaptiveMuon,
22+
)
23+
24+
25+
flags.DEFINE_string("device", "cpu", "Device to run tests on: 'cpu' or 'cuda'")
26+
27+
FLAGS = flags.FLAGS
28+
29+
30+
class AdaptiveMuonTest(parameterized.TestCase):
31+
@parameterized.product(
32+
shape=[(5, 7), (33, 65), (127, 257)],
33+
second_moment_method=["adamuon", "normuon"],
34+
use_nesterov=[True, False],
35+
)
36+
def test_smoke(self, shape, second_moment_method, use_nesterov) -> None:
37+
"""Smoke test AdaptiveMuon with both second moment methods."""
38+
test_param = nn.Parameter(torch.randint(-5, 5, shape, dtype=torch.float32, device=FLAGS.device))
39+
test_param.grad = torch.randint_like(test_param, -5, 5)
40+
41+
adaptive_opt = AdaptiveMuon(
42+
[test_param],
43+
lr=0.01,
44+
momentum_beta=0.9,
45+
weight_decay=0.01,
46+
use_nesterov=use_nesterov,
47+
moment2_method=second_moment_method,
48+
beta2=0.999,
49+
eps=1e-8,
50+
weight_decay_method="decoupled",
51+
fp32_matmul_prec="highest",
52+
)
53+
adaptive_opt.step()
54+
55+
@parameterized.parameters(
56+
{"shape": (8, 16), "second_moment_method": "adamuon"},
57+
{"shape": (16, 8), "second_moment_method": "normuon"},
58+
)
59+
def test_second_moment_matches_shapes(self, shape, second_moment_method) -> None:
60+
"""Test that second moment buffers are properly initialized."""
61+
test_param = nn.Parameter(torch.randint(-5, 5, shape, dtype=torch.float32, device=FLAGS.device))
62+
test_param.grad = torch.randint_like(test_param, -5, 5)
63+
64+
adaptive_opt = AdaptiveMuon(
65+
[test_param],
66+
lr=0.01,
67+
momentum_beta=0.9,
68+
weight_decay=0.0,
69+
use_nesterov=False,
70+
moment2_method=second_moment_method,
71+
beta2=0.999,
72+
eps=1e-8,
73+
weight_decay_method="decoupled",
74+
fp32_matmul_prec="highest",
75+
)
76+
77+
# Run one step to initialize buffers
78+
adaptive_opt.step()
79+
80+
# Check that second moment buffer was created
81+
state = adaptive_opt.state[test_param]
82+
self.assertIn("moment2_buffer", state)
83+
self.assertIn("momentum_buffer", state)
84+
85+
# Check second moment buffer shape
86+
second_moment = state["moment2_buffer"]
87+
if second_moment_method == "adamuon":
88+
# Full elementwise buffer
89+
self.assertEqual(second_moment.shape, test_param.shape)
90+
elif second_moment_method == "normuon":
91+
# Reduced shape buffer
92+
avg_dim = -1 if shape[-2] >= shape[-1] else -2
93+
expected_shape = list(shape)
94+
expected_shape[avg_dim] = 1
95+
self.assertEqual(list(second_moment.shape), expected_shape)
96+
97+
def test_unknown_moment2_method_raise_type_error(self) -> None:
98+
"""Test that AdaptiveMuon raises TypeError for unknown moment2_method."""
99+
test_param = nn.Parameter(torch.randint(-5, 5, (8, 16), dtype=torch.float32, device=FLAGS.device))
100+
test_param.grad = torch.randint_like(test_param, -5, 5)
101+
102+
adaptive_opt = AdaptiveMuon(
103+
[test_param],
104+
lr=0.01,
105+
momentum_beta=0.9,
106+
weight_decay=0.0,
107+
use_nesterov=False,
108+
moment2_method=None,
109+
beta2=0.999,
110+
eps=1e-8,
111+
weight_decay_method="decoupled",
112+
fp32_matmul_prec="highest",
113+
)
114+
115+
# TypeError is raised during step() when initializing moment2_buffer
116+
with self.assertRaises(TypeError):
117+
adaptive_opt.step()
118+
119+
120+
if __name__ == "__main__":
121+
absltest.main()

0 commit comments

Comments
 (0)