-
Notifications
You must be signed in to change notification settings - Fork 10
PSGD-Kron's helper functions #37
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Changes from all commits
276b68c
cf134b8
b43a123
7a9452e
b5251da
b3b4cd1
d09b447
5b2cf11
51187b5
af84f46
01740fa
f4dafee
b2b1f04
b10d1d5
8e6ac2e
f819fbb
0acc590
79b3b00
159c548
0f1505b
a69a97b
fe0d8ef
1e9ba21
852012b
33584b7
da88d06
9348d02
397c1af
15112ce
848fa9b
e5b5ba5
2fe3ce3
e0388ff
5006bb9
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
| Original file line number | Diff line number | Diff line change |
|---|---|---|
| @@ -0,0 +1,62 @@ | ||
| # SPDX-FileCopyrightText: Copyright (c) 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved. | ||
| # SPDX-License-Identifier: Apache-2.0 | ||
| # | ||
| # Licensed under the Apache License, Version 2.0 (the "License"); | ||
| # you may not use this file except in compliance with the License. | ||
| # You may obtain a copy of the License at | ||
| # | ||
| # http://www.apache.org/licenses/LICENSE-2.0 | ||
| # | ||
| # Unless required by applicable law or agreed to in writing, software | ||
| # distributed under the License is distributed on an "AS IS" BASIS, | ||
| # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | ||
| # See the License for the specific language governing permissions and | ||
| # limitations under the License. | ||
| import torch | ||
|
|
||
| import emerging_optimizers.utils as utils | ||
| from emerging_optimizers.psgd.psgd_utils import norm_lower_bound_skew | ||
|
|
||
|
|
||
| __all__ = [ | ||
| "procrustes_step", | ||
| ] | ||
|
|
||
|
|
||
| @torch.compile # type: ignore[misc] | ||
| def procrustes_step(Q: torch.Tensor, max_step_size: float = 0.125, eps: float = 1e-8) -> torch.Tensor: | ||
| r"""One step of an online solver for the orthogonal Procrustes problem. | ||
|
|
||
| The orthogonal Procrustes problem is :math:`\min_U \| U Q - I \|_F` s.t. :math:`U^H U = I` | ||
| by rotating Q as :math:`\exp(a R) Q`, where :math:`R = Q^H - Q` is the generator and :math:`\|a R\| < 1`. | ||
|
|
||
| `max_step_size` should be less than :math:`1/4` as we only expand :math:`\exp(a R)` to its 2nd order term. | ||
|
|
||
| This method is a second order expansion of a Lie algebra parametrized rotation that | ||
| uses a simple approximate line search to find the optimal step size, from Xi-Lin Li. | ||
|
|
||
| Args: | ||
| Q: Tensor of shape (n, n), general square matrix to orthogonalize. | ||
| max_step_size: Maximum step size for the line search. Default is 1/8. (0.125) | ||
| eps: Small number for numerical stability. | ||
| """ | ||
| # Note: this function is written in fp32 to avoid numerical instability while computing the taylor expansion of the exponential map | ||
| with utils.fp32_matmul_precision("highest"): | ||
| R = Q.T - Q | ||
| R /= torch.clamp(norm_lower_bound_skew(R), min=eps) | ||
| RQ = R @ Q | ||
| # trace of RQ is always positive, | ||
| # since tr(RQ) = ⟨R, Q⟩_F = ⟨Q^T - Q, Q⟩_F = ||Q||_F^2 - ⟨Q, Q⟩_F = ||Q||_F^2 - tr(Q^T Q) ≥ 0 | ||
| tr_RQ = torch.trace(RQ) | ||
| RRQ = R @ RQ | ||
| tr_RRQ = torch.trace(RRQ) | ||
| # clip step size to max_step_size, based on a 2nd order expansion. | ||
| _step_size = torch.clamp(-tr_RQ / tr_RRQ, min=0, max=max_step_size) | ||
| # If tr_RRQ >= 0, the quadratic approximation is not concave, we fallback to max_step_size. | ||
| step_size = torch.where(tr_RRQ < 0, _step_size, max_step_size) | ||
| # rotate Q as exp(a R) Q ~ (I + a R + a^2 R^2/2) Q with an optimal step size by line search | ||
| # for 2nd order expansion, only expand exp(a R) to its 2nd term. | ||
| # Q += step_size * (RQ + 0.5 * step_size * RRQ) | ||
| Q = torch.add(Q, torch.add(RQ, RRQ, alpha=0.5 * step_size), alpha=step_size) | ||
|
|
||
| return Q | ||
| Original file line number | Diff line number | Diff line change |
|---|---|---|
| @@ -0,0 +1,151 @@ | ||
| # SPDX-FileCopyrightText: Copyright (c) 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved. | ||
| # SPDX-License-Identifier: Apache-2.0 | ||
| # | ||
| # Licensed under the Apache License, Version 2.0 (the "License"); | ||
| # you may not use this file except in compliance with the License. | ||
| # You may obtain a copy of the License at | ||
| # | ||
| # http://www.apache.org/licenses/LICENSE-2.0 | ||
| # | ||
| # Unless required by applicable law or agreed to in writing, software | ||
| # distributed under the License is distributed on an "AS IS" BASIS, | ||
| # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | ||
| # See the License for the specific language governing permissions and | ||
| # limitations under the License. | ||
| from typing import List | ||
|
|
||
| import torch | ||
|
|
||
|
|
||
| __all__ = [ | ||
| "partial_contraction", | ||
| "apply_kronecker_factors", | ||
| "apply_preconditioner", | ||
| ] | ||
|
|
||
|
|
||
| @torch.compile # type: ignore[misc] | ||
| def partial_contraction(G1: torch.Tensor, G2: torch.Tensor, axis: int) -> torch.Tensor: | ||
| """Compute the partial contraction of G1 and G2 along axis `axis`. | ||
| This is the contraction of the two tensors, but with all axes except `axis` contracted. | ||
|
|
||
| Args: | ||
| G1: Tensor of shape (d_0, d_1, ..., d_{axis-1}, d_{axis}, d_{axis+1}, ..., d_N) | ||
| G2: Tensor of shape (d_0, d_1, ..., d_{axis-1}, d_{axis}, d_{axis+1}, ..., d_N) | ||
| axis: int, the axis to contract along | ||
|
Contributor
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Suggestion: add an example, tensordot is not very easy to understand.
Contributor
Author
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. We have tensordot documentation in the docs |
||
|
|
||
| Returns: | ||
| Tensor of shape (d_{axis}, d_{axis}) | ||
| """ | ||
| # dims_to_contract = all dims except `axis` | ||
| dims = list(range(G1.dim())) | ||
| dims.pop(axis) | ||
| # contraction is symmetric and has shape (d_{axis}, d_{axis}) | ||
| return torch.tensordot(G1, G2, dims=(dims, dims)) | ||
|
|
||
|
|
||
| @torch.compile # type: ignore[misc] | ||
| def apply_kronecker_factors(Q_list: List[torch.Tensor], X: torch.Tensor) -> torch.Tensor: | ||
| """Apply all Kronecker factors once to tensor :math:`X`, each to its corresponding dimension. | ||
|
|
||
| This applies each :math:`Q` factor once, for example in 2D case: :math:`Q_1 X Q_2^T`. | ||
|
|
||
| Args: | ||
| Q_list: List of :math:`Q` (the upper-triangular Kronecker factors), each of shape `(d_i, d_i)` or `(d_i,)`. | ||
| X: Tensor of shape `(d_0, d_1, ..., d_N)`. | ||
|
|
||
| Returns: | ||
| Tensor of shape `(d_0, d_1, ..., d_N)`. | ||
| """ | ||
| if len(Q_list) != X.dim(): | ||
| raise ValueError( | ||
| f"Number of Kronecker factors {len(Q_list)} must match the number of dimensions of X {X.dim()}" | ||
| ) | ||
|
|
||
| Y = X | ||
| for i in range(len(Q_list)): | ||
| Y = _apply_single_kronecker_factor(Q_list, Y, i) | ||
| return Y | ||
|
|
||
|
|
||
| @torch.compile # type: ignore[misc] | ||
| def apply_preconditioner(Q_list: List[torch.Tensor], X: torch.Tensor) -> torch.Tensor: | ||
| """Apply the full PSGD preconditioner to X. | ||
|
|
||
| This is the full Kronecker product of PSGD's kronecker factors Q^T Q, applied to X. | ||
|
|
||
| :math:`P X = (Q_1^T Q_1) X (Q_2^T Q_2)` | ||
|
|
||
| This applies each factor followed by its transpose for the full preconditioner effect. | ||
|
|
||
| Args: | ||
| Q_list: List of :math:`Q` (the Kronecker factors), each of shape `(d_i, d_i)` or `(d_i,)`. | ||
| X: Tensor of shape `(d_0, d_1, ..., d_N)`. | ||
|
|
||
| Returns: | ||
| Tensor of shape `(d_0, d_1, ..., d_N)`. | ||
| """ | ||
| # Apply Q first, then Q.T to get Q^T @ Q | ||
| Px = apply_kronecker_factors(Q_list, X) | ||
| Px = apply_kronecker_factors([q if q.dim() == 1 else q.T for q in Q_list], Px) | ||
| return Px | ||
|
|
||
|
|
||
| def _dim_n_mul_and_permute(X: torch.Tensor, M: torch.Tensor, contract_dim: int) -> torch.Tensor: | ||
| """Multiply tensor X along axis `contract_dim` by 2D matrix M. | ||
|
|
||
| Helper function for `_apply_single_kronecker_factor`. | ||
| If M is (d_out, d_in) we contract M’s second index with X’s `contract_dim` index. | ||
| `torch.tensordot` is used to contract the two tensors, and then the result is permuted to move the new axis 0 to position `contract_dim`. | ||
| Returns a new tensor of the same rank, but with size[contract_dim] replaced by d_out. | ||
| Note that d_{contract_dim} == d_in. | ||
|
|
||
| Args: | ||
| X: Tensor of shape (d_0, d_1, ..., d_{contract_dim-1}, d_{contract_dim}, d_{contract_dim+1}, ..., d_N) | ||
| M: Tensor of shape (d_out, d_in) | ||
| contract_dim: int, the dimension to contract with M, with d_{contract_dim} == d_in | ||
|
|
||
| Returns: | ||
| Tensor of shape (d_0, d_1, ..., d_{contract_dim-1}, d_out, d_{contract_dim+1}, ..., d_N) | ||
|
|
||
| Examples | ||
| -------- | ||
| >>> X = torch.randn(2, 3, 6) | ||
| >>> M = torch.randn(5, 6) | ||
| >>> contract_dim = 2 | ||
| >>> result = _dim_n_mul_and_permute(X, M, contract_dim) | ||
| >>> print(result.shape) | ||
| torch.Size([2, 3, 5]) | ||
|
|
||
| """ | ||
| if X.shape[contract_dim] != M.shape[1]: | ||
| raise ValueError( | ||
| f"Shape mismatch: X.shape[{contract_dim}] = {X.shape[contract_dim]}, M.shape[1] = {M.shape[1]}" | ||
| ) | ||
| # Contract M's 2nd dim (idx=1) with X's `contract_dim` dim | ||
| Y = torch.tensordot(M, X, dims=([1], [contract_dim])) | ||
| # Y now has shape (d_out, d_0, …, d_{contract_dim-1}, d_{contract_dim+1}, …). | ||
| # We want to move that new axis 0 back to position `contract_dim`, due to `torch.tensordot`. | ||
| nd = X.dim() | ||
| perm = list(range(1, contract_dim + 1)) + [0] + list(range(contract_dim + 1, nd)) | ||
| return Y.permute(perm) | ||
|
|
||
|
|
||
| @torch.compile # type: ignore[misc] | ||
| def _apply_single_kronecker_factor(Q_list: List[torch.Tensor], X: torch.Tensor, axis: int) -> torch.Tensor: | ||
| """Apply a single Kronecker factor Q to X at dimension `axis`. Helper function for apply_kronecker_factors. | ||
|
|
||
| If Q is a vector, we multiply X by Q. | ||
| If Q is a matrix, we contract Q's second index with X's `axis` index. | ||
|
|
||
| Args: | ||
| Q_list: List of Q (e.g. the Kronecker factors). | ||
| X: Tensor of shape (d_0, d_1, ..., d_{axis-1}, d_{axis+1}, ..., d_N) | ||
| """ | ||
| Q = Q_list[axis] | ||
| if Q.dim() == 1: | ||
| shape = [1] * X.dim() | ||
|
Contributor
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Nit: Use squeeze and unsqueeze is more PyTorch. |
||
| shape[axis] = Q.size(0) | ||
| return X * Q.view(shape) | ||
|
|
||
| return _dim_n_mul_and_permute(X, Q, contract_dim=axis) | ||
| Original file line number | Diff line number | Diff line change |
|---|---|---|
| @@ -0,0 +1,176 @@ | ||
| # SPDX-FileCopyrightText: Copyright (c) 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved. | ||
| # SPDX-License-Identifier: Apache-2.0 | ||
| # | ||
| # Licensed under the Apache License, Version 2.0 (the "License"); | ||
| # you may not use this file except in compliance with the License. | ||
| # You may obtain a copy of the License at | ||
| # | ||
| # http://www.apache.org/licenses/LICENSE-2.0 | ||
| # | ||
| # Unless required by applicable law or agreed to in writing, software | ||
| # distributed under the License is distributed on an "AS IS" BASIS, | ||
| # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | ||
| # See the License for the specific language governing permissions and | ||
| # limitations under the License. | ||
| from typing import List | ||
|
|
||
| import torch | ||
|
|
||
|
|
||
| __all__ = [ | ||
| "uniformize_q_in_place", | ||
| "norm_lower_bound_spd", | ||
| "norm_lower_bound_skew", | ||
| ] | ||
|
|
||
|
|
||
| @torch.compile # type: ignore[misc] | ||
| def uniformize_q_in_place(Q_list: List[torch.Tensor]) -> None: | ||
| """Balance the dynamic ranges of kronecker factors in place to prevent numerical underflow or overflow. | ||
|
|
||
| Each tensor in `Q_list` is rescaled so that its maximum absolute entry | ||
| becomes the geometric mean of all factors original maxima. This preserves | ||
| the overall product of norms (and thus the scale of the Kronecker product) | ||
| while avoiding numerical underflow or overflow when factors have widely | ||
| differing magnitudes. | ||
|
|
||
| Given tensors :math:`Q_1, Q_2, \\ldots, Q_n`: | ||
|
|
||
| 1. Compute max-absolute norms: :math:`\\|Q_i\\|_\\infty = \\max(|Q_i|)` for :math:`i = 1, \\ldots, n` | ||
| 2. Compute geometric mean: :math:`g = \\left(\\prod_{i=1}^{n} \\|Q_i\\|_\\infty \\right)^{1/n}` | ||
| 3. Rescale each tensor: :math:`Q_i \\leftarrow Q_i \\cdot \\frac{g}{\\|Q_i\\|_\\infty}` | ||
|
|
||
| This ensures :math:`\\|Q_i\\|_\\infty = g` for all :math:`i`, while preserving the norm of | ||
| the Kronecker product :math:`Q_1 \\otimes Q_2 \\otimes \\cdots \\otimes Q_n`. | ||
|
|
||
| Args: | ||
| Q_list: List of Q (e.g. the Kronecker factors), each tensor will be modified in place. | ||
|
|
||
| Returns: | ||
| None | ||
|
|
||
| """ | ||
| if not Q_list: | ||
| raise TypeError("Q_list cannot be empty.") | ||
|
|
||
| order = len(Q_list) | ||
| if order == 1: | ||
| # with a single factor, no balancing is needed | ||
| return | ||
skyw marked this conversation as resolved.
Show resolved
Hide resolved
|
||
|
|
||
| # Compute max-abs norm of each factor | ||
| norms = [torch.max(torch.abs(Q)) for Q in Q_list] | ||
|
|
||
| # Compute geometric mean of those norms | ||
| gmean = torch.prod(torch.stack(norms)) ** (1.0 / order) | ||
|
|
||
| # Rescale each factor so its max‐abs entry == geometric mean | ||
| for Q, norm in zip(Q_list, norms, strict=True): | ||
| Q.mul_(gmean / norm) | ||
|
Contributor
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Future improvement: I lot of code in this PR can/should be optimized. Sequence of operation is very inefficient. Although coming up with a better performing counterpart is some time non-trivial. |
||
|
|
||
|
|
||
| @torch.compile # type: ignore[misc] | ||
| def norm_lower_bound_spd(A: torch.Tensor, k: int = 4, half_iters: int = 2, eps: float = 1e-8) -> torch.Tensor: | ||
| r"""A cheap lower bound for the spectral norm of a symmetric positive definite matrix. | ||
|
|
||
|
|
||
| Args: | ||
| A: Tensor of shape :math:`(n, n)`, symmetric positive definite. | ||
| k: Dimension of the subspace. | ||
| half_iters: Half of the number of subspace iterations. | ||
| eps: Small number for numerical stability. | ||
|
|
||
| Returns: | ||
| A scalar giving a lower bound on :math:`\\|A\\|_2`. | ||
| """ | ||
|
|
||
| # Compute scaling factor from the largest diagonal entry to prevent overflow/underflow | ||
| scale = torch.clamp(A.diagonal().amax(), min=eps) | ||
| A = A / scale | ||
|
|
||
| bound_unnormalized = _subspace_iteration_bound(A, k=k, half_iters=half_iters, eps=eps) | ||
|
|
||
| return scale * bound_unnormalized | ||
|
|
||
|
|
||
| @torch.compile # type: ignore[misc] | ||
| def norm_lower_bound_skew(A: torch.Tensor, k: int = 32, half_iters: int = 2, eps: float = 1e-8) -> torch.Tensor: | ||
| """A cheap lower bound on the spectral norm (largest eigenvalue) of skew-symmetric matrix. | ||
|
|
||
|
|
||
| Note: For skew-symmetric matrices, all diagonal entries are zero and :math:`A^T = -A`. | ||
| From Xi-Lin Li. | ||
|
|
||
| Args: | ||
| A: Tensor of shape :math:`(n, n)`, skew-symmetric. | ||
| k: Dimension of the subspace. Suggested values: 128 for bfloat16, 32 for float32, 4 for float64. | ||
| half_iters: Half of the number of subspace iterations. | ||
| eps: Small number for numerical stability. | ||
|
|
||
| Returns: | ||
| A scalar Tensor giving a lower bound on :math:`\\|A\\|_2`. | ||
|
|
||
| """ | ||
|
|
||
| # Compute scaling factor from the max absolute value to prevent overflow/underflow | ||
| scale = torch.clamp(A.abs().amax(), min=eps) | ||
| A = A / scale | ||
|
|
||
| bound_unnormalized = _subspace_iteration_bound(A, k=k, half_iters=half_iters, eps=eps) | ||
|
|
||
| return scale * bound_unnormalized | ||
|
|
||
|
|
||
| @torch.compile # type: ignore[misc] | ||
| def _subspace_iteration_bound( | ||
| A: torch.Tensor, | ||
| k: int = 32, | ||
| half_iters: int = 2, | ||
| eps: float = 1e-8, | ||
| ) -> torch.Tensor: | ||
| """A helper function for subspace iteration to estimate spectral norm bounds. | ||
|
|
||
| Uses numerically stable subspace iteration with a random initialization that aligns with the | ||
| largest row of A to approximate the dominant eigenspace. This is more robust than simple | ||
| power iteration, especially for large matrices with very low rank. From Xi-Lin Li. | ||
|
|
||
| The algorithm: | ||
| 1. Normalize :math:`A` by its largest absolute entry to avoid overflow. | ||
| 2. Find the row :math:`j` of :math:`A_{\\text{scaled}}` with the largest 2-norm. | ||
| 3. Initialize a :math:`k \\times n` subspace matrix :math:`V` with random vectors aligned to :math:`A[j]`. | ||
| 4. Perform subspace iteration for `half_iters` steps: :math:`V \\leftarrow V \\cdot A_{\\text{scaled}}`. | ||
| 5. Estimate the norm as the maximum 2-norm among the k vectors, then rescale. | ||
|
|
||
| Args: | ||
| A: Input matrix, already normalized by caller. | ||
| k: Dimension of the subspace (number of random vectors). | ||
| half_iters: Number of half-iterations (each applies A twice). | ||
| eps: Smallest number for numerical stability. | ||
|
|
||
| Returns: | ||
| Maximum vector norm from the final subspace iteration (unnormalized). | ||
| """ | ||
|
|
||
| # Initialize random subspace matrix V of shape (k, n) | ||
| V = torch.randn(k, A.shape[1], dtype=A.dtype, device=A.device) | ||
|
|
||
| # Find the row index with the largest 2-norm to initialize our subspace | ||
| # This helps the algorithm converge faster to the dominant eigenspace | ||
| dominant_row_idx = torch.argmax(torch.linalg.vector_norm(A, dim=1)) | ||
| # Rotate the random vectors to align with the dominant row A[dominant_row_idx] | ||
| # This initialization trick makes the subspace iteration more robust for low-rank matrices | ||
| dominant_row = A[dominant_row_idx] | ||
| alignment = torch.sign(torch.sum(dominant_row * V, dim=1, keepdim=True)) | ||
|
|
||
| V = dominant_row + alignment * V | ||
|
|
||
| # Perform subspace iteration | ||
| for _ in range(half_iters): | ||
| V = V @ A | ||
| # Normalize each row of V to prevent exponential growth/decay | ||
| V /= torch.linalg.vector_norm(V, dim=1, keepdim=True) + eps | ||
| # Apply A again (V approximates the dominant eigenspace of A^2) | ||
| V = V @ A | ||
|
|
||
| # Return the maximum 2-norm among the k vectors | ||
| return torch.amax(torch.linalg.vector_norm(V, dim=1)) | ||
Uh oh!
There was an error while loading. Please reload this page.