diff --git a/.github/workflows/cherry-pick-release-commit.yml b/.github/workflows/cherry-pick-release-commit.yml index 8812ad3..d891018 100644 --- a/.github/workflows/cherry-pick-release-commit.yml +++ b/.github/workflows/cherry-pick-release-commit.yml @@ -20,7 +20,7 @@ on: jobs: cherry-pick: - uses: NVIDIA-NeMo/FW-CI-templates/.github/workflows/_cherry_pick.yml@v0.57.0 + uses: NVIDIA-NeMo/FW-CI-templates/.github/workflows/_cherry_pick.yml@v0.63.0 secrets: PAT: ${{ secrets.PAT }} SLACK_WEBHOOK_ADMIN: ${{ secrets.SLACK_WEBHOOK_ADMIN }} diff --git a/docs/apidocs/soap.md b/docs/apidocs/soap.md index 6e285d7..6dcf3bc 100644 --- a/docs/apidocs/soap.md +++ b/docs/apidocs/soap.md @@ -21,4 +21,10 @@ emerging_optimizers.soap .. autofunction:: update_kronecker_factors .. autofunction:: update_eigenbasis_and_momentum + +emerging_optimizers.soap.soap_utils +===================================== + +.. automodule:: emerging_optimizers.soap.soap_utils + :members: ``` diff --git a/docs/conf.py b/docs/conf.py index feb50d5..820c891 100644 --- a/docs/conf.py +++ b/docs/conf.py @@ -72,6 +72,7 @@ "numpy": ("https://numpy.org/doc/stable", None), "torch": ("https://pytorch.org/docs/2.5", None), } +autodoc_typehints = "description" def linkcode_resolve(domain, info): diff --git a/emerging_optimizers/orthogonalized_optimizers/__init__.py b/emerging_optimizers/orthogonalized_optimizers/__init__.py index 0afbeb2..8b8f9a4 100644 --- a/emerging_optimizers/orthogonalized_optimizers/__init__.py +++ b/emerging_optimizers/orthogonalized_optimizers/__init__.py @@ -14,3 +14,4 @@ # limitations under the License. from emerging_optimizers.orthogonalized_optimizers.muon import * from emerging_optimizers.orthogonalized_optimizers.orthogonalized_optimizer import * +from emerging_optimizers.orthogonalized_optimizers.spectral_clipping_utils import * diff --git a/emerging_optimizers/orthogonalized_optimizers/muon.py b/emerging_optimizers/orthogonalized_optimizers/muon.py index 682f439..717ecb2 100644 --- a/emerging_optimizers/orthogonalized_optimizers/muon.py +++ b/emerging_optimizers/orthogonalized_optimizers/muon.py @@ -37,9 +37,12 @@ class Muon(OrthogonalizedOptimizer): optimization via Frank-Wolfe. References: - - Jordan, K. *Muon Optimizer Implementation.* [`GitHub `_] - - *Modular Duality in Deep Learning.* arXiv:2410.21265 (2024). [`arXiv:2410.21265 `_] - - *Training Deep Learning Models with Norm-Constrained LMOs.* arXiv:2502.07529 (2025). [`arXiv:2502.07529 `_] + - Jordan, K. *Muon Optimizer Implementation.* + [`GitHub `_] + - *Modular Duality in Deep Learning.* arXiv:2410.21265 (2024). + [`arXiv:2410.21265 `_] + - *Training Deep Learning Models with Norm-Constrained LMOs.* arXiv:2502.07529 (2025). + [`arXiv:2502.07529 `_] Warning: - This optimizer requires that all parameters passed in are 2D. @@ -122,7 +125,8 @@ def get_muon_scale_factor( # Suggested by K. Jordan and Kimi (https://arxiv.org/abs/2502.16982) return extra_scale_factor * max(size_out, size_in) ** 0.5 elif mode == "unit_rms_norm": - # Suggested by Scion (https://arxiv.org/abs/2502.07529) and Bernstein et al. (https://jeremybernste.in/writing/deriving-muon) + # Suggested by Scion (https://arxiv.org/abs/2502.07529) and Bernstein et al. + # (https://jeremybernste.in/writing/deriving-muon) return extra_scale_factor * (size_out / size_in) ** 0.5 else: raise ValueError(f"Invalid mode for Muon update scale factor: {mode}") diff --git a/emerging_optimizers/orthogonalized_optimizers/muon_utils.py b/emerging_optimizers/orthogonalized_optimizers/muon_utils.py index db44c4e..b56a668 100644 --- a/emerging_optimizers/orthogonalized_optimizers/muon_utils.py +++ b/emerging_optimizers/orthogonalized_optimizers/muon_utils.py @@ -36,14 +36,14 @@ ], "polar_express": [ # Polar Express iteration from: https://arxiv.org/abs/2505.16932 - (7.2086, -15.5131, 9.0178), - (3.9623, -2.5813, 0.4542), - (3.9466, -2.5765, 0.4544), - (3.8991, -2.5671, 0.4566), - (3.7186, -2.5308, 0.4653), - (3.1390, -2.3073, 0.4733), - (2.1715, -1.5246, 0.3885), - (1.8648, -1.2224, 0.3577), + (8.205160, -22.90193, 16.46073), + (4.066395, -2.861154, 0.5183995), + (3.909595, -2.823351, 0.5250370), + (3.285564, -2.415302, 0.4852941), + (2.277873, -1.619822, 0.3984808), + (1.872576, -1.230704, 0.3585162), + (1.856437, -1.213239, 0.3567998), + (1.875, -1.25, 0.375), ], "aol": [ # from https://github.com/thib-s/flash-newton-schulz/blob/main/newton_schulz_triton.py#L511 diff --git a/emerging_optimizers/orthogonalized_optimizers/orthogonalized_optimizer.py b/emerging_optimizers/orthogonalized_optimizers/orthogonalized_optimizer.py index 0c643c7..cbb7d57 100644 --- a/emerging_optimizers/orthogonalized_optimizers/orthogonalized_optimizer.py +++ b/emerging_optimizers/orthogonalized_optimizers/orthogonalized_optimizer.py @@ -45,9 +45,11 @@ class OrthogonalizedOptimizer(optim.Optimizer): - Carlson, D., Cevher, V., and Carin, L. *Stochastic spectral descent for Restricted Boltzmann Machines.* In International Conference on Artificial Intelligence and Statistics (2015a). - - Carlson, D., Hsieh, Y.-P., Collins, E., Carin, L., and Cevher, V. *Stochastic Spectral Descent for Discrete Graphical Models.* + - Carlson, D., Hsieh, Y.-P., Collins, E., Carin, L., and Cevher, V. + *Stochastic Spectral Descent for Discrete Graphical Models.* In IEEE Journal of Selected Topics in Signal Processing, vol. 10, no. 2, pp. 296-311 (2016). - - Carlson, D., Collins, E., Hsieh, Y.-P., Carin, L., and Cevher, V. *Preconditioned spectral descent for deep learning.* + - Carlson, D., Collins, E., Hsieh, Y.-P., Carin, L., and Cevher, V. + *Preconditioned spectral descent for deep learning.* In Neural Information Processing Systems (2015b). - Flynn, T. *The duality structure gradient descent algorithm: analysis and applications to neural networks.* arXiv preprint arXiv:1708.00523 (2017). [`arXiv:1708.00523 `_] diff --git a/emerging_optimizers/orthogonalized_optimizers/spectral_clipping_utils.py b/emerging_optimizers/orthogonalized_optimizers/spectral_clipping_utils.py new file mode 100644 index 0000000..bebbfbe --- /dev/null +++ b/emerging_optimizers/orthogonalized_optimizers/spectral_clipping_utils.py @@ -0,0 +1,96 @@ +# SPDX-FileCopyrightText: Copyright (c) 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# SPDX-License-Identifier: Apache-2.0 +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +import torch + +from emerging_optimizers.orthogonalized_optimizers.muon_utils import newton_schulz + + +__all__ = ["spectral_hardcap", "spectral_clip"] + + +def spectral_clip(X: torch.Tensor, sigma_min: float = -1.0, sigma_max: float = 1.0) -> torch.Tensor: + r"""Applies spectral clipping to the input tensor. + + From the idea that clipping can be written using the sign function. This idea can be extended to singular values of matrices + using the matrix sign function, computed using Newton-Schulz iteration for efficiency. + + Based on https://leloykun.github.io/ponder/spectral-clipping/. + + Args: + X: The input tensor. + sigma_min: The minimum singular value. + sigma_max: The maximum singular value. + + Returns: + The spectral clipped tensor. + """ + if needs_transpose := X.shape[0] > X.shape[1]: + X = X.T + OX = newton_schulz(X, steps=8, coefficient_type="polar_express") + result = (sigma_min + sigma_max) * OX + identity_matrix = torch.eye(X.shape[0], device=X.device, dtype=X.dtype) + for s, sign in zip([sigma_min, sigma_max], [1, -1]): + A = torch.add(s * identity_matrix, OX @ X.T, alpha=-1) + B = torch.add(s * OX, X, alpha=-1) + result = torch.add(result, sign * newton_schulz(A, steps=8, coefficient_type="polar_express") @ B) + result = result * 0.5 + + if needs_transpose: + result = result.T + return result + + +def spectral_hardcap(X: torch.Tensor, beta: float = 1.0) -> torch.Tensor: + r"""Spectral hardcap function clips singular values from above to be less than beta. + + Simplifies the spectral clipping function to just an upper bound, resulting in a hardcap. + Based on https://leloykun.github.io/ponder/spectral-clipping/. + + Args: + X: The input tensor. + beta: The upper bound on the singular values. + + Returns: + The spectral hardcapped tensor. + + """ + if needs_transpose := X.shape[0] > X.shape[1]: + X = X.T + OX = newton_schulz(X, steps=8, coefficient_type="polar_express") + aX = torch.add(beta * OX, X, alpha=-1) + result = torch.add(beta * OX, X) + result = torch.add(result, aX @ newton_schulz(aX, steps=8, coefficient_type="polar_express").T @ OX, alpha=-1) + result = result * 0.5 + if needs_transpose: + result = result.T + return result + + +def spectral_clipped_weight_decay(X: torch.Tensor, beta: float = 1.0, c: float = 0.5) -> torch.Tensor: + r"""Applies weight decay to the input tensor while applying spectral hardcapping. + + This is the spectral version of Euclidean decoupled weight decay (Hanson & Pratt, 1988). + + Based on https://leloykun.github.io/ponder/spectral-clipping/. + + Args: + X: The input tensor. + beta: The upper bound on the singular values. + c: The coefficient parameter. + + Returns: + The spectral clipped weight decay tensor. + """ + return torch.add((1 - c) * X, spectral_hardcap(X, beta), alpha=c) diff --git a/emerging_optimizers/soap/soap.py b/emerging_optimizers/soap/soap.py index ebefedd..1e6aaf2 100644 --- a/emerging_optimizers/soap/soap.py +++ b/emerging_optimizers/soap/soap.py @@ -61,7 +61,6 @@ class SOAP(optim.Optimizer): precondition_warmup_steps: How many steps to warm up the preconditioner (i.e. update every step) adam_warmup_steps: How many steps to skip preconditioning in the beginning (i.e. use standard AdamW updates) precondition_1d: Whether to precondition 1D gradients (like biases). - max_precond_dim: Maximum dimension of the preconditioner matrices. Skips preconditioning if any tensor dimension exceeds. trace_normalization: Whether to normalize update by the trace of the kronecker factor matrix normalize_preconditioned_grads: Whether to normalize preconditioned gradients per layer correct_bias: Whether to use bias correction in Inner Adam and Kronecker factor matrices EMA @@ -91,7 +90,6 @@ def __init__( precondition_warmup_steps: int = 0, adam_warmup_steps: int = 1, precondition_1d: bool = False, - max_precond_dim: int = 8192, trace_normalization: bool = False, normalize_preconditioned_grads: bool = False, correct_bias: bool = True, @@ -127,7 +125,8 @@ def __init__( original_adam_warmup_steps = adam_warmup_steps adam_warmup_steps = max(1, precondition_warmup_steps - 1) logging.info( - f"adam_warmup_steps ({original_adam_warmup_steps}) should be less than precondition_warmup_steps ({precondition_warmup_steps}). " + f"adam_warmup_steps ({original_adam_warmup_steps}) should be less " + f"than precondition_warmup_steps ({precondition_warmup_steps}). " f"Setting adam_warmup_steps to {adam_warmup_steps} by default." ) @@ -141,7 +140,6 @@ def __init__( "precondition_warmup_steps": precondition_warmup_steps, "adam_warmup_steps": adam_warmup_steps, "precondition_1d": precondition_1d, - "max_precond_dim": max_precond_dim, "trace_normalization": trace_normalization, "normalize_preconditioned_grads": normalize_preconditioned_grads, "use_nesterov": use_nesterov, @@ -194,17 +192,16 @@ def step(self, closure: Callable[[], float] | None = None) -> float | None: state["GG"] = init_kronecker_factors( grad, precondition_1d=group["precondition_1d"], - max_precond_dim=group["max_precond_dim"], ) - # Update preconditioner matrices with gradient statistics, do not use shampoo_beta for EMA at first step + # Update preconditioner matrices with gradient statistics, + # do not use shampoo_beta for EMA at first step with utils.fp32_matmul_precision(group["fp32_matmul_prec"]): update_kronecker_factors( kronecker_factor_list=state["GG"], grad=grad, shampoo_beta=0.0, precondition_1d=group["precondition_1d"], - max_precond_dim=group["max_precond_dim"], ) # Increment step counter @@ -284,11 +281,11 @@ def step(self, closure: Callable[[], float] | None = None) -> float | None: grad=grad, shampoo_beta=shampoo_beta, precondition_1d=group["precondition_1d"], - max_precond_dim=group["max_precond_dim"], ) torch.cuda.nvtx.range_pop() - # If current step is the last step to skip preconditioning, initialize eigenbases and end first order warmup + # If current step is the last step to skip preconditioning, initialize eigenbases and + # end first order warmup if state["step"] == group["adam_warmup_steps"]: # Obtain kronecker factor eigenbases from kronecker factor matrices using eigendecomposition state["Q"] = get_eigenbasis_eigh(state["GG"]) @@ -330,7 +327,6 @@ def step(self, closure: Callable[[], float] | None = None) -> float | None: def init_kronecker_factors( grad: torch.Tensor, precondition_1d: bool = False, - max_precond_dim: int = 8192, ) -> List[torch.Tensor]: """Initializes the kronecker factor matrices for the SOAP optimizer. @@ -354,8 +350,6 @@ def init_kronecker_factors( The shape of this tensor determines the size of the kronecker factor matrices. precondition_1d: Whether to create kronecker factor matrices for 1D tensors (like biases). If False, 1D tensors will skip preconditioning. - max_precond_dim: Maximum dimension of the preconditioner matrices. - Skips preconditioning if any tensor dimension exceeds. Returns: List[torch.Tensor]: List of kronecker factor matrices (L and R in paper). @@ -387,21 +381,11 @@ def init_kronecker_factors( else: # Create a square preconditioner matrix for 1D tensors size = grad.shape[0] - if size > max_precond_dim: - # if tensor dimension is larger than max_precond_dim, skip preconditioning this dimension - # append empty tensor to kronecker_factor_list so that subsequent check that use numel() to check if preconditioner is initialized will not fail - kronecker_factor_list.append(torch.empty(0, device=grad.device)) - else: - kronecker_factor_list.append(torch.zeros(size, size, device=grad.device)) + kronecker_factor_list.append(torch.zeros(size, size, device=grad.device)) else: # Create a square kronecker factor matrix for each dimension for size in grad.shape: - if size > max_precond_dim: - # append empty tensor to kronecker_factor_list so that subsequent check that use numel() to check if preconditioner is initialized will not fail - # skip preconditioning this dimension - kronecker_factor_list.append(torch.empty(0, device=grad.device)) - else: - kronecker_factor_list.append(torch.zeros(size, size, device=grad.device)) + kronecker_factor_list.append(torch.zeros(size, size, device=grad.device)) return kronecker_factor_list @@ -412,7 +396,6 @@ def update_kronecker_factors( grad: torch.Tensor, shampoo_beta: float, precondition_1d: bool = False, - max_precond_dim: int = 8192, ) -> None: """Updates the preconditioner matrices using gradient outer products. @@ -429,8 +412,6 @@ def update_kronecker_factors( Controls how much weight to give to new vs old gradient statistics. precondition_1d: Whether to apply preconditioning to 1D tensors (like biases). If False, 1D tensors will skip preconditioning. - max_precond_dim: Maximum dimension of the preconditioner matrices. - Skips preconditioning if any tensor dimension exceeds. Example: >>> grad = torch.randn(10, 20) @@ -446,20 +427,23 @@ def update_kronecker_factors( kronecker_factor_list[0].lerp_(outer_product, 1 - shampoo_beta) else: # For 1D tensors, skip preconditioning + logging.error( + "1D tensor is passed to update_kronecker_factors, " + "but precondition_1d is not set to True, skipping preconditioning." + ) return else: # For higher dimensional tensors, compute outer products for each dimension for idx, dim_size in enumerate(grad.shape): - if dim_size <= max_precond_dim: - # Compute outer product by contracting all dimensions except idx - contract_dims = [*chain(range(idx), range(idx + 1, grad.dim()))] - outer_product = torch.tensordot( - grad, - grad, - dims=[contract_dims] * 2, - ) - # Update the corresponding Kronecker factor - kronecker_factor_list[idx].lerp_(outer_product, 1 - shampoo_beta) + # Compute outer product by contracting all dimensions except idx + contract_dims = [*chain(range(idx), range(idx + 1, grad.dim()))] + outer_product = torch.tensordot( + grad, + grad, + dims=[contract_dims] * 2, + ) + # Update the corresponding Kronecker factor + kronecker_factor_list[idx].lerp_(outer_product, 1 - shampoo_beta) @torch.no_grad() # type: ignore[misc] @@ -606,7 +590,8 @@ def precondition( ) else: # Permute gradient dimensions to process the next dimension in the following iteration - # when preconditioning for the current dimension is skipped (Q is empty), in the case of one-sided preconditioning. + # when preconditioning for the current dimension is skipped (Q is empty), in the case of + # one-sided preconditioning. permute_order = list(range(1, grad.dim())) + [0] grad = grad.permute(permute_order) diff --git a/emerging_optimizers/soap/soap_utils.py b/emerging_optimizers/soap/soap_utils.py index 136c089..d63b9d2 100644 --- a/emerging_optimizers/soap/soap_utils.py +++ b/emerging_optimizers/soap/soap_utils.py @@ -16,7 +16,7 @@ import torch -from emerging_optimizers import utils +from emerging_optimizers.utils import eig as eig_utils __all__ = [ @@ -85,16 +85,10 @@ def get_eigenbasis_eigh( # We use an empty tensor so that the `precondition` function will skip this factor. updated_eigenbasis_list.append(torch.empty(0, device=kronecker_factor.device)) continue - # Construct approximated eigenvalues using QL^T@L@QL or QR^T@R@QR. - # The approximated eigenvalues should be close to diagonal if the eigenbasis is close to the true eigenbasis of the kronecker factor - # (i.e. the approximated eigenvectors diagonalize the kronecker factor) - approx_eigenvalue_matrix = eigenbasis.T @ kronecker_factor @ eigenbasis - # Update eigenbasis when necessary. Update is skipped only when adaptive update criteria is met. - if _adaptive_criteria_met( - approx_eigenvalue_matrix=approx_eigenvalue_matrix, - tolerance=adaptive_update_tolerance, - ): - _, Q = utils.eig.eigh_with_fallback( + + approx_eigvals = eig_utils.conjugate(kronecker_factor, eigenbasis, diag=True) + if not eig_utils.met_approx_eigvals_criteria(kronecker_factor, approx_eigvals, adaptive_update_tolerance): + _, Q = eig_utils.eigh_with_fallback( kronecker_factor, force_double=False, eps=eps, @@ -109,7 +103,7 @@ def get_eigenbasis_eigh( if kronecker_factor.numel() == 0: updated_eigenbasis_list.append(torch.empty(0, device=kronecker_factor.device)) continue - _, Q = utils.eig.eigh_with_fallback( + _, Q = eig_utils.eigh_with_fallback( kronecker_factor, force_double=False, eps=eps, output_dtype=torch.float if convert_to_float else None ) updated_eigenbasis_list.append(Q) @@ -131,8 +125,8 @@ def get_eigenbasis_qr( Computes using multiple rounds of power iteration followed by QR decomposition (orthogonal iteration). Args: - kronecker_factor_list: List containing preconditioner (GGT and GTG) - eigenbasis_list: List containing eigenbases (QL and QR) + kronecker_factor_list: List containing preconditioner (:math:`GG^T` and :math:`G^TG`) + eigenbasis_list: List containing eigenbases (:math:`Q_L` and :math:`Q_R`) exp_avg_sq: inner adam second moment (exp_avg_sq). This tensor is modified in-place. convert_to_float: If True, preconditioner matrices and their corresponding orthonormal matrices will be cast to float. Otherwise, they are left in @@ -206,21 +200,21 @@ def get_eigenbasis_qr( if kronecker_factor.numel() == 0: updated_eigenbasis_list.append(torch.empty(0, device=kronecker_factor.device)) continue - # construct approximated eigenvalues using QL^T@L@QL or QR^T@R@QR, which should be close to diagonal - # if the eigenbasis is close to the true eigenbasis of the kronecker factor (i.e. diagonalizes it) - approx_eigenvalue_matrix = eigenbasis.T @ kronecker_factor @ eigenbasis - # Update eigenbasis when necessary. Update is skipped only when use_adaptive_criteria is True - # but criteria is not met. + # Update eigenbasis when necessary. Update is skipped only when use_adaptive_criteria is True while + # criteria is not met. if_update = True - if use_adaptive_criteria and not _adaptive_criteria_met( - approx_eigenvalue_matrix=approx_eigenvalue_matrix, - tolerance=adaptive_update_tolerance, - ): - if_update = False + # construct approximated eigenvalues using Q_L^T L Q_L or Q_R^T R Q_R, which should be close to + # diagonal if the eigenbasis is close to the true eigenbasis of the kronecker factor (i.e. diagonalizes it) + approx_eigvals = eig_utils.conjugate(kronecker_factor, eigenbasis, diag=True) + if use_adaptive_criteria: + if_update = not eig_utils.met_approx_eigvals_criteria( + kronecker_factor, approx_eigvals, adaptive_update_tolerance + ) + if if_update: - Q, exp_avg_sq = _orthogonal_iteration( - approx_eigenvalue_matrix=approx_eigenvalue_matrix, + Q, exp_avg_sq = eig_utils.orthogonal_iteration( + approx_eigvals=approx_eigvals, kronecker_factor=kronecker_factor, eigenbasis=eigenbasis, ind=ind, @@ -234,85 +228,3 @@ def get_eigenbasis_qr( updated_eigenbasis_list.append(eigenbasis) return updated_eigenbasis_list, exp_avg_sq - - -def _orthogonal_iteration( - approx_eigenvalue_matrix: torch.Tensor, - kronecker_factor: torch.Tensor, - eigenbasis: torch.Tensor, - ind: int, - exp_avg_sq: torch.Tensor, - convert_to_float: bool = True, - power_iter_steps: int = 1, -) -> Tuple[torch.Tensor, torch.Tensor]: - """Computes the eigenbases of the preconditioner using power iteration and QR decomposition. - - This function performs multiple rounds of power iteration followed by QR decomposition - to recompute the eigenbases of the preconditioner kronecker factor. Generalizes Vyas et al.'s (SOAP) algorithm of 1 step of power iteration for updating the eigenbasis. - - Args: - approx_eigenvalue_matrix : Projection of kronecker factor onto the eigenbasis, should be close to diagonal - kronecker_factor : Kronecker factor matrix. - eigenbasis : Kronecker factor eigenbasis matrix. - ind : Index for selecting dimension in the exp_avg_sq matrix to apply the sorting order over. - exp_avg_sq : inner Adam second moment (exp_avg_sq). - convert_to_float : If True, preconditioner matrices and their corresponding - orthonormal matrices will be cast to float. Otherwise, they are left in - their original type. Defaults to False. - power_iter_steps: Number of power iteration steps to perform before QR decomposition. - More steps can lead to better convergence but increased computation time. - - Returns: - tuple[torch.Tensor, torch.Tensor]: A tuple containing: - - Q: The updated eigenbasis - - exp_avg_sq: The updated (sorted) inner Adam second moment - """ - # extract approximated eigenvalues from the diagonal of the projection of kronecker factor onto eigenbases - approx_eigvals = torch.diag(approx_eigenvalue_matrix) - # Sort the approximated eigenvalues according to their magnitudes - sort_idx = torch.argsort(approx_eigvals, descending=True) - # re-order the inner adam second moment - exp_avg_sq = exp_avg_sq.index_select(ind, sort_idx) - - # Initialize power iteration after sorting the columns of the eigenbasis matrix according to the descending eigenvalues - Q = eigenbasis[:, sort_idx] - - # By default, perform QR decomposition with power iteration with FP32 precision - # Perform multiple steps of power iteration - for _ in range(power_iter_steps): - # Project current eigenbases on kronecker factor - Q = kronecker_factor @ Q - # Perform QR to maintain orthogonality between iterations - Q = torch.linalg.qr(Q).Q - - # When not converting to float, ensure that Q is in the original dtype - if not convert_to_float: - Q = Q.to(kronecker_factor.dtype) - - return Q, exp_avg_sq - - -@torch.compile # type: ignore[misc] -def _adaptive_criteria_met( - approx_eigenvalue_matrix: torch.Tensor, - tolerance: Optional[float] = None, -) -> bool: - """Determines whether the eigenbasis for a factor matrix should be updated in the next step of the orthogonal iteration. - - Determines whether the eigenbasis for a factor matrix should be updated based on computing - the approximated eigenvalues Q^T GG Q, where Q is the approximated eigenvectors and - GG is the Kronecker factor. The approximated eigenvalues update criteria is then defined as - ||diag(Q^T GG Q)||_F >= (1 - tolerance) * (Q^T GG Q)_F. - - Args: - approx_eigenvalue_matrix: Projection of kronecker factor onto the eigenbasis, should be close to diagonal - tolerance: Tolerance threshold for the normalized diagonal component of approximated eigenvalue matrix. - - Returns: - perform_update: Whether to update eigenbasis this iteration - """ - if tolerance is None: - return True - - # check if normalized diagonal component is not smaller than tolerance - return not utils.eig.adaptive_early_exit_criteria(approx_eigenvalue_matrix, tolerance) diff --git a/emerging_optimizers/utils/eig.py b/emerging_optimizers/utils/eig.py index b75c51b..9918df1 100644 --- a/emerging_optimizers/utils/eig.py +++ b/emerging_optimizers/utils/eig.py @@ -12,16 +12,19 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. -from typing import Optional +from typing import Optional, Tuple import torch from absl import logging from torch import Tensor -from emerging_optimizers import utils - -__all__ = ["eigh_with_fallback", "eig_orthogonal_iteration", "adaptive_early_exit_criteria"] +__all__ = [ + "eigh_with_fallback", + "met_approx_eigvals_criteria", + "conjugate", + "orthogonal_iteration", +] def eigh_with_fallback( @@ -37,13 +40,15 @@ def eigh_with_fallback( Default 2nd argument of eigh UPLO is 'L'. Args: - x: Tensor of shape (*, n, n) where "*" is zero or more batch dimensions consisting of symmetric or Hermitian matrices. + x: Tensor of shape (*, n, n) where "*" is zero or more batch dimensions consisting of symmetric or + Hermitian matrices. force_double: Force double precision computation. Default False. - eps: Small offset for numerical stability. If None, uses dtype-appropriate values (1e-7 for float32, 1e-15 for float64). Default None. + eps: Small offset for numerical stability. If None, uses dtype-appropriate values (1e-7 for float32, + 1e-15 for float64). Default None. output_dtype: Desired output dtype. If None, uses input dtype. Default None. Returns: - tuple[Tensor, Tensor]: Eigenvalues and eigenvectors tuple (eigenvalues in descending order). + Eigenvalues and eigenvectors tuple (eigenvalues in descending order). """ input_dtype = x.dtype if output_dtype is None: @@ -100,25 +105,27 @@ def eig_orthogonal_iteration( max_iterations: int = 1, tolerance: float = 0.01, ) -> tuple[Tensor, Tensor]: - """Approximately compute the eigendecomposition of a symmetric matrix by performing the orthogonal iteration algorithm. + """Approximately compute the eigen decomposition + [DEPRECATED] Use `orthogonal_iteration` instead. - Orthogonal or subspace iteration uses iterative power iteration and QR decomposition to update the approximated eigenvectors. - When the initial estimate is the zero matrix, the eigendecomposition is computed using `eigh_with_fallback`. + Orthogonal or subspace iteration uses iterative power iteration and QR decomposition to update the approximated + eigenvectors. When the initial estimate is the zero matrix, the eigendecomposition is computed + using `eigh_with_fallback`. - Based on Purifying Shampoo (https://www.arxiv.org/abs/2506.03595), we use an early exit criteria to stop the QR iterations. - This generalizes SOAP's algorithm of 1 step of power iteration for updating the eigenbasis. + Based on Purifying Shampoo (https://www.arxiv.org/abs/2506.03595), we use an early exit criteria to stop the + QR iterations. This generalizes SOAP's algorithm of 1 step of power iteration for updating the eigenbasis. Args: x: tensor of shape (n, n) where x is a symmetric or Hermitian matrix. approx_eigenvectors: The current estimate of the eigenvectors of x. If None or a zero matrix, falls back to using `eigh_with_fallback`. - max_iterations: The maximum number of iterations to perform. (Default: 1) - tolerance: The tolerance for determining convergence in terms of the norm of the off-diagonal elements of the approximated eigenvalues. - (Default: 0.01) + max_iterations: The maximum number of iterations to perform. + tolerance: The tolerance for determining convergence in terms of the norm of the off-diagonal elements + of the approximated eigenvalues. Returns: - tuple[Tensor, Tensor]: A tuple containing the approximated eigenvalues and eigenvectors matrix of the input matrix A. + A tuple containing the approximated eigenvalues and eigenvectors matrix of the input matrix A. """ # Check if x is already a diagonal matrix @@ -130,44 +137,129 @@ def eig_orthogonal_iteration( return eigh_with_fallback(x, force_double=True) # Perform power iteration and QR decomposition iteratively. - with utils.fp32_matmul_precision("highest"): - Q = approx_eigenvectors - approx_eigenvalues_matrix = Q.T @ x @ Q - approx_eigenvalues = torch.diag(approx_eigenvalues_matrix) - iteration = 0 - while iteration < max_iterations and not adaptive_early_exit_criteria(approx_eigenvalues_matrix, tolerance): - power_iteration = x @ Q - Q = torch.linalg.qr(power_iteration).Q - approx_eigenvalues_matrix = Q.T @ x @ Q - iteration += 1 - # Sort eigenvalues in descending order and reorder eigenvectors accordingly - # Sorting can help mitigate numerical instability since QR decompositions can mix the approximated eigenvectors - approx_eigenvalues, indices = torch.diag(approx_eigenvalues_matrix).sort(stable=True, descending=True) - Q = Q[:, indices] - - return approx_eigenvalues, Q - - -def adaptive_early_exit_criteria(approx_eigenvalues_matrix: Tensor, tolerance: float) -> bool: - """Evaluates if a criteria using approximated eigenvalues is below or equal to the tolerance. - - `approx_eigenvalues_matrix` is a matrix created from the approximated eigenvectors and the symmetric matrix that is being eigendecomposed. - We check if the ratio of the diagonal norm to the matrix norm is greater than or equal to (1 - tolerance). + Q = approx_eigenvectors + approx_eigvals = conjugate(x, Q, diag=True) + iteration = 0 + while iteration < max_iterations and not met_approx_eigvals_criteria(x, approx_eigvals, tolerance): + power_iteration = x @ Q + Q = torch.linalg.qr(power_iteration).Q + approx_eigvals = conjugate(x, Q, diag=True) + iteration += 1 + # Sort eigenvalues in descending order and reorder eigenvectors accordingly + # Sorting can help mitigate numerical instability since QR decompositions can mix the approximated eigenvectors + sorted_approx_eigvals, indices = approx_eigvals.sort(stable=True, descending=True) + Q = Q[:, indices] + + return sorted_approx_eigvals, Q + + +def met_approx_eigvals_criteria( + kronecker_factor: torch.Tensor, + approx_eigvals: torch.Tensor, + tolerance: float, +) -> bool: + """Determines whether the eigenbasis for a factor matrix met the desired criteria + + The approximated eigenvalues update criteria is then defined as + :math:`||diag(Q^T K Q)||_F >= (1 - tolerance) * (Q^T K Q)_F`, where :math:`Q` is the approximated eigenvectors and + :math:`K` is the kronecker factor (L or R). + + We use the kronecker factor and approximated eigenvalues directly to save compute because Frobenius norm of + kronecker factor is the same as that of the approximated eigenvalues matrix. Args: - approx_eigenvalues_matrix: The symmetric matrix whose eigenvalues is being eigendecomposed. - tolerance: The tolerance for the early exit criteria, the min relative error between diagonal norm and matrix norm of the approximated eigenvalues and the diagonal. + kronecker_factor: Kronecker factor matrix. + approx_eigvals: Approximated eigenvalues + tolerance: Tolerance threshold for the normalized diagonal component of approximated eigenvalue matrix. Returns: - bool: True if the criteria is below or equal to the tolerance, False otherwise. - + perform_update: Whether to update eigenbasis this iteration """ - matrix_norm = torch.linalg.norm(approx_eigenvalues_matrix) - approx_eigvals = torch.diag(approx_eigenvalues_matrix) + matrix_norm = torch.linalg.norm(kronecker_factor) diagonal_norm = torch.linalg.norm(approx_eigvals) + return diagonal_norm >= (1 - tolerance) * matrix_norm +def orthogonal_iteration( + approx_eigvals: torch.Tensor, + kronecker_factor: torch.Tensor, + eigenbasis: torch.Tensor, + ind: int, + exp_avg_sq: torch.Tensor, + convert_to_float: bool, + power_iter_steps: int, +) -> Tuple[torch.Tensor, torch.Tensor]: + """Computes the eigenbases of the preconditioner using power iteration and QR decomposition. + + This function performs multiple rounds of power iteration followed by QR decomposition + to recompute the eigenbases of the preconditioner kronecker factor. Generalizes Vyas et al.'s (SOAP) algorithm of 1 step of power iteration for updating the eigenbasis. + + Args: + approx_eigenvalue_matrix : Projection of kronecker factor onto the eigenbasis, should be close to diagonal + kronecker_factor : Kronecker factor matrix. + eigenbasis : Kronecker factor eigenbasis matrix. + ind : Index for selecting dimension in the exp_avg_sq matrix to apply the sorting order over. + exp_avg_sq : inner Adam second moment (exp_avg_sq). + convert_to_float : If True, preconditioner matrices and their corresponding + orthonormal matrices will be cast to float. Otherwise, they are left in + their original type. Defaults to False. + power_iter_steps: Number of power iteration steps to perform before QR decomposition. + More steps can lead to better convergence but increased computation time. + + Returns: + tuple[torch.Tensor, torch.Tensor]: A tuple containing: + - Q: The updated eigenbasis + - exp_avg_sq: The updated (sorted) inner Adam second moment + """ + # Sort the approximated eigenvalues according to their magnitudes + sort_idx = torch.argsort(approx_eigvals, descending=True) + # re-order the inner adam second moment + exp_avg_sq = exp_avg_sq.index_select(ind, sort_idx) + + # Initialize power iteration after sorting the columns of the eigenbasis matrix according to the descending eigenvalues + Q = eigenbasis[:, sort_idx] + + # By default, perform QR decomposition with power iteration with FP32 precision + # Perform multiple steps of power iteration + for _ in range(power_iter_steps): + # Project current eigenbases on kronecker factor + Q = kronecker_factor @ Q + # Perform QR to maintain orthogonality between iterations + Q = torch.linalg.qr(Q).Q + + # When not converting to float, ensure that Q is in the original dtype + if not convert_to_float: + Q = Q.to(kronecker_factor.dtype) + + return Q, exp_avg_sq + + +def conjugate(a: torch.Tensor, p: torch.Tensor, diag: bool = False) -> torch.Tensor: + """Calculate similarity transformation + + This function calculates :math:`B = P^T A P`. It assumes P is orthogonal so that :math:`P^{-1} = P^T` and + the similarity transformation exists. + + Args: + a: matrix to be transformed + p: An orthogonal matrix. + diag: If True, only return the diagonal of the similarity transformation + + Returns: + b + """ + if a.dim() != 2 or p.dim() != 2: + raise TypeError("a and p must be 2D matrices") + pta = p.T @ a + if not diag: + b = pta @ p + else: + # return the diagonal of the similarity transformation + b = (pta * p.T).sum(dim=1) + return b + + def _is_diagonal(x: Tensor) -> bool: r"""Checks if symmetric matrix is diagonal. Raises an error if the input is not a square matrix.""" @@ -189,7 +281,7 @@ def _try_handle_diagonal_matrix(x: Tensor) -> Optional[tuple[Tensor, Tensor]]: x: Tensor of shape (n, n) where x is a symmetric or Hermitian matrix. Returns: - Optional[tuple[Tensor, Tensor]]: Sorted eigenvalues and eigenvectors if A is diagonal, None otherwise. + Sorted eigenvalues and eigenvectors if A is diagonal, None otherwise. """ input_dtype = x.dtype if _is_diagonal(x): diff --git a/tests/ci/L0_Tests_GPU.sh b/tests/ci/L0_Tests_GPU.sh index 53b1fd1..f775f55 100644 --- a/tests/ci/L0_Tests_GPU.sh +++ b/tests/ci/L0_Tests_GPU.sh @@ -21,3 +21,4 @@ coverage run -p --source=emerging_optimizers tests/test_soap_utils.py coverage run -p --source=emerging_optimizers tests/soap_smoke_test.py coverage run -p --source=emerging_optimizers tests/soap_mnist_test.py coverage run -p --source=emerging_optimizers tests/test_scalar_optimizers.py +coverage run -p --source=emerging_optimizers tests/test_spectral_clip_utils.py \ No newline at end of file diff --git a/tests/ci/L1_Tests_GPU.sh b/tests/ci/L1_Tests_GPU.sh index 0e7570a..681ea79 100644 --- a/tests/ci/L1_Tests_GPU.sh +++ b/tests/ci/L1_Tests_GPU.sh @@ -19,3 +19,4 @@ python tests/test_soap_functions.py python tests/test_soap_utils.py python tests/soap_smoke_test.py python tests/test_scalar_optimizers.py +python tests/test_spectral_clip_utils.py diff --git a/tests/test_soap_functions.py b/tests/test_soap_functions.py index 937d0f0..81016c7 100644 --- a/tests/test_soap_functions.py +++ b/tests/test_soap_functions.py @@ -35,28 +35,12 @@ def test_init_preconditioner_multidim_tensor_shapes(self) -> None: """Tests init_preconditioner with a multi-dimensional tensor.""" grad = torch.randn(3, 4, 5) state: dict[str, Any] = {} - # No merge_dims: each dimension gets its own preconditioner unless dimension > max_precond_dim. - state["GG"] = init_kronecker_factors(grad, precondition_1d=False, max_precond_dim=8192) + state["GG"] = init_kronecker_factors(grad, precondition_1d=False) self.assertEqual(len(state["GG"]), grad.dim()) self.assertEqual(state["GG"][0].shape, (3, 3)) self.assertEqual(state["GG"][1].shape, (4, 4)) self.assertEqual(state["GG"][2].shape, (5, 5)) - def test_init_kronecker_factors_max_precond_dim(self) -> None: - """Tests init_kronecker_factors respects max_precond_dim.""" - max_dim = 8 - grad = torch.randn(3, max_dim + 2, 5) # Second dimension exceeds max_dim - kronecker_factors = init_kronecker_factors(grad, precondition_1d=False, max_precond_dim=max_dim) - - self.assertEqual(len(kronecker_factors), grad.dim()) - # Dimension 0 (size 3) <= max_dim - self.assertEqual(kronecker_factors[0].shape, (3, 3)) - # Dimension 1 (size max_dim + 2) > max_dim -> Should be empty - self.assertEqual(kronecker_factors[1].shape, (0,)) - self.assertEqual(kronecker_factors[1].numel(), 0) - # Dimension 2 (size 5) <= max_dim - self.assertEqual(kronecker_factors[2].shape, (5, 5)) - @parameterized.parameters( (1,), (2,), @@ -97,14 +81,13 @@ def test_adam_warmup_steps(self, adam_warmup_steps: int) -> None: self.assertEqual(state["Q"][1].shape, (3, 3)) def test_update_kronecker_factors(self) -> None: - """Tests update_kronecker_factors, including max_precond_dim handling.""" max_dim = 8 shampoo_beta = 0.9 dim0, dim1, dim2 = 3, max_dim + 2, 5 grad = torch.randn(dim0, dim1, dim2) # Initialize factors - initial_factors = init_kronecker_factors(grad, precondition_1d=False, max_precond_dim=max_dim) + initial_factors = init_kronecker_factors(grad, precondition_1d=False) kronecker_factors = [f.clone() for f in initial_factors] @@ -113,25 +96,15 @@ def test_update_kronecker_factors(self) -> None: grad=grad, shampoo_beta=shampoo_beta, precondition_1d=False, - max_precond_dim=max_dim, ) self.assertEqual(len(kronecker_factors), grad.dim()) - # Dimension 0 (size 3) <= max_dim: Should be updated contract_dims_0 = [1, 2] outer_product_0 = torch.tensordot(grad, grad, dims=[contract_dims_0] * 2) expected_factor_0 = initial_factors[0] * shampoo_beta + outer_product_0 * (1 - shampoo_beta) torch.testing.assert_close(kronecker_factors[0], expected_factor_0, atol=1e-6, rtol=1e-6) - # Dimension 1 (size 10) > max_dim: Should NOT be updated (still empty) - self.assertEqual(kronecker_factors[1].shape, (0,)) - self.assertEqual(kronecker_factors[1].numel(), 0) - - # Check it's the same object or has same properties as initial empty tensor - self.assertTrue(torch.equal(kronecker_factors[1], initial_factors[1])) - - # Dimension 2 (size 5) <= max_dim: Should be updated contract_dims_2 = [0, 1] outer_product_2 = torch.tensordot(grad, grad, dims=[contract_dims_2] * 2) expected_factor_2 = initial_factors[2] * shampoo_beta + outer_product_2 * (1 - shampoo_beta) @@ -176,7 +149,7 @@ def test_tensordot_vs_matmul(self, m, n): {"N": 32, "M": 8}, ) def test_project_and_project_back(self, N: int, M: int) -> None: - """Tests that projecting a tensor to eigenbasis of QL and QR and then projecting it back results in the original tensor. + """Tests that projecting a tensor to eigenbasis of QL and QR and back The projected tensor should approximately recover the original tensor. """ diff --git a/tests/test_soap_utils.py b/tests/test_soap_utils.py index 32e2211..2200732 100644 --- a/tests/test_soap_utils.py +++ b/tests/test_soap_utils.py @@ -16,12 +16,8 @@ from absl.testing import absltest, parameterized from emerging_optimizers import utils -from emerging_optimizers.soap.soap_utils import ( - _adaptive_criteria_met, - _orthogonal_iteration, - get_eigenbasis_eigh, - get_eigenbasis_qr, -) +from emerging_optimizers.soap import soap_utils +from emerging_optimizers.utils import eig as eig_utils # Base class for tests requiring seeding for determinism @@ -43,9 +39,10 @@ def test_adaptive_criteria_met(self) -> None: diagonal_matrix = torch.eye(n) # Test with small tolerance - should not update since matrix is diagonal - self.assertFalse( - _adaptive_criteria_met( - approx_eigenvalue_matrix=diagonal_matrix, + self.assertTrue( + eig_utils.met_approx_eigvals_criteria( + diagonal_matrix, + diagonal_matrix.diag(), tolerance=0.1, ), msg="Should not update for diagonal matrix with small tolerance", @@ -62,18 +59,20 @@ def test_adaptive_criteria_met(self) -> None: ) # Test with small tolerance - should update since matrix has significant off-diagonal elements - self.assertTrue( - _adaptive_criteria_met( - approx_eigenvalue_matrix=off_diagonal_matrix, + self.assertFalse( + eig_utils.met_approx_eigvals_criteria( + off_diagonal_matrix, + off_diagonal_matrix.diag(), tolerance=0.1, ), msg="Should update for matrix with significant off-diagonal elements and small tolerance", ) # Test with large tolerance - should not update even with off-diagonal elements - self.assertFalse( - _adaptive_criteria_met( - approx_eigenvalue_matrix=off_diagonal_matrix, + self.assertTrue( + eig_utils.met_approx_eigvals_criteria( + off_diagonal_matrix, + off_diagonal_matrix.diag(), tolerance=10.0, ), msg="Should not update for any matrix with large tolerance", @@ -102,7 +101,7 @@ def test_get_eigenbasis_qr(self, N: int, M: int) -> None: } # We'll call get_eigenbasis_qr - Q_new_list, exp_avg_sq_new = get_eigenbasis_qr( + Q_new_list, exp_avg_sq_new = soap_utils.get_eigenbasis_qr( kronecker_factor_list=state["GG"], eigenbasis_list=state["Q"], exp_avg_sq=state["exp_avg_sq"], @@ -171,11 +170,11 @@ def test_update_eigenbasis_with_QR(self, N: int, power_iter_steps: int) -> None: # Create estimated eigenvalue matrix by projecting kronecker_factor onto eigenbasis's basis approx_eigenvalue_matrix = eigenbasis.T.mm(kronecker_factor).mm(eigenbasis) # Extract eigenvalues from the diagonal of the estimated eigenvalue matrix - est_eigvals = torch.diag(approx_eigenvalue_matrix) + approx_eigvals = torch.diag(approx_eigenvalue_matrix) # Call the QR function to update the eigenbases and re-order the inner adam second moment - Q_new, exp_avg_sq_new = _orthogonal_iteration( - approx_eigenvalue_matrix=approx_eigenvalue_matrix, + Q_new, exp_avg_sq_new = eig_utils.orthogonal_iteration( + approx_eigvals=approx_eigvals, kronecker_factor=kronecker_factor, eigenbasis=eigenbasis, ind=0, # Test with first dimension @@ -200,7 +199,7 @@ def test_update_eigenbasis_with_QR(self, N: int, power_iter_steps: int) -> None: # Test 3: Check that exp_avg_sq is properly sorted based on eigenvalues # The sorting should be based on the diagonal elements of estimated_eigenvalue_matrix - sort_idx = torch.argsort(est_eigvals, descending=True) + sort_idx = torch.argsort(approx_eigvals, descending=True) expected_exp_avg_sq = exp_avg_sq.index_select(0, sort_idx) torch.testing.assert_close( exp_avg_sq_new, @@ -230,7 +229,7 @@ def test_get_eigenbasis_eigh(self, dims: list[int]) -> None: k_factor = k_factor @ k_factor.T + torch.eye(dim, device="cuda") * 1e-5 kronecker_factor_list.append(k_factor) - Q_list = get_eigenbasis_eigh(kronecker_factor_list, convert_to_float=True) + Q_list = soap_utils.get_eigenbasis_eigh(kronecker_factor_list, convert_to_float=True) self.assertEqual(len(Q_list), len(kronecker_factor_list)) @@ -265,6 +264,20 @@ def test_get_eigenbasis_eigh(self, dims: list[int]) -> None: msg=f"Matrix {i} was not properly diagonalized. Off-diagonal norm: {off_diagonal_norm}", ) + def test_conjugate_assert_2d_input(self) -> None: + """Tests the conjugate function.""" + a = torch.randn(2, 3, 4, device="cuda") + with self.assertRaises(TypeError): + eig_utils.conjugate(a, a) + + def test_conjugate_match_reference(self) -> None: + x = torch.randn(15, 17, device="cuda") + a = x @ x.T + _, p = torch.linalg.eigh(a) + + ref = p.T @ a @ p + torch.testing.assert_close(eig_utils.conjugate(a, p), ref, atol=0, rtol=0) + if __name__ == "__main__": absltest.main() diff --git a/tests/test_spectral_clip_utils.py b/tests/test_spectral_clip_utils.py new file mode 100644 index 0000000..2be4557 --- /dev/null +++ b/tests/test_spectral_clip_utils.py @@ -0,0 +1,126 @@ +# SPDX-FileCopyrightText: Copyright (c) 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# SPDX-License-Identifier: Apache-2.0 +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import torch +from absl import logging +from absl.testing import absltest, parameterized + +import emerging_optimizers.orthogonalized_optimizers as orthogonalized_optimizers + + +class TestSpectralClipping(parameterized.TestCase): + def setUp(self): + self.prev_precision = torch.get_float32_matmul_precision() + torch.set_float32_matmul_precision("highest") + self.device = "cuda" if torch.cuda.is_available() else "cpu" + logging.info(f"Using device: {self.device}") + torch.manual_seed(1234) + + def tearDown(self): + torch.set_float32_matmul_precision(self.prev_precision) + + @parameterized.product( + dims=[(256, 128), (128, 256), (512, 512), (2048, 2048)], + sigma_range=[(0.2, 0.8), (0.1, 20)], + ) + def test_spectral_clipping(self, dims, sigma_range): + """Test that spectral clipping properly clips singular values to the specified range.""" + + sigma_min, sigma_max = sigma_range + x = torch.randn(dims, device=self.device, dtype=torch.float32) + + _, original_singular_values, _ = torch.linalg.svd(x, full_matrices=False) + original_min_sv = original_singular_values.min().item() + original_max_sv = original_singular_values.max().item() + + clipped_x = orthogonalized_optimizers.spectral_clip(x, sigma_min=sigma_min, sigma_max=sigma_max) + + _, singular_values, _ = torch.linalg.svd(clipped_x, full_matrices=False) + + min_sv = singular_values.min().item() + max_sv = singular_values.max().item() + + logging.debug(f"Original matrix shape: {x.shape}") + logging.debug(f"Original singular values range: [{original_min_sv:.6f}, {original_max_sv:.6f}]") + logging.debug(f"Clipped singular values range: [{min_sv:.6f}, {max_sv:.6f}]") + logging.debug(f"Target range: [{sigma_min:.6f}, {sigma_max:.6f}]") + logging.debug(f"Shape preservation: input {x.shape} -> output {clipped_x.shape}") + + # use higher tolerance for lower singular values + # typically, this algorithm introduces more error for lower singular values + tolerance_upper = 1e-1 + tolerance_lower = 5e-1 + self.assertGreaterEqual( + min_sv + tolerance_lower, + sigma_min, + ) + self.assertLessEqual( + max_sv - tolerance_upper, + sigma_max, + ) + + self.assertEqual(clipped_x.shape, x.shape) + + @parameterized.product( + dims=[(256, 128), (128, 256), (512, 512), (100, 200)], + beta=[0.5, 1.0, 0.8, 2.0], + ) + def test_spectral_hardcap(self, dims, beta): + """Test that spectral hardcap properly clips singular values from above to be less than beta.""" + x = torch.randn(dims, device=self.device, dtype=torch.float32) + + U_orig, original_singular_values, Vt_orig = torch.linalg.svd(x, full_matrices=False) + original_min_sv = original_singular_values.min().item() + original_max_sv = original_singular_values.max().item() + logging.debug(f"Original matrix shape: {x.shape}") + logging.debug(f"Original singular values range: [{original_min_sv:.6f}, {original_max_sv:.6f}]") + + hardcapped_x = orthogonalized_optimizers.spectral_hardcap(x, beta=beta) + + U_hard, singular_values, Vt_hard = torch.linalg.svd(hardcapped_x, full_matrices=False) + + tolerance_upper = 1e-1 + + max_sv = singular_values.max().item() + + logging.debug(f"Hardcapped max singular value: {max_sv:.6f}") + logging.debug(f"Beta (upper bound): {beta:.6f}") + logging.debug(f"Shape preservation: input {x.shape} -> output {hardcapped_x.shape}") + + self.assertLessEqual( + max_sv - tolerance_upper, + beta, + ) + + self.assertEqual(hardcapped_x.shape, x.shape) + + # Test that singular vectors are preserved (polar factor UV^T should be similar) + polar_orig = U_orig @ Vt_orig + polar_hard = U_hard @ Vt_hard + + # The polar factors should be very similar since hardcap only changes singular values, compute the relative difference + relative_polar_frobenius_diff = torch.norm(polar_orig - polar_hard, "fro") / torch.norm(polar_orig, "fro") + polar_tolerance = 1e-4 + + logging.debug(f"Polar factor Frobenius norm difference: {relative_polar_frobenius_diff:.6f}") + + self.assertLessEqual( + relative_polar_frobenius_diff, + polar_tolerance, + ) + + +if __name__ == "__main__": + absltest.main()