Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
34 commits
Select commit Hold shift + click to select a range
276b68c
cleaned up lower bound function for spectral norm based on Xi-lin's l…
mkhona-nvidia Oct 1, 2025
cf134b8
added tests for contraction functions
mkhona-nvidia Oct 1, 2025
b43a123
added tests for psgd's utils
mkhona-nvidia Oct 1, 2025
7a9452e
reduced procrustes step to a single functiona and wrote docs
mkhona-nvidia Oct 1, 2025
b5251da
added tests for procrustes step
mkhona-nvidia Oct 1, 2025
b3b4cd1
removed iterations from procrustes step
mkhona-nvidia Oct 1, 2025
d09b447
improved test for procrustes
mkhona-nvidia Oct 1, 2025
5b2cf11
improved psgd util test
mkhona-nvidia Oct 1, 2025
51187b5
improved the subspace iteration bound
mkhona-nvidia Oct 1, 2025
af84f46
formatting
mkhona-nvidia Oct 1, 2025
01740fa
removed solve triangular right since we are not building psgd trisolve
mkhona-nvidia Oct 3, 2025
f4dafee
removed solve_triangular_right from function list
mkhona-nvidia Oct 3, 2025
b2b1f04
improved procrustes step, removed solve triangular right and formatting
mkhona-nvidia Oct 3, 2025
b10d1d5
removed extra seed
mkhona-nvidia Oct 3, 2025
8e6ac2e
improved code of procrustes step
mkhona-nvidia Oct 3, 2025
f819fbb
comment to explain why trace is always positive
mkhona-nvidia Oct 3, 2025
0acc590
added type hints for psgd
mkhona-nvidia Oct 3, 2025
79b3b00
undoing doc change
mkhona-nvidia Oct 4, 2025
159c548
changed name in apply_kronecker_factors to Px instead of Y to be cons…
mkhona-nvidia Oct 4, 2025
0f1505b
Update eig utils (#35)
skyw Oct 2, 2025
a69a97b
added license file to test
mkhona-nvidia Oct 4, 2025
fe0d8ef
removed trailing whitespace
mkhona-nvidia Oct 7, 2025
1e9ba21
addressed PR comments
mkhona-nvidia Oct 7, 2025
852012b
added torch compile
mkhona-nvidia Oct 7, 2025
33584b7
added psgd tests to ci
mkhona-nvidia Oct 7, 2025
da88d06
moved procrustes test frm cuda to cpu
mkhona-nvidia Oct 7, 2025
9348d02
revert docs change
mkhona-nvidia Oct 7, 2025
397c1af
removed extra doc from soap
mkhona-nvidia Oct 7, 2025
15112ce
fixed tests and relaxed some tolerances
mkhona-nvidia Oct 8, 2025
848fa9b
changed from normalization to a max(norm, eps) re:PR discussions
mkhona-nvidia Oct 8, 2025
e5b5ba5
replaced max with clamp
mkhona-nvidia Oct 8, 2025
2fe3ce3
reduced tolerance for identity matrix test
mkhona-nvidia Oct 8, 2025
e0388ff
improved test for random matrices for skew bound
mkhona-nvidia Oct 8, 2025
5006bb9
reduced tolerance for making test pass on GPU
mkhona-nvidia Oct 8, 2025
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
62 changes: 62 additions & 0 deletions emerging_optimizers/psgd/procrustes_step.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,62 @@
# SPDX-FileCopyrightText: Copyright (c) 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
# SPDX-License-Identifier: Apache-2.0
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import torch

import emerging_optimizers.utils as utils
from emerging_optimizers.psgd.psgd_utils import norm_lower_bound_skew


__all__ = [
"procrustes_step",
]


@torch.compile # type: ignore[misc]
def procrustes_step(Q: torch.Tensor, max_step_size: float = 0.125, eps: float = 1e-8) -> torch.Tensor:
r"""One step of an online solver for the orthogonal Procrustes problem.

The orthogonal Procrustes problem is :math:`\min_U \| U Q - I \|_F` s.t. :math:`U^H U = I`
by rotating Q as :math:`\exp(a R) Q`, where :math:`R = Q^H - Q` is the generator and :math:`\|a R\| < 1`.

`max_step_size` should be less than :math:`1/4` as we only expand :math:`\exp(a R)` to its 2nd order term.

This method is a second order expansion of a Lie algebra parametrized rotation that
uses a simple approximate line search to find the optimal step size, from Xi-Lin Li.

Args:
Q: Tensor of shape (n, n), general square matrix to orthogonalize.
max_step_size: Maximum step size for the line search. Default is 1/8. (0.125)
eps: Small number for numerical stability.
"""
# Note: this function is written in fp32 to avoid numerical instability while computing the taylor expansion of the exponential map
with utils.fp32_matmul_precision("highest"):
R = Q.T - Q
R /= torch.clamp(norm_lower_bound_skew(R), min=eps)
RQ = R @ Q
# trace of RQ is always positive,
# since tr(RQ) = ⟨R, Q⟩_F = ⟨Q^T - Q, Q⟩_F = ||Q||_F^2 - ⟨Q, Q⟩_F = ||Q||_F^2 - tr(Q^T Q) ≥ 0
tr_RQ = torch.trace(RQ)
RRQ = R @ RQ
tr_RRQ = torch.trace(RRQ)
# clip step size to max_step_size, based on a 2nd order expansion.
_step_size = torch.clamp(-tr_RQ / tr_RRQ, min=0, max=max_step_size)
# If tr_RRQ >= 0, the quadratic approximation is not concave, we fallback to max_step_size.
step_size = torch.where(tr_RRQ < 0, _step_size, max_step_size)
# rotate Q as exp(a R) Q ~ (I + a R + a^2 R^2/2) Q with an optimal step size by line search
# for 2nd order expansion, only expand exp(a R) to its 2nd term.
# Q += step_size * (RQ + 0.5 * step_size * RRQ)
Q = torch.add(Q, torch.add(RQ, RRQ, alpha=0.5 * step_size), alpha=step_size)

return Q
151 changes: 151 additions & 0 deletions emerging_optimizers/psgd/psgd_kron_contractions.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,151 @@
# SPDX-FileCopyrightText: Copyright (c) 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
# SPDX-License-Identifier: Apache-2.0
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from typing import List

import torch


__all__ = [
"partial_contraction",
"apply_kronecker_factors",
"apply_preconditioner",
]


@torch.compile # type: ignore[misc]
def partial_contraction(G1: torch.Tensor, G2: torch.Tensor, axis: int) -> torch.Tensor:
"""Compute the partial contraction of G1 and G2 along axis `axis`.
This is the contraction of the two tensors, but with all axes except `axis` contracted.

Args:
G1: Tensor of shape (d_0, d_1, ..., d_{axis-1}, d_{axis}, d_{axis+1}, ..., d_N)
G2: Tensor of shape (d_0, d_1, ..., d_{axis-1}, d_{axis}, d_{axis+1}, ..., d_N)
axis: int, the axis to contract along
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Suggestion: add an example, tensordot is not very easy to understand.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

We have tensordot documentation in the docs


Returns:
Tensor of shape (d_{axis}, d_{axis})
"""
# dims_to_contract = all dims except `axis`
dims = list(range(G1.dim()))
dims.pop(axis)
# contraction is symmetric and has shape (d_{axis}, d_{axis})
return torch.tensordot(G1, G2, dims=(dims, dims))


@torch.compile # type: ignore[misc]
def apply_kronecker_factors(Q_list: List[torch.Tensor], X: torch.Tensor) -> torch.Tensor:
"""Apply all Kronecker factors once to tensor :math:`X`, each to its corresponding dimension.

This applies each :math:`Q` factor once, for example in 2D case: :math:`Q_1 X Q_2^T`.

Args:
Q_list: List of :math:`Q` (the upper-triangular Kronecker factors), each of shape `(d_i, d_i)` or `(d_i,)`.
X: Tensor of shape `(d_0, d_1, ..., d_N)`.

Returns:
Tensor of shape `(d_0, d_1, ..., d_N)`.
"""
if len(Q_list) != X.dim():
raise ValueError(
f"Number of Kronecker factors {len(Q_list)} must match the number of dimensions of X {X.dim()}"
)

Y = X
for i in range(len(Q_list)):
Y = _apply_single_kronecker_factor(Q_list, Y, i)
return Y


@torch.compile # type: ignore[misc]
def apply_preconditioner(Q_list: List[torch.Tensor], X: torch.Tensor) -> torch.Tensor:
"""Apply the full PSGD preconditioner to X.

This is the full Kronecker product of PSGD's kronecker factors Q^T Q, applied to X.

:math:`P X = (Q_1^T Q_1) X (Q_2^T Q_2)`

This applies each factor followed by its transpose for the full preconditioner effect.

Args:
Q_list: List of :math:`Q` (the Kronecker factors), each of shape `(d_i, d_i)` or `(d_i,)`.
X: Tensor of shape `(d_0, d_1, ..., d_N)`.

Returns:
Tensor of shape `(d_0, d_1, ..., d_N)`.
"""
# Apply Q first, then Q.T to get Q^T @ Q
Px = apply_kronecker_factors(Q_list, X)
Px = apply_kronecker_factors([q if q.dim() == 1 else q.T for q in Q_list], Px)
return Px


def _dim_n_mul_and_permute(X: torch.Tensor, M: torch.Tensor, contract_dim: int) -> torch.Tensor:
"""Multiply tensor X along axis `contract_dim` by 2D matrix M.

Helper function for `_apply_single_kronecker_factor`.
If M is (d_out, d_in) we contract M’s second index with X’s `contract_dim` index.
`torch.tensordot` is used to contract the two tensors, and then the result is permuted to move the new axis 0 to position `contract_dim`.
Returns a new tensor of the same rank, but with size[contract_dim] replaced by d_out.
Note that d_{contract_dim} == d_in.

Args:
X: Tensor of shape (d_0, d_1, ..., d_{contract_dim-1}, d_{contract_dim}, d_{contract_dim+1}, ..., d_N)
M: Tensor of shape (d_out, d_in)
contract_dim: int, the dimension to contract with M, with d_{contract_dim} == d_in

Returns:
Tensor of shape (d_0, d_1, ..., d_{contract_dim-1}, d_out, d_{contract_dim+1}, ..., d_N)

Examples
--------
>>> X = torch.randn(2, 3, 6)
>>> M = torch.randn(5, 6)
>>> contract_dim = 2
>>> result = _dim_n_mul_and_permute(X, M, contract_dim)
>>> print(result.shape)
torch.Size([2, 3, 5])

"""
if X.shape[contract_dim] != M.shape[1]:
raise ValueError(
f"Shape mismatch: X.shape[{contract_dim}] = {X.shape[contract_dim]}, M.shape[1] = {M.shape[1]}"
)
# Contract M's 2nd dim (idx=1) with X's `contract_dim` dim
Y = torch.tensordot(M, X, dims=([1], [contract_dim]))
# Y now has shape (d_out, d_0, …, d_{contract_dim-1}, d_{contract_dim+1}, …).
# We want to move that new axis 0 back to position `contract_dim`, due to `torch.tensordot`.
nd = X.dim()
perm = list(range(1, contract_dim + 1)) + [0] + list(range(contract_dim + 1, nd))
return Y.permute(perm)


@torch.compile # type: ignore[misc]
def _apply_single_kronecker_factor(Q_list: List[torch.Tensor], X: torch.Tensor, axis: int) -> torch.Tensor:
"""Apply a single Kronecker factor Q to X at dimension `axis`. Helper function for apply_kronecker_factors.

If Q is a vector, we multiply X by Q.
If Q is a matrix, we contract Q's second index with X's `axis` index.

Args:
Q_list: List of Q (e.g. the Kronecker factors).
X: Tensor of shape (d_0, d_1, ..., d_{axis-1}, d_{axis+1}, ..., d_N)
"""
Q = Q_list[axis]
if Q.dim() == 1:
shape = [1] * X.dim()
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Nit: Use squeeze and unsqueeze is more PyTorch.

shape[axis] = Q.size(0)
return X * Q.view(shape)

return _dim_n_mul_and_permute(X, Q, contract_dim=axis)
176 changes: 176 additions & 0 deletions emerging_optimizers/psgd/psgd_utils.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,176 @@
# SPDX-FileCopyrightText: Copyright (c) 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
# SPDX-License-Identifier: Apache-2.0
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from typing import List

import torch


__all__ = [
"uniformize_q_in_place",
"norm_lower_bound_spd",
"norm_lower_bound_skew",
]


@torch.compile # type: ignore[misc]
def uniformize_q_in_place(Q_list: List[torch.Tensor]) -> None:
"""Balance the dynamic ranges of kronecker factors in place to prevent numerical underflow or overflow.

Each tensor in `Q_list` is rescaled so that its maximum absolute entry
becomes the geometric mean of all factors original maxima. This preserves
the overall product of norms (and thus the scale of the Kronecker product)
while avoiding numerical underflow or overflow when factors have widely
differing magnitudes.

Given tensors :math:`Q_1, Q_2, \\ldots, Q_n`:

1. Compute max-absolute norms: :math:`\\|Q_i\\|_\\infty = \\max(|Q_i|)` for :math:`i = 1, \\ldots, n`
2. Compute geometric mean: :math:`g = \\left(\\prod_{i=1}^{n} \\|Q_i\\|_\\infty \\right)^{1/n}`
3. Rescale each tensor: :math:`Q_i \\leftarrow Q_i \\cdot \\frac{g}{\\|Q_i\\|_\\infty}`

This ensures :math:`\\|Q_i\\|_\\infty = g` for all :math:`i`, while preserving the norm of
the Kronecker product :math:`Q_1 \\otimes Q_2 \\otimes \\cdots \\otimes Q_n`.

Args:
Q_list: List of Q (e.g. the Kronecker factors), each tensor will be modified in place.

Returns:
None

"""
if not Q_list:
raise TypeError("Q_list cannot be empty.")

order = len(Q_list)
if order == 1:
# with a single factor, no balancing is needed
return

# Compute max-abs norm of each factor
norms = [torch.max(torch.abs(Q)) for Q in Q_list]

# Compute geometric mean of those norms
gmean = torch.prod(torch.stack(norms)) ** (1.0 / order)

# Rescale each factor so its max‐abs entry == geometric mean
for Q, norm in zip(Q_list, norms, strict=True):
Q.mul_(gmean / norm)
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Future improvement: I lot of code in this PR can/should be optimized. Sequence of operation is very inefficient. Although coming up with a better performing counterpart is some time non-trivial.



@torch.compile # type: ignore[misc]
def norm_lower_bound_spd(A: torch.Tensor, k: int = 4, half_iters: int = 2, eps: float = 1e-8) -> torch.Tensor:
r"""A cheap lower bound for the spectral norm of a symmetric positive definite matrix.


Args:
A: Tensor of shape :math:`(n, n)`, symmetric positive definite.
k: Dimension of the subspace.
half_iters: Half of the number of subspace iterations.
eps: Small number for numerical stability.

Returns:
A scalar giving a lower bound on :math:`\\|A\\|_2`.
"""

# Compute scaling factor from the largest diagonal entry to prevent overflow/underflow
scale = torch.clamp(A.diagonal().amax(), min=eps)
A = A / scale

bound_unnormalized = _subspace_iteration_bound(A, k=k, half_iters=half_iters, eps=eps)

return scale * bound_unnormalized


@torch.compile # type: ignore[misc]
def norm_lower_bound_skew(A: torch.Tensor, k: int = 32, half_iters: int = 2, eps: float = 1e-8) -> torch.Tensor:
"""A cheap lower bound on the spectral norm (largest eigenvalue) of skew-symmetric matrix.


Note: For skew-symmetric matrices, all diagonal entries are zero and :math:`A^T = -A`.
From Xi-Lin Li.

Args:
A: Tensor of shape :math:`(n, n)`, skew-symmetric.
k: Dimension of the subspace. Suggested values: 128 for bfloat16, 32 for float32, 4 for float64.
half_iters: Half of the number of subspace iterations.
eps: Small number for numerical stability.

Returns:
A scalar Tensor giving a lower bound on :math:`\\|A\\|_2`.

"""

# Compute scaling factor from the max absolute value to prevent overflow/underflow
scale = torch.clamp(A.abs().amax(), min=eps)
A = A / scale

bound_unnormalized = _subspace_iteration_bound(A, k=k, half_iters=half_iters, eps=eps)

return scale * bound_unnormalized


@torch.compile # type: ignore[misc]
def _subspace_iteration_bound(
A: torch.Tensor,
k: int = 32,
half_iters: int = 2,
eps: float = 1e-8,
) -> torch.Tensor:
"""A helper function for subspace iteration to estimate spectral norm bounds.

Uses numerically stable subspace iteration with a random initialization that aligns with the
largest row of A to approximate the dominant eigenspace. This is more robust than simple
power iteration, especially for large matrices with very low rank. From Xi-Lin Li.

The algorithm:
1. Normalize :math:`A` by its largest absolute entry to avoid overflow.
2. Find the row :math:`j` of :math:`A_{\\text{scaled}}` with the largest 2-norm.
3. Initialize a :math:`k \\times n` subspace matrix :math:`V` with random vectors aligned to :math:`A[j]`.
4. Perform subspace iteration for `half_iters` steps: :math:`V \\leftarrow V \\cdot A_{\\text{scaled}}`.
5. Estimate the norm as the maximum 2-norm among the k vectors, then rescale.

Args:
A: Input matrix, already normalized by caller.
k: Dimension of the subspace (number of random vectors).
half_iters: Number of half-iterations (each applies A twice).
eps: Smallest number for numerical stability.

Returns:
Maximum vector norm from the final subspace iteration (unnormalized).
"""

# Initialize random subspace matrix V of shape (k, n)
V = torch.randn(k, A.shape[1], dtype=A.dtype, device=A.device)

# Find the row index with the largest 2-norm to initialize our subspace
# This helps the algorithm converge faster to the dominant eigenspace
dominant_row_idx = torch.argmax(torch.linalg.vector_norm(A, dim=1))
# Rotate the random vectors to align with the dominant row A[dominant_row_idx]
# This initialization trick makes the subspace iteration more robust for low-rank matrices
dominant_row = A[dominant_row_idx]
alignment = torch.sign(torch.sum(dominant_row * V, dim=1, keepdim=True))

V = dominant_row + alignment * V

# Perform subspace iteration
for _ in range(half_iters):
V = V @ A
# Normalize each row of V to prevent exponential growth/decay
V /= torch.linalg.vector_norm(V, dim=1, keepdim=True) + eps
# Apply A again (V approximates the dominant eigenspace of A^2)
V = V @ A

# Return the maximum 2-norm among the k vectors
return torch.amax(torch.linalg.vector_norm(V, dim=1))
2 changes: 1 addition & 1 deletion tests/ci/L0_Tests_CPU.sh
Original file line number Diff line number Diff line change
Expand Up @@ -16,4 +16,4 @@ set -o pipefail
torchrun --nproc_per_node=8 --no-python coverage run -p tests/test_distributed_muon_utils_cpu.py
torchrun --nproc_per_node=4 --no-python coverage run -p tests/test_distributed_muon_utils_cpu.py
coverage run -p --source=emerging_optimizers tests/test_scalar_optimizers.py --device=cpu

coverage run -p --source=emerging_optimizers tests/test_procrustes_step.py --device=cpu
Loading