|
| 1 | +# SPDX-FileCopyrightText: Copyright (c) 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved. |
| 2 | +# SPDX-License-Identifier: Apache-2.0 |
| 3 | +# |
| 4 | +# Licensed under the Apache License, Version 2.0 (the "License"); |
| 5 | +# you may not use this file except in compliance with the License. |
| 6 | +# You may obtain a copy of the License at |
| 7 | +# |
| 8 | +# http://www.apache.org/licenses/LICENSE-2.0 |
| 9 | +# |
| 10 | +# Unless required by applicable law or agreed to in writing, software |
| 11 | +# distributed under the License is distributed on an "AS IS" BASIS, |
| 12 | +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. |
| 13 | +# See the License for the specific language governing permissions and |
| 14 | +# limitations under the License. |
| 15 | +from typing import Callable |
| 16 | + |
| 17 | +import torch |
| 18 | +from torch.optim.optimizer import Optimizer |
| 19 | + |
| 20 | + |
| 21 | +class ObliqueSGD(Optimizer): |
| 22 | + """SGD optimizer for row- or column-normalized 2D parameters on oblique manifolds. |
| 23 | +
|
| 24 | + This optimizer performs SGD on oblique manifolds, where parameters are constrained |
| 25 | + to have unit-norm rows or columns. It implements Riemannian SGD with manifold-aware |
| 26 | + gradient updates and retraction operations. |
| 27 | +
|
| 28 | + References: |
| 29 | + - An Introduction to Optimization on Smooth Manifolds (Nicolas Boumal) |
| 30 | + - EDM2: https://arxiv.org/abs/2312.02696 |
| 31 | + - Jianlin Su: https://kexue.fm/archives/11196 |
| 32 | + - Raman et al.: https://arxiv.org/abs/1909.06463 |
| 33 | + - Franz Cesista: https://leloykun.github.io/ponder/steepest-descent-stiefel/#6-bonus-a-muon-like-optimizer-for-the-embedding-and-unembedding-layers |
| 34 | +
|
| 35 | + Args: |
| 36 | + lr: learning rate |
| 37 | + momentum: momentum coefficient |
| 38 | + weight_decay: weight decay coefficient |
| 39 | + dim: The dimension to normalize over |
| 40 | + eps: epsilon for numerical stability |
| 41 | + """ |
| 42 | + |
| 43 | + def __init__( |
| 44 | + self, |
| 45 | + params: list[torch.nn.Parameter], |
| 46 | + lr: float = 1e-3, |
| 47 | + momentum: float = 0.9, |
| 48 | + weight_decay: float = 0.0, |
| 49 | + dim: int = 0, |
| 50 | + eps: float = 1e-8, |
| 51 | + ) -> None: |
| 52 | + if lr < 0.0: |
| 53 | + raise ValueError(f"Invalid learning rate: {lr}") |
| 54 | + if momentum < 0.0 or momentum >= 1.0: |
| 55 | + raise ValueError(f"Invalid momentum value: {momentum}") |
| 56 | + if weight_decay < 0.0: |
| 57 | + raise ValueError(f"Invalid weight_decay value: {weight_decay}") |
| 58 | + |
| 59 | + defaults = dict( |
| 60 | + lr=lr, |
| 61 | + momentum=momentum, |
| 62 | + weight_decay=weight_decay, |
| 63 | + dim=dim, |
| 64 | + eps=eps, |
| 65 | + ) |
| 66 | + super().__init__(params, defaults) |
| 67 | + |
| 68 | + @torch.no_grad() # type: ignore[misc] |
| 69 | + def step(self, closure: Callable[[], float] | None = None) -> float | None: |
| 70 | + """Performs a single optimization step. |
| 71 | + Args: |
| 72 | + closure (callable, optional): A closure that reevaluates the model |
| 73 | + and returns the loss. |
| 74 | + """ |
| 75 | + loss = closure() if closure is not None else None |
| 76 | + |
| 77 | + for group in self.param_groups: |
| 78 | + lr = group["lr"] |
| 79 | + mom = group["momentum"] |
| 80 | + wd = group["weight_decay"] |
| 81 | + dim = group["dim"] |
| 82 | + eps = group["eps"] |
| 83 | + |
| 84 | + for param in group["params"]: |
| 85 | + if param.grad is None: |
| 86 | + continue |
| 87 | + if param.ndim != 2: |
| 88 | + raise ValueError("ObliqueSGD only supports 2D parameters") |
| 89 | + grad = param.grad |
| 90 | + |
| 91 | + # Initialize momentum buffer if needed |
| 92 | + state = self.state[param] |
| 93 | + if "momentum_buffer" not in state: |
| 94 | + state["momentum_buffer"] = torch.zeros_like(param) |
| 95 | + |
| 96 | + buf = state["momentum_buffer"] |
| 97 | + |
| 98 | + # theory style momentum |
| 99 | + buf = torch.add(grad, buf, alpha=mom) |
| 100 | + |
| 101 | + # Apply Riemannian gradient update |
| 102 | + _compute_riemannian_grad_and_update(param, buf, dim, lr, wd) |
| 103 | + |
| 104 | + # Retraction back to the manifold, the hyper-sphere |
| 105 | + torch.nn.functional.normalize(param, p=2.0, dim=dim, eps=eps, out=param) |
| 106 | + |
| 107 | + return loss |
| 108 | + |
| 109 | + |
| 110 | +class ObliqueAdam(Optimizer): |
| 111 | + """Adam optimizer for row- or column-normalized 2D parameters on oblique manifolds. |
| 112 | +
|
| 113 | + This optimizer adapts an Adam-like algorithm to work on oblique manifolds, where |
| 114 | + parameters are constrained to have unit-norm rows or columns. It combines |
| 115 | + adaptive momentum estimation with Riemannian gradient computation and manifold retraction. |
| 116 | + """ |
| 117 | + |
| 118 | + def __init__( |
| 119 | + self, |
| 120 | + params: list[torch.nn.Parameter], |
| 121 | + lr: float = 1e-3, |
| 122 | + betas: tuple[float, float] = (0.9, 0.99), |
| 123 | + weight_decay: float = 0.0, |
| 124 | + dim: int = 0, |
| 125 | + eps: float = 1e-8, |
| 126 | + correct_bias: bool = True, |
| 127 | + ) -> None: |
| 128 | + """An Adam-like optimizer for Normalized 2d Parameters |
| 129 | +
|
| 130 | + Args: |
| 131 | + lr: The learning rate. |
| 132 | + betas: The coefficients used for computing running averages of gradient and its square. |
| 133 | + weight_decay: The weight decay coefficient. |
| 134 | + dim: The dimension to normalize over. |
| 135 | + eps: The epsilon for numerical stability. |
| 136 | + correct_bias: Whether to correct bias in Adam-like computation. |
| 137 | + """ |
| 138 | + if lr < 0.0: |
| 139 | + raise ValueError(f"Invalid learning rate: {lr}") |
| 140 | + if betas[0] < 0.0 or betas[0] >= 1.0: |
| 141 | + raise ValueError(f"Invalid beta1 value: {betas[0]}") |
| 142 | + if betas[1] < 0.0 or betas[1] >= 1.0: |
| 143 | + raise ValueError(f"Invalid beta2 value: {betas[1]}") |
| 144 | + if weight_decay < 0.0: |
| 145 | + raise ValueError(f"Invalid weight_decay value: {weight_decay}") |
| 146 | + |
| 147 | + defaults = dict( |
| 148 | + lr=lr, |
| 149 | + betas=betas, |
| 150 | + weight_decay=weight_decay, |
| 151 | + dim=dim, |
| 152 | + eps=eps, |
| 153 | + correct_bias=correct_bias, |
| 154 | + ) |
| 155 | + super().__init__(params, defaults) |
| 156 | + |
| 157 | + @torch.no_grad() # type: ignore[misc] |
| 158 | + def step(self, closure: Callable[[], float] | None = None) -> float | None: |
| 159 | + """Performs a single optimization step. |
| 160 | + Args: |
| 161 | + closure (callable, optional): A closure that reevaluates the model |
| 162 | + and returns the loss. |
| 163 | + """ |
| 164 | + loss = closure() if closure is not None else None |
| 165 | + |
| 166 | + for group in self.param_groups: |
| 167 | + lr = group["lr"] |
| 168 | + betas = group["betas"] |
| 169 | + wd = group["weight_decay"] |
| 170 | + dim = group["dim"] |
| 171 | + eps = group["eps"] |
| 172 | + correct_bias = group["correct_bias"] |
| 173 | + |
| 174 | + for param in group["params"]: |
| 175 | + if param.grad is None: |
| 176 | + continue |
| 177 | + if param.ndim != 2: |
| 178 | + raise ValueError("ObliqueAdam only supports 2D parameters") |
| 179 | + |
| 180 | + state = self.state[param] |
| 181 | + if "step" not in state: |
| 182 | + state["step"] = 0 |
| 183 | + |
| 184 | + grad = param.grad |
| 185 | + |
| 186 | + # Initialize momentum buffer if needed |
| 187 | + if "exp_avg" not in state: |
| 188 | + state["exp_avg"] = torch.zeros_like(param) |
| 189 | + if "exp_avg_sq" not in state: |
| 190 | + state["exp_avg_sq"] = torch.zeros_like(param) |
| 191 | + |
| 192 | + exp_avg = state["exp_avg"] |
| 193 | + exp_avg_sq = state["exp_avg_sq"] |
| 194 | + |
| 195 | + # Increment step counter |
| 196 | + state["step"] += 1 |
| 197 | + step = state["step"] |
| 198 | + |
| 199 | + # Update biased first and second moment estimates |
| 200 | + exp_avg.mul_(betas[0]).add_(grad, alpha=1 - betas[0]) |
| 201 | + exp_avg_sq.mul_(betas[1]).addcmul_(grad, grad, value=1 - betas[1]) |
| 202 | + |
| 203 | + if correct_bias: |
| 204 | + # step size correction for ADAM moments EMA |
| 205 | + bias_correction1 = 1.0 - betas[0] ** step |
| 206 | + bias_correction2 = 1.0 - betas[1] ** step |
| 207 | + else: |
| 208 | + bias_correction1 = 1.0 |
| 209 | + bias_correction2 = 1.0 |
| 210 | + |
| 211 | + norm_grad = (exp_avg / bias_correction1) / (exp_avg_sq.sqrt() / bias_correction2 + eps) |
| 212 | + |
| 213 | + # Apply Riemannian gradient update |
| 214 | + _compute_riemannian_grad_and_update(param, norm_grad, dim, lr, wd) |
| 215 | + |
| 216 | + # Retraction back to the manifold, i.e. the hyper-sphere |
| 217 | + torch.nn.functional.normalize(param, p=2.0, dim=dim, eps=eps, out=param) |
| 218 | + |
| 219 | + return loss |
| 220 | + |
| 221 | + |
| 222 | +def _compute_riemannian_grad_and_update( |
| 223 | + param: torch.Tensor, grad_like: torch.Tensor, dim: int, lr: float, wd: float |
| 224 | +) -> None: |
| 225 | + """Compute Riemannian gradient for oblique manifold and update parameter in-place. |
| 226 | +
|
| 227 | + Args: |
| 228 | + param: Parameter tensor (2D) |
| 229 | + grad_like: Gradient-like tensor (momentum buffer or normalized gradient) |
| 230 | + dim: The dimension to normalize over |
| 231 | + lr: Learning rate |
| 232 | + wd: Weight decay coefficient |
| 233 | + """ |
| 234 | + |
| 235 | + inner = (param * grad_like).sum(dim=dim, keepdim=True) |
| 236 | + riem_grad = torch.add(grad_like, param * inner, alpha=-1) |
| 237 | + |
| 238 | + # Add decoupled weight decay |
| 239 | + param.mul_(1 - lr * wd) |
| 240 | + |
| 241 | + # Apply update in-place |
| 242 | + param.add_(riem_grad, alpha=-lr) |
0 commit comments