diff --git a/emerging_optimizers/psgd/__init__.py b/emerging_optimizers/psgd/__init__.py new file mode 100644 index 0000000..65bbfc8 --- /dev/null +++ b/emerging_optimizers/psgd/__init__.py @@ -0,0 +1,15 @@ +# SPDX-FileCopyrightText: Copyright (c) 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# SPDX-License-Identifier: Apache-2.0 +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +from emerging_optimizers.psgd.psgd import * diff --git a/emerging_optimizers/psgd/psgd.py b/emerging_optimizers/psgd/psgd.py new file mode 100644 index 0000000..34dd628 --- /dev/null +++ b/emerging_optimizers/psgd/psgd.py @@ -0,0 +1,302 @@ +# SPDX-FileCopyrightText: Copyright (c) 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# SPDX-License-Identifier: Apache-2.0 +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +import math +from typing import Callable, List, Tuple, override + +import torch +from torch.optim.optimizer import ParamsT + +from emerging_optimizers.psgd.procrustes_step import procrustes_step +from emerging_optimizers.psgd.psgd_kron_contractions import apply_preconditioner, partial_contraction +from emerging_optimizers.psgd.psgd_utils import norm_lower_bound_spd, uniformize_q_in_place +from emerging_optimizers.soap.soap import _clip_update_rms_in_place + + +__all__ = [ + "PSGDPro", +] + + +class PSGDPro(torch.optim.Optimizer): + """Implements a variant of the PSGD optimization algorithm (PSGD-Kron-Whiten with Procrustes step for preconditioner update). + + Preconditioned Stochastic Gradient Descent (PSGD) (https://arxiv.org/abs/1512.04202) is a preconditioned optimization algorithm + that fits amplitudes of perturbations of preconditioned stochastic gradient to match that of the perturbations of parameters. + PSGD with Kronecker-factored Preconditioner (PSGD-Kron-Whiten) is a variant of PSGD that reduces memory and computational complexity. + Procrustes step is an algorithm to update the preconditioner which respects a particular geometry: Q^0.5 * E * Q^1.5, see Stochastic Hessian + Fittings with Lie Groups (https://arxiv.org/abs/2402.11858) for more details. + + Args: + params: Iterable of parameters to optimize or dicts defining parameter groups + lr: The learning rate to use + weight_decay: Weight decay coefficient + use_decoupled_weight_decay: Whether to use decoupled weight decay, see Decoupled Weight Decay Regularization: + https://arxiv.org/abs/1711.05101. + momentum: Momentum coefficient for exponential moving average of gradient. + beta_lip: EMA beta for the Lipschitz constants. + precond_lr: Inner learning rate for the preconditioner. + precond_init_scale: scale of initial preconditioner values. + min_precond_lr: Minimum learning rate for preconditioner learning rate schedule. + warmup_steps: Warmup steps for preconditioner learning rate schedule. + damping_noise_scale: scale of dampening noise added to gradients. + max_update_rms: Clip the update RMS to this value (0 means no clipping). + """ + + def __init__( + self, + params: ParamsT, + lr: float = 3e-3, + weight_decay: float = 0.01, + use_decoupled_weight_decay: bool = True, + momentum: float = 0.9, + beta_lip: float = 0.9, + precond_lr: float = 0.1, + precond_init_scale: float = 1.0, + damping_noise_scale: float = 0.1, + min_precond_lr: float = 0.01, + warmup_steps: int = 10000, + max_update_rms: float = 0.0, + ) -> None: + defaults = { + "lr": lr, + "beta_lip": beta_lip, + "weight_decay": weight_decay, + "use_decoupled_weight_decay": use_decoupled_weight_decay, + "momentum": momentum, + "precond_lr": precond_lr, + "precond_init_scale": precond_init_scale, + "max_update_rms": max_update_rms, + "min_precond_lr": min_precond_lr, + "warmup_steps": warmup_steps, + "damping_noise_scale": damping_noise_scale, + } + super().__init__(params, defaults) + + @torch.no_grad() # type: ignore[misc] + @override + def step(self, closure: Callable[[], float] | None = None) -> float | None: + """Performs a single optimization step. + + Args: + closure: A closure that reevaluates the model and returns the loss. + """ + if closure is None: + loss = None + else: + loss = closure() + + for group in self.param_groups: + for p in group["params"]: + if p.grad is None: + continue + grad = p.grad + state = self.state[p] + + # Optimizer state initialization + if "step" not in state: + state["step"] = 0 + # Momentum buffer + if "exp_avg" not in state: + state["exp_avg"] = torch.zeros_like(grad) + # PSGD kronecker factor matrices and Lipschitz constants initialization + if "Q" not in state or "L" not in state: + state["Q"], state["L"] = _init_psgd_kron_states( + grad, + precond_init_scale=group["precond_init_scale"], + ) + + # weight decay + if group["weight_decay"] > 0.0: + if group["use_decoupled_weight_decay"]: + # Apply decoupled weight decay + p.add_(p, alpha=(-group["lr"] * group["weight_decay"])) + else: + # add l2 regularization before preconditioning (i.e. adding a squared loss term) + grad += group["weight_decay"] * p + + # update momentum buffer with EMA of gradient + exp_avg = state["exp_avg"] + exp_avg.lerp_(grad, 1 - group["momentum"]) + + # Get hyperparameters for preconditioner update + damping_noise_scale = group["damping_noise_scale"] + precond_lr = _get_precond_lr( + group["precond_lr"], state["step"], group["min_precond_lr"], group["warmup_steps"] + ) + + beta_lip = group["beta_lip"] + # Preconditioner update + state["Q"], state["L"] = _update_precond_procrustes( + state["Q"], state["L"], exp_avg, damping_noise_scale, precond_lr, beta_lip + ) + uniformize_q_in_place(state["Q"]) + + # Get weight update by preconditioning the momentum + update = apply_preconditioner(state["Q"], exp_avg) + _clip_update_rms_in_place(update, group["max_update_rms"]) + + # Apply weight update + p.add_(update, alpha=-group["lr"]) + + return loss + + +def _init_psgd_kron_states( + grad: torch.Tensor, + precond_init_scale: float = 1.0, +) -> Tuple[List[torch.Tensor], List[torch.Tensor]]: + """Initialize the Kronecker factor matrices and Lipschitz constants. + + Args: + grad: Gradient tensor. + precond_init_scale: Scale of preconditioner initialization. + + Returns: + q_list: List of Kronecker factors. + lip_const_list: List of Lipschitz constants for the Kronecker factors. + """ + q_list: List[torch.Tensor] = [] + lip_const_list: List[torch.Tensor] = [] + + # Create identity matrices scaled by precond_init_scale for each dimension + for size in grad.shape: + q_list.append(torch.eye(size, device=grad.device) * precond_init_scale) + lip_const_list.append(torch.ones((), device=grad.device)) + + return q_list, lip_const_list + + +def _update_precond_procrustes( + q_list: List[torch.Tensor], + lip_const_list: List[torch.Tensor], + exp_avg: torch.Tensor, + damping_noise_scale: float = 1e-9, + precond_lr: float = 0.1, + beta_lip: float = 0.9, +) -> Tuple[List[torch.Tensor], List[torch.Tensor]]: + r"""Update the Kron preconditioner Q using procrustes step and uniformization. + + Args: + q_list: List of Kronecker factors. + lip_const_list: List of Lipschitz constants for the Kronecker factors. + exp_avg: Exponential moving average of gradient. + damping_noise_scale: Scale of noise added to gradient. + precond_lr: Learning rate. + beta_lip: EMA beta for the Lipschitz constant. + + Returns: + q_list: List of Kronecker factors. + lip_const_list: List of Lipschitz constants for the Kronecker factors. + """ + dampened_momentum = exp_avg + (damping_noise_scale + 1e-7 * exp_avg.abs()) * torch.randn_like(exp_avg) + pg = apply_preconditioner(q_list, dampened_momentum) + total_numel = pg.numel() + updated_q_list: List[torch.Tensor] = [] + updated_lip_const_list: List[torch.Tensor] = [] + for dim, q in enumerate(q_list): + # compute gradient covariance + precond_grad_cov = partial_contraction(pg, pg, dim) + if q.dim() < 2: + # diagonal or scalar-structured preconditioner + q, updated_lip_const = _update_1d_preconditioner( + q, lip_const_list[dim], precond_grad_cov, total_numel, precond_lr, beta_lip + ) + else: + # matrix-structured preconditioner + q, updated_lip_const = _update_matrix_preconditioner( + q, lip_const_list[dim], precond_grad_cov, total_numel, precond_lr, beta_lip + ) + updated_q_list.append(q) + updated_lip_const_list.append(updated_lip_const) + + return updated_q_list, updated_lip_const_list + + +def _update_matrix_preconditioner( + q: torch.Tensor, + lip_const: torch.Tensor, + precond_grad_cov: torch.Tensor, + total_numel: int, + precond_lr: float, + beta_lip: float, +) -> Tuple[torch.Tensor, torch.Tensor]: + r"""Update matrix-structured preconditioner with adaptive Lipschitz constant. + + Args: + q: Kronecker factor matrix for this dimension to update. + lip_const: Lipschitz constant for this dimension. + precond_grad_cov: Gradient covariance. + total_numel: Total number of elements in the gradient. + precond_lr: Learning rate. + beta_lip: EMA beta for the Lipschitz constant. + + Returns: + q: Updated Kronecker factor matrix for this dimension. + lip_const: Updated Lipschitz constant for this dimension. + """ + normalization = total_numel / q.shape[0] + ell = norm_lower_bound_spd(precond_grad_cov) + normalization + lip_const = torch.max(beta_lip * lip_const + (1 - beta_lip) * ell, ell) + q = q - precond_lr / lip_const * (precond_grad_cov @ q - normalization * q) + q = procrustes_step(q) + return q, lip_const + + +def _update_1d_preconditioner( + q: torch.Tensor, + lip_const: torch.Tensor, + precond_grad_cov: torch.Tensor, + total_numel: int, + precond_lr: float, + beta_lip: float, +) -> Tuple[torch.Tensor, torch.Tensor]: + r"""Update 1D preconditioner with adaptive Lipschitz constant. + + Args: + q: Kronecker factor 1D tensor for this dimension to update. + lip_const: Lipschitz constant for this dimension. + precond_grad_cov: Gradient covariance. + total_numel: Total number of elements in the gradient. + precond_lr: Learning rate. + beta_lip: EMA beta for the Lipschitz constant. + + Returns: + q: Updated Kronecker factor 1D tensor for this dimension. + lip_const: Updated Lipschitz constant for this dimension. + """ + normalization = total_numel / q.numel() + ell = torch.max(precond_grad_cov) + normalization + lip_const = torch.max(beta_lip * lip_const + (1 - beta_lip) * ell, ell) + q = q * (1 - precond_lr / lip_const * (precond_grad_cov - normalization)) + return q, lip_const + + +def _get_precond_lr(precond_lr: float, step: int, min_precond_lr: float = 0.01, warmup_steps: int = 10000) -> float: + r"""Helper function to get preconditioner learning rate for this optimization step based on a square root schedule. + + Decaying from a higher lr down to min_precond_lr improves accuracy. + + Args: + precond_lr: Learning rate. + step: Current step. + min_precond_lr: Minimum learning rate. + warmup_steps: Warmup steps. + + Returns: + The preconditioner learning rate. + """ + + scheduled_lr = precond_lr / math.sqrt(1.0 + step / warmup_steps) + return max(scheduled_lr, min_precond_lr) diff --git a/emerging_optimizers/psgd/psgd_kron_contractions.py b/emerging_optimizers/psgd/psgd_kron_contractions.py index b223eec..9e0a1f2 100644 --- a/emerging_optimizers/psgd/psgd_kron_contractions.py +++ b/emerging_optimizers/psgd/psgd_kron_contractions.py @@ -24,7 +24,6 @@ ] -@torch.compile # type: ignore[misc] def partial_contraction(G1: torch.Tensor, G2: torch.Tensor, axis: int) -> torch.Tensor: """Compute the partial contraction of G1 and G2 along axis `axis`. This is the contraction of the two tensors, but with all axes except `axis` contracted. @@ -38,10 +37,9 @@ def partial_contraction(G1: torch.Tensor, G2: torch.Tensor, axis: int) -> torch. Tensor of shape (d_{axis}, d_{axis}) """ # dims_to_contract = all dims except `axis` - dims = list(range(G1.dim())) - dims.pop(axis) + dims_to_contract = [i for i in range(G1.dim()) if i != axis] # contraction is symmetric and has shape (d_{axis}, d_{axis}) - return torch.tensordot(G1, G2, dims=(dims, dims)) + return torch.tensordot(G1, G2, dims=(dims_to_contract, dims_to_contract)) @torch.compile # type: ignore[misc] diff --git a/emerging_optimizers/psgd/psgd_utils.py b/emerging_optimizers/psgd/psgd_utils.py index 535270a..831b9c4 100644 --- a/emerging_optimizers/psgd/psgd_utils.py +++ b/emerging_optimizers/psgd/psgd_utils.py @@ -70,7 +70,7 @@ def uniformize_q_in_place(Q_list: List[torch.Tensor]) -> None: @torch.compile # type: ignore[misc] -def norm_lower_bound_spd(A: torch.Tensor, k: int = 4, half_iters: int = 2, eps: float = 1e-8) -> torch.Tensor: +def norm_lower_bound_spd(A: torch.Tensor, k: int = 32, half_iters: int = 2, eps: float = 1e-8) -> torch.Tensor: r"""A cheap lower bound for the spectral norm of a symmetric positive definite matrix. diff --git a/tests/ci/L0_Tests_GPU.sh b/tests/ci/L0_Tests_GPU.sh index 53a64a1..3faf9b1 100644 --- a/tests/ci/L0_Tests_GPU.sh +++ b/tests/ci/L0_Tests_GPU.sh @@ -29,5 +29,6 @@ coverage run -p --source=emerging_optimizers tests/test_normalized_optimizer.py coverage run -p --source=emerging_optimizers tests/normalized_optimizer_convergence_test.py --device=cuda || error=1 coverage run -p --source=emerging_optimizers tests/test_psgd_contractions.py --device=cuda || error=1 coverage run -p --source=emerging_optimizers tests/test_psgd_utils.py --device=cuda || error=1 +coverage run -p --source=emerging_optimizers tests/test_psgd_convergence.py --device=cuda || error=1 exit "${error}" diff --git a/tests/test_psgd_convergence.py b/tests/test_psgd_convergence.py new file mode 100644 index 0000000..fa1e349 --- /dev/null +++ b/tests/test_psgd_convergence.py @@ -0,0 +1,211 @@ +# SPDX-FileCopyrightText: Copyright (c) 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# SPDX-License-Identifier: Apache-2.0 +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +import torch +import torch.nn as nn +import torch.nn.functional as F +from absl import flags +from absl.testing import absltest, parameterized +from torch.utils.data import DataLoader, TensorDataset + +from emerging_optimizers.psgd.psgd import PSGDPro + + +# Define command line flags +flags.DEFINE_string("device", "cpu", "Device to run tests on: 'cpu' or 'cuda'") + +FLAGS = flags.FLAGS + + +class SimpleMLP(nn.Module): + """Simple MLP for testing PSGD convergence.""" + + def __init__(self, input_size=784, hidden_size=128, num_classes=10): + super().__init__() + self.fc1 = nn.Linear(input_size, hidden_size) + self.fc2 = nn.Linear(hidden_size, hidden_size) + self.fc3 = nn.Linear(hidden_size, num_classes) + + self._initialize_weights() + + def _initialize_weights(self): + """Initialize weights for stable training.""" + for m in self.modules(): + if isinstance(m, nn.Linear): + nn.init.xavier_uniform_(m.weight) + if m.bias is not None: + nn.init.constant_(m.bias, 0) + + def forward(self, x): + x = x.view(x.size(0), -1) # Flatten + x = F.relu(self.fc1(x)) + x = F.relu(self.fc2(x)) + x = self.fc3(x) + return x + + +class PSGDConvergenceTest(parameterized.TestCase): + """Convergence tests for PSGD optimizer.""" + + def setUp(self): + """Set random seed and device before each test.""" + # Set seed for PyTorch + torch.manual_seed(1234) + # Set seed for CUDA if available + if torch.cuda.is_available(): + torch.cuda.manual_seed_all(1234) + self.device = FLAGS.device + + def test_quadratic_function_convergence(self): + """Test PSGD convergence on a simple quadratic function: f(x) = (x - target)^2.""" + # Create a parameter to optimize + target = torch.tensor([2.0, -1.5, 3.2], device=self.device) + x = torch.nn.Parameter(torch.zeros(3, device=self.device)) + + # Create PSGD optimizer + optimizer = PSGDPro([x], lr=0.1, precond_lr=0.1, beta_lip=0.9, damping_noise_scale=0.1) + + initial_loss = None + final_loss = None + + # Optimization loop + for _ in range(100): + optimizer.zero_grad() + + # Compute quadratic loss + loss = torch.sum((x - target) ** 2) + loss.backward() + + if initial_loss is None: + initial_loss = loss.item() + + optimizer.step() + final_loss = loss.item() + + # Check convergence + self.assertLess(final_loss, initial_loss, "Loss should decrease during optimization") + self.assertLess(final_loss, 0.01, "Should converge reasonably close to minimum") + + # Check that x is close to target + torch.testing.assert_close(x, target, atol=1e-1, rtol=1e-1) + + def test_matrix_optimization_convergence(self): + """Test PSGD convergence on matrix optimization problem.""" + # Create target matrix and parameter matrix + target = torch.randn(4, 6, device=self.device) + A = torch.nn.Parameter(torch.randn(4, 6, device=self.device)) + per_element_eps = 1e-3 + # Create PSGD optimizer + optimizer = PSGDPro([A], lr=0.05) + + initial_loss = None + final_loss = None + + # Optimization loop + for iteration in range(100): + optimizer.zero_grad() + + # Frobenius norm loss + loss = torch.norm(A - target, p="fro") ** 2 + loss.backward() + + if initial_loss is None: + initial_loss = loss.item() + + optimizer.step() + final_loss = loss.item() + + # Check per-element MSE for convergence + per_element_mse = final_loss / A.numel() + self.assertLess( + per_element_mse, + per_element_eps, + f"Per-element MSE should be < {per_element_eps}, got {per_element_mse:.6f}", + ) + + def _create_synthetic_mnist_data(self, num_samples: int = 1000) -> TensorDataset: + """Create synthetic MNIST-like data for testing.""" + torch.manual_seed(1234) + X = torch.randn(num_samples, 784, device=self.device) + # Create somewhat realistic targets with class distribution + y = torch.randint(0, 10, (num_samples,), device=self.device) + return TensorDataset(X, y) + + def _train_model( + self, model: SimpleMLP, optimizer: torch.optim.Optimizer, num_epochs: int = 5 + ) -> tuple[float, float, float]: + """Train model with given optimizer and return initial loss, final loss, and final accuracy.""" + # Create data + dataset = self._create_synthetic_mnist_data(num_samples=500) + dataloader = DataLoader(dataset, batch_size=64, shuffle=True) + + criterion = nn.CrossEntropyLoss() + + initial_loss = None + final_loss = None + final_accuracy = 0.0 + + model.train() + for _ in range(num_epochs): + epoch_loss = 0.0 + correct = 0 + total = 0 + + for batch_x, batch_y in dataloader: + # Zero gradients + optimizer.zero_grad() + + # Forward pass + outputs = model(batch_x) + loss = criterion(outputs, batch_y) + + # Backward pass + loss.backward() + + # Update parameters + optimizer.step() + + # Track metrics + epoch_loss += loss.item() + _, predicted = torch.max(outputs.data, 1) + total += batch_y.size(0) + correct += (predicted == batch_y).sum().item() + + avg_loss = epoch_loss / len(dataloader) + accuracy = 100 * correct / total + + if initial_loss is None: + initial_loss = avg_loss + final_loss = avg_loss + final_accuracy = accuracy + + return initial_loss, final_loss, final_accuracy + + def test_mnist_convergence(self): + """Test PSGD convergence on a simple neural network classification task.""" + model = SimpleMLP(input_size=784, hidden_size=64, num_classes=10).to(self.device) + + # Create PSGD optimizer + optimizer = PSGDPro(model.parameters(), lr=0.01, weight_decay=0.001) + + # Train model + initial_loss, final_loss, final_accuracy = self._train_model(model, optimizer, num_epochs=10) + + # Check convergence + self.assertLess(final_loss, initial_loss, "Loss should decrease during training") + self.assertGreater(final_accuracy, 80.0, "Accuracy should be better than random (10%)") + + +if __name__ == "__main__": + absltest.main()