|
| 1 | +# SPDX-FileCopyrightText: Copyright (c) 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved. |
| 2 | +# SPDX-License-Identifier: Apache-2.0 |
| 3 | +# |
| 4 | +# Licensed under the Apache License, Version 2.0 (the "License"); |
| 5 | +# you may not use this file except in compliance with the License. |
| 6 | +# You may obtain a copy of the License at |
| 7 | +# |
| 8 | +# http://www.apache.org/licenses/LICENSE-2.0 |
| 9 | +# |
| 10 | +# Unless required by applicable law or agreed to in writing, software |
| 11 | +# distributed under the License is distributed on an "AS IS" BASIS, |
| 12 | +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. |
| 13 | +# See the License for the specific language governing permissions and |
| 14 | +# limitations under the License. |
| 15 | +from typing import Callable, Literal |
| 16 | + |
| 17 | + |
| 18 | +# TODO(@boxiangw): remove this once bump to python 3.12 |
| 19 | +try: |
| 20 | + from typing import override |
| 21 | +except ImportError: |
| 22 | + from typing_extensions import override |
| 23 | + |
| 24 | +import torch |
| 25 | +from torch.optim.optimizer import ParamsT |
| 26 | + |
| 27 | +from emerging_optimizers import mixin as opt_mixin |
| 28 | +from emerging_optimizers import utils |
| 29 | +from emerging_optimizers.orthogonalized_optimizers.muon import Muon |
| 30 | + |
| 31 | + |
| 32 | +class AdaptiveMuon(Muon): |
| 33 | + """Adaptive Muon optimizer with adaptive second moment (AdaMuon/NorMuon variants). |
| 34 | +
|
| 35 | + This class extends Muon by adding AdamW-style or NorMuon-style second moment |
| 36 | + accumulation after orthogonalization. This idea was first explored in D.E. Carlson, |
| 37 | + E. Collins, Ya-Ping Hsieh, L. Carin, and V. Cevher. *Preconditioned spectral |
| 38 | + descent for deep learning.* In Advances in neural information processing systems 28 (2015). |
| 39 | + The step() method is overridden to include second moment normalization logic. |
| 40 | +
|
| 41 | + Args: |
| 42 | + params: Iterable of parameters to optimize or dicts defining parameter groups. |
| 43 | + lr: Learning rate. |
| 44 | + momentum_beta: The exponential decay rate for momentum. |
| 45 | + weight_decay: Weight decay coefficient. |
| 46 | + use_nesterov: Whether to use Nesterov momentum. |
| 47 | + weight_decay_method: The weight decay method to use. |
| 48 | + fp32_matmul_prec: Precision for FP32 matrix multiplication. |
| 49 | + coefficient_type: The type of coefficient set to use for the Newton-Schulz iteration. |
| 50 | + num_ns_steps: The number of iteration steps to use in the Newton-Schulz iteration. |
| 51 | + scale_mode: The type of scale factor to use for the update. |
| 52 | + extra_scale_factor: The additional scale factor to use for the update. |
| 53 | + use_syrk: Whether to use the Triton kernel for the Newton-Schulz iteration. |
| 54 | + moment2_method: Method for second moment accumulation ("adamuon" or "normuon"). |
| 55 | + beta2: The exponential decay rate for second moment. |
| 56 | + eps: Small constant for numerical stability. |
| 57 | + """ |
| 58 | + |
| 59 | + def __init__( |
| 60 | + self, |
| 61 | + params: ParamsT, |
| 62 | + lr: float, |
| 63 | + momentum_beta: float, |
| 64 | + weight_decay: float, |
| 65 | + *, |
| 66 | + use_nesterov: bool, |
| 67 | + weight_decay_method: opt_mixin.WeightDecayT = "decoupled", |
| 68 | + fp32_matmul_prec: str, |
| 69 | + coefficient_type: str = "quintic", |
| 70 | + num_ns_steps: int = 5, |
| 71 | + scale_mode: str = "spectral", |
| 72 | + extra_scale_factor: float = 1.0, |
| 73 | + use_syrk: bool = False, |
| 74 | + moment2_method: Literal["adamuon", "normuon"] = "adamuon", |
| 75 | + beta2: float = 0.95, |
| 76 | + eps: float = 1e-8, |
| 77 | + ): |
| 78 | + super().__init__( |
| 79 | + params, |
| 80 | + lr=lr, |
| 81 | + momentum_beta=momentum_beta, |
| 82 | + weight_decay=weight_decay, |
| 83 | + use_nesterov=use_nesterov, |
| 84 | + weight_decay_method=weight_decay_method, |
| 85 | + fp32_matmul_prec=fp32_matmul_prec, |
| 86 | + coefficient_type=coefficient_type, |
| 87 | + num_ns_steps=num_ns_steps, |
| 88 | + scale_mode=scale_mode, |
| 89 | + extra_scale_factor=extra_scale_factor, |
| 90 | + use_syrk=use_syrk, |
| 91 | + ) |
| 92 | + self.moment2_method = moment2_method |
| 93 | + |
| 94 | + for group in self.param_groups: |
| 95 | + group.setdefault("beta2", beta2) |
| 96 | + group.setdefault("eps", eps) |
| 97 | + |
| 98 | + def _initialize_moment2( |
| 99 | + self, |
| 100 | + state: dict[str, torch.Tensor], |
| 101 | + grad: torch.Tensor, |
| 102 | + ) -> None: |
| 103 | + """Initialize the second moment buffer if it doesn't exist. |
| 104 | +
|
| 105 | + The shape of the buffer depends on the moment2_method: |
| 106 | + - "adamuon": Full elementwise buffer with same shape as grad |
| 107 | + - "normuon": Reduced shape buffer (averaged along -1 if shape[-2] >= shape[-1], else -2) |
| 108 | +
|
| 109 | + Args: |
| 110 | + state: The optimizer state dict for a parameter. |
| 111 | + grad: The gradient tensor (used for shape/dtype). |
| 112 | + """ |
| 113 | + if "moment2_buffer" not in state: |
| 114 | + if self.moment2_method == "adamuon": |
| 115 | + # Full elementwise second moment |
| 116 | + moment2 = torch.zeros_like(grad) |
| 117 | + elif self.moment2_method == "normuon": |
| 118 | + # Row/column-wise second moment - reduced along one dimension |
| 119 | + # Determine which dimension to reduce based on parameter shape |
| 120 | + avg_dim = -1 if grad.shape[-2] >= grad.shape[-1] else -2 |
| 121 | + # Specify the shape with reduced dimension |
| 122 | + moment2_shape = list(grad.shape) |
| 123 | + moment2_shape[avg_dim] = 1 |
| 124 | + moment2 = torch.zeros(moment2_shape, dtype=grad.dtype, device=grad.device) |
| 125 | + else: |
| 126 | + raise TypeError(f"Invalid second moment method: {self.moment2_method}") |
| 127 | + |
| 128 | + state["moment2_buffer"] = moment2 |
| 129 | + |
| 130 | + def _apply_moment2_normalization( |
| 131 | + self, |
| 132 | + orth_grad: torch.Tensor, |
| 133 | + moment2: torch.Tensor, |
| 134 | + beta2: float, |
| 135 | + eps: float, |
| 136 | + ) -> torch.Tensor: |
| 137 | + """Apply AdamW-style second moment accumulation and normalization. |
| 138 | +
|
| 139 | + This method supports two variants: |
| 140 | + - "adamuon": Full elementwise second moment (like AdamW, https://arxiv.org/abs/2507.11005) |
| 141 | + - "normuon": Row or column-wise second moment (https://arxiv.org/abs/2510.05491) |
| 142 | +
|
| 143 | + For both methods: |
| 144 | + 1. Updates the second moment as an EMA of squared gradients |
| 145 | + 2. Returns the adaptively scaled gradient |
| 146 | +
|
| 147 | + Args: |
| 148 | + orth_grad: The orthogonalized gradient tensor. |
| 149 | + moment2: The second moment buffer from state. |
| 150 | + beta2: The exponential decay rate for second moment. |
| 151 | + eps: Small constant for numerical stability. |
| 152 | +
|
| 153 | + Returns: |
| 154 | + The adaptively scaled weight update tensor. |
| 155 | + """ |
| 156 | + if self.moment2_method == "adamuon": |
| 157 | + # AdamMuon: Full elementwise second moment like AdamW |
| 158 | + # Update second moment with EMA of squared orthogonalized gradient |
| 159 | + moment2.lerp_(orth_grad.square(), 1 - beta2) |
| 160 | + |
| 161 | + # AdamW-style division: grad / (sqrt(moment2) + eps) |
| 162 | + denom = moment2.sqrt() + eps |
| 163 | + return orth_grad / denom |
| 164 | + |
| 165 | + elif self.moment2_method == "normuon": |
| 166 | + # NorMuon: Row or column-wise second moment |
| 167 | + # Compute mean of squared gradients along one dimension based on shape |
| 168 | + # Average along the longer dimension to preserve structure along shorter dim |
| 169 | + avg_dim = -1 if orth_grad.shape[-2] >= orth_grad.shape[-1] else -2 |
| 170 | + v_mean = orth_grad.square().mean(dim=avg_dim, keepdim=True) |
| 171 | + |
| 172 | + # Update second moment with EMA |
| 173 | + moment2.lerp_(v_mean, 1 - beta2) |
| 174 | + |
| 175 | + # NorMuon uses reciprocal square root with clamping |
| 176 | + step_size = moment2.clamp_min(eps).rsqrt_() |
| 177 | + return orth_grad * step_size |
| 178 | + |
| 179 | + else: |
| 180 | + raise TypeError(f"Invalid second moment method: {self.moment2_method}") |
| 181 | + |
| 182 | + @torch.no_grad() # type: ignore[misc] |
| 183 | + @override |
| 184 | + def step(self, closure: Callable[[], float] | None = None) -> float | None: |
| 185 | + """Single optimization step. |
| 186 | +
|
| 187 | + Args: |
| 188 | + closure: A closure that reevaluates the model and returns the loss. |
| 189 | + """ |
| 190 | + if closure is None: |
| 191 | + loss = None |
| 192 | + else: |
| 193 | + loss = closure() |
| 194 | + |
| 195 | + for group in self.param_groups: |
| 196 | + for p in group["params"]: |
| 197 | + if p.dim() != 2: |
| 198 | + raise ValueError("AdaptiveMuon only supports 2D parameters") |
| 199 | + grad = p.grad |
| 200 | + if grad is None: |
| 201 | + continue |
| 202 | + state = self.state[p] |
| 203 | + |
| 204 | + if "momentum_buffer" not in state: |
| 205 | + state["momentum_buffer"] = torch.zeros_like(grad) |
| 206 | + self._initialize_moment2(state, grad) |
| 207 | + |
| 208 | + exp_avg = state["momentum_buffer"] |
| 209 | + |
| 210 | + self._apply_weight_decay_inplace( |
| 211 | + p, |
| 212 | + grad, |
| 213 | + group["lr"], |
| 214 | + group["weight_decay"], |
| 215 | + ) |
| 216 | + |
| 217 | + # update momentum buffer with EMA of gradient |
| 218 | + exp_avg.lerp_(grad, 1 - group["momentum_beta"]) |
| 219 | + |
| 220 | + if self.use_nesterov: |
| 221 | + grad = grad.lerp(exp_avg, group["momentum_beta"]) |
| 222 | + else: |
| 223 | + grad = exp_avg |
| 224 | + |
| 225 | + with utils.fp32_matmul_precision(self.fp32_matmul_prec): |
| 226 | + orth_grad = self.scaled_orthogonalize_fn(grad) |
| 227 | + |
| 228 | + update = self._apply_moment2_normalization( |
| 229 | + orth_grad=orth_grad, |
| 230 | + moment2=state["moment2_buffer"], |
| 231 | + beta2=group["beta2"], |
| 232 | + eps=group["eps"], |
| 233 | + ) |
| 234 | + |
| 235 | + # perform weight update |
| 236 | + # scale is applied to have update RMS == 1 |
| 237 | + p.add_(update, alpha=-group["lr"]) |
| 238 | + |
| 239 | + return loss |
0 commit comments