From 35862cc713c3d766f42ec9de865797a8603434fd Mon Sep 17 00:00:00 2001 From: mikail Date: Wed, 12 Nov 2025 18:18:25 -0800 Subject: [PATCH 01/32] support adaptive learning rate for Muon: normuon and adamuon Signed-off-by: mikail --- emerging_optimizers/mixin.py | 118 +++++++++ .../adaptive_orthogonalized_optimizer.py | 248 ++++++++++++++++++ tests/test_mixin.py | 123 +++++++++ 3 files changed, 489 insertions(+) create mode 100644 emerging_optimizers/orthogonalized_optimizers/adaptive_orthogonalized_optimizer.py create mode 100644 tests/test_mixin.py diff --git a/emerging_optimizers/mixin.py b/emerging_optimizers/mixin.py index 4ad2dc6..226af0e 100644 --- a/emerging_optimizers/mixin.py +++ b/emerging_optimizers/mixin.py @@ -19,6 +19,7 @@ WeightDecayT = Literal["decoupled", "independent", "l2"] +SecondMomentT = Literal["adamuon", "normuon"] class WeightDecayMixin: @@ -51,3 +52,120 @@ def _apply_weight_decay_inplace( grad.add_(p, alpha=weight_decay) else: raise ValueError(f"Invalid weight decay method: {weight_decay_method}") + + +class SecondMomentMixin: + """Mixin for second moment accumulation and adaptive learning rates. + + This mixin provides functionality similar to Adam's second moment (exp_avg_sq), + which can be applied after other transformations (e.g., orthogonalization). + It maintains an exponential moving average of squared gradients and applies + element-wise adaptive scaling. + """ + + def _initialize_second_moment( + self, + state: dict[str, torch.Tensor], + grad: torch.Tensor, + ) -> None: + """Initialize the second moment buffer if it doesn't exist. + + The shape of the buffer depends on the second_moment_method: + - "adamuon": Full elementwise buffer with same shape as grad + - "normuon": Reduced shape buffer (averaged along -1 if shape[-2] >= shape[-1], else -2) + + Args: + state: The optimizer state dict for a parameter. + grad: The gradient tensor (used for shape/dtype). + """ + second_moment_method = getattr(self, "second_moment_method", "adamuon") + if "second_moment_buffer" not in state: + if second_moment_method == "adamuon": + # Full elementwise second moment + second_moment = torch.zeros_like(grad) + elif second_moment_method == "normuon": + # Row/column-wise second moment - reduced along one dimension + # Determine which dimension to reduce based on parameter shape + avg_dim = -1 if grad.shape[-2] >= grad.shape[-1] else -2 + # Specify the shape with reduced dimension + second_moment_shape = list(grad.shape) + second_moment_shape[avg_dim] = 1 + second_moment = torch.zeros(second_moment_shape, dtype=grad.dtype, device=grad.device) + else: + raise ValueError(f"Invalid second moment method: {second_moment_method}") + + state["second_moment_buffer"] = second_moment + + def _apply_second_moment_normalization( + self, + orth_grad: torch.Tensor, + second_moment: torch.Tensor, + beta2: float, + eps: float, + correct_bias: bool = False, + step: int = 1, + ) -> torch.Tensor: + """Apply AdamW-style second moment accumulation and normalization. + + This method supports two variants: + - "adamuon": Full elementwise second moment (like AdamW, https://arxiv.org/abs/2507.11005) + - "normuon": Row or column-wise second moment (https://arxiv.org/abs/2510.05491) + + For both methods: + 1. Updates the second moment as an EMA of squared gradients + 2. Optionally applies bias correction + 3. Returns the adaptively scaled gradient + + Args: + orth_grad: The orthogonalized gradient tensor. + second_moment: The second moment buffer from state. + beta2: The exponential decay rate for second moment. + eps: Small constant for numerical stability. + correct_bias: Whether to apply bias correction (default: False). + step: Current optimization step (1-based), used for bias correction. + + Returns: + The adaptively scaled weight update tensor. + """ + + second_moment_method = getattr(self, "second_moment_method", "adamuon") + + if second_moment_method == "adamuon": + # AdamMuon: Full elementwise second moment like AdamW + # Update second moment with EMA of squared gradient + second_moment.lerp_(orth_grad.square(), 1 - beta2) + + # Optional bias correction + if correct_bias: + bias_correction2 = 1.0 - beta2**step + corrected_second_moment = second_moment / bias_correction2 + else: + corrected_second_moment = second_moment + + # AdamW-style division: grad / (sqrt(second_moment) + eps) + denom = corrected_second_moment.sqrt() + eps + return orth_grad / denom + + elif second_moment_method == "normuon": + # NorMuon: Row or column-wise second moment + # Compute mean of squared gradients along one dimension based on shape + # Average along the longer dimension to preserve structure along shorter dim + avg_dim = -1 if orth_grad.shape[-2] >= orth_grad.shape[-1] else -2 + v_mean = orth_grad.square().mean(dim=avg_dim, keepdim=True) + + # Update second moment with EMA + second_moment.lerp_(v_mean, 1 - beta2) + + # Optional bias correction + if correct_bias: + bias_correction2 = 1.0 - beta2**step + corrected_second_moment = second_moment / bias_correction2 + else: + corrected_second_moment = second_moment + + # NorMuon uses reciprocal square root with clamping + step_size = corrected_second_moment.clamp_min(eps).rsqrt_() + return orth_grad * step_size + + else: + raise ValueError(f"Invalid second moment method: {second_moment_method}") diff --git a/emerging_optimizers/orthogonalized_optimizers/adaptive_orthogonalized_optimizer.py b/emerging_optimizers/orthogonalized_optimizers/adaptive_orthogonalized_optimizer.py new file mode 100644 index 0000000..9b6dd3d --- /dev/null +++ b/emerging_optimizers/orthogonalized_optimizers/adaptive_orthogonalized_optimizer.py @@ -0,0 +1,248 @@ +# SPDX-FileCopyrightText: Copyright (c) 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# SPDX-License-Identifier: Apache-2.0 +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +from typing import Any, Callable + + +# TODO(@boxiangw): remove this once bump to python 3.12 +try: + from typing import override +except ImportError: + from typing_extensions import override + +import torch +import torch.optim as optim +from absl import logging +from torch.optim.optimizer import ParamsT + +from emerging_optimizers import mixin as opt_mixin +from emerging_optimizers import utils + + +_args_doc = """params: Iterable of parameters to optimize or dicts defining parameter groups + lr: The learning rate used by the internal SGD. + momentum_beta: The momentum used by the internal SGD. + beta2: The exponential decay rate for second moment (like AdamW's beta2). + eps: Small constant for numerical stability in second moment division. + weight_decay: The weight decay used by the optimizer, default to be decoupled weight decay. + See Decoupled Weight Decay Regularization: https://arxiv.org/abs/1711.05101 + use_nesterov: Whether to use Nesterov-style momentum in the internal SGD. + correct_bias: Whether to apply bias correction to second moment. + second_moment_method: Method to apply second moment, see :class:`~emerging_optimizers.mixin.SecondMomentMixin` + for more details. Options: "adamuon" (elementwise like AdamW), "normuon" (row/column-wise). + weight_decay_method: Method to apply weight decay, see :class:`~emerging_optimizers.mixin.WeightDecayMixin` + for more details. + fp32_matmul_prec: Precision of the matmul operations in optimizer states GEMM operations. +""" + + +class AdaptiveOrthogonalizedOptimizer( + opt_mixin.SecondMomentMixin, + opt_mixin.WeightDecayMixin, + optim.Optimizer, +): + """Adaptive orthogonalized optimizer with second moment accumulation. + + This optimizer extends the orthogonalized optimizer framework by adding AdamW-style + second moment accumulation and adaptive learning rates. The optimizer performs: + + 1. First moment (momentum) accumulation with optional Nesterov acceleration + 2. Orthogonalization/preconditioning of the momentum + 3. Second moment accumulation of the orthogonalized gradients + 4. Adaptive scaling using the second moment (like AdamW) + + The theoretical foundation of orthogonalization for stochastic gradient descent was developed by the + following papers: + + - Carlson, D., Cevher, V., and Carin, L. *Stochastic spectral descent for Restricted Boltzmann Machines.* + In International Conference on Artificial Intelligence and Statistics (2015a). + - Carlson, D., Hsieh, Y.-P., Collins, E., Carin, L., and Cevher, V. + *Stochastic Spectral Descent for Discrete Graphical Models.* + In IEEE Journal of Selected Topics in Signal Processing, vol. 10, no. 2, pp. 296-311 (2016). + - Carlson, D., Collins, E., Hsieh, Y.-P., Carin, L., and Cevher, V. + *Preconditioned spectral descent for deep learning.* + In Neural Information Processing Systems (2015b). + - Flynn, T. *The duality structure gradient descent algorithm: analysis and applications to neural networks.* + arXiv preprint arXiv:1708.00523 (2017). [`arXiv:1708.00523 `_] + + Note: + AdaptiveOrthogonalizedOptimizer as base class doesn't directly support orthogonalizing fused parameters separately. + Subclass can override the orthogonalize function to support this, see example below. + + .. code-block:: python + :caption: Split QKV example + + class SplitQkvAdaptiveOrthogonalizedOptimizer(AdaptiveOrthogonalizedOptimizer): + def __init__(..., split_qkv_shapes): + super().__init__(...) + self.qkv_split_shapes = split_qkv_shapes + + def orthogonalize(self, p: torch.Tensor, grad: torch.Tensor, **kwargs: Any) -> torch.Tensor: + + # Alternative is passing "is_qkv" to scaled_orthogonalize_fn and split inside the + # scaled_orthogonalize_fn. + if getattr(p, "is_qkv", False) or kwargs.get("is_qkv", False): + qkv_grads = torch.split(grad, self.qkv_split_shapes, dim=0) + qkv_orthogonalized = [self.scaled_orthogonalize_fn(g) for g in qkv_grads] + grad = torch.cat([orthogonalized for orthogonalized in qkv_orthogonalized]) + else: + grad = self.scaled_orthogonalize_fn(grad) + + return grad + + Args: + {_args_doc} + scaled_orthogonalize_fn: Function to orthogonalize and scale the updates. + **kwargs: Arguments passed through to the base optimizer. + + Note: + Keyword arguments passed through are not checked here. Optimizer inherited from this class should check them. + """ + + def __init__( + self, + params: ParamsT, + lr: float, + momentum_beta: float, + weight_decay: float, + *, + beta2: float = 0.999, + eps: float = 1e-8, + use_nesterov: bool, + correct_bias: bool = False, + second_moment_method: opt_mixin.SecondMomentT = "adamuon", + weight_decay_method: opt_mixin.WeightDecayT = "decoupled", + fp32_matmul_prec: str, + scaled_orthogonalize_fn: Callable | None = None, + **kwargs: Any, + ): + if scaled_orthogonalize_fn is None: + logging.warning("scaled_orthogonalize_fn not provided. Using noop") + scaled_orthogonalize_fn = torch.nn.Identity() + + self.fp32_matmul_prec = fp32_matmul_prec + self.use_nesterov = use_nesterov + self.correct_bias = correct_bias + self.second_moment_method = second_moment_method + self.weight_decay_method = weight_decay_method + + default_args_dict = dict( + lr=lr, + momentum_beta=momentum_beta, + beta2=beta2, + eps=eps, + weight_decay=weight_decay, + **kwargs, + ) + + super().__init__(params, default_args_dict) + self.scaled_orthogonalize_fn = scaled_orthogonalize_fn + + @torch.no_grad() # type: ignore[misc] + @override + def step(self, closure: Callable[[], float] | None = None) -> float | None: + """Performs a single optimization step. + + Args: + closure: A closure that reevaluates the model and returns the loss. + """ + if closure is None: + loss = None + else: + loss = closure() + + for group in self.param_groups: + for p in group["params"]: + if p.dim() == 1: + raise ValueError(f"{self.__class__.__name__} does not support 1D parameters") + grad = p.grad + if grad is None: + continue + state = self.state[p] + + # Initialize step counter + if "step" not in state: + state["step"] = 0 + state["step"] += 1 + + # initialize momentum buffer + if "momentum_buffer" not in state: + state["momentum_buffer"] = torch.zeros_like(grad) + + # Initialize second moment buffer using mixin + self._initialize_second_moment(state, grad) + + # Subsequent update to exp_avg are all inplace, so it is not assigned back to state. + exp_avg = state["momentum_buffer"] + second_moment = state["second_moment_buffer"] + + self._apply_weight_decay_inplace( + p, + grad, + group["lr"], + group["weight_decay"], + ) + + # update momentum buffer with EMA of gradient + exp_avg.lerp_(grad, 1 - group["momentum_beta"]) + + # include nesterov momentum + if self.use_nesterov: + grad = grad.lerp(exp_avg, group["momentum_beta"]) + else: + grad = exp_avg + + with utils.fp32_matmul_precision(self.fp32_matmul_prec): + group_kwargs = {k: v for k, v in group.items() if k != "params"} + grad = self.orthogonalize(p, grad, **group_kwargs) + + # Apply second moment accumulation and normalization using mixin + grad = self._apply_second_moment_normalization( + orth_grad=grad, + second_moment=second_moment, + beta2=group["beta2"], + eps=group["eps"], + correct_bias=self.correct_bias, + step=state["step"], + ) + + # perform weight update + # scale is applied to have update RMS == 1 + p.add_(grad, alpha=-group["lr"]) + + return loss + + def orthogonalize(self, p: torch.Tensor, grad: torch.Tensor, **kwargs: Any) -> torch.Tensor: + """Orthogonalize the momentum. + + The default orthogonalize function calls the scaled_orthogonalize_fn with the gradient. Subclass can + override this function to implement different orthogonalization logic as well as split fused parameters. + For example, a scaled_orthogonalize_fn function can get attributes from p or from kwargs to determine if + the parameter is a fused parameter and should be split for preconditioning. + + Args: + p: The parameter tensor. It is necessary to pass param tensor in addition to momentum because a lot of + information is only available in the param tensor, attributes for example. Although not used in + this default orthogonalize function. + grad: The momentum tensor. + **kwargs: keyword arguments of the param_group that p was belonged to. + + Returns: + The orthogonalized gradient tensor. + """ + grad = self.scaled_orthogonalize_fn(grad) + return grad + + +AdaptiveOrthogonalizedOptimizer.__doc__ = AdaptiveOrthogonalizedOptimizer.__doc__.format(_args_doc=_args_doc) # type: ignore[union-attr] diff --git a/tests/test_mixin.py b/tests/test_mixin.py new file mode 100644 index 0000000..85f5327 --- /dev/null +++ b/tests/test_mixin.py @@ -0,0 +1,123 @@ +# SPDX-FileCopyrightText: Copyright (c) 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# SPDX-License-Identifier: Apache-2.0 +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +import torch +from absl import flags +from absl.testing import absltest, parameterized + +from emerging_optimizers import mixin as opt_mixin + + +# Define command line flags +flags.DEFINE_string("device", "cpu", "Device to run tests on: 'cpu' or 'cuda'") +flags.DEFINE_integer("seed", 42, "Random seed for reproducible tests") + +FLAGS = flags.FLAGS + + +# Create a dummy class that uses SecondMomentMixin for testing +class TestOptimizer(opt_mixin.SecondMomentMixin): + """Test optimizer that inherits from SecondMomentMixin.""" + + def __init__(self, second_moment_method: str = "adamuon"): + self.second_moment_method = second_moment_method + + +class SecondMomentMixinTest(parameterized.TestCase): + def setUp(self): + """Set random seed and device before each test.""" + torch.manual_seed(FLAGS.seed) + if torch.cuda.is_available(): + torch.cuda.manual_seed_all(FLAGS.seed) + self.device = FLAGS.device + + @parameterized.parameters( + {"shape": (8, 16), "beta2": 0.999, "eps": 1e-8}, + {"shape": (32, 64), "beta2": 0.99, "eps": 1e-6}, + {"shape": (4, 4), "beta2": 0.9, "eps": 1e-10}, + ) + def test_adamuon_method(self, shape, beta2, eps): + """Test AdamMuon (elementwise) second moment method.""" + optimizer = TestOptimizer(second_moment_method="adamuon") + + orth_grad = torch.randn(shape, device=self.device) + second_moment = torch.zeros_like(orth_grad) + + # Apply second moment division + result = optimizer._apply_second_moment_normalization( + orth_grad=orth_grad, + second_moment=second_moment, + beta2=beta2, + eps=eps, + correct_bias=False, + step=1, + ) + + # Check that second moment was updated + expected_second_moment = (1 - beta2) * orth_grad.square() + torch.testing.assert_close(second_moment, expected_second_moment, rtol=1e-5, atol=1e-7) + + # Check result shape + self.assertEqual(result.shape, orth_grad.shape) + + # Check that result is computed correctly (elementwise division) + expected_result = orth_grad / (expected_second_moment.sqrt() + eps) + torch.testing.assert_close(result, expected_result, rtol=1e-5, atol=1e-7) + + @parameterized.parameters( + {"shape": (16, 8)}, # rows > cols, should average along -1 + {"shape": (8, 16)}, # cols > rows, should average along -2 + {"shape": (32, 32)}, # square, should average along -1 + ) + def test_normuon_method(self, shape): + """Test NorMuon (row/column-wise) second moment method.""" + optimizer = TestOptimizer(second_moment_method="normuon") + + orth_grad = torch.randn(shape, device=self.device) + + # Determine which dimension should be averaged + avg_dim = -1 if shape[-2] >= shape[-1] else -2 + expected_v_mean = orth_grad.square().mean(dim=avg_dim, keepdim=True) + + # Initialize second moment to zeros with correct shape + second_moment = torch.zeros_like(expected_v_mean) + + beta2 = 0.999 + eps = 1e-8 + + # Apply second moment division + result = optimizer._apply_second_moment_normalization( + orth_grad=orth_grad, + second_moment=second_moment, + beta2=beta2, + eps=eps, + correct_bias=False, + step=1, + ) + + # Check that second moment was updated with correct shape + expected_second_moment = (1 - beta2) * expected_v_mean + torch.testing.assert_close(second_moment, expected_second_moment, rtol=1e-5, atol=1e-7) + + # Check result shape matches input + self.assertEqual(result.shape, orth_grad.shape) + + # Check that result uses reciprocal square root + step_size = expected_second_moment.clamp_min(eps).rsqrt_() + expected_result = orth_grad * step_size + torch.testing.assert_close(result, expected_result, rtol=1e-5, atol=1e-7) + + +if __name__ == "__main__": + absltest.main() From c4f51d5a292013ad0694bc574c325083ae8a02e9 Mon Sep 17 00:00:00 2001 From: mikail Date: Wed, 12 Nov 2025 18:43:09 -0800 Subject: [PATCH 02/32] removed adaptive orthogonalized optimizer as separate class, supported second moment computations within same code Signed-off-by: mikail --- emerging_optimizers/mixin.py | 26 +- .../adaptive_orthogonalized_optimizer.py | 248 ------------------ .../orthogonalized_optimizer.py | 41 ++- tests/test_mixin.py | 8 +- 4 files changed, 44 insertions(+), 279 deletions(-) delete mode 100644 emerging_optimizers/orthogonalized_optimizers/adaptive_orthogonalized_optimizer.py diff --git a/emerging_optimizers/mixin.py b/emerging_optimizers/mixin.py index 226af0e..568b7c0 100644 --- a/emerging_optimizers/mixin.py +++ b/emerging_optimizers/mixin.py @@ -20,6 +20,7 @@ WeightDecayT = Literal["decoupled", "independent", "l2"] SecondMomentT = Literal["adamuon", "normuon"] +SecondMomentOptionalT = Literal["adamuon", "normuon", None] class WeightDecayMixin: @@ -102,8 +103,6 @@ def _apply_second_moment_normalization( second_moment: torch.Tensor, beta2: float, eps: float, - correct_bias: bool = False, - step: int = 1, ) -> torch.Tensor: """Apply AdamW-style second moment accumulation and normalization. @@ -113,16 +112,13 @@ def _apply_second_moment_normalization( For both methods: 1. Updates the second moment as an EMA of squared gradients - 2. Optionally applies bias correction - 3. Returns the adaptively scaled gradient + 2. Returns the adaptively scaled gradient Args: orth_grad: The orthogonalized gradient tensor. second_moment: The second moment buffer from state. beta2: The exponential decay rate for second moment. eps: Small constant for numerical stability. - correct_bias: Whether to apply bias correction (default: False). - step: Current optimization step (1-based), used for bias correction. Returns: The adaptively scaled weight update tensor. @@ -135,15 +131,8 @@ def _apply_second_moment_normalization( # Update second moment with EMA of squared gradient second_moment.lerp_(orth_grad.square(), 1 - beta2) - # Optional bias correction - if correct_bias: - bias_correction2 = 1.0 - beta2**step - corrected_second_moment = second_moment / bias_correction2 - else: - corrected_second_moment = second_moment - # AdamW-style division: grad / (sqrt(second_moment) + eps) - denom = corrected_second_moment.sqrt() + eps + denom = second_moment.sqrt() + eps return orth_grad / denom elif second_moment_method == "normuon": @@ -156,15 +145,8 @@ def _apply_second_moment_normalization( # Update second moment with EMA second_moment.lerp_(v_mean, 1 - beta2) - # Optional bias correction - if correct_bias: - bias_correction2 = 1.0 - beta2**step - corrected_second_moment = second_moment / bias_correction2 - else: - corrected_second_moment = second_moment - # NorMuon uses reciprocal square root with clamping - step_size = corrected_second_moment.clamp_min(eps).rsqrt_() + step_size = second_moment.clamp_min(eps).rsqrt_() return orth_grad * step_size else: diff --git a/emerging_optimizers/orthogonalized_optimizers/adaptive_orthogonalized_optimizer.py b/emerging_optimizers/orthogonalized_optimizers/adaptive_orthogonalized_optimizer.py deleted file mode 100644 index 9b6dd3d..0000000 --- a/emerging_optimizers/orthogonalized_optimizers/adaptive_orthogonalized_optimizer.py +++ /dev/null @@ -1,248 +0,0 @@ -# SPDX-FileCopyrightText: Copyright (c) 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved. -# SPDX-License-Identifier: Apache-2.0 -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. -from typing import Any, Callable - - -# TODO(@boxiangw): remove this once bump to python 3.12 -try: - from typing import override -except ImportError: - from typing_extensions import override - -import torch -import torch.optim as optim -from absl import logging -from torch.optim.optimizer import ParamsT - -from emerging_optimizers import mixin as opt_mixin -from emerging_optimizers import utils - - -_args_doc = """params: Iterable of parameters to optimize or dicts defining parameter groups - lr: The learning rate used by the internal SGD. - momentum_beta: The momentum used by the internal SGD. - beta2: The exponential decay rate for second moment (like AdamW's beta2). - eps: Small constant for numerical stability in second moment division. - weight_decay: The weight decay used by the optimizer, default to be decoupled weight decay. - See Decoupled Weight Decay Regularization: https://arxiv.org/abs/1711.05101 - use_nesterov: Whether to use Nesterov-style momentum in the internal SGD. - correct_bias: Whether to apply bias correction to second moment. - second_moment_method: Method to apply second moment, see :class:`~emerging_optimizers.mixin.SecondMomentMixin` - for more details. Options: "adamuon" (elementwise like AdamW), "normuon" (row/column-wise). - weight_decay_method: Method to apply weight decay, see :class:`~emerging_optimizers.mixin.WeightDecayMixin` - for more details. - fp32_matmul_prec: Precision of the matmul operations in optimizer states GEMM operations. -""" - - -class AdaptiveOrthogonalizedOptimizer( - opt_mixin.SecondMomentMixin, - opt_mixin.WeightDecayMixin, - optim.Optimizer, -): - """Adaptive orthogonalized optimizer with second moment accumulation. - - This optimizer extends the orthogonalized optimizer framework by adding AdamW-style - second moment accumulation and adaptive learning rates. The optimizer performs: - - 1. First moment (momentum) accumulation with optional Nesterov acceleration - 2. Orthogonalization/preconditioning of the momentum - 3. Second moment accumulation of the orthogonalized gradients - 4. Adaptive scaling using the second moment (like AdamW) - - The theoretical foundation of orthogonalization for stochastic gradient descent was developed by the - following papers: - - - Carlson, D., Cevher, V., and Carin, L. *Stochastic spectral descent for Restricted Boltzmann Machines.* - In International Conference on Artificial Intelligence and Statistics (2015a). - - Carlson, D., Hsieh, Y.-P., Collins, E., Carin, L., and Cevher, V. - *Stochastic Spectral Descent for Discrete Graphical Models.* - In IEEE Journal of Selected Topics in Signal Processing, vol. 10, no. 2, pp. 296-311 (2016). - - Carlson, D., Collins, E., Hsieh, Y.-P., Carin, L., and Cevher, V. - *Preconditioned spectral descent for deep learning.* - In Neural Information Processing Systems (2015b). - - Flynn, T. *The duality structure gradient descent algorithm: analysis and applications to neural networks.* - arXiv preprint arXiv:1708.00523 (2017). [`arXiv:1708.00523 `_] - - Note: - AdaptiveOrthogonalizedOptimizer as base class doesn't directly support orthogonalizing fused parameters separately. - Subclass can override the orthogonalize function to support this, see example below. - - .. code-block:: python - :caption: Split QKV example - - class SplitQkvAdaptiveOrthogonalizedOptimizer(AdaptiveOrthogonalizedOptimizer): - def __init__(..., split_qkv_shapes): - super().__init__(...) - self.qkv_split_shapes = split_qkv_shapes - - def orthogonalize(self, p: torch.Tensor, grad: torch.Tensor, **kwargs: Any) -> torch.Tensor: - - # Alternative is passing "is_qkv" to scaled_orthogonalize_fn and split inside the - # scaled_orthogonalize_fn. - if getattr(p, "is_qkv", False) or kwargs.get("is_qkv", False): - qkv_grads = torch.split(grad, self.qkv_split_shapes, dim=0) - qkv_orthogonalized = [self.scaled_orthogonalize_fn(g) for g in qkv_grads] - grad = torch.cat([orthogonalized for orthogonalized in qkv_orthogonalized]) - else: - grad = self.scaled_orthogonalize_fn(grad) - - return grad - - Args: - {_args_doc} - scaled_orthogonalize_fn: Function to orthogonalize and scale the updates. - **kwargs: Arguments passed through to the base optimizer. - - Note: - Keyword arguments passed through are not checked here. Optimizer inherited from this class should check them. - """ - - def __init__( - self, - params: ParamsT, - lr: float, - momentum_beta: float, - weight_decay: float, - *, - beta2: float = 0.999, - eps: float = 1e-8, - use_nesterov: bool, - correct_bias: bool = False, - second_moment_method: opt_mixin.SecondMomentT = "adamuon", - weight_decay_method: opt_mixin.WeightDecayT = "decoupled", - fp32_matmul_prec: str, - scaled_orthogonalize_fn: Callable | None = None, - **kwargs: Any, - ): - if scaled_orthogonalize_fn is None: - logging.warning("scaled_orthogonalize_fn not provided. Using noop") - scaled_orthogonalize_fn = torch.nn.Identity() - - self.fp32_matmul_prec = fp32_matmul_prec - self.use_nesterov = use_nesterov - self.correct_bias = correct_bias - self.second_moment_method = second_moment_method - self.weight_decay_method = weight_decay_method - - default_args_dict = dict( - lr=lr, - momentum_beta=momentum_beta, - beta2=beta2, - eps=eps, - weight_decay=weight_decay, - **kwargs, - ) - - super().__init__(params, default_args_dict) - self.scaled_orthogonalize_fn = scaled_orthogonalize_fn - - @torch.no_grad() # type: ignore[misc] - @override - def step(self, closure: Callable[[], float] | None = None) -> float | None: - """Performs a single optimization step. - - Args: - closure: A closure that reevaluates the model and returns the loss. - """ - if closure is None: - loss = None - else: - loss = closure() - - for group in self.param_groups: - for p in group["params"]: - if p.dim() == 1: - raise ValueError(f"{self.__class__.__name__} does not support 1D parameters") - grad = p.grad - if grad is None: - continue - state = self.state[p] - - # Initialize step counter - if "step" not in state: - state["step"] = 0 - state["step"] += 1 - - # initialize momentum buffer - if "momentum_buffer" not in state: - state["momentum_buffer"] = torch.zeros_like(grad) - - # Initialize second moment buffer using mixin - self._initialize_second_moment(state, grad) - - # Subsequent update to exp_avg are all inplace, so it is not assigned back to state. - exp_avg = state["momentum_buffer"] - second_moment = state["second_moment_buffer"] - - self._apply_weight_decay_inplace( - p, - grad, - group["lr"], - group["weight_decay"], - ) - - # update momentum buffer with EMA of gradient - exp_avg.lerp_(grad, 1 - group["momentum_beta"]) - - # include nesterov momentum - if self.use_nesterov: - grad = grad.lerp(exp_avg, group["momentum_beta"]) - else: - grad = exp_avg - - with utils.fp32_matmul_precision(self.fp32_matmul_prec): - group_kwargs = {k: v for k, v in group.items() if k != "params"} - grad = self.orthogonalize(p, grad, **group_kwargs) - - # Apply second moment accumulation and normalization using mixin - grad = self._apply_second_moment_normalization( - orth_grad=grad, - second_moment=second_moment, - beta2=group["beta2"], - eps=group["eps"], - correct_bias=self.correct_bias, - step=state["step"], - ) - - # perform weight update - # scale is applied to have update RMS == 1 - p.add_(grad, alpha=-group["lr"]) - - return loss - - def orthogonalize(self, p: torch.Tensor, grad: torch.Tensor, **kwargs: Any) -> torch.Tensor: - """Orthogonalize the momentum. - - The default orthogonalize function calls the scaled_orthogonalize_fn with the gradient. Subclass can - override this function to implement different orthogonalization logic as well as split fused parameters. - For example, a scaled_orthogonalize_fn function can get attributes from p or from kwargs to determine if - the parameter is a fused parameter and should be split for preconditioning. - - Args: - p: The parameter tensor. It is necessary to pass param tensor in addition to momentum because a lot of - information is only available in the param tensor, attributes for example. Although not used in - this default orthogonalize function. - grad: The momentum tensor. - **kwargs: keyword arguments of the param_group that p was belonged to. - - Returns: - The orthogonalized gradient tensor. - """ - grad = self.scaled_orthogonalize_fn(grad) - return grad - - -AdaptiveOrthogonalizedOptimizer.__doc__ = AdaptiveOrthogonalizedOptimizer.__doc__.format(_args_doc=_args_doc) # type: ignore[union-attr] diff --git a/emerging_optimizers/orthogonalized_optimizers/orthogonalized_optimizer.py b/emerging_optimizers/orthogonalized_optimizers/orthogonalized_optimizer.py index 78dcae1..4711e7b 100644 --- a/emerging_optimizers/orthogonalized_optimizers/orthogonalized_optimizer.py +++ b/emerging_optimizers/orthogonalized_optimizers/orthogonalized_optimizer.py @@ -36,16 +36,28 @@ weight_decay: The weight decay used by the optimizer, default to be decoupled weight decay. See Decoupled Weight Decay Regularization: https://arxiv.org/abs/1711.05101 use_nesterov: Whether to use Nesterov-style momentum in the internal SGD. + second_moment_method: Method to apply second moment, see :class:`~emerging_optimizers.mixin.SecondMomentMixin` + for more details. Options: None (disabled), "adamuon" (elementwise like AdamW), "normuon" (row/column-wise). + beta2: The exponential decay rate for second moment (like AdamW's beta2). Only used if second_moment_method is not None. + eps: Small constant for numerical stability in second moment normalization. Only used if second_moment_method is not None. weight_decay_method: Method to apply weight decay, see :class:`~emerging_optimizers.mixin.WeightDecayMixin` for more details. fp32_matmul_prec: Precision of the matmul operations in optimizer states GEMM operations. """ -class OrthogonalizedOptimizer(opt_mixin.WeightDecayMixin, optim.Optimizer): - """Base class for orthogonalized optimizers. +class OrthogonalizedOptimizer( + opt_mixin.SecondMomentMixin, + opt_mixin.WeightDecayMixin, + optim.Optimizer, +): + """Base class for orthogonalized optimizers with optional adaptive second moment. This class is a wrapper around a base optimizer that performs orthogonalization on the updates. + Optionally, it can apply AdamW-style or NorMuon-style second moment accumulation after orthogonalization + by setting `second_moment_method` to "adamuon" or "normuon". When `second_moment_method=None` (default), + the optimizer behaves as standard Muon without second moment. + The theoretical foundation of orthogonalization for stochastic gradient descent was developed by the following papers: @@ -102,7 +114,10 @@ def __init__( weight_decay: float, *, use_nesterov: bool, - weight_decay_method: opt_mixin.WeightDecayT, + second_moment_method: opt_mixin.SecondMomentOptionalT = None, + beta2: float = 0.999, + eps: float = 1e-8, + weight_decay_method: opt_mixin.WeightDecayT = "decoupled", fp32_matmul_prec: str, scaled_orthogonalize_fn: Callable | None = None, **kwargs: Any, @@ -113,6 +128,8 @@ def __init__( self.fp32_matmul_prec = fp32_matmul_prec self.use_nesterov = use_nesterov + self.use_second_moment = second_moment_method is not None + self.second_moment_method = second_moment_method self.weight_decay_method = weight_decay_method default_args_dict = dict( @@ -122,6 +139,11 @@ def __init__( **kwargs, ) + # Only add second moment params if second_moment_method is not None + if self.use_second_moment: + default_args_dict["beta2"] = beta2 + default_args_dict["eps"] = eps + super().__init__(params, default_args_dict) self.scaled_orthogonalize_fn = scaled_orthogonalize_fn @@ -150,6 +172,9 @@ def step(self, closure: Callable[[], float] | None = None) -> float | None: # initialize momentum buffer if "momentum_buffer" not in state: state["momentum_buffer"] = torch.zeros_like(grad) + if self.use_second_moment: + # Initialize second moment buffer using mixin + self._initialize_second_moment(state, grad) # Subsequent update to exp_avg are all inplace, so it is not assigned back to state. exp_avg = state["momentum_buffer"] @@ -174,6 +199,16 @@ def step(self, closure: Callable[[], float] | None = None) -> float | None: group_kwargs = {k: v for k, v in group.items() if k != "params"} grad = self.orthogonalize(p, grad, **group_kwargs) + # Apply second moment normalization if enabled + if self.use_second_moment: + # Apply second moment accumulation and normalization + grad = self._apply_second_moment_normalization( + orth_grad=grad, + second_moment=state["second_moment_buffer"], + beta2=group["beta2"], + eps=group["eps"], + ) + # perform weight update # scale is applied to have update RMS == 1 p.add_(grad, alpha=-group["lr"]) diff --git a/tests/test_mixin.py b/tests/test_mixin.py index 85f5327..4736895 100644 --- a/tests/test_mixin.py +++ b/tests/test_mixin.py @@ -54,14 +54,12 @@ def test_adamuon_method(self, shape, beta2, eps): orth_grad = torch.randn(shape, device=self.device) second_moment = torch.zeros_like(orth_grad) - # Apply second moment division + # Apply second moment normalization result = optimizer._apply_second_moment_normalization( orth_grad=orth_grad, second_moment=second_moment, beta2=beta2, eps=eps, - correct_bias=False, - step=1, ) # Check that second moment was updated @@ -96,14 +94,12 @@ def test_normuon_method(self, shape): beta2 = 0.999 eps = 1e-8 - # Apply second moment division + # Apply second moment normalization result = optimizer._apply_second_moment_normalization( orth_grad=orth_grad, second_moment=second_moment, beta2=beta2, eps=eps, - correct_bias=False, - step=1, ) # Check that second moment was updated with correct shape From a996f3f615a5afd47ba5ce443a9f0c52e08157bb Mon Sep 17 00:00:00 2001 From: mikail Date: Wed, 12 Nov 2025 18:45:41 -0800 Subject: [PATCH 03/32] removed extra literal Signed-off-by: mikail --- emerging_optimizers/mixin.py | 3 +-- .../orthogonalized_optimizers/orthogonalized_optimizer.py | 2 +- 2 files changed, 2 insertions(+), 3 deletions(-) diff --git a/emerging_optimizers/mixin.py b/emerging_optimizers/mixin.py index 568b7c0..0db16b5 100644 --- a/emerging_optimizers/mixin.py +++ b/emerging_optimizers/mixin.py @@ -19,8 +19,7 @@ WeightDecayT = Literal["decoupled", "independent", "l2"] -SecondMomentT = Literal["adamuon", "normuon"] -SecondMomentOptionalT = Literal["adamuon", "normuon", None] +SecondMomentT = Literal["adamuon", "normuon", None] class WeightDecayMixin: diff --git a/emerging_optimizers/orthogonalized_optimizers/orthogonalized_optimizer.py b/emerging_optimizers/orthogonalized_optimizers/orthogonalized_optimizer.py index 4711e7b..0c39ad7 100644 --- a/emerging_optimizers/orthogonalized_optimizers/orthogonalized_optimizer.py +++ b/emerging_optimizers/orthogonalized_optimizers/orthogonalized_optimizer.py @@ -114,7 +114,7 @@ def __init__( weight_decay: float, *, use_nesterov: bool, - second_moment_method: opt_mixin.SecondMomentOptionalT = None, + second_moment_method: opt_mixin.SecondMomentT = None, beta2: float = 0.999, eps: float = 1e-8, weight_decay_method: opt_mixin.WeightDecayT = "decoupled", From 0913b63be458663732ebd4b202b3ee45f2f1c87a Mon Sep 17 00:00:00 2001 From: mikail Date: Wed, 12 Nov 2025 22:11:08 -0800 Subject: [PATCH 04/32] subclassed orthogonalized optimizer and override step instead of mixin Signed-off-by: mikail --- emerging_optimizers/mixin.py | 99 ------- .../orthogonalized_optimizers/__init__.py | 1 + .../adaptive_orthogonalized_optimizer.py | 247 ++++++++++++++++++ .../orthogonalized_optimizer.py | 41 +-- tests/test_mixin.py | 36 ++- tests/test_orthogonalized_optimizer.py | 89 +++++++ 6 files changed, 366 insertions(+), 147 deletions(-) create mode 100644 emerging_optimizers/orthogonalized_optimizers/adaptive_orthogonalized_optimizer.py diff --git a/emerging_optimizers/mixin.py b/emerging_optimizers/mixin.py index 0db16b5..4ad2dc6 100644 --- a/emerging_optimizers/mixin.py +++ b/emerging_optimizers/mixin.py @@ -19,7 +19,6 @@ WeightDecayT = Literal["decoupled", "independent", "l2"] -SecondMomentT = Literal["adamuon", "normuon", None] class WeightDecayMixin: @@ -52,101 +51,3 @@ def _apply_weight_decay_inplace( grad.add_(p, alpha=weight_decay) else: raise ValueError(f"Invalid weight decay method: {weight_decay_method}") - - -class SecondMomentMixin: - """Mixin for second moment accumulation and adaptive learning rates. - - This mixin provides functionality similar to Adam's second moment (exp_avg_sq), - which can be applied after other transformations (e.g., orthogonalization). - It maintains an exponential moving average of squared gradients and applies - element-wise adaptive scaling. - """ - - def _initialize_second_moment( - self, - state: dict[str, torch.Tensor], - grad: torch.Tensor, - ) -> None: - """Initialize the second moment buffer if it doesn't exist. - - The shape of the buffer depends on the second_moment_method: - - "adamuon": Full elementwise buffer with same shape as grad - - "normuon": Reduced shape buffer (averaged along -1 if shape[-2] >= shape[-1], else -2) - - Args: - state: The optimizer state dict for a parameter. - grad: The gradient tensor (used for shape/dtype). - """ - second_moment_method = getattr(self, "second_moment_method", "adamuon") - if "second_moment_buffer" not in state: - if second_moment_method == "adamuon": - # Full elementwise second moment - second_moment = torch.zeros_like(grad) - elif second_moment_method == "normuon": - # Row/column-wise second moment - reduced along one dimension - # Determine which dimension to reduce based on parameter shape - avg_dim = -1 if grad.shape[-2] >= grad.shape[-1] else -2 - # Specify the shape with reduced dimension - second_moment_shape = list(grad.shape) - second_moment_shape[avg_dim] = 1 - second_moment = torch.zeros(second_moment_shape, dtype=grad.dtype, device=grad.device) - else: - raise ValueError(f"Invalid second moment method: {second_moment_method}") - - state["second_moment_buffer"] = second_moment - - def _apply_second_moment_normalization( - self, - orth_grad: torch.Tensor, - second_moment: torch.Tensor, - beta2: float, - eps: float, - ) -> torch.Tensor: - """Apply AdamW-style second moment accumulation and normalization. - - This method supports two variants: - - "adamuon": Full elementwise second moment (like AdamW, https://arxiv.org/abs/2507.11005) - - "normuon": Row or column-wise second moment (https://arxiv.org/abs/2510.05491) - - For both methods: - 1. Updates the second moment as an EMA of squared gradients - 2. Returns the adaptively scaled gradient - - Args: - orth_grad: The orthogonalized gradient tensor. - second_moment: The second moment buffer from state. - beta2: The exponential decay rate for second moment. - eps: Small constant for numerical stability. - - Returns: - The adaptively scaled weight update tensor. - """ - - second_moment_method = getattr(self, "second_moment_method", "adamuon") - - if second_moment_method == "adamuon": - # AdamMuon: Full elementwise second moment like AdamW - # Update second moment with EMA of squared gradient - second_moment.lerp_(orth_grad.square(), 1 - beta2) - - # AdamW-style division: grad / (sqrt(second_moment) + eps) - denom = second_moment.sqrt() + eps - return orth_grad / denom - - elif second_moment_method == "normuon": - # NorMuon: Row or column-wise second moment - # Compute mean of squared gradients along one dimension based on shape - # Average along the longer dimension to preserve structure along shorter dim - avg_dim = -1 if orth_grad.shape[-2] >= orth_grad.shape[-1] else -2 - v_mean = orth_grad.square().mean(dim=avg_dim, keepdim=True) - - # Update second moment with EMA - second_moment.lerp_(v_mean, 1 - beta2) - - # NorMuon uses reciprocal square root with clamping - step_size = second_moment.clamp_min(eps).rsqrt_() - return orth_grad * step_size - - else: - raise ValueError(f"Invalid second moment method: {second_moment_method}") diff --git a/emerging_optimizers/orthogonalized_optimizers/__init__.py b/emerging_optimizers/orthogonalized_optimizers/__init__.py index c809ebb..274defa 100644 --- a/emerging_optimizers/orthogonalized_optimizers/__init__.py +++ b/emerging_optimizers/orthogonalized_optimizers/__init__.py @@ -12,6 +12,7 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. +from emerging_optimizers.orthogonalized_optimizers.adaptive_orthogonalized_optimizer import * from emerging_optimizers.orthogonalized_optimizers.muon import * from emerging_optimizers.orthogonalized_optimizers.orthogonalized_optimizer import * from emerging_optimizers.orthogonalized_optimizers.scion import * diff --git a/emerging_optimizers/orthogonalized_optimizers/adaptive_orthogonalized_optimizer.py b/emerging_optimizers/orthogonalized_optimizers/adaptive_orthogonalized_optimizer.py new file mode 100644 index 0000000..b5ef901 --- /dev/null +++ b/emerging_optimizers/orthogonalized_optimizers/adaptive_orthogonalized_optimizer.py @@ -0,0 +1,247 @@ +# SPDX-FileCopyrightText: Copyright (c) 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# SPDX-License-Identifier: Apache-2.0 +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +from typing import Any, Callable, Literal + + +# TODO(@boxiangw): remove this once bump to python 3.12 +try: + from typing import override +except ImportError: + from typing_extensions import override + +import torch +from torch.optim.optimizer import ParamsT + +from emerging_optimizers import mixin as opt_mixin +from emerging_optimizers import utils +from emerging_optimizers.orthogonalized_optimizers.orthogonalized_optimizer import OrthogonalizedOptimizer + + +_adaptive_args_doc = """params: Iterable of parameters to optimize or dicts defining parameter groups + lr: The learning rate used by the internal SGD. + momentum_beta: The momentum used by the internal SGD. + weight_decay: The weight decay used by the optimizer, default to be decoupled weight decay. + See Decoupled Weight Decay Regularization: https://arxiv.org/abs/1711.05101 + use_nesterov: Whether to use Nesterov-style momentum in the internal SGD. + second_moment_method: Method to apply second moment. Options: "adamuon" (elementwise like AdamW), "normuon" (row/column-wise). + beta2: The exponential decay rate for second moment (like AdamW's beta2). + eps: Small constant for numerical stability in second moment normalization. + weight_decay_method: Method to apply weight decay, see :class:`~emerging_optimizers.mixin.WeightDecayMixin` + for more details. + fp32_matmul_prec: Precision of the matmul operations in optimizer states GEMM operations. +""" + + +class AdaptiveOrthogonalizedOptimizer(OrthogonalizedOptimizer): + """Orthogonalized optimizer with adaptive second moment (AdaMuon/NorMuon variants). + + This class extends OrthogonalizedOptimizer by adding AdamW-style or NorMuon-style second moment + accumulation after orthogonalization. The step() method is overridden to include second moment + normalization logic. + + Args: + {_adaptive_args_doc} + scaled_orthogonalize_fn: Function to orthogonalize and scale the updates. + **kwargs: Arguments passed through to the base optimizer. + + Note: + Keyword arguments passed through are not checked here. Optimizer inherited from this class should check them. + """ + + def __init__( + self, + params: ParamsT, + lr: float, + momentum_beta: float, + weight_decay: float, + *, + use_nesterov: bool, + second_moment_method: Literal["adamuon", "normuon"], + beta2: float = 0.999, + eps: float = 1e-8, + weight_decay_method: opt_mixin.WeightDecayT = "decoupled", + fp32_matmul_prec: str, + scaled_orthogonalize_fn: Callable | None = None, + **kwargs: Any, + ): + self.second_moment_method = second_moment_method + + super().__init__( + params=params, + lr=lr, + momentum_beta=momentum_beta, + weight_decay=weight_decay, + use_nesterov=use_nesterov, + weight_decay_method=weight_decay_method, + fp32_matmul_prec=fp32_matmul_prec, + scaled_orthogonalize_fn=scaled_orthogonalize_fn, + **kwargs, + ) + + for group in self.param_groups: + group["beta2"] = beta2 + group["eps"] = eps + + def _initialize_second_moment( + self, + state: dict[str, torch.Tensor], + grad: torch.Tensor, + ) -> None: + """Initialize the second moment buffer if it doesn't exist. + + The shape of the buffer depends on the second_moment_method: + - "adamuon": Full elementwise buffer with same shape as grad + - "normuon": Reduced shape buffer (averaged along -1 if shape[-2] >= shape[-1], else -2) + + Args: + state: The optimizer state dict for a parameter. + grad: The gradient tensor (used for shape/dtype). + """ + if "second_moment_buffer" not in state: + if self.second_moment_method == "adamuon": + # Full elementwise second moment + second_moment = torch.zeros_like(grad) + elif self.second_moment_method == "normuon": + # Row/column-wise second moment - reduced along one dimension + # Determine which dimension to reduce based on parameter shape + avg_dim = -1 if grad.shape[-2] >= grad.shape[-1] else -2 + # Specify the shape with reduced dimension + second_moment_shape = list(grad.shape) + second_moment_shape[avg_dim] = 1 + second_moment = torch.zeros(second_moment_shape, dtype=grad.dtype, device=grad.device) + else: + raise ValueError(f"Invalid second moment method: {self.second_moment_method}") + + state["second_moment_buffer"] = second_moment + + def _apply_second_moment_normalization( + self, + orth_grad: torch.Tensor, + second_moment: torch.Tensor, + beta2: float, + eps: float, + ) -> torch.Tensor: + """Apply AdamW-style second moment accumulation and normalization. + + This method supports two variants: + - "adamuon": Full elementwise second moment (like AdamW, https://arxiv.org/abs/2507.11005) + - "normuon": Row or column-wise second moment (https://arxiv.org/abs/2510.05491) + + For both methods: + 1. Updates the second moment as an EMA of squared gradients + 2. Returns the adaptively scaled gradient + + Args: + orth_grad: The orthogonalized gradient tensor. + second_moment: The second moment buffer from state. + beta2: The exponential decay rate for second moment. + eps: Small constant for numerical stability. + + Returns: + The adaptively scaled weight update tensor. + """ + if self.second_moment_method == "adamuon": + # AdamMuon: Full elementwise second moment like AdamW + # Update second moment with EMA of squared gradient + second_moment.lerp_(orth_grad.square(), 1 - beta2) + + # AdamW-style division: grad / (sqrt(second_moment) + eps) + denom = second_moment.sqrt() + eps + return orth_grad / denom + + elif self.second_moment_method == "normuon": + # NorMuon: Row or column-wise second moment + # Compute mean of squared gradients along one dimension based on shape + # Average along the longer dimension to preserve structure along shorter dim + avg_dim = -1 if orth_grad.shape[-2] >= orth_grad.shape[-1] else -2 + v_mean = orth_grad.square().mean(dim=avg_dim, keepdim=True) + + # Update second moment with EMA + second_moment.lerp_(v_mean, 1 - beta2) + + # NorMuon uses reciprocal square root with clamping + step_size = second_moment.clamp_min(eps).rsqrt_() + return orth_grad * step_size + + else: + raise ValueError(f"Invalid second moment method: {self.second_moment_method}") + + @torch.no_grad() # type: ignore[misc] + @override + def step(self, closure: Callable[[], float] | None = None) -> float | None: + """Performs a single optimization step with second moment normalization. + + Args: + closure: A closure that reevaluates the model and returns the loss. + """ + if closure is None: + loss = None + else: + loss = closure() + + for group in self.param_groups: + for p in group["params"]: + if p.dim() == 1: + raise ValueError(f"{self.__class__.__name__} does not support 1D parameters") + grad = p.grad + if grad is None: + continue + state = self.state[p] + + # initialize momentum buffer and second moment buffer + if "momentum_buffer" not in state: + state["momentum_buffer"] = torch.zeros_like(grad) + self._initialize_second_moment(state, grad) + + exp_avg = state["momentum_buffer"] + + self._apply_weight_decay_inplace( + p, + grad, + group["lr"], + group["weight_decay"], + ) + + # update momentum buffer with EMA of gradient + exp_avg.lerp_(grad, 1 - group["momentum_beta"]) + + # include nesterov momentum + if self.use_nesterov: + grad = grad.lerp(exp_avg, group["momentum_beta"]) + else: + grad = exp_avg + + with utils.fp32_matmul_precision(self.fp32_matmul_prec): + group_kwargs = {k: v for k, v in group.items() if k != "params"} + grad = self.orthogonalize(p, grad, **group_kwargs) + + # Apply second moment normalization + grad = self._apply_second_moment_normalization( + orth_grad=grad, + second_moment=state["second_moment_buffer"], + beta2=group["beta2"], + eps=group["eps"], + ) + + # perform weight update + # scale is applied to have update RMS == 1 + p.add_(grad, alpha=-group["lr"]) + + return loss + + +AdaptiveOrthogonalizedOptimizer.__doc__ = AdaptiveOrthogonalizedOptimizer.__doc__.format( # type: ignore[union-attr] + _adaptive_args_doc=_adaptive_args_doc +) diff --git a/emerging_optimizers/orthogonalized_optimizers/orthogonalized_optimizer.py b/emerging_optimizers/orthogonalized_optimizers/orthogonalized_optimizer.py index 0c39ad7..78dcae1 100644 --- a/emerging_optimizers/orthogonalized_optimizers/orthogonalized_optimizer.py +++ b/emerging_optimizers/orthogonalized_optimizers/orthogonalized_optimizer.py @@ -36,28 +36,16 @@ weight_decay: The weight decay used by the optimizer, default to be decoupled weight decay. See Decoupled Weight Decay Regularization: https://arxiv.org/abs/1711.05101 use_nesterov: Whether to use Nesterov-style momentum in the internal SGD. - second_moment_method: Method to apply second moment, see :class:`~emerging_optimizers.mixin.SecondMomentMixin` - for more details. Options: None (disabled), "adamuon" (elementwise like AdamW), "normuon" (row/column-wise). - beta2: The exponential decay rate for second moment (like AdamW's beta2). Only used if second_moment_method is not None. - eps: Small constant for numerical stability in second moment normalization. Only used if second_moment_method is not None. weight_decay_method: Method to apply weight decay, see :class:`~emerging_optimizers.mixin.WeightDecayMixin` for more details. fp32_matmul_prec: Precision of the matmul operations in optimizer states GEMM operations. """ -class OrthogonalizedOptimizer( - opt_mixin.SecondMomentMixin, - opt_mixin.WeightDecayMixin, - optim.Optimizer, -): - """Base class for orthogonalized optimizers with optional adaptive second moment. +class OrthogonalizedOptimizer(opt_mixin.WeightDecayMixin, optim.Optimizer): + """Base class for orthogonalized optimizers. This class is a wrapper around a base optimizer that performs orthogonalization on the updates. - Optionally, it can apply AdamW-style or NorMuon-style second moment accumulation after orthogonalization - by setting `second_moment_method` to "adamuon" or "normuon". When `second_moment_method=None` (default), - the optimizer behaves as standard Muon without second moment. - The theoretical foundation of orthogonalization for stochastic gradient descent was developed by the following papers: @@ -114,10 +102,7 @@ def __init__( weight_decay: float, *, use_nesterov: bool, - second_moment_method: opt_mixin.SecondMomentT = None, - beta2: float = 0.999, - eps: float = 1e-8, - weight_decay_method: opt_mixin.WeightDecayT = "decoupled", + weight_decay_method: opt_mixin.WeightDecayT, fp32_matmul_prec: str, scaled_orthogonalize_fn: Callable | None = None, **kwargs: Any, @@ -128,8 +113,6 @@ def __init__( self.fp32_matmul_prec = fp32_matmul_prec self.use_nesterov = use_nesterov - self.use_second_moment = second_moment_method is not None - self.second_moment_method = second_moment_method self.weight_decay_method = weight_decay_method default_args_dict = dict( @@ -139,11 +122,6 @@ def __init__( **kwargs, ) - # Only add second moment params if second_moment_method is not None - if self.use_second_moment: - default_args_dict["beta2"] = beta2 - default_args_dict["eps"] = eps - super().__init__(params, default_args_dict) self.scaled_orthogonalize_fn = scaled_orthogonalize_fn @@ -172,9 +150,6 @@ def step(self, closure: Callable[[], float] | None = None) -> float | None: # initialize momentum buffer if "momentum_buffer" not in state: state["momentum_buffer"] = torch.zeros_like(grad) - if self.use_second_moment: - # Initialize second moment buffer using mixin - self._initialize_second_moment(state, grad) # Subsequent update to exp_avg are all inplace, so it is not assigned back to state. exp_avg = state["momentum_buffer"] @@ -199,16 +174,6 @@ def step(self, closure: Callable[[], float] | None = None) -> float | None: group_kwargs = {k: v for k, v in group.items() if k != "params"} grad = self.orthogonalize(p, grad, **group_kwargs) - # Apply second moment normalization if enabled - if self.use_second_moment: - # Apply second moment accumulation and normalization - grad = self._apply_second_moment_normalization( - orth_grad=grad, - second_moment=state["second_moment_buffer"], - beta2=group["beta2"], - eps=group["eps"], - ) - # perform weight update # scale is applied to have update RMS == 1 p.add_(grad, alpha=-group["lr"]) diff --git a/tests/test_mixin.py b/tests/test_mixin.py index 4736895..c0bcf82 100644 --- a/tests/test_mixin.py +++ b/tests/test_mixin.py @@ -13,10 +13,11 @@ # See the License for the specific language governing permissions and # limitations under the License. import torch +import torch.nn as nn from absl import flags from absl.testing import absltest, parameterized -from emerging_optimizers import mixin as opt_mixin +from emerging_optimizers.orthogonalized_optimizers import AdaptiveOrthogonalizedOptimizer # Define command line flags @@ -26,15 +27,26 @@ FLAGS = flags.FLAGS -# Create a dummy class that uses SecondMomentMixin for testing -class TestOptimizer(opt_mixin.SecondMomentMixin): - """Test optimizer that inherits from SecondMomentMixin.""" - - def __init__(self, second_moment_method: str = "adamuon"): - self.second_moment_method = second_moment_method +# Create a dummy class for testing second moment methods +class TestAdaptiveOptimizer(AdaptiveOrthogonalizedOptimizer): + """Test optimizer that uses AdaptiveOrthogonalizedOptimizer.""" + + def __init__(self, params, second_moment_method: str = "adamuon"): + super().__init__( + params=params, + lr=0.001, + momentum_beta=0.9, + weight_decay=0.0, + use_nesterov=False, + second_moment_method=second_moment_method, + beta2=0.999, + eps=1e-8, + weight_decay_method="decoupled", + fp32_matmul_prec="highest", + ) -class SecondMomentMixinTest(parameterized.TestCase): +class SecondMomentTest(parameterized.TestCase): def setUp(self): """Set random seed and device before each test.""" torch.manual_seed(FLAGS.seed) @@ -49,7 +61,9 @@ def setUp(self): ) def test_adamuon_method(self, shape, beta2, eps): """Test AdamMuon (elementwise) second moment method.""" - optimizer = TestOptimizer(second_moment_method="adamuon") + # Create a dummy parameter for the optimizer + dummy_param = nn.Parameter(torch.randn(shape, device=self.device)) + optimizer = TestAdaptiveOptimizer([dummy_param], second_moment_method="adamuon") orth_grad = torch.randn(shape, device=self.device) second_moment = torch.zeros_like(orth_grad) @@ -80,7 +94,9 @@ def test_adamuon_method(self, shape, beta2, eps): ) def test_normuon_method(self, shape): """Test NorMuon (row/column-wise) second moment method.""" - optimizer = TestOptimizer(second_moment_method="normuon") + # Create a dummy parameter for the optimizer + dummy_param = nn.Parameter(torch.randn(shape, device=self.device)) + optimizer = TestAdaptiveOptimizer([dummy_param], second_moment_method="normuon") orth_grad = torch.randn(shape, device=self.device) diff --git a/tests/test_orthogonalized_optimizer.py b/tests/test_orthogonalized_optimizer.py index 39312c1..ee3c5bf 100644 --- a/tests/test_orthogonalized_optimizer.py +++ b/tests/test_orthogonalized_optimizer.py @@ -18,6 +18,9 @@ from absl.testing import absltest, parameterized from emerging_optimizers.orthogonalized_optimizers import muon, scion +from emerging_optimizers.orthogonalized_optimizers.adaptive_orthogonalized_optimizer import ( + AdaptiveOrthogonalizedOptimizer, +) from emerging_optimizers.orthogonalized_optimizers.orthogonalized_optimizer import OrthogonalizedOptimizer @@ -249,5 +252,91 @@ def test_smoke(self, shape) -> None: scion_opt.step() +class AdaptiveOrthogonalizedOptimizerTest(parameterized.TestCase): + @parameterized.product( + shape=[(5, 7), (33, 65), (127, 257)], + second_moment_method=["adamuon", "normuon"], + use_nesterov=[True, False], + ) + def test_smoke(self, shape, second_moment_method, use_nesterov) -> None: + """Smoke test AdaptiveOrthogonalizedOptimizer with both second moment methods.""" + test_param = nn.Parameter(torch.randint(-5, 5, shape, dtype=torch.float32, device="cuda")) + test_param.grad = torch.randint_like(test_param, -5, 5) + + adaptive_opt = AdaptiveOrthogonalizedOptimizer( + [test_param], + lr=0.01, + momentum_beta=0.9, + weight_decay=0.01, + use_nesterov=use_nesterov, + second_moment_method=second_moment_method, + beta2=0.999, + eps=1e-8, + weight_decay_method="decoupled", + fp32_matmul_prec="highest", + ) + adaptive_opt.step() + + @parameterized.parameters( + {"shape": (8, 16), "second_moment_method": "adamuon"}, + {"shape": (16, 8), "second_moment_method": "normuon"}, + ) + def test_second_moment_initialization(self, shape, second_moment_method) -> None: + """Test that second moment buffers are properly initialized.""" + test_param = nn.Parameter(torch.randint(-5, 5, shape, dtype=torch.float32, device="cuda")) + test_param.grad = torch.randint_like(test_param, -5, 5) + + adaptive_opt = AdaptiveOrthogonalizedOptimizer( + [test_param], + lr=0.01, + momentum_beta=0.9, + weight_decay=0.0, + use_nesterov=False, + second_moment_method=second_moment_method, + beta2=0.999, + eps=1e-8, + weight_decay_method="decoupled", + fp32_matmul_prec="highest", + ) + + # Run one step to initialize buffers + adaptive_opt.step() + + # Check that second moment buffer was created + state = adaptive_opt.state[test_param] + self.assertIn("second_moment_buffer", state) + self.assertIn("momentum_buffer", state) + + # Check second moment buffer shape + second_moment = state["second_moment_buffer"] + if second_moment_method == "adamuon": + # Full elementwise buffer + self.assertEqual(second_moment.shape, test_param.shape) + elif second_moment_method == "normuon": + # Reduced shape buffer + avg_dim = -1 if shape[-2] >= shape[-1] else -2 + expected_shape = list(shape) + expected_shape[avg_dim] = 1 + self.assertEqual(list(second_moment.shape), expected_shape) + + def test_requires_second_moment_method(self) -> None: + """Test that AdaptiveOrthogonalizedOptimizer requires second_moment_method.""" + test_param = nn.Parameter(torch.randint(-5, 5, (8, 16), dtype=torch.float32, device="cuda")) + + with self.assertRaises(ValueError): + AdaptiveOrthogonalizedOptimizer( + [test_param], + lr=0.01, + momentum_beta=0.9, + weight_decay=0.0, + use_nesterov=False, + second_moment_method=None, # Should raise error + beta2=0.999, + eps=1e-8, + weight_decay_method="decoupled", + fp32_matmul_prec="highest", + ) + + if __name__ == "__main__": absltest.main() From 1b87e0691891f1fac11995914ef52c37eada8ecc Mon Sep 17 00:00:00 2001 From: mikail Date: Wed, 12 Nov 2025 22:14:56 -0800 Subject: [PATCH 05/32] removed mixin test Signed-off-by: mikail --- tests/test_mixin.py | 135 -------------------------------------------- 1 file changed, 135 deletions(-) delete mode 100644 tests/test_mixin.py diff --git a/tests/test_mixin.py b/tests/test_mixin.py deleted file mode 100644 index c0bcf82..0000000 --- a/tests/test_mixin.py +++ /dev/null @@ -1,135 +0,0 @@ -# SPDX-FileCopyrightText: Copyright (c) 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved. -# SPDX-License-Identifier: Apache-2.0 -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. -import torch -import torch.nn as nn -from absl import flags -from absl.testing import absltest, parameterized - -from emerging_optimizers.orthogonalized_optimizers import AdaptiveOrthogonalizedOptimizer - - -# Define command line flags -flags.DEFINE_string("device", "cpu", "Device to run tests on: 'cpu' or 'cuda'") -flags.DEFINE_integer("seed", 42, "Random seed for reproducible tests") - -FLAGS = flags.FLAGS - - -# Create a dummy class for testing second moment methods -class TestAdaptiveOptimizer(AdaptiveOrthogonalizedOptimizer): - """Test optimizer that uses AdaptiveOrthogonalizedOptimizer.""" - - def __init__(self, params, second_moment_method: str = "adamuon"): - super().__init__( - params=params, - lr=0.001, - momentum_beta=0.9, - weight_decay=0.0, - use_nesterov=False, - second_moment_method=second_moment_method, - beta2=0.999, - eps=1e-8, - weight_decay_method="decoupled", - fp32_matmul_prec="highest", - ) - - -class SecondMomentTest(parameterized.TestCase): - def setUp(self): - """Set random seed and device before each test.""" - torch.manual_seed(FLAGS.seed) - if torch.cuda.is_available(): - torch.cuda.manual_seed_all(FLAGS.seed) - self.device = FLAGS.device - - @parameterized.parameters( - {"shape": (8, 16), "beta2": 0.999, "eps": 1e-8}, - {"shape": (32, 64), "beta2": 0.99, "eps": 1e-6}, - {"shape": (4, 4), "beta2": 0.9, "eps": 1e-10}, - ) - def test_adamuon_method(self, shape, beta2, eps): - """Test AdamMuon (elementwise) second moment method.""" - # Create a dummy parameter for the optimizer - dummy_param = nn.Parameter(torch.randn(shape, device=self.device)) - optimizer = TestAdaptiveOptimizer([dummy_param], second_moment_method="adamuon") - - orth_grad = torch.randn(shape, device=self.device) - second_moment = torch.zeros_like(orth_grad) - - # Apply second moment normalization - result = optimizer._apply_second_moment_normalization( - orth_grad=orth_grad, - second_moment=second_moment, - beta2=beta2, - eps=eps, - ) - - # Check that second moment was updated - expected_second_moment = (1 - beta2) * orth_grad.square() - torch.testing.assert_close(second_moment, expected_second_moment, rtol=1e-5, atol=1e-7) - - # Check result shape - self.assertEqual(result.shape, orth_grad.shape) - - # Check that result is computed correctly (elementwise division) - expected_result = orth_grad / (expected_second_moment.sqrt() + eps) - torch.testing.assert_close(result, expected_result, rtol=1e-5, atol=1e-7) - - @parameterized.parameters( - {"shape": (16, 8)}, # rows > cols, should average along -1 - {"shape": (8, 16)}, # cols > rows, should average along -2 - {"shape": (32, 32)}, # square, should average along -1 - ) - def test_normuon_method(self, shape): - """Test NorMuon (row/column-wise) second moment method.""" - # Create a dummy parameter for the optimizer - dummy_param = nn.Parameter(torch.randn(shape, device=self.device)) - optimizer = TestAdaptiveOptimizer([dummy_param], second_moment_method="normuon") - - orth_grad = torch.randn(shape, device=self.device) - - # Determine which dimension should be averaged - avg_dim = -1 if shape[-2] >= shape[-1] else -2 - expected_v_mean = orth_grad.square().mean(dim=avg_dim, keepdim=True) - - # Initialize second moment to zeros with correct shape - second_moment = torch.zeros_like(expected_v_mean) - - beta2 = 0.999 - eps = 1e-8 - - # Apply second moment normalization - result = optimizer._apply_second_moment_normalization( - orth_grad=orth_grad, - second_moment=second_moment, - beta2=beta2, - eps=eps, - ) - - # Check that second moment was updated with correct shape - expected_second_moment = (1 - beta2) * expected_v_mean - torch.testing.assert_close(second_moment, expected_second_moment, rtol=1e-5, atol=1e-7) - - # Check result shape matches input - self.assertEqual(result.shape, orth_grad.shape) - - # Check that result uses reciprocal square root - step_size = expected_second_moment.clamp_min(eps).rsqrt_() - expected_result = orth_grad * step_size - torch.testing.assert_close(result, expected_result, rtol=1e-5, atol=1e-7) - - -if __name__ == "__main__": - absltest.main() From fb802b88a0f6f1416b2dc5a60a97c1bd0bb7958a Mon Sep 17 00:00:00 2001 From: mikail Date: Wed, 12 Nov 2025 22:32:22 -0800 Subject: [PATCH 06/32] cleaned up adaptive orthogonalized optimizer Signed-off-by: mikail --- .../adaptive_orthogonalized_optimizer.py | 40 ++++++++++++++----- 1 file changed, 29 insertions(+), 11 deletions(-) diff --git a/emerging_optimizers/orthogonalized_optimizers/adaptive_orthogonalized_optimizer.py b/emerging_optimizers/orthogonalized_optimizers/adaptive_orthogonalized_optimizer.py index b5ef901..b37f62e 100644 --- a/emerging_optimizers/orthogonalized_optimizers/adaptive_orthogonalized_optimizer.py +++ b/emerging_optimizers/orthogonalized_optimizers/adaptive_orthogonalized_optimizer.py @@ -12,7 +12,7 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. -from typing import Any, Callable, Literal +from typing import Callable, Literal # TODO(@boxiangw): remove this once bump to python 3.12 @@ -22,10 +22,13 @@ from typing_extensions import override import torch +from absl import logging from torch.optim.optimizer import ParamsT from emerging_optimizers import mixin as opt_mixin from emerging_optimizers import utils +from emerging_optimizers.orthogonalized_optimizers import muon_utils +from emerging_optimizers.orthogonalized_optimizers.muon import get_muon_scale_factor from emerging_optimizers.orthogonalized_optimizers.orthogonalized_optimizer import OrthogonalizedOptimizer @@ -73,21 +76,37 @@ def __init__( eps: float = 1e-8, weight_decay_method: opt_mixin.WeightDecayT = "decoupled", fp32_matmul_prec: str, - scaled_orthogonalize_fn: Callable | None = None, - **kwargs: Any, + coefficient_type: str = "quintic", + num_ns_steps: int = 5, + scale_mode: str = "spectral", + extra_scale_factor: float = 1.0, + use_syrk: bool = False, ): self.second_moment_method = second_moment_method + def scaled_orthogonalize_fn(grad: torch.Tensor) -> torch.Tensor: + logging.debug( + f"Orthogonalizing grad with {num_ns_steps} steps, {coefficient_type} coefficient, " + f"{scale_mode} scale mode, extra_scale_factor={extra_scale_factor}" + ) + orth_grad = muon_utils.newton_schulz( + grad, + steps=num_ns_steps, + coefficient_type=coefficient_type, + use_syrk=use_syrk, + ) + scale_factor = get_muon_scale_factor(grad.size(-2), grad.size(-1), mode=scale_mode) + return orth_grad * scale_factor * extra_scale_factor + super().__init__( - params=params, - lr=lr, - momentum_beta=momentum_beta, - weight_decay=weight_decay, + params, + lr, + momentum_beta, use_nesterov=use_nesterov, + weight_decay=weight_decay, weight_decay_method=weight_decay_method, fp32_matmul_prec=fp32_matmul_prec, scaled_orthogonalize_fn=scaled_orthogonalize_fn, - **kwargs, ) for group in self.param_groups: @@ -154,7 +173,7 @@ def _apply_second_moment_normalization( """ if self.second_moment_method == "adamuon": # AdamMuon: Full elementwise second moment like AdamW - # Update second moment with EMA of squared gradient + # Update second moment with EMA of squared orthogonalized gradient second_moment.lerp_(orth_grad.square(), 1 - beta2) # AdamW-style division: grad / (sqrt(second_moment) + eps) @@ -224,8 +243,7 @@ def step(self, closure: Callable[[], float] | None = None) -> float | None: grad = exp_avg with utils.fp32_matmul_precision(self.fp32_matmul_prec): - group_kwargs = {k: v for k, v in group.items() if k != "params"} - grad = self.orthogonalize(p, grad, **group_kwargs) + grad = self.scaled_orthogonalize_fn(grad) # Apply second moment normalization grad = self._apply_second_moment_normalization( From 890e7f632ebcc92a147ab2f8c5004102d9953de6 Mon Sep 17 00:00:00 2001 From: mikail Date: Fri, 14 Nov 2025 09:30:26 -0800 Subject: [PATCH 07/32] changed second moment to moment2, addressed other MR comments Signed-off-by: mikail --- .../adaptive_orthogonalized_optimizer.py | 63 +++++++++---------- 1 file changed, 30 insertions(+), 33 deletions(-) diff --git a/emerging_optimizers/orthogonalized_optimizers/adaptive_orthogonalized_optimizer.py b/emerging_optimizers/orthogonalized_optimizers/adaptive_orthogonalized_optimizer.py index b37f62e..2c1c4f4 100644 --- a/emerging_optimizers/orthogonalized_optimizers/adaptive_orthogonalized_optimizer.py +++ b/emerging_optimizers/orthogonalized_optimizers/adaptive_orthogonalized_optimizer.py @@ -38,7 +38,7 @@ weight_decay: The weight decay used by the optimizer, default to be decoupled weight decay. See Decoupled Weight Decay Regularization: https://arxiv.org/abs/1711.05101 use_nesterov: Whether to use Nesterov-style momentum in the internal SGD. - second_moment_method: Method to apply second moment. Options: "adamuon" (elementwise like AdamW), "normuon" (row/column-wise). + moment2_method: Method to apply second moment. Options: "adamuon" (elementwise like AdamW), "normuon" (row/column-wise). beta2: The exponential decay rate for second moment (like AdamW's beta2). eps: Small constant for numerical stability in second moment normalization. weight_decay_method: Method to apply weight decay, see :class:`~emerging_optimizers.mixin.WeightDecayMixin` @@ -55,7 +55,6 @@ class AdaptiveOrthogonalizedOptimizer(OrthogonalizedOptimizer): normalization logic. Args: - {_adaptive_args_doc} scaled_orthogonalize_fn: Function to orthogonalize and scale the updates. **kwargs: Arguments passed through to the base optimizer. @@ -71,7 +70,7 @@ def __init__( weight_decay: float, *, use_nesterov: bool, - second_moment_method: Literal["adamuon", "normuon"], + moment2_method: Literal["adamuon", "normuon"], beta2: float = 0.999, eps: float = 1e-8, weight_decay_method: opt_mixin.WeightDecayT = "decoupled", @@ -82,7 +81,7 @@ def __init__( extra_scale_factor: float = 1.0, use_syrk: bool = False, ): - self.second_moment_method = second_moment_method + self.moment2_method = moment2_method def scaled_orthogonalize_fn(grad: torch.Tensor) -> torch.Tensor: logging.debug( @@ -107,20 +106,18 @@ def scaled_orthogonalize_fn(grad: torch.Tensor) -> torch.Tensor: weight_decay_method=weight_decay_method, fp32_matmul_prec=fp32_matmul_prec, scaled_orthogonalize_fn=scaled_orthogonalize_fn, + beta2=beta2, + eps=eps, ) - for group in self.param_groups: - group["beta2"] = beta2 - group["eps"] = eps - - def _initialize_second_moment( + def _initialize_moment2( self, state: dict[str, torch.Tensor], grad: torch.Tensor, ) -> None: """Initialize the second moment buffer if it doesn't exist. - The shape of the buffer depends on the second_moment_method: + The shape of the buffer depends on the moment2_method: - "adamuon": Full elementwise buffer with same shape as grad - "normuon": Reduced shape buffer (averaged along -1 if shape[-2] >= shape[-1], else -2) @@ -128,27 +125,27 @@ def _initialize_second_moment( state: The optimizer state dict for a parameter. grad: The gradient tensor (used for shape/dtype). """ - if "second_moment_buffer" not in state: - if self.second_moment_method == "adamuon": + if "moment2_buffer" not in state: + if self.moment2_method == "adamuon": # Full elementwise second moment - second_moment = torch.zeros_like(grad) - elif self.second_moment_method == "normuon": + moment2 = torch.zeros_like(grad) + elif self.moment2_method == "normuon": # Row/column-wise second moment - reduced along one dimension # Determine which dimension to reduce based on parameter shape avg_dim = -1 if grad.shape[-2] >= grad.shape[-1] else -2 # Specify the shape with reduced dimension - second_moment_shape = list(grad.shape) - second_moment_shape[avg_dim] = 1 - second_moment = torch.zeros(second_moment_shape, dtype=grad.dtype, device=grad.device) + moment2_shape = list(grad.shape) + moment2_shape[avg_dim] = 1 + moment2 = torch.zeros(moment2_shape, dtype=grad.dtype, device=grad.device) else: - raise ValueError(f"Invalid second moment method: {self.second_moment_method}") + raise ValueError(f"Invalid second moment method: {self.moment2_method}") - state["second_moment_buffer"] = second_moment + state["moment2_buffer"] = moment2 - def _apply_second_moment_normalization( + def _apply_moment2_normalization( self, orth_grad: torch.Tensor, - second_moment: torch.Tensor, + moment2: torch.Tensor, beta2: float, eps: float, ) -> torch.Tensor: @@ -164,23 +161,23 @@ def _apply_second_moment_normalization( Args: orth_grad: The orthogonalized gradient tensor. - second_moment: The second moment buffer from state. + moment2: The second moment buffer from state. beta2: The exponential decay rate for second moment. eps: Small constant for numerical stability. Returns: The adaptively scaled weight update tensor. """ - if self.second_moment_method == "adamuon": + if self.moment2_method == "adamuon": # AdamMuon: Full elementwise second moment like AdamW # Update second moment with EMA of squared orthogonalized gradient - second_moment.lerp_(orth_grad.square(), 1 - beta2) + moment2.lerp_(orth_grad.square(), 1 - beta2) - # AdamW-style division: grad / (sqrt(second_moment) + eps) - denom = second_moment.sqrt() + eps + # AdamW-style division: grad / (sqrt(moment2) + eps) + denom = moment2.sqrt() + eps return orth_grad / denom - elif self.second_moment_method == "normuon": + elif self.moment2_method == "normuon": # NorMuon: Row or column-wise second moment # Compute mean of squared gradients along one dimension based on shape # Average along the longer dimension to preserve structure along shorter dim @@ -188,14 +185,14 @@ def _apply_second_moment_normalization( v_mean = orth_grad.square().mean(dim=avg_dim, keepdim=True) # Update second moment with EMA - second_moment.lerp_(v_mean, 1 - beta2) + moment2.lerp_(v_mean, 1 - beta2) # NorMuon uses reciprocal square root with clamping - step_size = second_moment.clamp_min(eps).rsqrt_() + step_size = moment2.clamp_min(eps).rsqrt_() return orth_grad * step_size else: - raise ValueError(f"Invalid second moment method: {self.second_moment_method}") + raise ValueError(f"Invalid second moment method: {self.moment2_method}") @torch.no_grad() # type: ignore[misc] @override @@ -222,7 +219,7 @@ def step(self, closure: Callable[[], float] | None = None) -> float | None: # initialize momentum buffer and second moment buffer if "momentum_buffer" not in state: state["momentum_buffer"] = torch.zeros_like(grad) - self._initialize_second_moment(state, grad) + self._initialize_moment2(state, grad) exp_avg = state["momentum_buffer"] @@ -246,9 +243,9 @@ def step(self, closure: Callable[[], float] | None = None) -> float | None: grad = self.scaled_orthogonalize_fn(grad) # Apply second moment normalization - grad = self._apply_second_moment_normalization( + grad = self._apply_moment2_normalization( orth_grad=grad, - second_moment=state["second_moment_buffer"], + moment2=state["moment2_buffer"], beta2=group["beta2"], eps=group["eps"], ) From a47a325d8bd4e5fbe6a59dc7e1811ded87dc81b6 Mon Sep 17 00:00:00 2001 From: mikail Date: Fri, 14 Nov 2025 10:37:01 -0800 Subject: [PATCH 08/32] removed args doc Signed-off-by: mikail --- .../adaptive_orthogonalized_optimizer.py | 20 ------------------- 1 file changed, 20 deletions(-) diff --git a/emerging_optimizers/orthogonalized_optimizers/adaptive_orthogonalized_optimizer.py b/emerging_optimizers/orthogonalized_optimizers/adaptive_orthogonalized_optimizer.py index 2c1c4f4..efc3603 100644 --- a/emerging_optimizers/orthogonalized_optimizers/adaptive_orthogonalized_optimizer.py +++ b/emerging_optimizers/orthogonalized_optimizers/adaptive_orthogonalized_optimizer.py @@ -32,21 +32,6 @@ from emerging_optimizers.orthogonalized_optimizers.orthogonalized_optimizer import OrthogonalizedOptimizer -_adaptive_args_doc = """params: Iterable of parameters to optimize or dicts defining parameter groups - lr: The learning rate used by the internal SGD. - momentum_beta: The momentum used by the internal SGD. - weight_decay: The weight decay used by the optimizer, default to be decoupled weight decay. - See Decoupled Weight Decay Regularization: https://arxiv.org/abs/1711.05101 - use_nesterov: Whether to use Nesterov-style momentum in the internal SGD. - moment2_method: Method to apply second moment. Options: "adamuon" (elementwise like AdamW), "normuon" (row/column-wise). - beta2: The exponential decay rate for second moment (like AdamW's beta2). - eps: Small constant for numerical stability in second moment normalization. - weight_decay_method: Method to apply weight decay, see :class:`~emerging_optimizers.mixin.WeightDecayMixin` - for more details. - fp32_matmul_prec: Precision of the matmul operations in optimizer states GEMM operations. -""" - - class AdaptiveOrthogonalizedOptimizer(OrthogonalizedOptimizer): """Orthogonalized optimizer with adaptive second moment (AdaMuon/NorMuon variants). @@ -255,8 +240,3 @@ def step(self, closure: Callable[[], float] | None = None) -> float | None: p.add_(grad, alpha=-group["lr"]) return loss - - -AdaptiveOrthogonalizedOptimizer.__doc__ = AdaptiveOrthogonalizedOptimizer.__doc__.format( # type: ignore[union-attr] - _adaptive_args_doc=_adaptive_args_doc -) From b9736d8354823f4ccd0aa90403579042662b911b Mon Sep 17 00:00:00 2001 From: mikail Date: Sat, 15 Nov 2025 10:10:58 -0800 Subject: [PATCH 09/32] made a separate test file for adaptive orthogonalized optimizer Signed-off-by: mikail --- .../test_adaptive_orthogonalized_optimizer.py | 97 +++++++++++++++++++ 1 file changed, 97 insertions(+) create mode 100644 tests/test_adaptive_orthogonalized_optimizer.py diff --git a/tests/test_adaptive_orthogonalized_optimizer.py b/tests/test_adaptive_orthogonalized_optimizer.py new file mode 100644 index 0000000..5ef680e --- /dev/null +++ b/tests/test_adaptive_orthogonalized_optimizer.py @@ -0,0 +1,97 @@ +import torch +import torch.nn as nn +from absl.testing import absltest, parameterized + +from emerging_optimizers.orthogonalized_optimizers.adaptive_orthogonalized_optimizer import ( + AdaptiveOrthogonalizedOptimizer, +) + + +class AdaptiveOrthogonalizedOptimizerTest(parameterized.TestCase): + @parameterized.product( + shape=[(5, 7), (33, 65), (127, 257)], + second_moment_method=["adamuon", "normuon"], + use_nesterov=[True, False], + ) + def test_smoke(self, shape, second_moment_method, use_nesterov) -> None: + """Smoke test AdaptiveOrthogonalizedOptimizer with both second moment methods.""" + test_param = nn.Parameter(torch.randint(-5, 5, shape, dtype=torch.float32, device="cuda")) + test_param.grad = torch.randint_like(test_param, -5, 5) + + adaptive_opt = AdaptiveOrthogonalizedOptimizer( + [test_param], + lr=0.01, + momentum_beta=0.9, + weight_decay=0.01, + use_nesterov=use_nesterov, + moment2_method=second_moment_method, + beta2=0.999, + eps=1e-8, + weight_decay_method="decoupled", + fp32_matmul_prec="highest", + ) + adaptive_opt.step() + + @parameterized.parameters( + {"shape": (8, 16), "second_moment_method": "adamuon"}, + {"shape": (16, 8), "second_moment_method": "normuon"}, + ) + def test_second_moment_matches_shapes(self, shape, second_moment_method) -> None: + """Test that second moment buffers are properly initialized.""" + test_param = nn.Parameter(torch.randint(-5, 5, shape, dtype=torch.float32, device="cuda")) + test_param.grad = torch.randint_like(test_param, -5, 5) + + adaptive_opt = AdaptiveOrthogonalizedOptimizer( + [test_param], + lr=0.01, + momentum_beta=0.9, + weight_decay=0.0, + use_nesterov=False, + moment2_method=second_moment_method, + beta2=0.999, + eps=1e-8, + weight_decay_method="decoupled", + fp32_matmul_prec="highest", + ) + + # Run one step to initialize buffers + adaptive_opt.step() + + # Check that second moment buffer was created + state = adaptive_opt.state[test_param] + self.assertIn("moment2_buffer", state) + self.assertIn("momentum_buffer", state) + + # Check second moment buffer shape + second_moment = state["moment2_buffer"] + if second_moment_method == "adamuon": + # Full elementwise buffer + self.assertEqual(second_moment.shape, test_param.shape) + elif second_moment_method == "normuon": + # Reduced shape buffer + avg_dim = -1 if shape[-2] >= shape[-1] else -2 + expected_shape = list(shape) + expected_shape[avg_dim] = 1 + self.assertEqual(list(second_moment.shape), expected_shape) + + def test_requires_second_moment_method(self) -> None: + """Test that AdaptiveOrthogonalizedOptimizer requires second_moment_method.""" + test_param = nn.Parameter(torch.randint(-5, 5, (8, 16), dtype=torch.float32, device="cuda")) + + with self.assertRaises(TypeError): + AdaptiveOrthogonalizedOptimizer( + [test_param], + lr=0.01, + momentum_beta=0.9, + weight_decay=0.0, + use_nesterov=False, + moment2_method=None, # Should raise error + beta2=0.999, + eps=1e-8, + weight_decay_method="decoupled", + fp32_matmul_prec="highest", + ) + + +if __name__ == "__main__": + absltest.main() From a626188310cac96ac2a1a83d602e8884dccc3902 Mon Sep 17 00:00:00 2001 From: mikail Date: Sat, 15 Nov 2025 10:12:18 -0800 Subject: [PATCH 10/32] added device flag Signed-off-by: mikail --- tests/test_adaptive_orthogonalized_optimizer.py | 15 ++++++++++++--- 1 file changed, 12 insertions(+), 3 deletions(-) diff --git a/tests/test_adaptive_orthogonalized_optimizer.py b/tests/test_adaptive_orthogonalized_optimizer.py index 5ef680e..c347787 100644 --- a/tests/test_adaptive_orthogonalized_optimizer.py +++ b/tests/test_adaptive_orthogonalized_optimizer.py @@ -1,5 +1,6 @@ import torch import torch.nn as nn +from absl import flags from absl.testing import absltest, parameterized from emerging_optimizers.orthogonalized_optimizers.adaptive_orthogonalized_optimizer import ( @@ -7,6 +8,14 @@ ) +# Define command line flags +flags.DEFINE_string("device", "cpu", "Device to run tests on: 'cpu' or 'cuda'") + +FLAGS = flags.FLAGS + +device = FLAGS.device + + class AdaptiveOrthogonalizedOptimizerTest(parameterized.TestCase): @parameterized.product( shape=[(5, 7), (33, 65), (127, 257)], @@ -15,7 +24,7 @@ class AdaptiveOrthogonalizedOptimizerTest(parameterized.TestCase): ) def test_smoke(self, shape, second_moment_method, use_nesterov) -> None: """Smoke test AdaptiveOrthogonalizedOptimizer with both second moment methods.""" - test_param = nn.Parameter(torch.randint(-5, 5, shape, dtype=torch.float32, device="cuda")) + test_param = nn.Parameter(torch.randint(-5, 5, shape, dtype=torch.float32, device=device)) test_param.grad = torch.randint_like(test_param, -5, 5) adaptive_opt = AdaptiveOrthogonalizedOptimizer( @@ -38,7 +47,7 @@ def test_smoke(self, shape, second_moment_method, use_nesterov) -> None: ) def test_second_moment_matches_shapes(self, shape, second_moment_method) -> None: """Test that second moment buffers are properly initialized.""" - test_param = nn.Parameter(torch.randint(-5, 5, shape, dtype=torch.float32, device="cuda")) + test_param = nn.Parameter(torch.randint(-5, 5, shape, dtype=torch.float32, device=device)) test_param.grad = torch.randint_like(test_param, -5, 5) adaptive_opt = AdaptiveOrthogonalizedOptimizer( @@ -76,7 +85,7 @@ def test_second_moment_matches_shapes(self, shape, second_moment_method) -> None def test_requires_second_moment_method(self) -> None: """Test that AdaptiveOrthogonalizedOptimizer requires second_moment_method.""" - test_param = nn.Parameter(torch.randint(-5, 5, (8, 16), dtype=torch.float32, device="cuda")) + test_param = nn.Parameter(torch.randint(-5, 5, (8, 16), dtype=torch.float32, device=device)) with self.assertRaises(TypeError): AdaptiveOrthogonalizedOptimizer( From 159505e944edb71356afba92ddc90a9a199b3bf0 Mon Sep 17 00:00:00 2001 From: mikail Date: Sat, 15 Nov 2025 10:15:20 -0800 Subject: [PATCH 11/32] changed test name Signed-off-by: mikail --- tests/test_adaptive_orthogonalized_optimizer.py | 6 +++--- 1 file changed, 3 insertions(+), 3 deletions(-) diff --git a/tests/test_adaptive_orthogonalized_optimizer.py b/tests/test_adaptive_orthogonalized_optimizer.py index c347787..9cfb4ab 100644 --- a/tests/test_adaptive_orthogonalized_optimizer.py +++ b/tests/test_adaptive_orthogonalized_optimizer.py @@ -83,8 +83,8 @@ def test_second_moment_matches_shapes(self, shape, second_moment_method) -> None expected_shape[avg_dim] = 1 self.assertEqual(list(second_moment.shape), expected_shape) - def test_requires_second_moment_method(self) -> None: - """Test that AdaptiveOrthogonalizedOptimizer requires second_moment_method.""" + def test_unknown_moment2_method_raise_type_error(self) -> None: + """Test that AdaptiveOrthogonalizedOptimizer raises TypeError for unknown moment2_method.""" test_param = nn.Parameter(torch.randint(-5, 5, (8, 16), dtype=torch.float32, device=device)) with self.assertRaises(TypeError): @@ -94,7 +94,7 @@ def test_requires_second_moment_method(self) -> None: momentum_beta=0.9, weight_decay=0.0, use_nesterov=False, - moment2_method=None, # Should raise error + moment2_method=None, beta2=0.999, eps=1e-8, weight_decay_method="decoupled", From c02621f6b9f7004e76995569fa227081a95c3f0c Mon Sep 17 00:00:00 2001 From: mikail Date: Sat, 15 Nov 2025 10:16:10 -0800 Subject: [PATCH 12/32] changed value error to type error Signed-off-by: mikail --- .../adaptive_orthogonalized_optimizer.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/emerging_optimizers/orthogonalized_optimizers/adaptive_orthogonalized_optimizer.py b/emerging_optimizers/orthogonalized_optimizers/adaptive_orthogonalized_optimizer.py index efc3603..5a4a42e 100644 --- a/emerging_optimizers/orthogonalized_optimizers/adaptive_orthogonalized_optimizer.py +++ b/emerging_optimizers/orthogonalized_optimizers/adaptive_orthogonalized_optimizer.py @@ -123,7 +123,7 @@ def _initialize_moment2( moment2_shape[avg_dim] = 1 moment2 = torch.zeros(moment2_shape, dtype=grad.dtype, device=grad.device) else: - raise ValueError(f"Invalid second moment method: {self.moment2_method}") + raise TypeError(f"Invalid second moment method: {self.moment2_method}") state["moment2_buffer"] = moment2 @@ -177,7 +177,7 @@ def _apply_moment2_normalization( return orth_grad * step_size else: - raise ValueError(f"Invalid second moment method: {self.moment2_method}") + raise TypeError(f"Invalid second moment method: {self.moment2_method}") @torch.no_grad() # type: ignore[misc] @override From 3a5a9ba6f3721c835ac17b842ec834f8fbb2cc61 Mon Sep 17 00:00:00 2001 From: mikail Date: Sat, 15 Nov 2025 11:22:26 -0800 Subject: [PATCH 13/32] removed adaptive test Signed-off-by: mikail --- tests/test_orthogonalized_optimizer.py | 89 -------------------------- 1 file changed, 89 deletions(-) diff --git a/tests/test_orthogonalized_optimizer.py b/tests/test_orthogonalized_optimizer.py index ee3c5bf..39312c1 100644 --- a/tests/test_orthogonalized_optimizer.py +++ b/tests/test_orthogonalized_optimizer.py @@ -18,9 +18,6 @@ from absl.testing import absltest, parameterized from emerging_optimizers.orthogonalized_optimizers import muon, scion -from emerging_optimizers.orthogonalized_optimizers.adaptive_orthogonalized_optimizer import ( - AdaptiveOrthogonalizedOptimizer, -) from emerging_optimizers.orthogonalized_optimizers.orthogonalized_optimizer import OrthogonalizedOptimizer @@ -252,91 +249,5 @@ def test_smoke(self, shape) -> None: scion_opt.step() -class AdaptiveOrthogonalizedOptimizerTest(parameterized.TestCase): - @parameterized.product( - shape=[(5, 7), (33, 65), (127, 257)], - second_moment_method=["adamuon", "normuon"], - use_nesterov=[True, False], - ) - def test_smoke(self, shape, second_moment_method, use_nesterov) -> None: - """Smoke test AdaptiveOrthogonalizedOptimizer with both second moment methods.""" - test_param = nn.Parameter(torch.randint(-5, 5, shape, dtype=torch.float32, device="cuda")) - test_param.grad = torch.randint_like(test_param, -5, 5) - - adaptive_opt = AdaptiveOrthogonalizedOptimizer( - [test_param], - lr=0.01, - momentum_beta=0.9, - weight_decay=0.01, - use_nesterov=use_nesterov, - second_moment_method=second_moment_method, - beta2=0.999, - eps=1e-8, - weight_decay_method="decoupled", - fp32_matmul_prec="highest", - ) - adaptive_opt.step() - - @parameterized.parameters( - {"shape": (8, 16), "second_moment_method": "adamuon"}, - {"shape": (16, 8), "second_moment_method": "normuon"}, - ) - def test_second_moment_initialization(self, shape, second_moment_method) -> None: - """Test that second moment buffers are properly initialized.""" - test_param = nn.Parameter(torch.randint(-5, 5, shape, dtype=torch.float32, device="cuda")) - test_param.grad = torch.randint_like(test_param, -5, 5) - - adaptive_opt = AdaptiveOrthogonalizedOptimizer( - [test_param], - lr=0.01, - momentum_beta=0.9, - weight_decay=0.0, - use_nesterov=False, - second_moment_method=second_moment_method, - beta2=0.999, - eps=1e-8, - weight_decay_method="decoupled", - fp32_matmul_prec="highest", - ) - - # Run one step to initialize buffers - adaptive_opt.step() - - # Check that second moment buffer was created - state = adaptive_opt.state[test_param] - self.assertIn("second_moment_buffer", state) - self.assertIn("momentum_buffer", state) - - # Check second moment buffer shape - second_moment = state["second_moment_buffer"] - if second_moment_method == "adamuon": - # Full elementwise buffer - self.assertEqual(second_moment.shape, test_param.shape) - elif second_moment_method == "normuon": - # Reduced shape buffer - avg_dim = -1 if shape[-2] >= shape[-1] else -2 - expected_shape = list(shape) - expected_shape[avg_dim] = 1 - self.assertEqual(list(second_moment.shape), expected_shape) - - def test_requires_second_moment_method(self) -> None: - """Test that AdaptiveOrthogonalizedOptimizer requires second_moment_method.""" - test_param = nn.Parameter(torch.randint(-5, 5, (8, 16), dtype=torch.float32, device="cuda")) - - with self.assertRaises(ValueError): - AdaptiveOrthogonalizedOptimizer( - [test_param], - lr=0.01, - momentum_beta=0.9, - weight_decay=0.0, - use_nesterov=False, - second_moment_method=None, # Should raise error - beta2=0.999, - eps=1e-8, - weight_decay_method="decoupled", - fp32_matmul_prec="highest", - ) - - if __name__ == "__main__": absltest.main() From 377f510ea2deed7b6e069989da238ecbb74ec8d4 Mon Sep 17 00:00:00 2001 From: mikail Date: Mon, 17 Nov 2025 17:06:21 -0800 Subject: [PATCH 14/32] changed subclass to Muon instead of OrthogonalizedOptimizer Signed-off-by: mikail --- .../adaptive_orthogonalized_optimizer.py | 62 +++++++++---------- 1 file changed, 31 insertions(+), 31 deletions(-) diff --git a/emerging_optimizers/orthogonalized_optimizers/adaptive_orthogonalized_optimizer.py b/emerging_optimizers/orthogonalized_optimizers/adaptive_orthogonalized_optimizer.py index 5a4a42e..72f0ead 100644 --- a/emerging_optimizers/orthogonalized_optimizers/adaptive_orthogonalized_optimizer.py +++ b/emerging_optimizers/orthogonalized_optimizers/adaptive_orthogonalized_optimizer.py @@ -22,29 +22,36 @@ from typing_extensions import override import torch -from absl import logging from torch.optim.optimizer import ParamsT from emerging_optimizers import mixin as opt_mixin from emerging_optimizers import utils -from emerging_optimizers.orthogonalized_optimizers import muon_utils -from emerging_optimizers.orthogonalized_optimizers.muon import get_muon_scale_factor -from emerging_optimizers.orthogonalized_optimizers.orthogonalized_optimizer import OrthogonalizedOptimizer +from emerging_optimizers.orthogonalized_optimizers.muon import Muon -class AdaptiveOrthogonalizedOptimizer(OrthogonalizedOptimizer): +class AdaptiveOrthogonalizedOptimizer(Muon): """Orthogonalized optimizer with adaptive second moment (AdaMuon/NorMuon variants). - This class extends OrthogonalizedOptimizer by adding AdamW-style or NorMuon-style second moment + This class extends Muon by adding AdamW-style or NorMuon-style second moment accumulation after orthogonalization. The step() method is overridden to include second moment normalization logic. Args: - scaled_orthogonalize_fn: Function to orthogonalize and scale the updates. - **kwargs: Arguments passed through to the base optimizer. - - Note: - Keyword arguments passed through are not checked here. Optimizer inherited from this class should check them. + params: Iterable of parameters to optimize or dicts defining parameter groups. + lr: Learning rate. + momentum_beta: The exponential decay rate for momentum. + weight_decay: Weight decay coefficient. + use_nesterov: Whether to use Nesterov momentum. + moment2_method: Method for second moment accumulation ("adamuon" or "normuon"). + beta2: The exponential decay rate for second moment. + eps: Small constant for numerical stability. + weight_decay_method: The weight decay method to use. + fp32_matmul_prec: Precision for FP32 matrix multiplication. + coefficient_type: The type of coefficient set to use for the Newton-Schulz iteration. + num_ns_steps: The number of iteration steps to use in the Newton-Schulz iteration. + scale_mode: The type of scale factor to use for the update. + extra_scale_factor: The additional scale factor to use for the update. + use_syrk: Whether to use the Triton kernel for the Newton-Schulz iteration. """ def __init__( @@ -68,33 +75,26 @@ def __init__( ): self.moment2_method = moment2_method - def scaled_orthogonalize_fn(grad: torch.Tensor) -> torch.Tensor: - logging.debug( - f"Orthogonalizing grad with {num_ns_steps} steps, {coefficient_type} coefficient, " - f"{scale_mode} scale mode, extra_scale_factor={extra_scale_factor}" - ) - orth_grad = muon_utils.newton_schulz( - grad, - steps=num_ns_steps, - coefficient_type=coefficient_type, - use_syrk=use_syrk, - ) - scale_factor = get_muon_scale_factor(grad.size(-2), grad.size(-1), mode=scale_mode) - return orth_grad * scale_factor * extra_scale_factor - super().__init__( params, - lr, - momentum_beta, - use_nesterov=use_nesterov, + lr=lr, + momentum_beta=momentum_beta, weight_decay=weight_decay, + use_nesterov=use_nesterov, weight_decay_method=weight_decay_method, fp32_matmul_prec=fp32_matmul_prec, - scaled_orthogonalize_fn=scaled_orthogonalize_fn, - beta2=beta2, - eps=eps, + coefficient_type=coefficient_type, + num_ns_steps=num_ns_steps, + scale_mode=scale_mode, + extra_scale_factor=extra_scale_factor, + use_syrk=use_syrk, ) + # Add beta2 and eps to param_groups defaults + for group in self.param_groups: + group.setdefault("beta2", beta2) + group.setdefault("eps", eps) + def _initialize_moment2( self, state: dict[str, torch.Tensor], From c1446c2bd6cfc090669961c2a5fe52b9cf525137 Mon Sep 17 00:00:00 2001 From: mikail Date: Mon, 17 Nov 2025 17:28:24 -0800 Subject: [PATCH 15/32] changed name to adaptivemuon, updated tests Signed-off-by: mikail --- .../orthogonalized_optimizers/__init__.py | 2 +- ...rthogonalized_optimizer.py => adaptive_muon.py} | 6 +++--- tests/test_adaptive_orthogonalized_optimizer.py | 14 +++++++------- 3 files changed, 11 insertions(+), 11 deletions(-) rename emerging_optimizers/orthogonalized_optimizers/{adaptive_orthogonalized_optimizer.py => adaptive_muon.py} (97%) diff --git a/emerging_optimizers/orthogonalized_optimizers/__init__.py b/emerging_optimizers/orthogonalized_optimizers/__init__.py index 274defa..d0eddf8 100644 --- a/emerging_optimizers/orthogonalized_optimizers/__init__.py +++ b/emerging_optimizers/orthogonalized_optimizers/__init__.py @@ -12,7 +12,7 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. -from emerging_optimizers.orthogonalized_optimizers.adaptive_orthogonalized_optimizer import * +from emerging_optimizers.orthogonalized_optimizers.adaptive_muon import * from emerging_optimizers.orthogonalized_optimizers.muon import * from emerging_optimizers.orthogonalized_optimizers.orthogonalized_optimizer import * from emerging_optimizers.orthogonalized_optimizers.scion import * diff --git a/emerging_optimizers/orthogonalized_optimizers/adaptive_orthogonalized_optimizer.py b/emerging_optimizers/orthogonalized_optimizers/adaptive_muon.py similarity index 97% rename from emerging_optimizers/orthogonalized_optimizers/adaptive_orthogonalized_optimizer.py rename to emerging_optimizers/orthogonalized_optimizers/adaptive_muon.py index 72f0ead..fcdf058 100644 --- a/emerging_optimizers/orthogonalized_optimizers/adaptive_orthogonalized_optimizer.py +++ b/emerging_optimizers/orthogonalized_optimizers/adaptive_muon.py @@ -29,8 +29,8 @@ from emerging_optimizers.orthogonalized_optimizers.muon import Muon -class AdaptiveOrthogonalizedOptimizer(Muon): - """Orthogonalized optimizer with adaptive second moment (AdaMuon/NorMuon variants). +class AdaptiveMuon(Muon): + """Adaptive Muon optimizer with adaptive second moment (AdaMuon/NorMuon variants). This class extends Muon by adding AdamW-style or NorMuon-style second moment accumulation after orthogonalization. The step() method is overridden to include second moment @@ -195,7 +195,7 @@ def step(self, closure: Callable[[], float] | None = None) -> float | None: for group in self.param_groups: for p in group["params"]: if p.dim() == 1: - raise ValueError(f"{self.__class__.__name__} does not support 1D parameters") + raise ValueError("AdaptiveMuon does not support 1D parameters") grad = p.grad if grad is None: continue diff --git a/tests/test_adaptive_orthogonalized_optimizer.py b/tests/test_adaptive_orthogonalized_optimizer.py index 9cfb4ab..82778fc 100644 --- a/tests/test_adaptive_orthogonalized_optimizer.py +++ b/tests/test_adaptive_orthogonalized_optimizer.py @@ -4,7 +4,7 @@ from absl.testing import absltest, parameterized from emerging_optimizers.orthogonalized_optimizers.adaptive_orthogonalized_optimizer import ( - AdaptiveOrthogonalizedOptimizer, + AdaptiveMuon, ) @@ -16,18 +16,18 @@ device = FLAGS.device -class AdaptiveOrthogonalizedOptimizerTest(parameterized.TestCase): +class AdaptiveMuonTest(parameterized.TestCase): @parameterized.product( shape=[(5, 7), (33, 65), (127, 257)], second_moment_method=["adamuon", "normuon"], use_nesterov=[True, False], ) def test_smoke(self, shape, second_moment_method, use_nesterov) -> None: - """Smoke test AdaptiveOrthogonalizedOptimizer with both second moment methods.""" + """Smoke test AdaptiveMuon with both second moment methods.""" test_param = nn.Parameter(torch.randint(-5, 5, shape, dtype=torch.float32, device=device)) test_param.grad = torch.randint_like(test_param, -5, 5) - adaptive_opt = AdaptiveOrthogonalizedOptimizer( + adaptive_opt = AdaptiveMuon( [test_param], lr=0.01, momentum_beta=0.9, @@ -50,7 +50,7 @@ def test_second_moment_matches_shapes(self, shape, second_moment_method) -> None test_param = nn.Parameter(torch.randint(-5, 5, shape, dtype=torch.float32, device=device)) test_param.grad = torch.randint_like(test_param, -5, 5) - adaptive_opt = AdaptiveOrthogonalizedOptimizer( + adaptive_opt = AdaptiveMuon( [test_param], lr=0.01, momentum_beta=0.9, @@ -84,11 +84,11 @@ def test_second_moment_matches_shapes(self, shape, second_moment_method) -> None self.assertEqual(list(second_moment.shape), expected_shape) def test_unknown_moment2_method_raise_type_error(self) -> None: - """Test that AdaptiveOrthogonalizedOptimizer raises TypeError for unknown moment2_method.""" + """Test that AdaptiveMuon raises TypeError for unknown moment2_method.""" test_param = nn.Parameter(torch.randint(-5, 5, (8, 16), dtype=torch.float32, device=device)) with self.assertRaises(TypeError): - AdaptiveOrthogonalizedOptimizer( + AdaptiveMuon( [test_param], lr=0.01, momentum_beta=0.9, From 6c88ac72d6e3b16d92a818ee307ac66eea189966 Mon Sep 17 00:00:00 2001 From: mikail Date: Mon, 17 Nov 2025 17:32:00 -0800 Subject: [PATCH 16/32] updated test to import from adaptive_muon Signed-off-by: mikail --- tests/test_adaptive_orthogonalized_optimizer.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/tests/test_adaptive_orthogonalized_optimizer.py b/tests/test_adaptive_orthogonalized_optimizer.py index 82778fc..9453f9f 100644 --- a/tests/test_adaptive_orthogonalized_optimizer.py +++ b/tests/test_adaptive_orthogonalized_optimizer.py @@ -3,7 +3,7 @@ from absl import flags from absl.testing import absltest, parameterized -from emerging_optimizers.orthogonalized_optimizers.adaptive_orthogonalized_optimizer import ( +from emerging_optimizers.orthogonalized_optimizers.adaptive_muon import ( AdaptiveMuon, ) From 530bff10e0ffb7b9b6c0336b28c58e35c398a5db Mon Sep 17 00:00:00 2001 From: mikail Date: Mon, 17 Nov 2025 17:33:32 -0800 Subject: [PATCH 17/32] added missing copyright Signed-off-by: mikail --- tests/test_adaptive_orthogonalized_optimizer.py | 14 ++++++++++++++ 1 file changed, 14 insertions(+) diff --git a/tests/test_adaptive_orthogonalized_optimizer.py b/tests/test_adaptive_orthogonalized_optimizer.py index 9453f9f..a4a7bdc 100644 --- a/tests/test_adaptive_orthogonalized_optimizer.py +++ b/tests/test_adaptive_orthogonalized_optimizer.py @@ -1,3 +1,17 @@ +# SPDX-FileCopyrightText: Copyright (c) 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# SPDX-License-Identifier: Apache-2.0 +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. import torch import torch.nn as nn from absl import flags From 8f46a91dd02fccab4638cb1121e98b4b7c718228 Mon Sep 17 00:00:00 2001 From: mikail Date: Mon, 17 Nov 2025 17:45:21 -0800 Subject: [PATCH 18/32] changed scale mode to 1.0, added it as a scale mode Signed-off-by: mikail --- .../orthogonalized_optimizers/adaptive_muon.py | 4 ++-- emerging_optimizers/orthogonalized_optimizers/muon.py | 2 ++ 2 files changed, 4 insertions(+), 2 deletions(-) diff --git a/emerging_optimizers/orthogonalized_optimizers/adaptive_muon.py b/emerging_optimizers/orthogonalized_optimizers/adaptive_muon.py index fcdf058..82b482b 100644 --- a/emerging_optimizers/orthogonalized_optimizers/adaptive_muon.py +++ b/emerging_optimizers/orthogonalized_optimizers/adaptive_muon.py @@ -49,7 +49,7 @@ class AdaptiveMuon(Muon): fp32_matmul_prec: Precision for FP32 matrix multiplication. coefficient_type: The type of coefficient set to use for the Newton-Schulz iteration. num_ns_steps: The number of iteration steps to use in the Newton-Schulz iteration. - scale_mode: The type of scale factor to use for the update. + scale_mode: The type of scale factor to use for the update. Defaults to "none" to avoid scaling the update. extra_scale_factor: The additional scale factor to use for the update. use_syrk: Whether to use the Triton kernel for the Newton-Schulz iteration. """ @@ -69,7 +69,7 @@ def __init__( fp32_matmul_prec: str, coefficient_type: str = "quintic", num_ns_steps: int = 5, - scale_mode: str = "spectral", + scale_mode: str = "none", extra_scale_factor: float = 1.0, use_syrk: bool = False, ): diff --git a/emerging_optimizers/orthogonalized_optimizers/muon.py b/emerging_optimizers/orthogonalized_optimizers/muon.py index 8200fb1..ea8e737 100644 --- a/emerging_optimizers/orthogonalized_optimizers/muon.py +++ b/emerging_optimizers/orthogonalized_optimizers/muon.py @@ -147,5 +147,7 @@ def get_muon_scale_factor(size_out: int, size_in: int, mode: str = "spectral") - # Suggested by Scion (https://arxiv.org/abs/2502.07529) and Bernstein et al. # (https://jeremybernste.in/writing/deriving-muon) return (size_out / size_in) ** 0.5 + elif mode == "none": + return 1.0 else: raise ValueError(f"Invalid mode for Muon update scale factor: {mode}") From f4da4f16832fb8980554a0a6e85d68bd1fc5a1f5 Mon Sep 17 00:00:00 2001 From: mikail Date: Mon, 17 Nov 2025 17:54:30 -0800 Subject: [PATCH 19/32] Revert "changed scale mode to 1.0, added it as a scale mode" This reverts commit 9d9ddf257083d077f92f3bafa6336c7e80beb2ad. Signed-off-by: mikail --- .../orthogonalized_optimizers/adaptive_muon.py | 4 ++-- emerging_optimizers/orthogonalized_optimizers/muon.py | 2 -- 2 files changed, 2 insertions(+), 4 deletions(-) diff --git a/emerging_optimizers/orthogonalized_optimizers/adaptive_muon.py b/emerging_optimizers/orthogonalized_optimizers/adaptive_muon.py index 82b482b..fcdf058 100644 --- a/emerging_optimizers/orthogonalized_optimizers/adaptive_muon.py +++ b/emerging_optimizers/orthogonalized_optimizers/adaptive_muon.py @@ -49,7 +49,7 @@ class AdaptiveMuon(Muon): fp32_matmul_prec: Precision for FP32 matrix multiplication. coefficient_type: The type of coefficient set to use for the Newton-Schulz iteration. num_ns_steps: The number of iteration steps to use in the Newton-Schulz iteration. - scale_mode: The type of scale factor to use for the update. Defaults to "none" to avoid scaling the update. + scale_mode: The type of scale factor to use for the update. extra_scale_factor: The additional scale factor to use for the update. use_syrk: Whether to use the Triton kernel for the Newton-Schulz iteration. """ @@ -69,7 +69,7 @@ def __init__( fp32_matmul_prec: str, coefficient_type: str = "quintic", num_ns_steps: int = 5, - scale_mode: str = "none", + scale_mode: str = "spectral", extra_scale_factor: float = 1.0, use_syrk: bool = False, ): diff --git a/emerging_optimizers/orthogonalized_optimizers/muon.py b/emerging_optimizers/orthogonalized_optimizers/muon.py index ea8e737..8200fb1 100644 --- a/emerging_optimizers/orthogonalized_optimizers/muon.py +++ b/emerging_optimizers/orthogonalized_optimizers/muon.py @@ -147,7 +147,5 @@ def get_muon_scale_factor(size_out: int, size_in: int, mode: str = "spectral") - # Suggested by Scion (https://arxiv.org/abs/2502.07529) and Bernstein et al. # (https://jeremybernste.in/writing/deriving-muon) return (size_out / size_in) ** 0.5 - elif mode == "none": - return 1.0 else: raise ValueError(f"Invalid mode for Muon update scale factor: {mode}") From 939d9bf53e5f8f421f799125695c64a25b2c3816 Mon Sep 17 00:00:00 2001 From: mikail Date: Mon, 17 Nov 2025 20:52:59 -0800 Subject: [PATCH 20/32] addressed MR comments Signed-off-by: mikail --- .../orthogonalized_optimizers/adaptive_muon.py | 6 +----- 1 file changed, 1 insertion(+), 5 deletions(-) diff --git a/emerging_optimizers/orthogonalized_optimizers/adaptive_muon.py b/emerging_optimizers/orthogonalized_optimizers/adaptive_muon.py index fcdf058..2c7ba66 100644 --- a/emerging_optimizers/orthogonalized_optimizers/adaptive_muon.py +++ b/emerging_optimizers/orthogonalized_optimizers/adaptive_muon.py @@ -90,7 +90,6 @@ def __init__( use_syrk=use_syrk, ) - # Add beta2 and eps to param_groups defaults for group in self.param_groups: group.setdefault("beta2", beta2) group.setdefault("eps", eps) @@ -182,7 +181,7 @@ def _apply_moment2_normalization( @torch.no_grad() # type: ignore[misc] @override def step(self, closure: Callable[[], float] | None = None) -> float | None: - """Performs a single optimization step with second moment normalization. + """Single optimization step. Args: closure: A closure that reevaluates the model and returns the loss. @@ -201,7 +200,6 @@ def step(self, closure: Callable[[], float] | None = None) -> float | None: continue state = self.state[p] - # initialize momentum buffer and second moment buffer if "momentum_buffer" not in state: state["momentum_buffer"] = torch.zeros_like(grad) self._initialize_moment2(state, grad) @@ -218,7 +216,6 @@ def step(self, closure: Callable[[], float] | None = None) -> float | None: # update momentum buffer with EMA of gradient exp_avg.lerp_(grad, 1 - group["momentum_beta"]) - # include nesterov momentum if self.use_nesterov: grad = grad.lerp(exp_avg, group["momentum_beta"]) else: @@ -227,7 +224,6 @@ def step(self, closure: Callable[[], float] | None = None) -> float | None: with utils.fp32_matmul_precision(self.fp32_matmul_prec): grad = self.scaled_orthogonalize_fn(grad) - # Apply second moment normalization grad = self._apply_moment2_normalization( orth_grad=grad, moment2=state["moment2_buffer"], From 6f71c8663ed7df19bad3f79066cf52b04878e782 Mon Sep 17 00:00:00 2001 From: mikail Date: Mon, 17 Nov 2025 21:05:51 -0800 Subject: [PATCH 21/32] use consistent orth_grad naming Signed-off-by: mikail --- .../orthogonalized_optimizers/adaptive_muon.py | 8 ++++---- 1 file changed, 4 insertions(+), 4 deletions(-) diff --git a/emerging_optimizers/orthogonalized_optimizers/adaptive_muon.py b/emerging_optimizers/orthogonalized_optimizers/adaptive_muon.py index 2c7ba66..17a4c33 100644 --- a/emerging_optimizers/orthogonalized_optimizers/adaptive_muon.py +++ b/emerging_optimizers/orthogonalized_optimizers/adaptive_muon.py @@ -222,10 +222,10 @@ def step(self, closure: Callable[[], float] | None = None) -> float | None: grad = exp_avg with utils.fp32_matmul_precision(self.fp32_matmul_prec): - grad = self.scaled_orthogonalize_fn(grad) + orth_grad = self.scaled_orthogonalize_fn(grad) - grad = self._apply_moment2_normalization( - orth_grad=grad, + update = self._apply_moment2_normalization( + orth_grad=orth_grad, moment2=state["moment2_buffer"], beta2=group["beta2"], eps=group["eps"], @@ -233,6 +233,6 @@ def step(self, closure: Callable[[], float] | None = None) -> float | None: # perform weight update # scale is applied to have update RMS == 1 - p.add_(grad, alpha=-group["lr"]) + p.add_(update, alpha=-group["lr"]) return loss From ed365d9a6cbbc0d297b70e78b07f43c7db30d8c7 Mon Sep 17 00:00:00 2001 From: mikail Date: Mon, 17 Nov 2025 21:06:48 -0800 Subject: [PATCH 22/32] addressed MR comment Signed-off-by: mikail --- emerging_optimizers/orthogonalized_optimizers/adaptive_muon.py | 3 +-- 1 file changed, 1 insertion(+), 2 deletions(-) diff --git a/emerging_optimizers/orthogonalized_optimizers/adaptive_muon.py b/emerging_optimizers/orthogonalized_optimizers/adaptive_muon.py index 17a4c33..ad3451d 100644 --- a/emerging_optimizers/orthogonalized_optimizers/adaptive_muon.py +++ b/emerging_optimizers/orthogonalized_optimizers/adaptive_muon.py @@ -73,8 +73,6 @@ def __init__( extra_scale_factor: float = 1.0, use_syrk: bool = False, ): - self.moment2_method = moment2_method - super().__init__( params, lr=lr, @@ -89,6 +87,7 @@ def __init__( extra_scale_factor=extra_scale_factor, use_syrk=use_syrk, ) + self.moment2_method = moment2_method for group in self.param_groups: group.setdefault("beta2", beta2) From 60c690310f193ba75671a4063950d03d599e8c33 Mon Sep 17 00:00:00 2001 From: mikail Date: Tue, 18 Nov 2025 12:19:49 -0800 Subject: [PATCH 23/32] addressed MR comments Signed-off-by: mikail --- .../orthogonalized_optimizers/adaptive_muon.py | 16 ++++++++-------- 1 file changed, 8 insertions(+), 8 deletions(-) diff --git a/emerging_optimizers/orthogonalized_optimizers/adaptive_muon.py b/emerging_optimizers/orthogonalized_optimizers/adaptive_muon.py index ad3451d..08cf175 100644 --- a/emerging_optimizers/orthogonalized_optimizers/adaptive_muon.py +++ b/emerging_optimizers/orthogonalized_optimizers/adaptive_muon.py @@ -42,9 +42,6 @@ class AdaptiveMuon(Muon): momentum_beta: The exponential decay rate for momentum. weight_decay: Weight decay coefficient. use_nesterov: Whether to use Nesterov momentum. - moment2_method: Method for second moment accumulation ("adamuon" or "normuon"). - beta2: The exponential decay rate for second moment. - eps: Small constant for numerical stability. weight_decay_method: The weight decay method to use. fp32_matmul_prec: Precision for FP32 matrix multiplication. coefficient_type: The type of coefficient set to use for the Newton-Schulz iteration. @@ -52,6 +49,9 @@ class AdaptiveMuon(Muon): scale_mode: The type of scale factor to use for the update. extra_scale_factor: The additional scale factor to use for the update. use_syrk: Whether to use the Triton kernel for the Newton-Schulz iteration. + moment2_method: Method for second moment accumulation ("adamuon" or "normuon"). + beta2: The exponential decay rate for second moment. + eps: Small constant for numerical stability. """ def __init__( @@ -62,9 +62,6 @@ def __init__( weight_decay: float, *, use_nesterov: bool, - moment2_method: Literal["adamuon", "normuon"], - beta2: float = 0.999, - eps: float = 1e-8, weight_decay_method: opt_mixin.WeightDecayT = "decoupled", fp32_matmul_prec: str, coefficient_type: str = "quintic", @@ -72,6 +69,9 @@ def __init__( scale_mode: str = "spectral", extra_scale_factor: float = 1.0, use_syrk: bool = False, + moment2_method: Literal["adamuon", "normuon"] = "adamuon", + beta2: float = 0.999, + eps: float = 1e-8, ): super().__init__( params, @@ -192,8 +192,8 @@ def step(self, closure: Callable[[], float] | None = None) -> float | None: for group in self.param_groups: for p in group["params"]: - if p.dim() == 1: - raise ValueError("AdaptiveMuon does not support 1D parameters") + if p.dim() != 2: + raise ValueError("AdaptiveMuon only supports 2D parameters") grad = p.grad if grad is None: continue From b3bb0f3fe6d6c6612f9f0a183faa9e3dad952591 Mon Sep 17 00:00:00 2001 From: mikail Date: Tue, 18 Nov 2025 12:20:50 -0800 Subject: [PATCH 24/32] addressed MR comments Signed-off-by: mikail --- tests/test_adaptive_orthogonalized_optimizer.py | 1 - 1 file changed, 1 deletion(-) diff --git a/tests/test_adaptive_orthogonalized_optimizer.py b/tests/test_adaptive_orthogonalized_optimizer.py index a4a7bdc..9d2a104 100644 --- a/tests/test_adaptive_orthogonalized_optimizer.py +++ b/tests/test_adaptive_orthogonalized_optimizer.py @@ -22,7 +22,6 @@ ) -# Define command line flags flags.DEFINE_string("device", "cpu", "Device to run tests on: 'cpu' or 'cuda'") FLAGS = flags.FLAGS From c26805df8d63bb1fb9288eac6ae6fb32e9ecfccb Mon Sep 17 00:00:00 2001 From: mikail Date: Tue, 18 Nov 2025 12:21:53 -0800 Subject: [PATCH 25/32] added adaptive_muon test to CI Signed-off-by: mikail --- tests/ci/L0_Tests_GPU.sh | 1 + 1 file changed, 1 insertion(+) diff --git a/tests/ci/L0_Tests_GPU.sh b/tests/ci/L0_Tests_GPU.sh index bd28ef8..64ede36 100644 --- a/tests/ci/L0_Tests_GPU.sh +++ b/tests/ci/L0_Tests_GPU.sh @@ -17,6 +17,7 @@ export TORCH_ALLOW_TF32_CUBLAS_OVERRIDE=0 error=0 coverage run -p --source=emerging_optimizers tests/test_muon_utils.py -v -2 || error=1 +coverage run -p --source=emerging_optimizers tests/test_adaptive_orthogonalized_optimizer.py -v -2 || error=1 coverage run -p --source=emerging_optimizers tests/test_orthogonalized_optimizer.py -v -2 || error=1 coverage run -p --source=emerging_optimizers tests/test_soap_utils.py -v -2 || error=1 coverage run -p --source=emerging_optimizers tests/test_soap.py -v -2 || error=1 From 4f5d01b99f91ef94fff07cd625fe45e35ee7c3d6 Mon Sep 17 00:00:00 2001 From: mikail Date: Tue, 18 Nov 2025 12:22:48 -0800 Subject: [PATCH 26/32] changed name to adaptive_muon Signed-off-by: mikail --- tests/ci/L0_Tests_GPU.sh | 2 +- tests/ci/L1_Tests_GPU.sh | 1 + ...aptive_orthogonalized_optimizer.py => test_adaptive_muon.py} | 0 3 files changed, 2 insertions(+), 1 deletion(-) rename tests/{test_adaptive_orthogonalized_optimizer.py => test_adaptive_muon.py} (100%) diff --git a/tests/ci/L0_Tests_GPU.sh b/tests/ci/L0_Tests_GPU.sh index 64ede36..c131692 100644 --- a/tests/ci/L0_Tests_GPU.sh +++ b/tests/ci/L0_Tests_GPU.sh @@ -17,7 +17,7 @@ export TORCH_ALLOW_TF32_CUBLAS_OVERRIDE=0 error=0 coverage run -p --source=emerging_optimizers tests/test_muon_utils.py -v -2 || error=1 -coverage run -p --source=emerging_optimizers tests/test_adaptive_orthogonalized_optimizer.py -v -2 || error=1 +coverage run -p --source=emerging_optimizers tests/test_adaptive_muon.py -v -2 || error=1 coverage run -p --source=emerging_optimizers tests/test_orthogonalized_optimizer.py -v -2 || error=1 coverage run -p --source=emerging_optimizers tests/test_soap_utils.py -v -2 || error=1 coverage run -p --source=emerging_optimizers tests/test_soap.py -v -2 || error=1 diff --git a/tests/ci/L1_Tests_GPU.sh b/tests/ci/L1_Tests_GPU.sh index c07ac75..19a75cf 100644 --- a/tests/ci/L1_Tests_GPU.sh +++ b/tests/ci/L1_Tests_GPU.sh @@ -16,6 +16,7 @@ export TORCH_ALLOW_TF32_CUBLAS_OVERRIDE=0 error=0 python tests/test_muon_utils.py || error=1 +python tests/test_adaptive_muon.py || error=1 python tests/test_orthogonalized_optimizer.py || error=1 python tests/test_soap_utils.py || error=1 python tests/test_soap.py || error=1 diff --git a/tests/test_adaptive_orthogonalized_optimizer.py b/tests/test_adaptive_muon.py similarity index 100% rename from tests/test_adaptive_orthogonalized_optimizer.py rename to tests/test_adaptive_muon.py From dd0f28225c01267c46911c6f87aecf9ca18f5a43 Mon Sep 17 00:00:00 2001 From: mikail Date: Tue, 18 Nov 2025 12:59:00 -0800 Subject: [PATCH 27/32] changed b2 default to 0.95 Signed-off-by: mikail --- emerging_optimizers/orthogonalized_optimizers/adaptive_muon.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/emerging_optimizers/orthogonalized_optimizers/adaptive_muon.py b/emerging_optimizers/orthogonalized_optimizers/adaptive_muon.py index 08cf175..eb9e04b 100644 --- a/emerging_optimizers/orthogonalized_optimizers/adaptive_muon.py +++ b/emerging_optimizers/orthogonalized_optimizers/adaptive_muon.py @@ -70,7 +70,7 @@ def __init__( extra_scale_factor: float = 1.0, use_syrk: bool = False, moment2_method: Literal["adamuon", "normuon"] = "adamuon", - beta2: float = 0.999, + beta2: float = 0.95, eps: float = 1e-8, ): super().__init__( From e4d70c9439369260675b152e6b71dce9fd77fc3e Mon Sep 17 00:00:00 2001 From: mikail Date: Tue, 18 Nov 2025 14:14:39 -0800 Subject: [PATCH 28/32] added ref Signed-off-by: mikail --- .../orthogonalized_optimizers/adaptive_muon.py | 4 +++- 1 file changed, 3 insertions(+), 1 deletion(-) diff --git a/emerging_optimizers/orthogonalized_optimizers/adaptive_muon.py b/emerging_optimizers/orthogonalized_optimizers/adaptive_muon.py index eb9e04b..82ddd09 100644 --- a/emerging_optimizers/orthogonalized_optimizers/adaptive_muon.py +++ b/emerging_optimizers/orthogonalized_optimizers/adaptive_muon.py @@ -34,7 +34,9 @@ class AdaptiveMuon(Muon): This class extends Muon by adding AdamW-style or NorMuon-style second moment accumulation after orthogonalization. The step() method is overridden to include second moment - normalization logic. + normalization logic. This idea was first explored in D.E. Carlson, E. Collins, Ya-Ping Hsieh, L. Carin, + and V. Cevher. *Preconditioned spectral descent for deep learning.* In Advances in + neural information processing systems 28 (2015). Args: params: Iterable of parameters to optimize or dicts defining parameter groups. From 6d963456bec0a52d47c5070ae19b088b5b7cc37e Mon Sep 17 00:00:00 2001 From: mikail Date: Tue, 18 Nov 2025 14:24:27 -0800 Subject: [PATCH 29/32] Revert "added ref" This reverts commit 66f9196e1857e8de4e74091e201283d8b573807c. Signed-off-by: mikail --- .../orthogonalized_optimizers/adaptive_muon.py | 4 +--- 1 file changed, 1 insertion(+), 3 deletions(-) diff --git a/emerging_optimizers/orthogonalized_optimizers/adaptive_muon.py b/emerging_optimizers/orthogonalized_optimizers/adaptive_muon.py index 82ddd09..eb9e04b 100644 --- a/emerging_optimizers/orthogonalized_optimizers/adaptive_muon.py +++ b/emerging_optimizers/orthogonalized_optimizers/adaptive_muon.py @@ -34,9 +34,7 @@ class AdaptiveMuon(Muon): This class extends Muon by adding AdamW-style or NorMuon-style second moment accumulation after orthogonalization. The step() method is overridden to include second moment - normalization logic. This idea was first explored in D.E. Carlson, E. Collins, Ya-Ping Hsieh, L. Carin, - and V. Cevher. *Preconditioned spectral descent for deep learning.* In Advances in - neural information processing systems 28 (2015). + normalization logic. Args: params: Iterable of parameters to optimize or dicts defining parameter groups. From cc72e34dea3f44a5f4db8f56afafb526db559890 Mon Sep 17 00:00:00 2001 From: mikail Date: Tue, 18 Nov 2025 14:27:37 -0800 Subject: [PATCH 30/32] added ref Signed-off-by: mikail --- .../orthogonalized_optimizers/adaptive_muon.py | 6 ++++-- 1 file changed, 4 insertions(+), 2 deletions(-) diff --git a/emerging_optimizers/orthogonalized_optimizers/adaptive_muon.py b/emerging_optimizers/orthogonalized_optimizers/adaptive_muon.py index eb9e04b..f861213 100644 --- a/emerging_optimizers/orthogonalized_optimizers/adaptive_muon.py +++ b/emerging_optimizers/orthogonalized_optimizers/adaptive_muon.py @@ -33,8 +33,10 @@ class AdaptiveMuon(Muon): """Adaptive Muon optimizer with adaptive second moment (AdaMuon/NorMuon variants). This class extends Muon by adding AdamW-style or NorMuon-style second moment - accumulation after orthogonalization. The step() method is overridden to include second moment - normalization logic. + accumulation after orthogonalization. This idea was first explored in D.E. Carlson, + E. Collins, Ya-Ping Hsieh, L. Carin, and V. Cevher. *Preconditioned spectral + descent for deep learning.* In Advances in neural information processing systems 28 (2015). + The step() method is overridden to include second moment normalization logic. Args: params: Iterable of parameters to optimize or dicts defining parameter groups. From cf0a2751dfbd1a06839a8467ab3b24c6faaede21 Mon Sep 17 00:00:00 2001 From: mikail Date: Tue, 18 Nov 2025 14:50:44 -0800 Subject: [PATCH 31/32] fixed test Signed-off-by: mikail --- tests/test_adaptive_muon.py | 8 +++----- 1 file changed, 3 insertions(+), 5 deletions(-) diff --git a/tests/test_adaptive_muon.py b/tests/test_adaptive_muon.py index 9d2a104..8378362 100644 --- a/tests/test_adaptive_muon.py +++ b/tests/test_adaptive_muon.py @@ -26,8 +26,6 @@ FLAGS = flags.FLAGS -device = FLAGS.device - class AdaptiveMuonTest(parameterized.TestCase): @parameterized.product( @@ -37,7 +35,7 @@ class AdaptiveMuonTest(parameterized.TestCase): ) def test_smoke(self, shape, second_moment_method, use_nesterov) -> None: """Smoke test AdaptiveMuon with both second moment methods.""" - test_param = nn.Parameter(torch.randint(-5, 5, shape, dtype=torch.float32, device=device)) + test_param = nn.Parameter(torch.randint(-5, 5, shape, dtype=torch.float32, device=FLAGS.device)) test_param.grad = torch.randint_like(test_param, -5, 5) adaptive_opt = AdaptiveMuon( @@ -60,7 +58,7 @@ def test_smoke(self, shape, second_moment_method, use_nesterov) -> None: ) def test_second_moment_matches_shapes(self, shape, second_moment_method) -> None: """Test that second moment buffers are properly initialized.""" - test_param = nn.Parameter(torch.randint(-5, 5, shape, dtype=torch.float32, device=device)) + test_param = nn.Parameter(torch.randint(-5, 5, shape, dtype=torch.float32, device=FLAGS.device)) test_param.grad = torch.randint_like(test_param, -5, 5) adaptive_opt = AdaptiveMuon( @@ -98,7 +96,7 @@ def test_second_moment_matches_shapes(self, shape, second_moment_method) -> None def test_unknown_moment2_method_raise_type_error(self) -> None: """Test that AdaptiveMuon raises TypeError for unknown moment2_method.""" - test_param = nn.Parameter(torch.randint(-5, 5, (8, 16), dtype=torch.float32, device=device)) + test_param = nn.Parameter(torch.randint(-5, 5, (8, 16), dtype=torch.float32, device=FLAGS.device)) with self.assertRaises(TypeError): AdaptiveMuon( From a677a3adadbed4d568dbc656704e6627fb860217 Mon Sep 17 00:00:00 2001 From: mikail Date: Tue, 18 Nov 2025 14:52:45 -0800 Subject: [PATCH 32/32] raised typereror during correct place Signed-off-by: mikail --- tests/test_adaptive_muon.py | 28 ++++++++++++++++------------ 1 file changed, 16 insertions(+), 12 deletions(-) diff --git a/tests/test_adaptive_muon.py b/tests/test_adaptive_muon.py index 8378362..5134b31 100644 --- a/tests/test_adaptive_muon.py +++ b/tests/test_adaptive_muon.py @@ -97,20 +97,24 @@ def test_second_moment_matches_shapes(self, shape, second_moment_method) -> None def test_unknown_moment2_method_raise_type_error(self) -> None: """Test that AdaptiveMuon raises TypeError for unknown moment2_method.""" test_param = nn.Parameter(torch.randint(-5, 5, (8, 16), dtype=torch.float32, device=FLAGS.device)) + test_param.grad = torch.randint_like(test_param, -5, 5) + + adaptive_opt = AdaptiveMuon( + [test_param], + lr=0.01, + momentum_beta=0.9, + weight_decay=0.0, + use_nesterov=False, + moment2_method=None, + beta2=0.999, + eps=1e-8, + weight_decay_method="decoupled", + fp32_matmul_prec="highest", + ) + # TypeError is raised during step() when initializing moment2_buffer with self.assertRaises(TypeError): - AdaptiveMuon( - [test_param], - lr=0.01, - momentum_beta=0.9, - weight_decay=0.0, - use_nesterov=False, - moment2_method=None, - beta2=0.999, - eps=1e-8, - weight_decay_method="decoupled", - fp32_matmul_prec="highest", - ) + adaptive_opt.step() if __name__ == "__main__":