Skip to content
Draft
Show file tree
Hide file tree
Changes from 7 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
6 changes: 6 additions & 0 deletions docs/apidocs/orthogonalized-optimizers.md
Original file line number Diff line number Diff line change
Expand Up @@ -39,6 +39,12 @@ emerging_optimizers.orthogonalized_optimizers
.. autoclass:: MuonHyperball
:members:

:hidden:`Spectron`
~~~~~~~~~~~~~~~~~~~

.. autoclass:: Spectron
:members:


:hidden:`Newton-Schulz`
~~~~~~~~~~~~~~~~~~~~~~~~
Expand Down
1 change: 1 addition & 0 deletions emerging_optimizers/orthogonalized_optimizers/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,3 +19,4 @@
from emerging_optimizers.orthogonalized_optimizers.orthogonalized_optimizer import *
from emerging_optimizers.orthogonalized_optimizers.scion import *
from emerging_optimizers.orthogonalized_optimizers.spectral_clipping_utils import *
from emerging_optimizers.orthogonalized_optimizers.spectron import *
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Missing trailing newline

The file is missing a trailing newline after the new import line. This is flagged by most linters and POSIX standards, and the previous version of the file had one.

Suggested change
from emerging_optimizers.orthogonalized_optimizers.spectron import *
from emerging_optimizers.orthogonalized_optimizers.spectron import *

Note: If this suggestion doesn't match your team's coding style, reply to this and let me know. I'll remember it for next time!

251 changes: 251 additions & 0 deletions emerging_optimizers/orthogonalized_optimizers/spectron.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,251 @@
# SPDX-FileCopyrightText: Copyright (c) 2026 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
# SPDX-License-Identifier: Apache-2.0
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

from typing import Callable, overload, override

import torch
import torch.optim as optim
from absl import logging
from torch.optim.optimizer import ParamsT

from emerging_optimizers import mixin as opt_mixin
from emerging_optimizers import registry, utils
from emerging_optimizers.orthogonalized_optimizers import muon_utils
from emerging_optimizers.orthogonalized_optimizers.muon_utils import NSCoeffT
from emerging_optimizers.utils import FP32MatmulPrecT
from emerging_optimizers.utils.eig import power_iteration


__all__ = ["Spectron"]


@registry.register_optimizer("spectron")
class Spectron(opt_mixin.WeightDecayMixin, optim.Optimizer):
"""Spectron: Low-rank spectral optimizer with orthogonalized momentum.

Spectron maintains each 2D weight matrix W as a low-rank factorization W = A @ B^T,
where A ∈ R^(m×r) and B ∈ R^(n×r). It applies momentum, orthogonalizes the updates
using Newton-Schulz iteration, and scales the learning rate by the spectral radii
of both factors.

The algorithm:
1. Compute gradients with respect to A and B from parameter gradients
2. Apply momentum to both factors
3. Orthogonalize momentum buffers using Newton-Schulz iteration
4. Estimate spectral radius of A and B using power iteration
5. Update with scaled learning rate: η / (σ_A + σ_B + 1)
6. Reconstruct full weight matrix W = A @ B^T

References:
- Algorithm 1 (Spectron) and Algorithm 3 (PowerIter) from the Spectron paper (https://arxiv.org/abs/2602.12429).
Low-rank spectral optimization with orthogonalized momentum.

Warning:
- This optimizer requires that all parameters passed in are 2D.
- Low-rank factorization may not be suitable for all parameter types.

Args:
params: Iterable of parameters to optimize or dicts defining parameter groups
lr: The learning rate (η in the algorithm). Default: 3e-4
rank: The rank of the low-rank factorization. Default: 64
momentum_beta: The momentum decay coefficient (β). Default: 0.9
weight_decay: The weight decay coefficient. Default: 0.01
weight_decay_method: Method to apply weight decay. Default: "decoupled"
fp32_matmul_prec: Precision of matmul operations. Default: "medium"
num_ns_steps: Number of Newton-Schulz iteration steps. Default: 5
num_power_iter: Number of power iteration steps for spectral radius. Default: 1
coefficient_type: Type of coefficient set for Newton-Schulz. Default: "quintic"
"""

def __init__(
self,
params: ParamsT,
lr: float = 3e-4,
rank: int = 64,
momentum_beta: float = 0.9,
weight_decay: float = 0.01,
*,
weight_decay_method: opt_mixin.WeightDecayT = "decoupled",
fp32_matmul_prec: FP32MatmulPrecT = "medium",
num_ns_steps: int = 5,
num_power_iter: int = 1,
coefficient_type: NSCoeffT = "quintic",
) -> None:
if lr < 0.0:
raise ValueError(f"Invalid learning rate: {lr}")
if rank < 1:
raise ValueError(f"Invalid rank: {rank}")
if not 0.0 <= momentum_beta < 1.0:
raise ValueError(f"Invalid momentum_beta: {momentum_beta}")
if weight_decay < 0.0:
raise ValueError(f"Invalid weight_decay: {weight_decay}")
if num_ns_steps < 1:
raise ValueError(f"num_ns_steps must be at least 1, got {num_ns_steps}")
if num_power_iter < 1:
raise ValueError(f"num_power_iter must be at least 1, got {num_power_iter}")

self.fp32_matmul_prec = fp32_matmul_prec
self.weight_decay_method = weight_decay_method
self.rank = rank
self.num_power_iter = num_power_iter

# Create orthogonalization function following OrthogonalizedOptimizer pattern
def scaled_orthogonalize_fn(grad: torch.Tensor) -> torch.Tensor:
logging.debug(f"Orthogonalizing grad with {num_ns_steps} steps, {coefficient_type} coefficient")
return muon_utils.newton_schulz(
grad,
steps=num_ns_steps,
coefficient_type=coefficient_type,
)

self.scaled_orthogonalize_fn = scaled_orthogonalize_fn

defaults = dict(
lr=lr,
momentum_beta=momentum_beta,
weight_decay=weight_decay,
)

super().__init__(params, defaults)

@overload
def step(self, closure: None = ...) -> None: ...

@overload
def step(self, closure: Callable[[], float]) -> float: ...

@torch.no_grad() # type: ignore[misc]
@override
def step(self, closure: Callable[[], float] | None = None) -> float | None:
"""Performs a single optimization step.

Args:
closure: A closure that reevaluates the model and returns the loss.
"""
if closure is None:
loss = None
else:
loss = closure()

for group in self.param_groups:
for p in group["params"]:
if p.grad is None:
continue

if p.ndim != 2:
raise ValueError(f"Spectron only supports 2D parameters, got shape {p.shape}")

grad = p.grad
state = self.state[p]

# Initialize low-rank factors and momentum buffers
if "factor_A" not in state:
self._initialize_state(p, state)

# Get state variables
factor_A = state["factor_A"]
factor_B = state["factor_B"]
momentum_A = state["momentum_A"]
momentum_B = state["momentum_B"]
u_A = state["u_A"]
u_B = state["u_B"]

# Compute gradients for A and B from parameter gradient
# Using chain rule: ∂L/∂A = ∂L/∂W @ B, ∂L/∂B = ∂L/∂W^T @ A
grad_A = grad @ factor_B # shape: (m, r)
grad_B = grad.mT @ factor_A # shape: (n, r)

Comment on lines +178 to +181
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Gradient dtype mismatch with non-fp32 parameters

grad = p.grad inherits p's dtype, but factor_B is always float32 (initialized from torch.linalg.svd(p.float(), ...)). When the parameter is bfloat16 — the standard dtype for LLM pretraining, which is the stated use case — the line grad @ factor_B will raise a RuntimeError at runtime:

RuntimeError: expected scalar type Float but found BFloat16

Even if PyTorch silently promotes the dtype in some contexts, momentum_A.lerp_(grad_A, ...) on line 187 will then fail because momentum_A is float32 but grad_A would be bfloat16.

The gradient should be explicitly cast to float32 before the matmul:

Suggested change
with utils.fp32_matmul_precision("highest"):
grad_A = grad @ factor_B # shape: (m, r)
grad_B = grad.mT @ factor_A # shape: (n, r)
with utils.fp32_matmul_precision("highest"):
grad_A = grad.float() @ factor_B # shape: (m, r)
grad_B = grad.float().mT @ factor_A # shape: (n, r)

# Apply weight decay
self._apply_weight_decay_inplace(factor_A, grad_A, group["lr"], group["weight_decay"])
self._apply_weight_decay_inplace(factor_B, grad_B, group["lr"], group["weight_decay"])

# Update momentum buffers (EMA of gradients)
momentum_A.lerp_(grad_A, 1 - group["momentum_beta"])
momentum_B.lerp_(grad_B, 1 - group["momentum_beta"])

# Orthogonalize momentum using Newton-Schulz
with utils.fp32_matmul_precision(self.fp32_matmul_prec):
orth_momentum_A = self.scaled_orthogonalize_fn(momentum_A)
orth_momentum_B = self.scaled_orthogonalize_fn(momentum_B)

# Estimate spectral radius using power iteration (Algorithm 3)
sigma_A, u_A = self._power_iteration(factor_A, u_A, self.num_power_iter)
sigma_B, u_B = self._power_iteration(factor_B, u_B, self.num_power_iter)

# Update power iteration vectors
state["u_A"] = u_A
state["u_B"] = u_B

# Compute scaled learning rate
scaled_lr = group["lr"] / (sigma_A + sigma_B + 1.0)

# Update low-rank factors
factor_A.add_(orth_momentum_A, alpha=-scaled_lr)
factor_B.add_(orth_momentum_B, alpha=-scaled_lr)

# Reconstruct full weight matrix: W = A @ B^T
p.copy_(factor_A @ factor_B.mT)

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I am guessing this reconstruction is for the compatibility with the rest of the library. Otherwise the whole implementation looks correct.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I leave the weights of the model as a single matrix, but do the low-rank decomposition as optimizer states (rather than having the low-rank factored weights as 2 separate matrices in the model, which make it harder to access them inside the optimizer). This is functionally identical but makes the SW easier to use

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I agree


return loss

def _initialize_state(self, p: torch.Tensor, state: dict[str, torch.Tensor]) -> None:
"""Initialize low-rank factors and state for a parameter.

Args:
p: The parameter tensor (shape: m × n)
state: The state dictionary for this parameter
"""
m, n = p.shape
r = min(self.rank, m, n) # Ensure rank doesn't exceed dimensions

# Initialize A and B using SVD of the parameter
# This provides a good initialization close to the original weights
with torch.no_grad():
U, S, Vh = torch.linalg.svd(p.float(), full_matrices=False)
# Keep only top r singular values/vectors
sqrt_S = torch.sqrt(S[:r])
factor_A = (U[:, :r] * sqrt_S).to(p.dtype)
factor_B = (Vh[:r, :].mT * sqrt_S).to(p.dtype)

state["factor_A"] = factor_A.clone()
state["factor_B"] = factor_B.clone()
state["momentum_A"] = torch.zeros_like(factor_A)
state["momentum_B"] = torch.zeros_like(factor_B)

# Initialize power iteration vectors (normalized random vectors)
u_A = torch.randn(m, dtype=p.dtype, device=p.device)
u_A = u_A / u_A.norm()
u_B = torch.randn(n, dtype=p.dtype, device=p.device)
u_B = u_B / u_B.norm()

state["u_A"] = u_A
state["u_B"] = u_B

def _power_iteration(
self, matrix: torch.Tensor, u: torch.Tensor, num_iters: int
) -> tuple[torch.Tensor, torch.Tensor]:
"""Estimate the largest singular value using power iteration.

Args:
matrix: The matrix to estimate largest singular value for (shape: p × q)
u: The current approximation of the dominant left singular vector
num_iters: Number of power iteration steps

Returns:
Tuple of (largest singular value, updated_u)
"""
# power_iteration returns (sigma, u, v) but Spectron only needs sigma and u (left singular vector)
sigma, u, _v = power_iteration(matrix, u, k=num_iters)
return sigma, u
51 changes: 51 additions & 0 deletions emerging_optimizers/utils/eig.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,9 +22,60 @@
"met_approx_eigvals_criteria",
"conjugate",
"orthogonal_iteration",
"power_iteration",
]


def power_iteration(
W: torch.Tensor,
u: torch.Tensor,
k: int = 1,
eps: float = 1e-8,
) -> tuple[torch.Tensor, torch.Tensor, torch.Tensor]:
"""Approximate largest singular value and left/right singular vectors using power iteration.

Implements Algorithm 3 from the Spectron paper (https://arxiv.org/abs/2602.12429). This method iteratively refines
estimates of the dominant singular value and corresponding left and right singular vectors
of a matrix W.

Args:
W: Matrix of shape (p, q) to analyze
u: Initial left singular vector of shape (p,), should be normalized
k: Number of power iteration steps. Default: 1
eps: Small constant for numerical stability. Default: 1e-8

Returns:
Tuple of (sigma, u, v) where:
- sigma: Approximation of the largest singular value (scalar tensor)
- u: Updated left singular vector of shape (p,)
- v: Updated right singular vector of shape (q,)
"""
# Ensure initial normalization
u = u / u.norm(p=2).clamp_min(eps)

# Power iteration loop
for _ in range(k):
# v ← W^T u (right vector)
v = W.mT @ u

# v ← v / ||v||_2 (normalize right vector)
v = v / v.norm(p=2).clamp_min(eps)

# u ← W v (left vector)
u = W @ v

# u ← u / ||u||_2 (normalize left vector)
u = u / u.norm(p=2).clamp_min(eps)

# σ ← u^T W v (Rayleigh quotient approximation)
v = W.mT @ u
v = v / v.norm(p=2).clamp_min(eps)
sigma = u @ (W @ v)

# Return σ, u, and v
return sigma, u, v


def eigh_with_fallback(
x: Tensor,
force_double: bool = False,
Expand Down
1 change: 1 addition & 0 deletions tests/ci/L0_Tests_CPU.sh
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,7 @@ error=0
torchrun --nproc_per_node=8 --no-python coverage run -p tests/test_distributed_muon_utils_cpu.py -v -2 || error=1
torchrun --nproc_per_node=4 --no-python coverage run -p tests/test_distributed_muon_utils_cpu.py -v -2 || error=1
coverage run -p --source=emerging_optimizers tests/test_scalar_optimizers.py --device=cpu -v -2 || error=1
coverage run -p --source=emerging_optimizers tests/test_spectron.py --device=cpu -v -2 || error=1
coverage run -p --source=emerging_optimizers tests/test_procrustes_step.py --device=cpu -v -2 || error=1

exit "${error}"
1 change: 1 addition & 0 deletions tests/ci/L0_Tests_GPU.sh
Original file line number Diff line number Diff line change
Expand Up @@ -19,6 +19,7 @@ error=0
coverage run -p --source=emerging_optimizers tests/test_muon_utils.py -v -2 || error=1
coverage run -p --source=emerging_optimizers tests/test_adaptive_muon.py -v -2 || error=1
coverage run -p --source=emerging_optimizers tests/test_orthogonalized_optimizer.py -v -2 || error=1
coverage run -p --source=emerging_optimizers tests/test_spectron.py --device=cuda -v -2 || error=1
coverage run -p --source=emerging_optimizers tests/test_soap_utils.py -v -2 || error=1
coverage run -p --source=emerging_optimizers tests/test_soap.py -v -2 || error=1
coverage run -p --source=emerging_optimizers tests/soap_mnist_test.py -v -2 || error=1
Expand Down
Loading
Loading