Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
29 commits
Select commit Hold shift + click to select a range
0a68eb3
added normalized optimizers and fixed docstrings and formatting
mkhona-nvidia Oct 2, 2025
81d2a64
added tests for normalized optimizer
mkhona-nvidia Oct 2, 2025
b32a71e
separated out riemannian grad as a private function
mkhona-nvidia Oct 2, 2025
b13a403
added a simpek convergence test
mkhona-nvidia Oct 2, 2025
c7c48a5
improved accuracy of simple models
mkhona-nvidia Oct 2, 2025
c706f27
improved helper function by reducing code, docs
mkhona-nvidia Oct 2, 2025
9af8a3a
changed to in-place updates
mkhona-nvidia Oct 2, 2025
97db541
moved weight decay before update, previous case was a no-op after norm
mkhona-nvidia Oct 2, 2025
f9ce54c
removed overwrite msg from test
mkhona-nvidia Oct 3, 2025
18f5a64
address PR comments, moved convergence to separate test
mkhona-nvidia Oct 3, 2025
9a0b6c6
added convergence test as a separate test
mkhona-nvidia Oct 3, 2025
64e74be
renamed from key for row/column to numerical dim in optimizer and all…
mkhona-nvidia Oct 3, 2025
ba80445
tensors are now not torch.Parameter, cleaned up test
mkhona-nvidia Oct 3, 2025
0170de8
added test to L1
mkhona-nvidia Oct 3, 2025
134e08b
added tests to L0
mkhona-nvidia Oct 3, 2025
3ae3147
changed docstring
mkhona-nvidia Oct 3, 2025
f130aa6
changed to one torch.add for memory pressure
mkhona-nvidia Oct 3, 2025
6859351
added missing type hints
mkhona-nvidia Oct 3, 2025
017cc39
added missing types for optimizer args
mkhona-nvidia Oct 3, 2025
1adfebb
added some more missing type hints
mkhona-nvidia Oct 3, 2025
ed90ff6
cleaned up test cases
mkhona-nvidia Oct 3, 2025
2b4cb4e
fixed testing as per PR
mkhona-nvidia Oct 4, 2025
5bc9a49
fixed inplace init
mkhona-nvidia Oct 4, 2025
07c52bd
fixed formatting
mkhona-nvidia Oct 6, 2025
344689f
added missing device movement
mkhona-nvidia Oct 7, 2025
6eab498
moved model to device appropriately
mkhona-nvidia Oct 7, 2025
5d7dcf2
added edm2 paper in docstring
mkhona-nvidia Oct 7, 2025
b8a9bdc
added explicit device flag to control device
mkhona-nvidia Oct 7, 2025
f97ef77
added flags to test
mkhona-nvidia Oct 7, 2025
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
242 changes: 242 additions & 0 deletions emerging_optimizers/riemannian_optimizers/normalized_optimizer.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,242 @@
# SPDX-FileCopyrightText: Copyright (c) 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
# SPDX-License-Identifier: Apache-2.0
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from typing import Callable

import torch
from torch.optim.optimizer import Optimizer


class ObliqueSGD(Optimizer):
"""SGD optimizer for row- or column-normalized 2D parameters on oblique manifolds.

This optimizer performs SGD on oblique manifolds, where parameters are constrained
to have unit-norm rows or columns. It implements Riemannian SGD with manifold-aware
gradient updates and retraction operations.

References:
- An Introduction to Optimization on Smooth Manifolds (Nicolas Boumal)
- EDM2: https://arxiv.org/abs/2312.02696
- Jianlin Su: https://kexue.fm/archives/11196
- Raman et al.: https://arxiv.org/abs/1909.06463
- Franz Cesista: https://leloykun.github.io/ponder/steepest-descent-stiefel/#6-bonus-a-muon-like-optimizer-for-the-embedding-and-unembedding-layers

Args:
lr: learning rate
momentum: momentum coefficient
weight_decay: weight decay coefficient
dim: The dimension to normalize over
eps: epsilon for numerical stability
"""

def __init__(
self,
params: list[torch.nn.Parameter],
lr: float = 1e-3,
momentum: float = 0.9,
weight_decay: float = 0.0,
dim: int = 0,
eps: float = 1e-8,
) -> None:
if lr < 0.0:
raise ValueError(f"Invalid learning rate: {lr}")
if momentum < 0.0 or momentum >= 1.0:
raise ValueError(f"Invalid momentum value: {momentum}")
if weight_decay < 0.0:
raise ValueError(f"Invalid weight_decay value: {weight_decay}")

defaults = dict(
lr=lr,
momentum=momentum,
weight_decay=weight_decay,
dim=dim,
eps=eps,
)
super().__init__(params, defaults)

@torch.no_grad() # type: ignore[misc]
def step(self, closure: Callable[[], float] | None = None) -> float | None:
"""Performs a single optimization step.
Args:
closure (callable, optional): A closure that reevaluates the model
and returns the loss.
"""
loss = closure() if closure is not None else None

for group in self.param_groups:
lr = group["lr"]
mom = group["momentum"]
wd = group["weight_decay"]
dim = group["dim"]
eps = group["eps"]

for param in group["params"]:
if param.grad is None:
continue
if param.ndim != 2:
raise ValueError("ObliqueSGD only supports 2D parameters")
grad = param.grad

# Initialize momentum buffer if needed
state = self.state[param]
if "momentum_buffer" not in state:
state["momentum_buffer"] = torch.zeros_like(param)

buf = state["momentum_buffer"]

# theory style momentum
buf = torch.add(grad, buf, alpha=mom)

# Apply Riemannian gradient update
_compute_riemannian_grad_and_update(param, buf, dim, lr, wd)

# Retraction back to the manifold, the hyper-sphere
torch.nn.functional.normalize(param, p=2.0, dim=dim, eps=eps, out=param)

return loss


class ObliqueAdam(Optimizer):
"""Adam optimizer for row- or column-normalized 2D parameters on oblique manifolds.

This optimizer adapts an Adam-like algorithm to work on oblique manifolds, where
parameters are constrained to have unit-norm rows or columns. It combines
adaptive momentum estimation with Riemannian gradient computation and manifold retraction.
"""

def __init__(
self,
params: list[torch.nn.Parameter],
lr: float = 1e-3,
betas: tuple[float, float] = (0.9, 0.99),
weight_decay: float = 0.0,
dim: int = 0,
eps: float = 1e-8,
correct_bias: bool = True,
) -> None:
"""An Adam-like optimizer for Normalized 2d Parameters

Args:
lr: The learning rate.
betas: The coefficients used for computing running averages of gradient and its square.
weight_decay: The weight decay coefficient.
dim: The dimension to normalize over.
eps: The epsilon for numerical stability.
correct_bias: Whether to correct bias in Adam-like computation.
"""
if lr < 0.0:
raise ValueError(f"Invalid learning rate: {lr}")
if betas[0] < 0.0 or betas[0] >= 1.0:
raise ValueError(f"Invalid beta1 value: {betas[0]}")
if betas[1] < 0.0 or betas[1] >= 1.0:
raise ValueError(f"Invalid beta2 value: {betas[1]}")
if weight_decay < 0.0:
raise ValueError(f"Invalid weight_decay value: {weight_decay}")

defaults = dict(
lr=lr,
betas=betas,
weight_decay=weight_decay,
dim=dim,
eps=eps,
correct_bias=correct_bias,
)
super().__init__(params, defaults)

@torch.no_grad() # type: ignore[misc]
def step(self, closure: Callable[[], float] | None = None) -> float | None:
"""Performs a single optimization step.
Args:
closure (callable, optional): A closure that reevaluates the model
and returns the loss.
"""
loss = closure() if closure is not None else None

for group in self.param_groups:
lr = group["lr"]
betas = group["betas"]
wd = group["weight_decay"]
dim = group["dim"]
eps = group["eps"]
correct_bias = group["correct_bias"]

for param in group["params"]:
if param.grad is None:
continue
if param.ndim != 2:
raise ValueError("ObliqueAdam only supports 2D parameters")

state = self.state[param]
if "step" not in state:
state["step"] = 0

grad = param.grad

# Initialize momentum buffer if needed
if "exp_avg" not in state:
state["exp_avg"] = torch.zeros_like(param)
if "exp_avg_sq" not in state:
state["exp_avg_sq"] = torch.zeros_like(param)

exp_avg = state["exp_avg"]
exp_avg_sq = state["exp_avg_sq"]

# Increment step counter
state["step"] += 1
step = state["step"]

# Update biased first and second moment estimates
exp_avg.mul_(betas[0]).add_(grad, alpha=1 - betas[0])
exp_avg_sq.mul_(betas[1]).addcmul_(grad, grad, value=1 - betas[1])

if correct_bias:
# step size correction for ADAM moments EMA
bias_correction1 = 1.0 - betas[0] ** step
bias_correction2 = 1.0 - betas[1] ** step
else:
bias_correction1 = 1.0
bias_correction2 = 1.0

norm_grad = (exp_avg / bias_correction1) / (exp_avg_sq.sqrt() / bias_correction2 + eps)

# Apply Riemannian gradient update
_compute_riemannian_grad_and_update(param, norm_grad, dim, lr, wd)

# Retraction back to the manifold, i.e. the hyper-sphere
torch.nn.functional.normalize(param, p=2.0, dim=dim, eps=eps, out=param)

return loss


def _compute_riemannian_grad_and_update(
param: torch.Tensor, grad_like: torch.Tensor, dim: int, lr: float, wd: float
) -> None:
"""Compute Riemannian gradient for oblique manifold and update parameter in-place.

Args:
param: Parameter tensor (2D)
grad_like: Gradient-like tensor (momentum buffer or normalized gradient)
dim: The dimension to normalize over
lr: Learning rate
wd: Weight decay coefficient
"""

inner = (param * grad_like).sum(dim=dim, keepdim=True)
riem_grad = torch.add(grad_like, param * inner, alpha=-1)

# Add decoupled weight decay
param.mul_(1 - lr * wd)

# Apply update in-place
param.add_(riem_grad, alpha=-lr)
4 changes: 3 additions & 1 deletion tests/ci/L0_Tests_GPU.sh
Original file line number Diff line number Diff line change
Expand Up @@ -21,4 +21,6 @@ coverage run -p --source=emerging_optimizers tests/test_soap_utils.py
coverage run -p --source=emerging_optimizers tests/soap_smoke_test.py
coverage run -p --source=emerging_optimizers tests/soap_mnist_test.py
coverage run -p --source=emerging_optimizers tests/test_scalar_optimizers.py --device=cuda
coverage run -p --source=emerging_optimizers tests/test_spectral_clipping_utils.py
coverage run -p --source=emerging_optimizers tests/test_spectral_clipping_utils.py
coverage run -p --source=emerging_optimizers tests/test_normalized_optimizer.py --device=cuda
coverage run -p --source=emerging_optimizers tests/normalized_optimizer_convergence_test.py --device=cuda
2 changes: 2 additions & 0 deletions tests/ci/L1_Tests_GPU.sh
Original file line number Diff line number Diff line change
Expand Up @@ -20,3 +20,5 @@ python tests/test_soap_utils.py
python tests/soap_smoke_test.py
python tests/test_scalar_optimizers.py --device=cuda
python tests/test_spectral_clipping_utils.py
python tests/test_normalized_optimizer.py
python tests/normalized_optimizer_convergence_test.py
Loading