Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
32 commits
Select commit Hold shift + click to select a range
35862cc
support adaptive learning rate for Muon: normuon and adamuon
mkhona-nvidia Nov 13, 2025
c4f51d5
removed adaptive orthogonalized optimizer as separate class, supporte…
mkhona-nvidia Nov 13, 2025
a996f3f
removed extra literal
mkhona-nvidia Nov 13, 2025
0913b63
subclassed orthogonalized optimizer and override step instead of mixin
mkhona-nvidia Nov 13, 2025
1b87e06
removed mixin test
mkhona-nvidia Nov 13, 2025
fb802b8
cleaned up adaptive orthogonalized optimizer
mkhona-nvidia Nov 13, 2025
890e7f6
changed second moment to moment2, addressed other MR comments
mkhona-nvidia Nov 14, 2025
a47a325
removed args doc
mkhona-nvidia Nov 14, 2025
b9736d8
made a separate test file for adaptive orthogonalized optimizer
mkhona-nvidia Nov 15, 2025
a626188
added device flag
mkhona-nvidia Nov 15, 2025
159505e
changed test name
mkhona-nvidia Nov 15, 2025
c02621f
changed value error to type error
mkhona-nvidia Nov 15, 2025
3a5a9ba
removed adaptive test
mkhona-nvidia Nov 15, 2025
377f510
changed subclass to Muon instead of OrthogonalizedOptimizer
mkhona-nvidia Nov 18, 2025
c1446c2
changed name to adaptivemuon, updated tests
mkhona-nvidia Nov 18, 2025
6c88ac7
updated test to import from adaptive_muon
mkhona-nvidia Nov 18, 2025
530bff1
added missing copyright
mkhona-nvidia Nov 18, 2025
8f46a91
changed scale mode to 1.0, added it as a scale mode
mkhona-nvidia Nov 18, 2025
f4da4f1
Revert "changed scale mode to 1.0, added it as a scale mode"
mkhona-nvidia Nov 18, 2025
939d9bf
addressed MR comments
mkhona-nvidia Nov 18, 2025
6f71c86
use consistent orth_grad naming
mkhona-nvidia Nov 18, 2025
ed365d9
addressed MR comment
mkhona-nvidia Nov 18, 2025
60c6903
addressed MR comments
mkhona-nvidia Nov 18, 2025
b3bb0f3
addressed MR comments
mkhona-nvidia Nov 18, 2025
c26805d
added adaptive_muon test to CI
mkhona-nvidia Nov 18, 2025
4f5d01b
changed name to adaptive_muon
mkhona-nvidia Nov 18, 2025
dd0f282
changed b2 default to 0.95
mkhona-nvidia Nov 18, 2025
e4d70c9
added ref
mkhona-nvidia Nov 18, 2025
6d96345
Revert "added ref"
mkhona-nvidia Nov 18, 2025
cc72e34
added ref
mkhona-nvidia Nov 18, 2025
cf0a275
fixed test
mkhona-nvidia Nov 18, 2025
a677a3a
raised typereror during correct place
mkhona-nvidia Nov 18, 2025
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions emerging_optimizers/orthogonalized_optimizers/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,7 @@
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from emerging_optimizers.orthogonalized_optimizers.adaptive_muon import *
from emerging_optimizers.orthogonalized_optimizers.muon import *
from emerging_optimizers.orthogonalized_optimizers.orthogonalized_optimizer import *
from emerging_optimizers.orthogonalized_optimizers.scion import *
Expand Down
239 changes: 239 additions & 0 deletions emerging_optimizers/orthogonalized_optimizers/adaptive_muon.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,239 @@
# SPDX-FileCopyrightText: Copyright (c) 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
# SPDX-License-Identifier: Apache-2.0
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from typing import Callable, Literal


# TODO(@boxiangw): remove this once bump to python 3.12
try:
from typing import override
except ImportError:
from typing_extensions import override

import torch
from torch.optim.optimizer import ParamsT

from emerging_optimizers import mixin as opt_mixin
from emerging_optimizers import utils
from emerging_optimizers.orthogonalized_optimizers.muon import Muon


class AdaptiveMuon(Muon):
"""Adaptive Muon optimizer with adaptive second moment (AdaMuon/NorMuon variants).

This class extends Muon by adding AdamW-style or NorMuon-style second moment
accumulation after orthogonalization. This idea was first explored in D.E. Carlson,
E. Collins, Ya-Ping Hsieh, L. Carin, and V. Cevher. *Preconditioned spectral
descent for deep learning.* In Advances in neural information processing systems 28 (2015).
The step() method is overridden to include second moment normalization logic.

Args:
params: Iterable of parameters to optimize or dicts defining parameter groups.
lr: Learning rate.
momentum_beta: The exponential decay rate for momentum.
weight_decay: Weight decay coefficient.
use_nesterov: Whether to use Nesterov momentum.
weight_decay_method: The weight decay method to use.
fp32_matmul_prec: Precision for FP32 matrix multiplication.
coefficient_type: The type of coefficient set to use for the Newton-Schulz iteration.
num_ns_steps: The number of iteration steps to use in the Newton-Schulz iteration.
scale_mode: The type of scale factor to use for the update.
extra_scale_factor: The additional scale factor to use for the update.
use_syrk: Whether to use the Triton kernel for the Newton-Schulz iteration.
moment2_method: Method for second moment accumulation ("adamuon" or "normuon").
beta2: The exponential decay rate for second moment.
eps: Small constant for numerical stability.
"""

def __init__(
self,
params: ParamsT,
lr: float,
momentum_beta: float,
weight_decay: float,
*,
use_nesterov: bool,
weight_decay_method: opt_mixin.WeightDecayT = "decoupled",
fp32_matmul_prec: str,
coefficient_type: str = "quintic",
num_ns_steps: int = 5,
scale_mode: str = "spectral",
extra_scale_factor: float = 1.0,
use_syrk: bool = False,
moment2_method: Literal["adamuon", "normuon"] = "adamuon",
beta2: float = 0.95,
eps: float = 1e-8,
):
super().__init__(
params,
lr=lr,
momentum_beta=momentum_beta,
weight_decay=weight_decay,
use_nesterov=use_nesterov,
weight_decay_method=weight_decay_method,
fp32_matmul_prec=fp32_matmul_prec,
coefficient_type=coefficient_type,
num_ns_steps=num_ns_steps,
scale_mode=scale_mode,
extra_scale_factor=extra_scale_factor,
use_syrk=use_syrk,
)
self.moment2_method = moment2_method

for group in self.param_groups:
group.setdefault("beta2", beta2)
group.setdefault("eps", eps)

def _initialize_moment2(
self,
state: dict[str, torch.Tensor],
grad: torch.Tensor,
) -> None:
"""Initialize the second moment buffer if it doesn't exist.

The shape of the buffer depends on the moment2_method:
- "adamuon": Full elementwise buffer with same shape as grad
- "normuon": Reduced shape buffer (averaged along -1 if shape[-2] >= shape[-1], else -2)

Args:
state: The optimizer state dict for a parameter.
grad: The gradient tensor (used for shape/dtype).
"""
if "moment2_buffer" not in state:
if self.moment2_method == "adamuon":
# Full elementwise second moment
moment2 = torch.zeros_like(grad)
elif self.moment2_method == "normuon":
# Row/column-wise second moment - reduced along one dimension
# Determine which dimension to reduce based on parameter shape
avg_dim = -1 if grad.shape[-2] >= grad.shape[-1] else -2
# Specify the shape with reduced dimension
moment2_shape = list(grad.shape)
moment2_shape[avg_dim] = 1
moment2 = torch.zeros(moment2_shape, dtype=grad.dtype, device=grad.device)
else:
raise TypeError(f"Invalid second moment method: {self.moment2_method}")

state["moment2_buffer"] = moment2

def _apply_moment2_normalization(
self,
orth_grad: torch.Tensor,
moment2: torch.Tensor,
beta2: float,
eps: float,
) -> torch.Tensor:
"""Apply AdamW-style second moment accumulation and normalization.

This method supports two variants:
- "adamuon": Full elementwise second moment (like AdamW, https://arxiv.org/abs/2507.11005)
- "normuon": Row or column-wise second moment (https://arxiv.org/abs/2510.05491)

For both methods:
1. Updates the second moment as an EMA of squared gradients
2. Returns the adaptively scaled gradient

Args:
orth_grad: The orthogonalized gradient tensor.
moment2: The second moment buffer from state.
beta2: The exponential decay rate for second moment.
eps: Small constant for numerical stability.

Returns:
The adaptively scaled weight update tensor.
"""
if self.moment2_method == "adamuon":
# AdamMuon: Full elementwise second moment like AdamW
# Update second moment with EMA of squared orthogonalized gradient
moment2.lerp_(orth_grad.square(), 1 - beta2)

# AdamW-style division: grad / (sqrt(moment2) + eps)
denom = moment2.sqrt() + eps
return orth_grad / denom

elif self.moment2_method == "normuon":
# NorMuon: Row or column-wise second moment
# Compute mean of squared gradients along one dimension based on shape
# Average along the longer dimension to preserve structure along shorter dim
avg_dim = -1 if orth_grad.shape[-2] >= orth_grad.shape[-1] else -2
v_mean = orth_grad.square().mean(dim=avg_dim, keepdim=True)

# Update second moment with EMA
moment2.lerp_(v_mean, 1 - beta2)

# NorMuon uses reciprocal square root with clamping
step_size = moment2.clamp_min(eps).rsqrt_()
return orth_grad * step_size

else:
raise TypeError(f"Invalid second moment method: {self.moment2_method}")

@torch.no_grad() # type: ignore[misc]
@override
def step(self, closure: Callable[[], float] | None = None) -> float | None:
"""Single optimization step.

Args:
closure: A closure that reevaluates the model and returns the loss.
"""
if closure is None:
loss = None
else:
loss = closure()

for group in self.param_groups:
for p in group["params"]:
if p.dim() != 2:
raise ValueError("AdaptiveMuon only supports 2D parameters")
grad = p.grad
if grad is None:
continue
state = self.state[p]

if "momentum_buffer" not in state:
state["momentum_buffer"] = torch.zeros_like(grad)
self._initialize_moment2(state, grad)

exp_avg = state["momentum_buffer"]

self._apply_weight_decay_inplace(
p,
grad,
group["lr"],
group["weight_decay"],
)

# update momentum buffer with EMA of gradient
exp_avg.lerp_(grad, 1 - group["momentum_beta"])

if self.use_nesterov:
grad = grad.lerp(exp_avg, group["momentum_beta"])
else:
grad = exp_avg

with utils.fp32_matmul_precision(self.fp32_matmul_prec):
orth_grad = self.scaled_orthogonalize_fn(grad)

update = self._apply_moment2_normalization(
orth_grad=orth_grad,
moment2=state["moment2_buffer"],
beta2=group["beta2"],
eps=group["eps"],
)

# perform weight update
# scale is applied to have update RMS == 1
p.add_(update, alpha=-group["lr"])

return loss
1 change: 1 addition & 0 deletions tests/ci/L0_Tests_GPU.sh
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,7 @@ export TORCH_ALLOW_TF32_CUBLAS_OVERRIDE=0

error=0
coverage run -p --source=emerging_optimizers tests/test_muon_utils.py -v -2 || error=1
coverage run -p --source=emerging_optimizers tests/test_adaptive_muon.py -v -2 || error=1
coverage run -p --source=emerging_optimizers tests/test_orthogonalized_optimizer.py -v -2 || error=1
coverage run -p --source=emerging_optimizers tests/test_soap_utils.py -v -2 || error=1
coverage run -p --source=emerging_optimizers tests/test_soap.py -v -2 || error=1
Expand Down
1 change: 1 addition & 0 deletions tests/ci/L1_Tests_GPU.sh
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,7 @@ export TORCH_ALLOW_TF32_CUBLAS_OVERRIDE=0

error=0
python tests/test_muon_utils.py || error=1
python tests/test_adaptive_muon.py || error=1
python tests/test_orthogonalized_optimizer.py || error=1
python tests/test_soap_utils.py || error=1
python tests/test_soap.py || error=1
Expand Down
121 changes: 121 additions & 0 deletions tests/test_adaptive_muon.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,121 @@
# SPDX-FileCopyrightText: Copyright (c) 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
# SPDX-License-Identifier: Apache-2.0
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import torch
import torch.nn as nn
from absl import flags
from absl.testing import absltest, parameterized

from emerging_optimizers.orthogonalized_optimizers.adaptive_muon import (
AdaptiveMuon,
)


flags.DEFINE_string("device", "cpu", "Device to run tests on: 'cpu' or 'cuda'")

FLAGS = flags.FLAGS


class AdaptiveMuonTest(parameterized.TestCase):
@parameterized.product(
shape=[(5, 7), (33, 65), (127, 257)],
second_moment_method=["adamuon", "normuon"],
use_nesterov=[True, False],
)
def test_smoke(self, shape, second_moment_method, use_nesterov) -> None:
"""Smoke test AdaptiveMuon with both second moment methods."""
test_param = nn.Parameter(torch.randint(-5, 5, shape, dtype=torch.float32, device=FLAGS.device))
test_param.grad = torch.randint_like(test_param, -5, 5)

adaptive_opt = AdaptiveMuon(
[test_param],
lr=0.01,
momentum_beta=0.9,
weight_decay=0.01,
use_nesterov=use_nesterov,
moment2_method=second_moment_method,
beta2=0.999,
eps=1e-8,
weight_decay_method="decoupled",
fp32_matmul_prec="highest",
)
adaptive_opt.step()

@parameterized.parameters(
{"shape": (8, 16), "second_moment_method": "adamuon"},
{"shape": (16, 8), "second_moment_method": "normuon"},
)
def test_second_moment_matches_shapes(self, shape, second_moment_method) -> None:
"""Test that second moment buffers are properly initialized."""
test_param = nn.Parameter(torch.randint(-5, 5, shape, dtype=torch.float32, device=FLAGS.device))
test_param.grad = torch.randint_like(test_param, -5, 5)

adaptive_opt = AdaptiveMuon(
[test_param],
lr=0.01,
momentum_beta=0.9,
weight_decay=0.0,
use_nesterov=False,
moment2_method=second_moment_method,
beta2=0.999,
eps=1e-8,
weight_decay_method="decoupled",
fp32_matmul_prec="highest",
)

# Run one step to initialize buffers
adaptive_opt.step()

# Check that second moment buffer was created
state = adaptive_opt.state[test_param]
self.assertIn("moment2_buffer", state)
self.assertIn("momentum_buffer", state)

# Check second moment buffer shape
second_moment = state["moment2_buffer"]
if second_moment_method == "adamuon":
# Full elementwise buffer
self.assertEqual(second_moment.shape, test_param.shape)
elif second_moment_method == "normuon":
# Reduced shape buffer
avg_dim = -1 if shape[-2] >= shape[-1] else -2
expected_shape = list(shape)
expected_shape[avg_dim] = 1
self.assertEqual(list(second_moment.shape), expected_shape)

def test_unknown_moment2_method_raise_type_error(self) -> None:
"""Test that AdaptiveMuon raises TypeError for unknown moment2_method."""
test_param = nn.Parameter(torch.randint(-5, 5, (8, 16), dtype=torch.float32, device=FLAGS.device))
test_param.grad = torch.randint_like(test_param, -5, 5)

adaptive_opt = AdaptiveMuon(
[test_param],
lr=0.01,
momentum_beta=0.9,
weight_decay=0.0,
use_nesterov=False,
moment2_method=None,
beta2=0.999,
eps=1e-8,
weight_decay_method="decoupled",
fp32_matmul_prec="highest",
)

# TypeError is raised during step() when initializing moment2_buffer
with self.assertRaises(TypeError):
adaptive_opt.step()


if __name__ == "__main__":
absltest.main()