diff --git a/emerging_optimizers/psgd/procrustes_step.py b/emerging_optimizers/psgd/procrustes_step.py new file mode 100644 index 0000000..e35b1f1 --- /dev/null +++ b/emerging_optimizers/psgd/procrustes_step.py @@ -0,0 +1,62 @@ +# SPDX-FileCopyrightText: Copyright (c) 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# SPDX-License-Identifier: Apache-2.0 +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +import torch + +import emerging_optimizers.utils as utils +from emerging_optimizers.psgd.psgd_utils import norm_lower_bound_skew + + +__all__ = [ + "procrustes_step", +] + + +@torch.compile # type: ignore[misc] +def procrustes_step(Q: torch.Tensor, max_step_size: float = 0.125, eps: float = 1e-8) -> torch.Tensor: + r"""One step of an online solver for the orthogonal Procrustes problem. + + The orthogonal Procrustes problem is :math:`\min_U \| U Q - I \|_F` s.t. :math:`U^H U = I` + by rotating Q as :math:`\exp(a R) Q`, where :math:`R = Q^H - Q` is the generator and :math:`\|a R\| < 1`. + + `max_step_size` should be less than :math:`1/4` as we only expand :math:`\exp(a R)` to its 2nd order term. + + This method is a second order expansion of a Lie algebra parametrized rotation that + uses a simple approximate line search to find the optimal step size, from Xi-Lin Li. + + Args: + Q: Tensor of shape (n, n), general square matrix to orthogonalize. + max_step_size: Maximum step size for the line search. Default is 1/8. (0.125) + eps: Small number for numerical stability. + """ + # Note: this function is written in fp32 to avoid numerical instability while computing the taylor expansion of the exponential map + with utils.fp32_matmul_precision("highest"): + R = Q.T - Q + R /= torch.clamp(norm_lower_bound_skew(R), min=eps) + RQ = R @ Q + # trace of RQ is always positive, + # since tr(RQ) = ⟨R, Q⟩_F = ⟨Q^T - Q, Q⟩_F = ||Q||_F^2 - ⟨Q, Q⟩_F = ||Q||_F^2 - tr(Q^T Q) ≥ 0 + tr_RQ = torch.trace(RQ) + RRQ = R @ RQ + tr_RRQ = torch.trace(RRQ) + # clip step size to max_step_size, based on a 2nd order expansion. + _step_size = torch.clamp(-tr_RQ / tr_RRQ, min=0, max=max_step_size) + # If tr_RRQ >= 0, the quadratic approximation is not concave, we fallback to max_step_size. + step_size = torch.where(tr_RRQ < 0, _step_size, max_step_size) + # rotate Q as exp(a R) Q ~ (I + a R + a^2 R^2/2) Q with an optimal step size by line search + # for 2nd order expansion, only expand exp(a R) to its 2nd term. + # Q += step_size * (RQ + 0.5 * step_size * RRQ) + Q = torch.add(Q, torch.add(RQ, RRQ, alpha=0.5 * step_size), alpha=step_size) + + return Q diff --git a/emerging_optimizers/psgd/psgd_kron_contractions.py b/emerging_optimizers/psgd/psgd_kron_contractions.py new file mode 100644 index 0000000..b223eec --- /dev/null +++ b/emerging_optimizers/psgd/psgd_kron_contractions.py @@ -0,0 +1,151 @@ +# SPDX-FileCopyrightText: Copyright (c) 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# SPDX-License-Identifier: Apache-2.0 +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +from typing import List + +import torch + + +__all__ = [ + "partial_contraction", + "apply_kronecker_factors", + "apply_preconditioner", +] + + +@torch.compile # type: ignore[misc] +def partial_contraction(G1: torch.Tensor, G2: torch.Tensor, axis: int) -> torch.Tensor: + """Compute the partial contraction of G1 and G2 along axis `axis`. + This is the contraction of the two tensors, but with all axes except `axis` contracted. + + Args: + G1: Tensor of shape (d_0, d_1, ..., d_{axis-1}, d_{axis}, d_{axis+1}, ..., d_N) + G2: Tensor of shape (d_0, d_1, ..., d_{axis-1}, d_{axis}, d_{axis+1}, ..., d_N) + axis: int, the axis to contract along + + Returns: + Tensor of shape (d_{axis}, d_{axis}) + """ + # dims_to_contract = all dims except `axis` + dims = list(range(G1.dim())) + dims.pop(axis) + # contraction is symmetric and has shape (d_{axis}, d_{axis}) + return torch.tensordot(G1, G2, dims=(dims, dims)) + + +@torch.compile # type: ignore[misc] +def apply_kronecker_factors(Q_list: List[torch.Tensor], X: torch.Tensor) -> torch.Tensor: + """Apply all Kronecker factors once to tensor :math:`X`, each to its corresponding dimension. + + This applies each :math:`Q` factor once, for example in 2D case: :math:`Q_1 X Q_2^T`. + + Args: + Q_list: List of :math:`Q` (the upper-triangular Kronecker factors), each of shape `(d_i, d_i)` or `(d_i,)`. + X: Tensor of shape `(d_0, d_1, ..., d_N)`. + + Returns: + Tensor of shape `(d_0, d_1, ..., d_N)`. + """ + if len(Q_list) != X.dim(): + raise ValueError( + f"Number of Kronecker factors {len(Q_list)} must match the number of dimensions of X {X.dim()}" + ) + + Y = X + for i in range(len(Q_list)): + Y = _apply_single_kronecker_factor(Q_list, Y, i) + return Y + + +@torch.compile # type: ignore[misc] +def apply_preconditioner(Q_list: List[torch.Tensor], X: torch.Tensor) -> torch.Tensor: + """Apply the full PSGD preconditioner to X. + + This is the full Kronecker product of PSGD's kronecker factors Q^T Q, applied to X. + + :math:`P X = (Q_1^T Q_1) X (Q_2^T Q_2)` + + This applies each factor followed by its transpose for the full preconditioner effect. + + Args: + Q_list: List of :math:`Q` (the Kronecker factors), each of shape `(d_i, d_i)` or `(d_i,)`. + X: Tensor of shape `(d_0, d_1, ..., d_N)`. + + Returns: + Tensor of shape `(d_0, d_1, ..., d_N)`. + """ + # Apply Q first, then Q.T to get Q^T @ Q + Px = apply_kronecker_factors(Q_list, X) + Px = apply_kronecker_factors([q if q.dim() == 1 else q.T for q in Q_list], Px) + return Px + + +def _dim_n_mul_and_permute(X: torch.Tensor, M: torch.Tensor, contract_dim: int) -> torch.Tensor: + """Multiply tensor X along axis `contract_dim` by 2D matrix M. + + Helper function for `_apply_single_kronecker_factor`. + If M is (d_out, d_in) we contract M’s second index with X’s `contract_dim` index. + `torch.tensordot` is used to contract the two tensors, and then the result is permuted to move the new axis 0 to position `contract_dim`. + Returns a new tensor of the same rank, but with size[contract_dim] replaced by d_out. + Note that d_{contract_dim} == d_in. + + Args: + X: Tensor of shape (d_0, d_1, ..., d_{contract_dim-1}, d_{contract_dim}, d_{contract_dim+1}, ..., d_N) + M: Tensor of shape (d_out, d_in) + contract_dim: int, the dimension to contract with M, with d_{contract_dim} == d_in + + Returns: + Tensor of shape (d_0, d_1, ..., d_{contract_dim-1}, d_out, d_{contract_dim+1}, ..., d_N) + + Examples + -------- + >>> X = torch.randn(2, 3, 6) + >>> M = torch.randn(5, 6) + >>> contract_dim = 2 + >>> result = _dim_n_mul_and_permute(X, M, contract_dim) + >>> print(result.shape) + torch.Size([2, 3, 5]) + + """ + if X.shape[contract_dim] != M.shape[1]: + raise ValueError( + f"Shape mismatch: X.shape[{contract_dim}] = {X.shape[contract_dim]}, M.shape[1] = {M.shape[1]}" + ) + # Contract M's 2nd dim (idx=1) with X's `contract_dim` dim + Y = torch.tensordot(M, X, dims=([1], [contract_dim])) + # Y now has shape (d_out, d_0, …, d_{contract_dim-1}, d_{contract_dim+1}, …). + # We want to move that new axis 0 back to position `contract_dim`, due to `torch.tensordot`. + nd = X.dim() + perm = list(range(1, contract_dim + 1)) + [0] + list(range(contract_dim + 1, nd)) + return Y.permute(perm) + + +@torch.compile # type: ignore[misc] +def _apply_single_kronecker_factor(Q_list: List[torch.Tensor], X: torch.Tensor, axis: int) -> torch.Tensor: + """Apply a single Kronecker factor Q to X at dimension `axis`. Helper function for apply_kronecker_factors. + + If Q is a vector, we multiply X by Q. + If Q is a matrix, we contract Q's second index with X's `axis` index. + + Args: + Q_list: List of Q (e.g. the Kronecker factors). + X: Tensor of shape (d_0, d_1, ..., d_{axis-1}, d_{axis+1}, ..., d_N) + """ + Q = Q_list[axis] + if Q.dim() == 1: + shape = [1] * X.dim() + shape[axis] = Q.size(0) + return X * Q.view(shape) + + return _dim_n_mul_and_permute(X, Q, contract_dim=axis) diff --git a/emerging_optimizers/psgd/psgd_utils.py b/emerging_optimizers/psgd/psgd_utils.py new file mode 100644 index 0000000..535270a --- /dev/null +++ b/emerging_optimizers/psgd/psgd_utils.py @@ -0,0 +1,176 @@ +# SPDX-FileCopyrightText: Copyright (c) 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# SPDX-License-Identifier: Apache-2.0 +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +from typing import List + +import torch + + +__all__ = [ + "uniformize_q_in_place", + "norm_lower_bound_spd", + "norm_lower_bound_skew", +] + + +@torch.compile # type: ignore[misc] +def uniformize_q_in_place(Q_list: List[torch.Tensor]) -> None: + """Balance the dynamic ranges of kronecker factors in place to prevent numerical underflow or overflow. + + Each tensor in `Q_list` is rescaled so that its maximum absolute entry + becomes the geometric mean of all factors original maxima. This preserves + the overall product of norms (and thus the scale of the Kronecker product) + while avoiding numerical underflow or overflow when factors have widely + differing magnitudes. + + Given tensors :math:`Q_1, Q_2, \\ldots, Q_n`: + + 1. Compute max-absolute norms: :math:`\\|Q_i\\|_\\infty = \\max(|Q_i|)` for :math:`i = 1, \\ldots, n` + 2. Compute geometric mean: :math:`g = \\left(\\prod_{i=1}^{n} \\|Q_i\\|_\\infty \\right)^{1/n}` + 3. Rescale each tensor: :math:`Q_i \\leftarrow Q_i \\cdot \\frac{g}{\\|Q_i\\|_\\infty}` + + This ensures :math:`\\|Q_i\\|_\\infty = g` for all :math:`i`, while preserving the norm of + the Kronecker product :math:`Q_1 \\otimes Q_2 \\otimes \\cdots \\otimes Q_n`. + + Args: + Q_list: List of Q (e.g. the Kronecker factors), each tensor will be modified in place. + + Returns: + None + + """ + if not Q_list: + raise TypeError("Q_list cannot be empty.") + + order = len(Q_list) + if order == 1: + # with a single factor, no balancing is needed + return + + # Compute max-abs norm of each factor + norms = [torch.max(torch.abs(Q)) for Q in Q_list] + + # Compute geometric mean of those norms + gmean = torch.prod(torch.stack(norms)) ** (1.0 / order) + + # Rescale each factor so its max‐abs entry == geometric mean + for Q, norm in zip(Q_list, norms, strict=True): + Q.mul_(gmean / norm) + + +@torch.compile # type: ignore[misc] +def norm_lower_bound_spd(A: torch.Tensor, k: int = 4, half_iters: int = 2, eps: float = 1e-8) -> torch.Tensor: + r"""A cheap lower bound for the spectral norm of a symmetric positive definite matrix. + + + Args: + A: Tensor of shape :math:`(n, n)`, symmetric positive definite. + k: Dimension of the subspace. + half_iters: Half of the number of subspace iterations. + eps: Small number for numerical stability. + + Returns: + A scalar giving a lower bound on :math:`\\|A\\|_2`. + """ + + # Compute scaling factor from the largest diagonal entry to prevent overflow/underflow + scale = torch.clamp(A.diagonal().amax(), min=eps) + A = A / scale + + bound_unnormalized = _subspace_iteration_bound(A, k=k, half_iters=half_iters, eps=eps) + + return scale * bound_unnormalized + + +@torch.compile # type: ignore[misc] +def norm_lower_bound_skew(A: torch.Tensor, k: int = 32, half_iters: int = 2, eps: float = 1e-8) -> torch.Tensor: + """A cheap lower bound on the spectral norm (largest eigenvalue) of skew-symmetric matrix. + + + Note: For skew-symmetric matrices, all diagonal entries are zero and :math:`A^T = -A`. + From Xi-Lin Li. + + Args: + A: Tensor of shape :math:`(n, n)`, skew-symmetric. + k: Dimension of the subspace. Suggested values: 128 for bfloat16, 32 for float32, 4 for float64. + half_iters: Half of the number of subspace iterations. + eps: Small number for numerical stability. + + Returns: + A scalar Tensor giving a lower bound on :math:`\\|A\\|_2`. + + """ + + # Compute scaling factor from the max absolute value to prevent overflow/underflow + scale = torch.clamp(A.abs().amax(), min=eps) + A = A / scale + + bound_unnormalized = _subspace_iteration_bound(A, k=k, half_iters=half_iters, eps=eps) + + return scale * bound_unnormalized + + +@torch.compile # type: ignore[misc] +def _subspace_iteration_bound( + A: torch.Tensor, + k: int = 32, + half_iters: int = 2, + eps: float = 1e-8, +) -> torch.Tensor: + """A helper function for subspace iteration to estimate spectral norm bounds. + + Uses numerically stable subspace iteration with a random initialization that aligns with the + largest row of A to approximate the dominant eigenspace. This is more robust than simple + power iteration, especially for large matrices with very low rank. From Xi-Lin Li. + + The algorithm: + 1. Normalize :math:`A` by its largest absolute entry to avoid overflow. + 2. Find the row :math:`j` of :math:`A_{\\text{scaled}}` with the largest 2-norm. + 3. Initialize a :math:`k \\times n` subspace matrix :math:`V` with random vectors aligned to :math:`A[j]`. + 4. Perform subspace iteration for `half_iters` steps: :math:`V \\leftarrow V \\cdot A_{\\text{scaled}}`. + 5. Estimate the norm as the maximum 2-norm among the k vectors, then rescale. + + Args: + A: Input matrix, already normalized by caller. + k: Dimension of the subspace (number of random vectors). + half_iters: Number of half-iterations (each applies A twice). + eps: Smallest number for numerical stability. + + Returns: + Maximum vector norm from the final subspace iteration (unnormalized). + """ + + # Initialize random subspace matrix V of shape (k, n) + V = torch.randn(k, A.shape[1], dtype=A.dtype, device=A.device) + + # Find the row index with the largest 2-norm to initialize our subspace + # This helps the algorithm converge faster to the dominant eigenspace + dominant_row_idx = torch.argmax(torch.linalg.vector_norm(A, dim=1)) + # Rotate the random vectors to align with the dominant row A[dominant_row_idx] + # This initialization trick makes the subspace iteration more robust for low-rank matrices + dominant_row = A[dominant_row_idx] + alignment = torch.sign(torch.sum(dominant_row * V, dim=1, keepdim=True)) + + V = dominant_row + alignment * V + + # Perform subspace iteration + for _ in range(half_iters): + V = V @ A + # Normalize each row of V to prevent exponential growth/decay + V /= torch.linalg.vector_norm(V, dim=1, keepdim=True) + eps + # Apply A again (V approximates the dominant eigenspace of A^2) + V = V @ A + + # Return the maximum 2-norm among the k vectors + return torch.amax(torch.linalg.vector_norm(V, dim=1)) diff --git a/tests/ci/L0_Tests_CPU.sh b/tests/ci/L0_Tests_CPU.sh index 594773f..c4dde44 100644 --- a/tests/ci/L0_Tests_CPU.sh +++ b/tests/ci/L0_Tests_CPU.sh @@ -16,4 +16,4 @@ set -o pipefail torchrun --nproc_per_node=8 --no-python coverage run -p tests/test_distributed_muon_utils_cpu.py torchrun --nproc_per_node=4 --no-python coverage run -p tests/test_distributed_muon_utils_cpu.py coverage run -p --source=emerging_optimizers tests/test_scalar_optimizers.py --device=cpu - +coverage run -p --source=emerging_optimizers tests/test_procrustes_step.py --device=cpu diff --git a/tests/ci/L0_Tests_GPU.sh b/tests/ci/L0_Tests_GPU.sh index 785eddd..0d06ccd 100644 --- a/tests/ci/L0_Tests_GPU.sh +++ b/tests/ci/L0_Tests_GPU.sh @@ -23,4 +23,6 @@ coverage run -p --source=emerging_optimizers tests/soap_mnist_test.py coverage run -p --source=emerging_optimizers tests/test_scalar_optimizers.py --device=cuda coverage run -p --source=emerging_optimizers tests/test_spectral_clipping_utils.py coverage run -p --source=emerging_optimizers tests/test_normalized_optimizer.py --device=cuda -coverage run -p --source=emerging_optimizers tests/normalized_optimizer_convergence_test.py --device=cuda \ No newline at end of file +coverage run -p --source=emerging_optimizers tests/normalized_optimizer_convergence_test.py --device=cuda +coverage run -p --source=emerging_optimizers tests/test_psgd_contractions.py --device=cuda +coverage run -p --source=emerging_optimizers tests/test_psgd_utils.py --device=cuda \ No newline at end of file diff --git a/tests/test_procrustes_step.py b/tests/test_procrustes_step.py new file mode 100644 index 0000000..8593c11 --- /dev/null +++ b/tests/test_procrustes_step.py @@ -0,0 +1,161 @@ +# SPDX-FileCopyrightText: Copyright (c) 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# SPDX-License-Identifier: Apache-2.0 +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +import math + +import torch +from absl import flags, testing +from absl.testing import parameterized + +from emerging_optimizers.psgd.procrustes_step import procrustes_step +from emerging_optimizers.utils import fp32_matmul_precision + + +# Define command line flags +flags.DEFINE_string("device", "cpu", "Device to run tests on: 'cpu' or 'cuda'") + +FLAGS = flags.FLAGS + + +class ProcrustesStepTest(parameterized.TestCase): + """Test cases for procrustes_step function.""" + + def setUp(self) -> None: + """Set up test fixtures.""" + self.device = FLAGS.device + + def _procrustes_objective(self, Q: torch.Tensor) -> torch.Tensor: + """Helper function to compute Procrustes objective ||Q^H Q - I||_F^2.""" + return torch.linalg.matrix_norm(Q.H @ Q - torch.eye(Q.size(0), dtype=Q.dtype, device=Q.device), ord="fro") ** 2 + + def test_improves_orthogonality_simple_case(self) -> None: + """Test that procrustes_step doesn't worsen orthogonality for a simple case.""" + + # Make a SPD non-orthogonal matrix + Q = torch.randn(2, 2, device=self.device, dtype=torch.float32) + Q = Q @ Q.T + + initial_obj = self._procrustes_objective(Q) + + procrustes_step(Q, max_step_size=1 / 16) + + final_obj = self._procrustes_objective(Q) + + self.assertLessEqual(final_obj.item(), initial_obj.item() + 1e-6) + + @parameterized.parameters( + (8,), + (128,), + (1024,), + ) + def test_minimal_change_when_already_orthogonal(self, size: int) -> None: + """Test that procrustes_step makes minimal changes to an already orthogonal matrix.""" + # Create an orthogonal matrix using QR decomposition + A = torch.randn(size, size, device=self.device, dtype=torch.float32) + with fp32_matmul_precision("highest"): + Q, _ = torch.linalg.qr(A) + + initial_obj = self._procrustes_objective(Q) + + Q = procrustes_step(Q, max_step_size=1 / 16) + + final_obj = self._procrustes_objective(Q) + + # For already orthogonal matrices, the objective should remain small + self.assertLess(final_obj.item(), 1e-5) + self.assertLess(final_obj.item(), initial_obj.item() + 1e-5) + + def test_handles_small_norm_gracefully(self) -> None: + """Test that procrustes_step handles matrices with small R norm improvement.""" + # Create a matrix very close to orthogonal + A = torch.randn(3, 3, device=self.device, dtype=torch.float32) + with fp32_matmul_precision("highest"): + Q, _ = torch.linalg.qr(A) + # Add tiny perturbation + Q += 1e-10 * torch.randn_like(Q, dtype=torch.float32) + + initial_obj = self._procrustes_objective(Q) + + Q = procrustes_step(Q, max_step_size=0.0625) + + final_obj = self._procrustes_objective(Q) + + self.assertLess(final_obj.item(), 1e-6) + self.assertLess(final_obj.item(), initial_obj.item() + 1e-6) + + @parameterized.parameters( + (0.015625,), + (0.03125,), + (0.0625,), + (0.125,), + ) + def test_different_step_sizes_reduces_objective(self, max_step_size: float) -> None: + """Test procrustes_step improvement with different step sizes.""" + perturbation = 1e-1 * torch.randn(10, 10, device=self.device, dtype=torch.float32) / math.sqrt(10) + Q = torch.linalg.qr(torch.randn(10, 10, device=self.device, dtype=torch.float32)).Q + perturbation + initial_obj = self._procrustes_objective(Q) + + Q = procrustes_step(Q, max_step_size=max_step_size) + + final_obj = self._procrustes_objective(Q) + + self.assertLessEqual(final_obj.item(), initial_obj.item() + 1e-4) + + @parameterized.parameters( + (8,), + (64,), + (512,), + (8192,), + ) + def test_different_matrix_sizes_reduces_objective(self, size: int) -> None: + """Test procrustes_step improvement with different matrix sizes.""" + # Create a non-orthogonal matrix by scaling an orthogonal one + A = torch.randn(size, size, device=self.device, dtype=torch.float32) + with fp32_matmul_precision("highest"): + Q_orth, _ = torch.linalg.qr(A) + # Add perturbation, we choose 1e-2 to be small enough to not affect the objective too much + # but large enough to make the matrix non-orthogonal. + Q = Q_orth + 1e-2 * torch.randn(size, size, device=self.device, dtype=torch.float32) / math.sqrt(size) + max_step_size = 0.5 * size ** (-1 / 3) + initial_obj = self._procrustes_objective(Q) + + Q = procrustes_step(Q, max_step_size=max_step_size) + + final_obj = self._procrustes_objective(Q) + + self.assertLessEqual(final_obj.item(), initial_obj.item() + 1e-3) + + def test_preserves_determinant_sign_for_real_matrices(self) -> None: + """Test that procrustes_step preserves the sign of determinant for real matrices.""" + # Create real matrices with positive and negative determinants + Q_pos = torch.tensor([[2.0, 0.1], [0.1, 1.5]], device=self.device, dtype=torch.float32) # det > 0 + Q_neg = torch.tensor([[-2.0, 0.1], [0.1, 1.5]], device=self.device, dtype=torch.float32) # det < 0 + + initial_det_pos = torch.det(Q_pos) + initial_det_neg = torch.det(Q_neg) + + Q_pos = procrustes_step(Q_pos, max_step_size=0.0625) + Q_neg = procrustes_step(Q_neg, max_step_size=0.0625) + + final_det_pos = torch.det(Q_pos) + final_det_neg = torch.det(Q_neg) + + # Signs should be preserved + self.assertGreater(initial_det_pos.item() * final_det_pos.item(), 0) + self.assertGreater(initial_det_neg.item() * final_det_neg.item(), 0) + + +if __name__ == "__main__": + torch.manual_seed(42) + testing.absltest.main() diff --git a/tests/test_psgd_contractions.py b/tests/test_psgd_contractions.py new file mode 100644 index 0000000..96e2744 --- /dev/null +++ b/tests/test_psgd_contractions.py @@ -0,0 +1,138 @@ +# SPDX-FileCopyrightText: Copyright (c) 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# SPDX-License-Identifier: Apache-2.0 +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +import torch +from absl import flags, testing +from absl.testing import parameterized + +from emerging_optimizers.psgd.psgd_kron_contractions import ( + _dim_n_mul_and_permute, + apply_kronecker_factors, + apply_preconditioner, + partial_contraction, +) +from emerging_optimizers.utils import fp32_matmul_precision + + +# Define command line flags +flags.DEFINE_string("device", "cpu", "Device to run tests on: 'cpu' or 'cuda'") + +FLAGS = flags.FLAGS + + +class TestPSGDKronContractions(parameterized.TestCase): + """Test cases for PSGD Kronecker contractions.""" + + def setUp(self) -> None: + """Set up test fixtures.""" + self.device = FLAGS.device + + @parameterized.parameters( + (2, 3, 3), + (2, 3, 4), + (2, 3, 5), + ) + def test_partial_contraction_matches_reconstructed(self, size1: int, size2: int, size3: int) -> None: + """Test partial_contraction matches reconstructed.""" + G1 = torch.randn(size1, size2, size3, device=self.device) + G2 = torch.randn(size1, size2, size3, device=self.device) + with fp32_matmul_precision("highest"): + result = partial_contraction(G1, G2, axis=1) + reconstructed = torch.tensordot(G1, G2, dims=([0, 2], [0, 2])) + torch.testing.assert_close(result, reconstructed) + + def test_apply_kronecker_factors_matches_reconstructed(self) -> None: + """Test apply_kronecker_factors matches reconstructed.""" + Q_list = [ + torch.triu(torch.randn(2, 2, device=self.device)), + torch.triu(torch.randn(3, 3, device=self.device)), + torch.triu(torch.randn(3, 3, device=self.device)), + ] + X = torch.randn(2, 3, 3, device=self.device) + with fp32_matmul_precision("highest"): + result = apply_kronecker_factors(Q_list, X) + Y = X + + temp = torch.tensordot(Q_list[0], Y, dims=([1], [0])) + nd = Y.dim() + perm = list(range(1, 0 + 1)) + [0] + list(range(0 + 1, nd)) # [1, 0, 2] + Y = temp.permute(perm) + + temp = torch.tensordot(Q_list[1], Y, dims=([1], [1])) + nd = Y.dim() + perm = list(range(1, 1 + 1)) + [0] + list(range(1 + 1, nd)) # [1, 0, 2] + Y = temp.permute(perm) + + temp = torch.tensordot(Q_list[2], Y, dims=([1], [2])) + nd = Y.dim() + perm = list(range(1, 2 + 1)) + [0] + list(range(2 + 1, nd)) # [1, 2, 0] + reconstructed = temp.permute(perm) + + torch.testing.assert_close(result, reconstructed) + + def test_apply_preconditioner_matches_reconstructed(self) -> None: + """Test apply_preconditioner matches manual reconstruction for 2D tensor.""" + Q_list = [torch.triu(torch.randn(3, 3, device=self.device)), torch.triu(torch.randn(4, 4, device=self.device))] + X = torch.randn(3, 4, device=self.device) + + with fp32_matmul_precision("highest"): + result = apply_preconditioner(Q_list, X) + + # Manual reconstruction: precompute Q^T @ Q matrices then apply them + # Create full preconditioner matrices Q^T @ Q + QTQ_list = [q.T @ q for q in Q_list] + + # Apply the preconditioner matrices + Y = X + + # Apply QTQ_list[0] to dimension 0 + temp = torch.tensordot(QTQ_list[0], Y, dims=([1], [0])) + nd = Y.dim() + perm = list(range(1, 0 + 1)) + [0] + list(range(0 + 1, nd)) + Y = temp.permute(perm) + + # Apply QTQ_list[1] to dimension 1 + temp = torch.tensordot(QTQ_list[1], Y, dims=([1], [1])) + nd = Y.dim() + perm = list(range(1, 1 + 1)) + [0] + list(range(1 + 1, nd)) + reconstructed = temp.permute(perm) + + torch.testing.assert_close(result, reconstructed) + + @parameterized.parameters( + (2, 3, 5, 0), + (2, 3, 5, 1), + (2, 3, 5, 2), + (4, 6, 2, 1), + ) + def test_dim_n_mul_and_permute__matches_shapes(self, dim0: int, dim1: int, dim2: int, contract_dim: int) -> None: + """Test `_dim_n_mul_and_permute` with non-uniform shapes and different contract_dim.""" + X = torch.randn(dim0, dim1, dim2, device=self.device) + input_shape = X.shape + + input_dim = input_shape[contract_dim] + output_dim = 7 # arbitrary output dimension + M = torch.randn(output_dim, input_dim, device=self.device) + + result = _dim_n_mul_and_permute(X, M, contract_dim) + + # Verify output shape: same as input but dimension `contract_dim` replaced by output_dim + expected_shape = list(input_shape) + expected_shape[contract_dim] = output_dim + self.assertEqual(result.shape, torch.Size(expected_shape)) + + +if __name__ == "__main__": + torch.manual_seed(42) + testing.absltest.main() diff --git a/tests/test_psgd_utils.py b/tests/test_psgd_utils.py new file mode 100644 index 0000000..26413b0 --- /dev/null +++ b/tests/test_psgd_utils.py @@ -0,0 +1,249 @@ +# SPDX-FileCopyrightText: Copyright (c) 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# SPDX-License-Identifier: Apache-2.0 +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +import torch +from absl import flags, testing +from absl.testing import parameterized + +from emerging_optimizers.psgd.psgd_utils import ( + norm_lower_bound_skew, + norm_lower_bound_spd, + uniformize_q_in_place, +) + + +# Define command line flags +flags.DEFINE_string("device", "cpu", "Device to run tests on: 'cpu' or 'cuda'") + +FLAGS = flags.FLAGS + + +class BalanceQTest(parameterized.TestCase): + """Test cases for uniformize_q_in_place function.""" + + def setUp(self) -> None: + """Set up test fixtures.""" + self.device = FLAGS.device + + def test_normalization_on_empty_list(self) -> None: + """Test uniformize_q_in_place with empty list.""" + Q_list = [] + with self.assertRaises(TypeError): + uniformize_q_in_place(Q_list) + + def test_normalization_on_single_tensor(self) -> None: + """Test uniformize_q_in_place with single tensor.""" + Q = torch.randn(3, 3, device=self.device) + original_Q = Q.clone() + uniformize_q_in_place([Q]) + # for a single tensor, the result should be the same as the original + torch.testing.assert_close(Q, original_Q) + + def test_normalization_on_two_tensors(self) -> None: + """Test uniformize_q_in_place with two tensors.""" + Q1 = torch.tensor([[1.0, 2.0], [3.0, 4.0]], device=self.device) + Q2 = torch.tensor([[0.1, 0.2], [0.3, 0.4]], device=self.device) + + orig_max1 = torch.max(torch.abs(Q1)) + orig_max2 = torch.max(torch.abs(Q2)) + + uniformize_q_in_place([Q1, Q2]) + + new_max1 = torch.max(torch.abs(Q1)) + new_max2 = torch.max(torch.abs(Q2)) + + # Should be equal to geometric mean of original maxima + expected_max = (orig_max1 * orig_max2) ** 0.5 + self.assertAlmostEqual(new_max1.item(), expected_max.item(), places=3) + self.assertAlmostEqual(new_max2.item(), expected_max.item(), places=3) + + @parameterized.parameters( + (32, 32, 32), + (256, 256, 256), + (4096, 4096, 4096), + ) + def test_normalization_on_three_tensors(self, size1: int, size2: int, size3: int) -> None: + """Test uniformize_q_in_place with multiple tensors of different dynamic ranges.""" + Q1 = torch.randn(size1, size1, device=self.device) * 10.0 + Q2 = torch.randn(size2, size2, device=self.device) * 0.01 + Q3 = torch.randn(size3, size3, device=self.device) * 1.0 + + orig_max1 = torch.max(torch.abs(Q1)) + orig_max2 = torch.max(torch.abs(Q2)) + orig_max3 = torch.max(torch.abs(Q3)) + + uniformize_q_in_place([Q1, Q2, Q3]) + + # All tensors should have the same max absolute value + new_max1 = torch.max(torch.abs(Q1)) + new_max2 = torch.max(torch.abs(Q2)) + new_max3 = torch.max(torch.abs(Q3)) + + # Should be equal to geometric mean + expected_max = (orig_max1 * orig_max2 * orig_max3) ** (1.0 / 3.0) + self.assertAlmostEqual(new_max1.item(), expected_max.item(), places=3) + self.assertAlmostEqual(new_max2.item(), expected_max.item(), places=3) + self.assertAlmostEqual(new_max3.item(), expected_max.item(), places=3) + + def test_modifies_in_place_on_three_tensors(self) -> None: + """Test that uniformize_q_in_place modifies tensors in place.""" + Q = torch.randn(3, 3, device=self.device) + original_id = id(Q) + uniformize_q_in_place([Q, torch.randn(2, 2, device=self.device)]) + + # Should be the same object (modified in place) + self.assertEqual(id(Q), original_id) + + +class NormLowerBoundSpdTest(parameterized.TestCase): + """Test cases for norm_lower_bound_spd function.""" + + def setUp(self) -> None: + """Set up test fixtures.""" + self.device = FLAGS.device + + def test_diagonal_matrix(self) -> None: + """Test norm_lower_bound_spd with diagonal matrix.""" + # For diagonal matrix, spectral norm equals largest diagonal entry + diag_values = torch.tensor([1.0, 3.0, 2.0], device=self.device) + A = torch.diag(diag_values) + + bound = norm_lower_bound_spd(A) + actual_norm = torch.max(diag_values) + + # Bound should be <= actual norm + self.assertLessEqual(bound.item(), actual_norm.item() + 1e-4) + # For diagonal matrix, bound should be reasonably tight + self.assertGreater(bound.item(), 0.5 * actual_norm.item()) + + def test_identity_matrix(self) -> None: + """Test norm_lower_bound_spd with identity matrix.""" + A = torch.eye(3, device=self.device) + bound = norm_lower_bound_spd(A) + + # For identity matrix, spectral norm is 1 + self.assertAlmostEqual(bound.item(), 1.0, places=3) + + def test_zero_matrix(self) -> None: + """Test norm_lower_bound_spd with zero matrix.""" + A = torch.zeros(3, 3, device=self.device) + bound = norm_lower_bound_spd(A) + + # For zero matrix, bound should be 0 + self.assertAlmostEqual(bound.item(), 0.0, places=3) + + @parameterized.product( + dtype=[torch.float32, torch.bfloat16], + size=[32, 256, 4096], + ) + def test_norm_lower_bound_spd_is_lower_bound(self, dtype: torch.dtype, size: int) -> None: + """Test that norm_lower_bound_spd provides a valid lower bound.""" + # Create a random SPD matrix + B = torch.randn(size, size, dtype=dtype, device=self.device) + A = B @ B.T + 1e-3 * torch.eye( + size, dtype=dtype, device=self.device + ) # Ensure positive definite and well-conditioned + + bound = norm_lower_bound_spd(A) + # Spectral norm (largest singular value) + # Pytorch's matrix norm does not support bfloat16, so we convert to float32 + actual_norm = torch.linalg.matrix_norm(A.to(torch.float32), ord=2) + + # Bound should be <= actual norm + self.assertLessEqual(bound.item(), actual_norm.item() + 1e-4) + # Bound should be positive for positive definite matrix + self.assertGreater(bound.item(), 0.0) + + +class NormLowerBoundSkewTest(parameterized.TestCase): + """Test cases for norm_lower_bound_skew function.""" + + def setUp(self) -> None: + """Set up test fixtures.""" + self.device = FLAGS.device + + def test_zero_matrix(self) -> None: + """Test norm_lower_bound_skew with zero matrix.""" + A = torch.zeros(3, 3, device=self.device) + bound = norm_lower_bound_skew(A) + + # For zero matrix, bound should be 0 + self.assertAlmostEqual(bound.item(), 0.0, places=3) + + def test_small_skew_symmetric_matrix(self) -> None: + """Test norm_lower_bound_skew with a simple skew-symmetric matrix.""" + # Create a simple 3x3 skew-symmetric matrix + A = torch.tensor([[0.0, 1.0, -2.0], [-1.0, 0.0, 3.0], [2.0, -3.0, 0.0]], device=self.device) + + bound = norm_lower_bound_skew(A) + # Compute actual spectral norm + actual_norm = torch.linalg.matrix_norm(A, ord=2) + + # Bound should be <= actual norm + self.assertLessEqual(bound.item(), actual_norm.item() + 1e-3) + # Bound should be positive for non-zero matrix + self.assertGreater(bound.item(), 0.0) + + @parameterized.parameters([4, 16, 32]) + def test_random_based_skew_matrix_with_different_sizes(self, size: int) -> None: + """Test norm_lower_bound_skew with random skew-symmetric matrix.""" + # Create skew-symmetric matrix from anti-symmetric part of random matrix + n = size + B = torch.randn(n, n, device=self.device) + A = B - B.T # This creates a skew-symmetric matrix + + bound = norm_lower_bound_skew(A) + actual_norm = torch.linalg.matrix_norm(A, ord=2) + + # Bound should be <= actual norm + self.assertLessEqual(bound.item(), actual_norm.item() + 1e-3) + + @parameterized.product( + dtype=[torch.float32, torch.float64], + size=[32, 128, 256], + ) + def test_norm_lower_bound_skew_is_lower_bound(self, dtype: torch.dtype, size: int) -> None: + """Test that norm_lower_bound_skew provides a valid lower bound.""" + # Create a random skew-symmetric matrix + B = torch.randn(size, size, dtype=dtype, device=self.device) + A = B - B.T # Ensure skew-symmetric property: A^T = -A + + bound = norm_lower_bound_skew(A) + # Compute actual spectral norm + actual_norm = torch.linalg.matrix_norm(A.to(torch.float32), ord=2) + + # Bound should be <= actual norm (with small tolerance for numerical errors) + self.assertLessEqual(bound.item(), actual_norm.item() + 1e-4) + + # Bound should be non-negative + self.assertGreaterEqual(bound.item(), 0.0) + + @parameterized.parameters([4, 16, 32, 64]) + def test_different_subspace_dimensions(self, rank: int) -> None: + """Test norm_lower_bound_skew with different subspace dimensions.""" + # Create a skew-symmetric matrix + B = torch.randn(64, 64, device=self.device) + A = B - B.T + + bound = norm_lower_bound_skew(A, k=rank, half_iters=2) + + self.assertGreaterEqual(bound.item(), 0.0) + + actual_norm = torch.linalg.matrix_norm(A, ord=2) + self.assertLessEqual(bound.item(), actual_norm.item() + 1e-5) + + +if __name__ == "__main__": + torch.manual_seed(42) + testing.absltest.main()