Skip to content

Commit fb1add8

Browse files
Add normalized Riemannian optimizer (#36)
* added normalized optimizers and fixed docstrings and formatting Signed-off-by: mikail <[email protected]>
1 parent 0f4e1ee commit fb1add8

File tree

5 files changed

+748
-1
lines changed

5 files changed

+748
-1
lines changed
Lines changed: 242 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,242 @@
1+
# SPDX-FileCopyrightText: Copyright (c) 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
2+
# SPDX-License-Identifier: Apache-2.0
3+
#
4+
# Licensed under the Apache License, Version 2.0 (the "License");
5+
# you may not use this file except in compliance with the License.
6+
# You may obtain a copy of the License at
7+
#
8+
# http://www.apache.org/licenses/LICENSE-2.0
9+
#
10+
# Unless required by applicable law or agreed to in writing, software
11+
# distributed under the License is distributed on an "AS IS" BASIS,
12+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13+
# See the License for the specific language governing permissions and
14+
# limitations under the License.
15+
from typing import Callable
16+
17+
import torch
18+
from torch.optim.optimizer import Optimizer
19+
20+
21+
class ObliqueSGD(Optimizer):
22+
"""SGD optimizer for row- or column-normalized 2D parameters on oblique manifolds.
23+
24+
This optimizer performs SGD on oblique manifolds, where parameters are constrained
25+
to have unit-norm rows or columns. It implements Riemannian SGD with manifold-aware
26+
gradient updates and retraction operations.
27+
28+
References:
29+
- An Introduction to Optimization on Smooth Manifolds (Nicolas Boumal)
30+
- EDM2: https://arxiv.org/abs/2312.02696
31+
- Jianlin Su: https://kexue.fm/archives/11196
32+
- Raman et al.: https://arxiv.org/abs/1909.06463
33+
- Franz Cesista: https://leloykun.github.io/ponder/steepest-descent-stiefel/#6-bonus-a-muon-like-optimizer-for-the-embedding-and-unembedding-layers
34+
35+
Args:
36+
lr: learning rate
37+
momentum: momentum coefficient
38+
weight_decay: weight decay coefficient
39+
dim: The dimension to normalize over
40+
eps: epsilon for numerical stability
41+
"""
42+
43+
def __init__(
44+
self,
45+
params: list[torch.nn.Parameter],
46+
lr: float = 1e-3,
47+
momentum: float = 0.9,
48+
weight_decay: float = 0.0,
49+
dim: int = 0,
50+
eps: float = 1e-8,
51+
) -> None:
52+
if lr < 0.0:
53+
raise ValueError(f"Invalid learning rate: {lr}")
54+
if momentum < 0.0 or momentum >= 1.0:
55+
raise ValueError(f"Invalid momentum value: {momentum}")
56+
if weight_decay < 0.0:
57+
raise ValueError(f"Invalid weight_decay value: {weight_decay}")
58+
59+
defaults = dict(
60+
lr=lr,
61+
momentum=momentum,
62+
weight_decay=weight_decay,
63+
dim=dim,
64+
eps=eps,
65+
)
66+
super().__init__(params, defaults)
67+
68+
@torch.no_grad() # type: ignore[misc]
69+
def step(self, closure: Callable[[], float] | None = None) -> float | None:
70+
"""Performs a single optimization step.
71+
Args:
72+
closure (callable, optional): A closure that reevaluates the model
73+
and returns the loss.
74+
"""
75+
loss = closure() if closure is not None else None
76+
77+
for group in self.param_groups:
78+
lr = group["lr"]
79+
mom = group["momentum"]
80+
wd = group["weight_decay"]
81+
dim = group["dim"]
82+
eps = group["eps"]
83+
84+
for param in group["params"]:
85+
if param.grad is None:
86+
continue
87+
if param.ndim != 2:
88+
raise ValueError("ObliqueSGD only supports 2D parameters")
89+
grad = param.grad
90+
91+
# Initialize momentum buffer if needed
92+
state = self.state[param]
93+
if "momentum_buffer" not in state:
94+
state["momentum_buffer"] = torch.zeros_like(param)
95+
96+
buf = state["momentum_buffer"]
97+
98+
# theory style momentum
99+
buf = torch.add(grad, buf, alpha=mom)
100+
101+
# Apply Riemannian gradient update
102+
_compute_riemannian_grad_and_update(param, buf, dim, lr, wd)
103+
104+
# Retraction back to the manifold, the hyper-sphere
105+
torch.nn.functional.normalize(param, p=2.0, dim=dim, eps=eps, out=param)
106+
107+
return loss
108+
109+
110+
class ObliqueAdam(Optimizer):
111+
"""Adam optimizer for row- or column-normalized 2D parameters on oblique manifolds.
112+
113+
This optimizer adapts an Adam-like algorithm to work on oblique manifolds, where
114+
parameters are constrained to have unit-norm rows or columns. It combines
115+
adaptive momentum estimation with Riemannian gradient computation and manifold retraction.
116+
"""
117+
118+
def __init__(
119+
self,
120+
params: list[torch.nn.Parameter],
121+
lr: float = 1e-3,
122+
betas: tuple[float, float] = (0.9, 0.99),
123+
weight_decay: float = 0.0,
124+
dim: int = 0,
125+
eps: float = 1e-8,
126+
correct_bias: bool = True,
127+
) -> None:
128+
"""An Adam-like optimizer for Normalized 2d Parameters
129+
130+
Args:
131+
lr: The learning rate.
132+
betas: The coefficients used for computing running averages of gradient and its square.
133+
weight_decay: The weight decay coefficient.
134+
dim: The dimension to normalize over.
135+
eps: The epsilon for numerical stability.
136+
correct_bias: Whether to correct bias in Adam-like computation.
137+
"""
138+
if lr < 0.0:
139+
raise ValueError(f"Invalid learning rate: {lr}")
140+
if betas[0] < 0.0 or betas[0] >= 1.0:
141+
raise ValueError(f"Invalid beta1 value: {betas[0]}")
142+
if betas[1] < 0.0 or betas[1] >= 1.0:
143+
raise ValueError(f"Invalid beta2 value: {betas[1]}")
144+
if weight_decay < 0.0:
145+
raise ValueError(f"Invalid weight_decay value: {weight_decay}")
146+
147+
defaults = dict(
148+
lr=lr,
149+
betas=betas,
150+
weight_decay=weight_decay,
151+
dim=dim,
152+
eps=eps,
153+
correct_bias=correct_bias,
154+
)
155+
super().__init__(params, defaults)
156+
157+
@torch.no_grad() # type: ignore[misc]
158+
def step(self, closure: Callable[[], float] | None = None) -> float | None:
159+
"""Performs a single optimization step.
160+
Args:
161+
closure (callable, optional): A closure that reevaluates the model
162+
and returns the loss.
163+
"""
164+
loss = closure() if closure is not None else None
165+
166+
for group in self.param_groups:
167+
lr = group["lr"]
168+
betas = group["betas"]
169+
wd = group["weight_decay"]
170+
dim = group["dim"]
171+
eps = group["eps"]
172+
correct_bias = group["correct_bias"]
173+
174+
for param in group["params"]:
175+
if param.grad is None:
176+
continue
177+
if param.ndim != 2:
178+
raise ValueError("ObliqueAdam only supports 2D parameters")
179+
180+
state = self.state[param]
181+
if "step" not in state:
182+
state["step"] = 0
183+
184+
grad = param.grad
185+
186+
# Initialize momentum buffer if needed
187+
if "exp_avg" not in state:
188+
state["exp_avg"] = torch.zeros_like(param)
189+
if "exp_avg_sq" not in state:
190+
state["exp_avg_sq"] = torch.zeros_like(param)
191+
192+
exp_avg = state["exp_avg"]
193+
exp_avg_sq = state["exp_avg_sq"]
194+
195+
# Increment step counter
196+
state["step"] += 1
197+
step = state["step"]
198+
199+
# Update biased first and second moment estimates
200+
exp_avg.mul_(betas[0]).add_(grad, alpha=1 - betas[0])
201+
exp_avg_sq.mul_(betas[1]).addcmul_(grad, grad, value=1 - betas[1])
202+
203+
if correct_bias:
204+
# step size correction for ADAM moments EMA
205+
bias_correction1 = 1.0 - betas[0] ** step
206+
bias_correction2 = 1.0 - betas[1] ** step
207+
else:
208+
bias_correction1 = 1.0
209+
bias_correction2 = 1.0
210+
211+
norm_grad = (exp_avg / bias_correction1) / (exp_avg_sq.sqrt() / bias_correction2 + eps)
212+
213+
# Apply Riemannian gradient update
214+
_compute_riemannian_grad_and_update(param, norm_grad, dim, lr, wd)
215+
216+
# Retraction back to the manifold, i.e. the hyper-sphere
217+
torch.nn.functional.normalize(param, p=2.0, dim=dim, eps=eps, out=param)
218+
219+
return loss
220+
221+
222+
def _compute_riemannian_grad_and_update(
223+
param: torch.Tensor, grad_like: torch.Tensor, dim: int, lr: float, wd: float
224+
) -> None:
225+
"""Compute Riemannian gradient for oblique manifold and update parameter in-place.
226+
227+
Args:
228+
param: Parameter tensor (2D)
229+
grad_like: Gradient-like tensor (momentum buffer or normalized gradient)
230+
dim: The dimension to normalize over
231+
lr: Learning rate
232+
wd: Weight decay coefficient
233+
"""
234+
235+
inner = (param * grad_like).sum(dim=dim, keepdim=True)
236+
riem_grad = torch.add(grad_like, param * inner, alpha=-1)
237+
238+
# Add decoupled weight decay
239+
param.mul_(1 - lr * wd)
240+
241+
# Apply update in-place
242+
param.add_(riem_grad, alpha=-lr)

tests/ci/L0_Tests_GPU.sh

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -21,4 +21,6 @@ coverage run -p --source=emerging_optimizers tests/test_soap_utils.py
2121
coverage run -p --source=emerging_optimizers tests/soap_smoke_test.py
2222
coverage run -p --source=emerging_optimizers tests/soap_mnist_test.py
2323
coverage run -p --source=emerging_optimizers tests/test_scalar_optimizers.py --device=cuda
24-
coverage run -p --source=emerging_optimizers tests/test_spectral_clipping_utils.py
24+
coverage run -p --source=emerging_optimizers tests/test_spectral_clipping_utils.py
25+
coverage run -p --source=emerging_optimizers tests/test_normalized_optimizer.py --device=cuda
26+
coverage run -p --source=emerging_optimizers tests/normalized_optimizer_convergence_test.py --device=cuda

tests/ci/L1_Tests_GPU.sh

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -20,3 +20,5 @@ python tests/test_soap_utils.py
2020
python tests/soap_smoke_test.py
2121
python tests/test_scalar_optimizers.py --device=cuda
2222
python tests/test_spectral_clipping_utils.py
23+
python tests/test_normalized_optimizer.py
24+
python tests/normalized_optimizer_convergence_test.py

0 commit comments

Comments
 (0)