Skip to content

Commit d1e462c

Browse files
Add PSGD-Kron's helper functions (#37)
* cleaned up lower bound function for spectral norm based on Xi-lin's latest code Signed-off-by: mikail <[email protected]>
1 parent fb1add8 commit d1e462c

File tree

8 files changed

+941
-2
lines changed

8 files changed

+941
-2
lines changed
Lines changed: 62 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,62 @@
1+
# SPDX-FileCopyrightText: Copyright (c) 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
2+
# SPDX-License-Identifier: Apache-2.0
3+
#
4+
# Licensed under the Apache License, Version 2.0 (the "License");
5+
# you may not use this file except in compliance with the License.
6+
# You may obtain a copy of the License at
7+
#
8+
# http://www.apache.org/licenses/LICENSE-2.0
9+
#
10+
# Unless required by applicable law or agreed to in writing, software
11+
# distributed under the License is distributed on an "AS IS" BASIS,
12+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13+
# See the License for the specific language governing permissions and
14+
# limitations under the License.
15+
import torch
16+
17+
import emerging_optimizers.utils as utils
18+
from emerging_optimizers.psgd.psgd_utils import norm_lower_bound_skew
19+
20+
21+
__all__ = [
22+
"procrustes_step",
23+
]
24+
25+
26+
@torch.compile # type: ignore[misc]
27+
def procrustes_step(Q: torch.Tensor, max_step_size: float = 0.125, eps: float = 1e-8) -> torch.Tensor:
28+
r"""One step of an online solver for the orthogonal Procrustes problem.
29+
30+
The orthogonal Procrustes problem is :math:`\min_U \| U Q - I \|_F` s.t. :math:`U^H U = I`
31+
by rotating Q as :math:`\exp(a R) Q`, where :math:`R = Q^H - Q` is the generator and :math:`\|a R\| < 1`.
32+
33+
`max_step_size` should be less than :math:`1/4` as we only expand :math:`\exp(a R)` to its 2nd order term.
34+
35+
This method is a second order expansion of a Lie algebra parametrized rotation that
36+
uses a simple approximate line search to find the optimal step size, from Xi-Lin Li.
37+
38+
Args:
39+
Q: Tensor of shape (n, n), general square matrix to orthogonalize.
40+
max_step_size: Maximum step size for the line search. Default is 1/8. (0.125)
41+
eps: Small number for numerical stability.
42+
"""
43+
# Note: this function is written in fp32 to avoid numerical instability while computing the taylor expansion of the exponential map
44+
with utils.fp32_matmul_precision("highest"):
45+
R = Q.T - Q
46+
R /= torch.clamp(norm_lower_bound_skew(R), min=eps)
47+
RQ = R @ Q
48+
# trace of RQ is always positive,
49+
# since tr(RQ) = ⟨R, Q⟩_F = ⟨Q^T - Q, Q⟩_F = ||Q||_F^2 - ⟨Q, Q⟩_F = ||Q||_F^2 - tr(Q^T Q) ≥ 0
50+
tr_RQ = torch.trace(RQ)
51+
RRQ = R @ RQ
52+
tr_RRQ = torch.trace(RRQ)
53+
# clip step size to max_step_size, based on a 2nd order expansion.
54+
_step_size = torch.clamp(-tr_RQ / tr_RRQ, min=0, max=max_step_size)
55+
# If tr_RRQ >= 0, the quadratic approximation is not concave, we fallback to max_step_size.
56+
step_size = torch.where(tr_RRQ < 0, _step_size, max_step_size)
57+
# rotate Q as exp(a R) Q ~ (I + a R + a^2 R^2/2) Q with an optimal step size by line search
58+
# for 2nd order expansion, only expand exp(a R) to its 2nd term.
59+
# Q += step_size * (RQ + 0.5 * step_size * RRQ)
60+
Q = torch.add(Q, torch.add(RQ, RRQ, alpha=0.5 * step_size), alpha=step_size)
61+
62+
return Q
Lines changed: 151 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,151 @@
1+
# SPDX-FileCopyrightText: Copyright (c) 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
2+
# SPDX-License-Identifier: Apache-2.0
3+
#
4+
# Licensed under the Apache License, Version 2.0 (the "License");
5+
# you may not use this file except in compliance with the License.
6+
# You may obtain a copy of the License at
7+
#
8+
# http://www.apache.org/licenses/LICENSE-2.0
9+
#
10+
# Unless required by applicable law or agreed to in writing, software
11+
# distributed under the License is distributed on an "AS IS" BASIS,
12+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13+
# See the License for the specific language governing permissions and
14+
# limitations under the License.
15+
from typing import List
16+
17+
import torch
18+
19+
20+
__all__ = [
21+
"partial_contraction",
22+
"apply_kronecker_factors",
23+
"apply_preconditioner",
24+
]
25+
26+
27+
@torch.compile # type: ignore[misc]
28+
def partial_contraction(G1: torch.Tensor, G2: torch.Tensor, axis: int) -> torch.Tensor:
29+
"""Compute the partial contraction of G1 and G2 along axis `axis`.
30+
This is the contraction of the two tensors, but with all axes except `axis` contracted.
31+
32+
Args:
33+
G1: Tensor of shape (d_0, d_1, ..., d_{axis-1}, d_{axis}, d_{axis+1}, ..., d_N)
34+
G2: Tensor of shape (d_0, d_1, ..., d_{axis-1}, d_{axis}, d_{axis+1}, ..., d_N)
35+
axis: int, the axis to contract along
36+
37+
Returns:
38+
Tensor of shape (d_{axis}, d_{axis})
39+
"""
40+
# dims_to_contract = all dims except `axis`
41+
dims = list(range(G1.dim()))
42+
dims.pop(axis)
43+
# contraction is symmetric and has shape (d_{axis}, d_{axis})
44+
return torch.tensordot(G1, G2, dims=(dims, dims))
45+
46+
47+
@torch.compile # type: ignore[misc]
48+
def apply_kronecker_factors(Q_list: List[torch.Tensor], X: torch.Tensor) -> torch.Tensor:
49+
"""Apply all Kronecker factors once to tensor :math:`X`, each to its corresponding dimension.
50+
51+
This applies each :math:`Q` factor once, for example in 2D case: :math:`Q_1 X Q_2^T`.
52+
53+
Args:
54+
Q_list: List of :math:`Q` (the upper-triangular Kronecker factors), each of shape `(d_i, d_i)` or `(d_i,)`.
55+
X: Tensor of shape `(d_0, d_1, ..., d_N)`.
56+
57+
Returns:
58+
Tensor of shape `(d_0, d_1, ..., d_N)`.
59+
"""
60+
if len(Q_list) != X.dim():
61+
raise ValueError(
62+
f"Number of Kronecker factors {len(Q_list)} must match the number of dimensions of X {X.dim()}"
63+
)
64+
65+
Y = X
66+
for i in range(len(Q_list)):
67+
Y = _apply_single_kronecker_factor(Q_list, Y, i)
68+
return Y
69+
70+
71+
@torch.compile # type: ignore[misc]
72+
def apply_preconditioner(Q_list: List[torch.Tensor], X: torch.Tensor) -> torch.Tensor:
73+
"""Apply the full PSGD preconditioner to X.
74+
75+
This is the full Kronecker product of PSGD's kronecker factors Q^T Q, applied to X.
76+
77+
:math:`P X = (Q_1^T Q_1) X (Q_2^T Q_2)`
78+
79+
This applies each factor followed by its transpose for the full preconditioner effect.
80+
81+
Args:
82+
Q_list: List of :math:`Q` (the Kronecker factors), each of shape `(d_i, d_i)` or `(d_i,)`.
83+
X: Tensor of shape `(d_0, d_1, ..., d_N)`.
84+
85+
Returns:
86+
Tensor of shape `(d_0, d_1, ..., d_N)`.
87+
"""
88+
# Apply Q first, then Q.T to get Q^T @ Q
89+
Px = apply_kronecker_factors(Q_list, X)
90+
Px = apply_kronecker_factors([q if q.dim() == 1 else q.T for q in Q_list], Px)
91+
return Px
92+
93+
94+
def _dim_n_mul_and_permute(X: torch.Tensor, M: torch.Tensor, contract_dim: int) -> torch.Tensor:
95+
"""Multiply tensor X along axis `contract_dim` by 2D matrix M.
96+
97+
Helper function for `_apply_single_kronecker_factor`.
98+
If M is (d_out, d_in) we contract M’s second index with X’s `contract_dim` index.
99+
`torch.tensordot` is used to contract the two tensors, and then the result is permuted to move the new axis 0 to position `contract_dim`.
100+
Returns a new tensor of the same rank, but with size[contract_dim] replaced by d_out.
101+
Note that d_{contract_dim} == d_in.
102+
103+
Args:
104+
X: Tensor of shape (d_0, d_1, ..., d_{contract_dim-1}, d_{contract_dim}, d_{contract_dim+1}, ..., d_N)
105+
M: Tensor of shape (d_out, d_in)
106+
contract_dim: int, the dimension to contract with M, with d_{contract_dim} == d_in
107+
108+
Returns:
109+
Tensor of shape (d_0, d_1, ..., d_{contract_dim-1}, d_out, d_{contract_dim+1}, ..., d_N)
110+
111+
Examples
112+
--------
113+
>>> X = torch.randn(2, 3, 6)
114+
>>> M = torch.randn(5, 6)
115+
>>> contract_dim = 2
116+
>>> result = _dim_n_mul_and_permute(X, M, contract_dim)
117+
>>> print(result.shape)
118+
torch.Size([2, 3, 5])
119+
120+
"""
121+
if X.shape[contract_dim] != M.shape[1]:
122+
raise ValueError(
123+
f"Shape mismatch: X.shape[{contract_dim}] = {X.shape[contract_dim]}, M.shape[1] = {M.shape[1]}"
124+
)
125+
# Contract M's 2nd dim (idx=1) with X's `contract_dim` dim
126+
Y = torch.tensordot(M, X, dims=([1], [contract_dim]))
127+
# Y now has shape (d_out, d_0, …, d_{contract_dim-1}, d_{contract_dim+1}, …).
128+
# We want to move that new axis 0 back to position `contract_dim`, due to `torch.tensordot`.
129+
nd = X.dim()
130+
perm = list(range(1, contract_dim + 1)) + [0] + list(range(contract_dim + 1, nd))
131+
return Y.permute(perm)
132+
133+
134+
@torch.compile # type: ignore[misc]
135+
def _apply_single_kronecker_factor(Q_list: List[torch.Tensor], X: torch.Tensor, axis: int) -> torch.Tensor:
136+
"""Apply a single Kronecker factor Q to X at dimension `axis`. Helper function for apply_kronecker_factors.
137+
138+
If Q is a vector, we multiply X by Q.
139+
If Q is a matrix, we contract Q's second index with X's `axis` index.
140+
141+
Args:
142+
Q_list: List of Q (e.g. the Kronecker factors).
143+
X: Tensor of shape (d_0, d_1, ..., d_{axis-1}, d_{axis+1}, ..., d_N)
144+
"""
145+
Q = Q_list[axis]
146+
if Q.dim() == 1:
147+
shape = [1] * X.dim()
148+
shape[axis] = Q.size(0)
149+
return X * Q.view(shape)
150+
151+
return _dim_n_mul_and_permute(X, Q, contract_dim=axis)
Lines changed: 176 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,176 @@
1+
# SPDX-FileCopyrightText: Copyright (c) 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
2+
# SPDX-License-Identifier: Apache-2.0
3+
#
4+
# Licensed under the Apache License, Version 2.0 (the "License");
5+
# you may not use this file except in compliance with the License.
6+
# You may obtain a copy of the License at
7+
#
8+
# http://www.apache.org/licenses/LICENSE-2.0
9+
#
10+
# Unless required by applicable law or agreed to in writing, software
11+
# distributed under the License is distributed on an "AS IS" BASIS,
12+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13+
# See the License for the specific language governing permissions and
14+
# limitations under the License.
15+
from typing import List
16+
17+
import torch
18+
19+
20+
__all__ = [
21+
"uniformize_q_in_place",
22+
"norm_lower_bound_spd",
23+
"norm_lower_bound_skew",
24+
]
25+
26+
27+
@torch.compile # type: ignore[misc]
28+
def uniformize_q_in_place(Q_list: List[torch.Tensor]) -> None:
29+
"""Balance the dynamic ranges of kronecker factors in place to prevent numerical underflow or overflow.
30+
31+
Each tensor in `Q_list` is rescaled so that its maximum absolute entry
32+
becomes the geometric mean of all factors original maxima. This preserves
33+
the overall product of norms (and thus the scale of the Kronecker product)
34+
while avoiding numerical underflow or overflow when factors have widely
35+
differing magnitudes.
36+
37+
Given tensors :math:`Q_1, Q_2, \\ldots, Q_n`:
38+
39+
1. Compute max-absolute norms: :math:`\\|Q_i\\|_\\infty = \\max(|Q_i|)` for :math:`i = 1, \\ldots, n`
40+
2. Compute geometric mean: :math:`g = \\left(\\prod_{i=1}^{n} \\|Q_i\\|_\\infty \\right)^{1/n}`
41+
3. Rescale each tensor: :math:`Q_i \\leftarrow Q_i \\cdot \\frac{g}{\\|Q_i\\|_\\infty}`
42+
43+
This ensures :math:`\\|Q_i\\|_\\infty = g` for all :math:`i`, while preserving the norm of
44+
the Kronecker product :math:`Q_1 \\otimes Q_2 \\otimes \\cdots \\otimes Q_n`.
45+
46+
Args:
47+
Q_list: List of Q (e.g. the Kronecker factors), each tensor will be modified in place.
48+
49+
Returns:
50+
None
51+
52+
"""
53+
if not Q_list:
54+
raise TypeError("Q_list cannot be empty.")
55+
56+
order = len(Q_list)
57+
if order == 1:
58+
# with a single factor, no balancing is needed
59+
return
60+
61+
# Compute max-abs norm of each factor
62+
norms = [torch.max(torch.abs(Q)) for Q in Q_list]
63+
64+
# Compute geometric mean of those norms
65+
gmean = torch.prod(torch.stack(norms)) ** (1.0 / order)
66+
67+
# Rescale each factor so its max‐abs entry == geometric mean
68+
for Q, norm in zip(Q_list, norms, strict=True):
69+
Q.mul_(gmean / norm)
70+
71+
72+
@torch.compile # type: ignore[misc]
73+
def norm_lower_bound_spd(A: torch.Tensor, k: int = 4, half_iters: int = 2, eps: float = 1e-8) -> torch.Tensor:
74+
r"""A cheap lower bound for the spectral norm of a symmetric positive definite matrix.
75+
76+
77+
Args:
78+
A: Tensor of shape :math:`(n, n)`, symmetric positive definite.
79+
k: Dimension of the subspace.
80+
half_iters: Half of the number of subspace iterations.
81+
eps: Small number for numerical stability.
82+
83+
Returns:
84+
A scalar giving a lower bound on :math:`\\|A\\|_2`.
85+
"""
86+
87+
# Compute scaling factor from the largest diagonal entry to prevent overflow/underflow
88+
scale = torch.clamp(A.diagonal().amax(), min=eps)
89+
A = A / scale
90+
91+
bound_unnormalized = _subspace_iteration_bound(A, k=k, half_iters=half_iters, eps=eps)
92+
93+
return scale * bound_unnormalized
94+
95+
96+
@torch.compile # type: ignore[misc]
97+
def norm_lower_bound_skew(A: torch.Tensor, k: int = 32, half_iters: int = 2, eps: float = 1e-8) -> torch.Tensor:
98+
"""A cheap lower bound on the spectral norm (largest eigenvalue) of skew-symmetric matrix.
99+
100+
101+
Note: For skew-symmetric matrices, all diagonal entries are zero and :math:`A^T = -A`.
102+
From Xi-Lin Li.
103+
104+
Args:
105+
A: Tensor of shape :math:`(n, n)`, skew-symmetric.
106+
k: Dimension of the subspace. Suggested values: 128 for bfloat16, 32 for float32, 4 for float64.
107+
half_iters: Half of the number of subspace iterations.
108+
eps: Small number for numerical stability.
109+
110+
Returns:
111+
A scalar Tensor giving a lower bound on :math:`\\|A\\|_2`.
112+
113+
"""
114+
115+
# Compute scaling factor from the max absolute value to prevent overflow/underflow
116+
scale = torch.clamp(A.abs().amax(), min=eps)
117+
A = A / scale
118+
119+
bound_unnormalized = _subspace_iteration_bound(A, k=k, half_iters=half_iters, eps=eps)
120+
121+
return scale * bound_unnormalized
122+
123+
124+
@torch.compile # type: ignore[misc]
125+
def _subspace_iteration_bound(
126+
A: torch.Tensor,
127+
k: int = 32,
128+
half_iters: int = 2,
129+
eps: float = 1e-8,
130+
) -> torch.Tensor:
131+
"""A helper function for subspace iteration to estimate spectral norm bounds.
132+
133+
Uses numerically stable subspace iteration with a random initialization that aligns with the
134+
largest row of A to approximate the dominant eigenspace. This is more robust than simple
135+
power iteration, especially for large matrices with very low rank. From Xi-Lin Li.
136+
137+
The algorithm:
138+
1. Normalize :math:`A` by its largest absolute entry to avoid overflow.
139+
2. Find the row :math:`j` of :math:`A_{\\text{scaled}}` with the largest 2-norm.
140+
3. Initialize a :math:`k \\times n` subspace matrix :math:`V` with random vectors aligned to :math:`A[j]`.
141+
4. Perform subspace iteration for `half_iters` steps: :math:`V \\leftarrow V \\cdot A_{\\text{scaled}}`.
142+
5. Estimate the norm as the maximum 2-norm among the k vectors, then rescale.
143+
144+
Args:
145+
A: Input matrix, already normalized by caller.
146+
k: Dimension of the subspace (number of random vectors).
147+
half_iters: Number of half-iterations (each applies A twice).
148+
eps: Smallest number for numerical stability.
149+
150+
Returns:
151+
Maximum vector norm from the final subspace iteration (unnormalized).
152+
"""
153+
154+
# Initialize random subspace matrix V of shape (k, n)
155+
V = torch.randn(k, A.shape[1], dtype=A.dtype, device=A.device)
156+
157+
# Find the row index with the largest 2-norm to initialize our subspace
158+
# This helps the algorithm converge faster to the dominant eigenspace
159+
dominant_row_idx = torch.argmax(torch.linalg.vector_norm(A, dim=1))
160+
# Rotate the random vectors to align with the dominant row A[dominant_row_idx]
161+
# This initialization trick makes the subspace iteration more robust for low-rank matrices
162+
dominant_row = A[dominant_row_idx]
163+
alignment = torch.sign(torch.sum(dominant_row * V, dim=1, keepdim=True))
164+
165+
V = dominant_row + alignment * V
166+
167+
# Perform subspace iteration
168+
for _ in range(half_iters):
169+
V = V @ A
170+
# Normalize each row of V to prevent exponential growth/decay
171+
V /= torch.linalg.vector_norm(V, dim=1, keepdim=True) + eps
172+
# Apply A again (V approximates the dominant eigenspace of A^2)
173+
V = V @ A
174+
175+
# Return the maximum 2-norm among the k vectors
176+
return torch.amax(torch.linalg.vector_norm(V, dim=1))

tests/ci/L0_Tests_CPU.sh

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -16,4 +16,4 @@ set -o pipefail
1616
torchrun --nproc_per_node=8 --no-python coverage run -p tests/test_distributed_muon_utils_cpu.py
1717
torchrun --nproc_per_node=4 --no-python coverage run -p tests/test_distributed_muon_utils_cpu.py
1818
coverage run -p --source=emerging_optimizers tests/test_scalar_optimizers.py --device=cpu
19-
19+
coverage run -p --source=emerging_optimizers tests/test_procrustes_step.py --device=cpu

0 commit comments

Comments
 (0)