-
Notifications
You must be signed in to change notification settings - Fork 10
Normalized Riemannian optimizer #36
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Merged
Merged
Changes from all commits
Commits
Show all changes
29 commits
Select commit
Hold shift + click to select a range
0a68eb3
added normalized optimizers and fixed docstrings and formatting
mkhona-nvidia 81d2a64
added tests for normalized optimizer
mkhona-nvidia b32a71e
separated out riemannian grad as a private function
mkhona-nvidia b13a403
added a simpek convergence test
mkhona-nvidia c7c48a5
improved accuracy of simple models
mkhona-nvidia c706f27
improved helper function by reducing code, docs
mkhona-nvidia 9af8a3a
changed to in-place updates
mkhona-nvidia 97db541
moved weight decay before update, previous case was a no-op after norm
mkhona-nvidia f9ce54c
removed overwrite msg from test
mkhona-nvidia 18f5a64
address PR comments, moved convergence to separate test
mkhona-nvidia 9a0b6c6
added convergence test as a separate test
mkhona-nvidia 64e74be
renamed from key for row/column to numerical dim in optimizer and all…
mkhona-nvidia ba80445
tensors are now not torch.Parameter, cleaned up test
mkhona-nvidia 0170de8
added test to L1
mkhona-nvidia 134e08b
added tests to L0
mkhona-nvidia 3ae3147
changed docstring
mkhona-nvidia f130aa6
changed to one torch.add for memory pressure
mkhona-nvidia 6859351
added missing type hints
mkhona-nvidia 017cc39
added missing types for optimizer args
mkhona-nvidia 1adfebb
added some more missing type hints
mkhona-nvidia ed90ff6
cleaned up test cases
mkhona-nvidia 2b4cb4e
fixed testing as per PR
mkhona-nvidia 5bc9a49
fixed inplace init
mkhona-nvidia 07c52bd
fixed formatting
mkhona-nvidia 344689f
added missing device movement
mkhona-nvidia 6eab498
moved model to device appropriately
mkhona-nvidia 5d7dcf2
added edm2 paper in docstring
mkhona-nvidia b8a9bdc
added explicit device flag to control device
mkhona-nvidia f97ef77
added flags to test
mkhona-nvidia File filter
Filter by extension
Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
There are no files selected for viewing
242 changes: 242 additions & 0 deletions
242
emerging_optimizers/riemannian_optimizers/normalized_optimizer.py
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
| Original file line number | Diff line number | Diff line change |
|---|---|---|
| @@ -0,0 +1,242 @@ | ||
| # SPDX-FileCopyrightText: Copyright (c) 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved. | ||
| # SPDX-License-Identifier: Apache-2.0 | ||
| # | ||
| # Licensed under the Apache License, Version 2.0 (the "License"); | ||
| # you may not use this file except in compliance with the License. | ||
| # You may obtain a copy of the License at | ||
| # | ||
| # http://www.apache.org/licenses/LICENSE-2.0 | ||
| # | ||
| # Unless required by applicable law or agreed to in writing, software | ||
| # distributed under the License is distributed on an "AS IS" BASIS, | ||
| # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | ||
| # See the License for the specific language governing permissions and | ||
| # limitations under the License. | ||
| from typing import Callable | ||
|
|
||
| import torch | ||
| from torch.optim.optimizer import Optimizer | ||
|
|
||
|
|
||
| class ObliqueSGD(Optimizer): | ||
| """SGD optimizer for row- or column-normalized 2D parameters on oblique manifolds. | ||
|
|
||
| This optimizer performs SGD on oblique manifolds, where parameters are constrained | ||
| to have unit-norm rows or columns. It implements Riemannian SGD with manifold-aware | ||
| gradient updates and retraction operations. | ||
|
|
||
| References: | ||
| - An Introduction to Optimization on Smooth Manifolds (Nicolas Boumal) | ||
| - EDM2: https://arxiv.org/abs/2312.02696 | ||
| - Jianlin Su: https://kexue.fm/archives/11196 | ||
| - Raman et al.: https://arxiv.org/abs/1909.06463 | ||
| - Franz Cesista: https://leloykun.github.io/ponder/steepest-descent-stiefel/#6-bonus-a-muon-like-optimizer-for-the-embedding-and-unembedding-layers | ||
|
|
||
| Args: | ||
| lr: learning rate | ||
| momentum: momentum coefficient | ||
| weight_decay: weight decay coefficient | ||
| dim: The dimension to normalize over | ||
| eps: epsilon for numerical stability | ||
| """ | ||
|
|
||
| def __init__( | ||
| self, | ||
| params: list[torch.nn.Parameter], | ||
| lr: float = 1e-3, | ||
| momentum: float = 0.9, | ||
| weight_decay: float = 0.0, | ||
| dim: int = 0, | ||
| eps: float = 1e-8, | ||
| ) -> None: | ||
| if lr < 0.0: | ||
| raise ValueError(f"Invalid learning rate: {lr}") | ||
| if momentum < 0.0 or momentum >= 1.0: | ||
| raise ValueError(f"Invalid momentum value: {momentum}") | ||
| if weight_decay < 0.0: | ||
| raise ValueError(f"Invalid weight_decay value: {weight_decay}") | ||
|
|
||
| defaults = dict( | ||
| lr=lr, | ||
| momentum=momentum, | ||
| weight_decay=weight_decay, | ||
| dim=dim, | ||
| eps=eps, | ||
| ) | ||
| super().__init__(params, defaults) | ||
|
|
||
| @torch.no_grad() # type: ignore[misc] | ||
| def step(self, closure: Callable[[], float] | None = None) -> float | None: | ||
| """Performs a single optimization step. | ||
| Args: | ||
| closure (callable, optional): A closure that reevaluates the model | ||
| and returns the loss. | ||
| """ | ||
| loss = closure() if closure is not None else None | ||
|
|
||
| for group in self.param_groups: | ||
| lr = group["lr"] | ||
| mom = group["momentum"] | ||
| wd = group["weight_decay"] | ||
| dim = group["dim"] | ||
| eps = group["eps"] | ||
|
|
||
| for param in group["params"]: | ||
| if param.grad is None: | ||
| continue | ||
| if param.ndim != 2: | ||
| raise ValueError("ObliqueSGD only supports 2D parameters") | ||
| grad = param.grad | ||
|
|
||
| # Initialize momentum buffer if needed | ||
| state = self.state[param] | ||
| if "momentum_buffer" not in state: | ||
| state["momentum_buffer"] = torch.zeros_like(param) | ||
|
|
||
| buf = state["momentum_buffer"] | ||
|
|
||
| # theory style momentum | ||
| buf = torch.add(grad, buf, alpha=mom) | ||
|
|
||
| # Apply Riemannian gradient update | ||
| _compute_riemannian_grad_and_update(param, buf, dim, lr, wd) | ||
|
|
||
| # Retraction back to the manifold, the hyper-sphere | ||
| torch.nn.functional.normalize(param, p=2.0, dim=dim, eps=eps, out=param) | ||
|
|
||
| return loss | ||
|
|
||
|
|
||
| class ObliqueAdam(Optimizer): | ||
| """Adam optimizer for row- or column-normalized 2D parameters on oblique manifolds. | ||
|
|
||
| This optimizer adapts an Adam-like algorithm to work on oblique manifolds, where | ||
| parameters are constrained to have unit-norm rows or columns. It combines | ||
| adaptive momentum estimation with Riemannian gradient computation and manifold retraction. | ||
| """ | ||
|
|
||
| def __init__( | ||
| self, | ||
mkhona-nvidia marked this conversation as resolved.
Show resolved
Hide resolved
|
||
| params: list[torch.nn.Parameter], | ||
| lr: float = 1e-3, | ||
| betas: tuple[float, float] = (0.9, 0.99), | ||
| weight_decay: float = 0.0, | ||
| dim: int = 0, | ||
| eps: float = 1e-8, | ||
| correct_bias: bool = True, | ||
| ) -> None: | ||
| """An Adam-like optimizer for Normalized 2d Parameters | ||
|
|
||
| Args: | ||
| lr: The learning rate. | ||
| betas: The coefficients used for computing running averages of gradient and its square. | ||
| weight_decay: The weight decay coefficient. | ||
| dim: The dimension to normalize over. | ||
| eps: The epsilon for numerical stability. | ||
| correct_bias: Whether to correct bias in Adam-like computation. | ||
| """ | ||
| if lr < 0.0: | ||
| raise ValueError(f"Invalid learning rate: {lr}") | ||
| if betas[0] < 0.0 or betas[0] >= 1.0: | ||
| raise ValueError(f"Invalid beta1 value: {betas[0]}") | ||
| if betas[1] < 0.0 or betas[1] >= 1.0: | ||
| raise ValueError(f"Invalid beta2 value: {betas[1]}") | ||
| if weight_decay < 0.0: | ||
| raise ValueError(f"Invalid weight_decay value: {weight_decay}") | ||
|
|
||
| defaults = dict( | ||
| lr=lr, | ||
| betas=betas, | ||
| weight_decay=weight_decay, | ||
| dim=dim, | ||
| eps=eps, | ||
| correct_bias=correct_bias, | ||
| ) | ||
| super().__init__(params, defaults) | ||
|
|
||
| @torch.no_grad() # type: ignore[misc] | ||
| def step(self, closure: Callable[[], float] | None = None) -> float | None: | ||
| """Performs a single optimization step. | ||
| Args: | ||
| closure (callable, optional): A closure that reevaluates the model | ||
| and returns the loss. | ||
| """ | ||
| loss = closure() if closure is not None else None | ||
|
|
||
| for group in self.param_groups: | ||
| lr = group["lr"] | ||
| betas = group["betas"] | ||
| wd = group["weight_decay"] | ||
| dim = group["dim"] | ||
| eps = group["eps"] | ||
| correct_bias = group["correct_bias"] | ||
|
|
||
| for param in group["params"]: | ||
| if param.grad is None: | ||
| continue | ||
| if param.ndim != 2: | ||
| raise ValueError("ObliqueAdam only supports 2D parameters") | ||
|
|
||
| state = self.state[param] | ||
| if "step" not in state: | ||
| state["step"] = 0 | ||
|
|
||
| grad = param.grad | ||
|
|
||
| # Initialize momentum buffer if needed | ||
| if "exp_avg" not in state: | ||
| state["exp_avg"] = torch.zeros_like(param) | ||
| if "exp_avg_sq" not in state: | ||
| state["exp_avg_sq"] = torch.zeros_like(param) | ||
|
|
||
| exp_avg = state["exp_avg"] | ||
| exp_avg_sq = state["exp_avg_sq"] | ||
|
|
||
| # Increment step counter | ||
| state["step"] += 1 | ||
| step = state["step"] | ||
|
|
||
| # Update biased first and second moment estimates | ||
| exp_avg.mul_(betas[0]).add_(grad, alpha=1 - betas[0]) | ||
| exp_avg_sq.mul_(betas[1]).addcmul_(grad, grad, value=1 - betas[1]) | ||
|
|
||
| if correct_bias: | ||
| # step size correction for ADAM moments EMA | ||
| bias_correction1 = 1.0 - betas[0] ** step | ||
| bias_correction2 = 1.0 - betas[1] ** step | ||
| else: | ||
| bias_correction1 = 1.0 | ||
| bias_correction2 = 1.0 | ||
|
|
||
| norm_grad = (exp_avg / bias_correction1) / (exp_avg_sq.sqrt() / bias_correction2 + eps) | ||
mkhona-nvidia marked this conversation as resolved.
Show resolved
Hide resolved
|
||
|
|
||
| # Apply Riemannian gradient update | ||
| _compute_riemannian_grad_and_update(param, norm_grad, dim, lr, wd) | ||
|
|
||
| # Retraction back to the manifold, i.e. the hyper-sphere | ||
| torch.nn.functional.normalize(param, p=2.0, dim=dim, eps=eps, out=param) | ||
|
|
||
| return loss | ||
|
|
||
|
|
||
| def _compute_riemannian_grad_and_update( | ||
| param: torch.Tensor, grad_like: torch.Tensor, dim: int, lr: float, wd: float | ||
| ) -> None: | ||
| """Compute Riemannian gradient for oblique manifold and update parameter in-place. | ||
|
|
||
| Args: | ||
| param: Parameter tensor (2D) | ||
| grad_like: Gradient-like tensor (momentum buffer or normalized gradient) | ||
| dim: The dimension to normalize over | ||
| lr: Learning rate | ||
| wd: Weight decay coefficient | ||
| """ | ||
|
|
||
| inner = (param * grad_like).sum(dim=dim, keepdim=True) | ||
| riem_grad = torch.add(grad_like, param * inner, alpha=-1) | ||
|
|
||
| # Add decoupled weight decay | ||
| param.mul_(1 - lr * wd) | ||
|
|
||
| # Apply update in-place | ||
| param.add_(riem_grad, alpha=-lr) | ||
skyw marked this conversation as resolved.
Show resolved
Hide resolved
|
||
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Oops, something went wrong.
Add this suggestion to a batch that can be applied as a single commit.
This suggestion is invalid because no changes were made to the code.
Suggestions cannot be applied while the pull request is closed.
Suggestions cannot be applied while viewing a subset of changes.
Only one suggestion per line can be applied in a batch.
Add this suggestion to a batch that can be applied as a single commit.
Applying suggestions on deleted lines is not supported.
You must change the existing code in this line in order to create a valid suggestion.
Outdated suggestions cannot be applied.
This suggestion has been applied or marked resolved.
Suggestions cannot be applied from pending reviews.
Suggestions cannot be applied on multi-line comments.
Suggestions cannot be applied while the pull request is queued to merge.
Suggestion cannot be applied right now. Please check back later.
Uh oh!
There was an error while loading. Please reload this page.