diff --git a/emerging_optimizers/riemannian_optimizers/normalized_optimizer.py b/emerging_optimizers/riemannian_optimizers/normalized_optimizer.py new file mode 100644 index 0000000..e2667ba --- /dev/null +++ b/emerging_optimizers/riemannian_optimizers/normalized_optimizer.py @@ -0,0 +1,242 @@ +# SPDX-FileCopyrightText: Copyright (c) 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# SPDX-License-Identifier: Apache-2.0 +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +from typing import Callable + +import torch +from torch.optim.optimizer import Optimizer + + +class ObliqueSGD(Optimizer): + """SGD optimizer for row- or column-normalized 2D parameters on oblique manifolds. + + This optimizer performs SGD on oblique manifolds, where parameters are constrained + to have unit-norm rows or columns. It implements Riemannian SGD with manifold-aware + gradient updates and retraction operations. + + References: + - An Introduction to Optimization on Smooth Manifolds (Nicolas Boumal) + - EDM2: https://arxiv.org/abs/2312.02696 + - Jianlin Su: https://kexue.fm/archives/11196 + - Raman et al.: https://arxiv.org/abs/1909.06463 + - Franz Cesista: https://leloykun.github.io/ponder/steepest-descent-stiefel/#6-bonus-a-muon-like-optimizer-for-the-embedding-and-unembedding-layers + + Args: + lr: learning rate + momentum: momentum coefficient + weight_decay: weight decay coefficient + dim: The dimension to normalize over + eps: epsilon for numerical stability + """ + + def __init__( + self, + params: list[torch.nn.Parameter], + lr: float = 1e-3, + momentum: float = 0.9, + weight_decay: float = 0.0, + dim: int = 0, + eps: float = 1e-8, + ) -> None: + if lr < 0.0: + raise ValueError(f"Invalid learning rate: {lr}") + if momentum < 0.0 or momentum >= 1.0: + raise ValueError(f"Invalid momentum value: {momentum}") + if weight_decay < 0.0: + raise ValueError(f"Invalid weight_decay value: {weight_decay}") + + defaults = dict( + lr=lr, + momentum=momentum, + weight_decay=weight_decay, + dim=dim, + eps=eps, + ) + super().__init__(params, defaults) + + @torch.no_grad() # type: ignore[misc] + def step(self, closure: Callable[[], float] | None = None) -> float | None: + """Performs a single optimization step. + Args: + closure (callable, optional): A closure that reevaluates the model + and returns the loss. + """ + loss = closure() if closure is not None else None + + for group in self.param_groups: + lr = group["lr"] + mom = group["momentum"] + wd = group["weight_decay"] + dim = group["dim"] + eps = group["eps"] + + for param in group["params"]: + if param.grad is None: + continue + if param.ndim != 2: + raise ValueError("ObliqueSGD only supports 2D parameters") + grad = param.grad + + # Initialize momentum buffer if needed + state = self.state[param] + if "momentum_buffer" not in state: + state["momentum_buffer"] = torch.zeros_like(param) + + buf = state["momentum_buffer"] + + # theory style momentum + buf = torch.add(grad, buf, alpha=mom) + + # Apply Riemannian gradient update + _compute_riemannian_grad_and_update(param, buf, dim, lr, wd) + + # Retraction back to the manifold, the hyper-sphere + torch.nn.functional.normalize(param, p=2.0, dim=dim, eps=eps, out=param) + + return loss + + +class ObliqueAdam(Optimizer): + """Adam optimizer for row- or column-normalized 2D parameters on oblique manifolds. + + This optimizer adapts an Adam-like algorithm to work on oblique manifolds, where + parameters are constrained to have unit-norm rows or columns. It combines + adaptive momentum estimation with Riemannian gradient computation and manifold retraction. + """ + + def __init__( + self, + params: list[torch.nn.Parameter], + lr: float = 1e-3, + betas: tuple[float, float] = (0.9, 0.99), + weight_decay: float = 0.0, + dim: int = 0, + eps: float = 1e-8, + correct_bias: bool = True, + ) -> None: + """An Adam-like optimizer for Normalized 2d Parameters + + Args: + lr: The learning rate. + betas: The coefficients used for computing running averages of gradient and its square. + weight_decay: The weight decay coefficient. + dim: The dimension to normalize over. + eps: The epsilon for numerical stability. + correct_bias: Whether to correct bias in Adam-like computation. + """ + if lr < 0.0: + raise ValueError(f"Invalid learning rate: {lr}") + if betas[0] < 0.0 or betas[0] >= 1.0: + raise ValueError(f"Invalid beta1 value: {betas[0]}") + if betas[1] < 0.0 or betas[1] >= 1.0: + raise ValueError(f"Invalid beta2 value: {betas[1]}") + if weight_decay < 0.0: + raise ValueError(f"Invalid weight_decay value: {weight_decay}") + + defaults = dict( + lr=lr, + betas=betas, + weight_decay=weight_decay, + dim=dim, + eps=eps, + correct_bias=correct_bias, + ) + super().__init__(params, defaults) + + @torch.no_grad() # type: ignore[misc] + def step(self, closure: Callable[[], float] | None = None) -> float | None: + """Performs a single optimization step. + Args: + closure (callable, optional): A closure that reevaluates the model + and returns the loss. + """ + loss = closure() if closure is not None else None + + for group in self.param_groups: + lr = group["lr"] + betas = group["betas"] + wd = group["weight_decay"] + dim = group["dim"] + eps = group["eps"] + correct_bias = group["correct_bias"] + + for param in group["params"]: + if param.grad is None: + continue + if param.ndim != 2: + raise ValueError("ObliqueAdam only supports 2D parameters") + + state = self.state[param] + if "step" not in state: + state["step"] = 0 + + grad = param.grad + + # Initialize momentum buffer if needed + if "exp_avg" not in state: + state["exp_avg"] = torch.zeros_like(param) + if "exp_avg_sq" not in state: + state["exp_avg_sq"] = torch.zeros_like(param) + + exp_avg = state["exp_avg"] + exp_avg_sq = state["exp_avg_sq"] + + # Increment step counter + state["step"] += 1 + step = state["step"] + + # Update biased first and second moment estimates + exp_avg.mul_(betas[0]).add_(grad, alpha=1 - betas[0]) + exp_avg_sq.mul_(betas[1]).addcmul_(grad, grad, value=1 - betas[1]) + + if correct_bias: + # step size correction for ADAM moments EMA + bias_correction1 = 1.0 - betas[0] ** step + bias_correction2 = 1.0 - betas[1] ** step + else: + bias_correction1 = 1.0 + bias_correction2 = 1.0 + + norm_grad = (exp_avg / bias_correction1) / (exp_avg_sq.sqrt() / bias_correction2 + eps) + + # Apply Riemannian gradient update + _compute_riemannian_grad_and_update(param, norm_grad, dim, lr, wd) + + # Retraction back to the manifold, i.e. the hyper-sphere + torch.nn.functional.normalize(param, p=2.0, dim=dim, eps=eps, out=param) + + return loss + + +def _compute_riemannian_grad_and_update( + param: torch.Tensor, grad_like: torch.Tensor, dim: int, lr: float, wd: float +) -> None: + """Compute Riemannian gradient for oblique manifold and update parameter in-place. + + Args: + param: Parameter tensor (2D) + grad_like: Gradient-like tensor (momentum buffer or normalized gradient) + dim: The dimension to normalize over + lr: Learning rate + wd: Weight decay coefficient + """ + + inner = (param * grad_like).sum(dim=dim, keepdim=True) + riem_grad = torch.add(grad_like, param * inner, alpha=-1) + + # Add decoupled weight decay + param.mul_(1 - lr * wd) + + # Apply update in-place + param.add_(riem_grad, alpha=-lr) diff --git a/tests/ci/L0_Tests_GPU.sh b/tests/ci/L0_Tests_GPU.sh index 9d80bcf..785eddd 100644 --- a/tests/ci/L0_Tests_GPU.sh +++ b/tests/ci/L0_Tests_GPU.sh @@ -21,4 +21,6 @@ coverage run -p --source=emerging_optimizers tests/test_soap_utils.py coverage run -p --source=emerging_optimizers tests/soap_smoke_test.py coverage run -p --source=emerging_optimizers tests/soap_mnist_test.py coverage run -p --source=emerging_optimizers tests/test_scalar_optimizers.py --device=cuda -coverage run -p --source=emerging_optimizers tests/test_spectral_clipping_utils.py \ No newline at end of file +coverage run -p --source=emerging_optimizers tests/test_spectral_clipping_utils.py +coverage run -p --source=emerging_optimizers tests/test_normalized_optimizer.py --device=cuda +coverage run -p --source=emerging_optimizers tests/normalized_optimizer_convergence_test.py --device=cuda \ No newline at end of file diff --git a/tests/ci/L1_Tests_GPU.sh b/tests/ci/L1_Tests_GPU.sh index add9c89..6677098 100644 --- a/tests/ci/L1_Tests_GPU.sh +++ b/tests/ci/L1_Tests_GPU.sh @@ -20,3 +20,5 @@ python tests/test_soap_utils.py python tests/soap_smoke_test.py python tests/test_scalar_optimizers.py --device=cuda python tests/test_spectral_clipping_utils.py +python tests/test_normalized_optimizer.py +python tests/normalized_optimizer_convergence_test.py \ No newline at end of file diff --git a/tests/normalized_optimizer_convergence_test.py b/tests/normalized_optimizer_convergence_test.py new file mode 100644 index 0000000..01ad223 --- /dev/null +++ b/tests/normalized_optimizer_convergence_test.py @@ -0,0 +1,228 @@ +# SPDX-FileCopyrightText: Copyright (c) 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# SPDX-License-Identifier: Apache-2.0 +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +import torch +import torch.nn as nn +import torch.nn.functional as F +from absl import flags +from absl.testing import absltest, parameterized +from torch.utils.data import DataLoader, TensorDataset + +from emerging_optimizers.riemannian_optimizers.normalized_optimizer import ObliqueAdam, ObliqueSGD + + +# Define command line flags +flags.DEFINE_string("device", "cpu", "Device to run tests on: 'cpu' or 'cuda'") + +FLAGS = flags.FLAGS + + +class SimpleMLP(nn.Module): + """Simple MLP with oblique-optimized layers for testing.""" + + def __init__(self, input_size=784, hidden_size=128, num_classes=10, dim=0): + super().__init__() + self.fc1 = nn.Linear(input_size, hidden_size, bias=False) + self.fc2 = nn.Linear(hidden_size, hidden_size, bias=False) + self.fc3 = nn.Linear(hidden_size, num_classes, bias=True) # Final layer with bias + self.dim = dim + # Initialize weights for oblique optimization + self._initialize_oblique_weights(dim) + + def _initialize_oblique_weights(self, dim): + """Initialize weights to be normalized for oblique optimization.""" + with torch.no_grad(): + # Normalize in-place for oblique layers + self.fc1.weight.data /= self.fc1.weight.data.norm(dim=dim, keepdim=True).clamp(min=1e-8) + self.fc2.weight.data /= self.fc2.weight.data.norm(dim=dim, keepdim=True).clamp(min=1e-8) + + def forward(self, x): + x = x.view(x.size(0), -1) # Flatten + x = F.relu(self.fc1(x)) + x = F.relu(self.fc2(x)) + x = self.fc3(x) + return x + + def get_oblique_parameters(self): + """Return parameters that should use oblique optimization.""" + return [self.fc1.weight, self.fc2.weight] + + def get_regular_parameters(self): + """Return parameters that should use regular optimization.""" + return [self.fc3.weight, self.fc3.bias] + + +class NormalizedOptimizerConvergenceTest(parameterized.TestCase): + """Convergence tests for normalized optimizers on a simple MLP task.""" + + def setUp(self): + """Set random seed before each test.""" + # Set seed for PyTorch + torch.manual_seed(1234) + # Set seed for CUDA if available + if torch.cuda.is_available(): + torch.cuda.manual_seed_all(1234) + self.device = FLAGS.device + + def _create_synthetic_mnist_data(self, num_samples: int = 1000) -> TensorDataset: + """Create synthetic MNIST-like data for testing.""" + torch.manual_seed(1234) + X = torch.randn(num_samples, 784, device=self.device) + # Create somewhat realistic targets with class imbalance + y = torch.randint(0, 10, (num_samples,), device=self.device) + return TensorDataset(X, y) + + def _train_model( + self, model: SimpleMLP, optimizer_class: torch.optim.Optimizer, optimizer_kwargs: dict, num_epochs: int = 5 + ) -> tuple[float, float, float]: + """Train model with given optimizer and return final loss and accuracy.""" + # Create data + dataset = self._create_synthetic_mnist_data(num_samples=500) + dataloader = DataLoader(dataset, batch_size=64, shuffle=True) + + # Setup optimizers - separate for oblique and regular parameters + oblique_params = model.get_oblique_parameters() + regular_params = model.get_regular_parameters() + + oblique_optimizer = optimizer_class(oblique_params, **optimizer_kwargs) + regular_optimizer = torch.optim.Adam(regular_params, lr=optimizer_kwargs.get("lr", 0.001)) + + criterion = nn.CrossEntropyLoss() + + initial_loss = None + final_loss = None + final_accuracy = 0.0 + + model.train() + for epoch in range(num_epochs): + epoch_loss = 0.0 + correct = 0 + total = 0 + + for batch_x, batch_y in dataloader: + # Zero gradients + oblique_optimizer.zero_grad() + regular_optimizer.zero_grad() + + # Forward pass + outputs = model(batch_x) + loss = criterion(outputs, batch_y) + + # Backward pass + loss.backward() + + # Update parameters + oblique_optimizer.step() + regular_optimizer.step() + + # Track metrics + epoch_loss += loss.item() + _, predicted = torch.max(outputs.data, 1) + total += batch_y.size(0) + correct += (predicted == batch_y).sum().item() + + avg_loss = epoch_loss / len(dataloader) + accuracy = 100 * correct / total + + if initial_loss is None: + initial_loss = avg_loss + final_loss = avg_loss + final_accuracy = accuracy + + return initial_loss, final_loss, final_accuracy + + def _verify_norms_preserved(self, model: SimpleMLP) -> None: + """Verify that oblique parameters maintain unit column norms.""" + for param in model.get_oblique_parameters(): + column_norms = param.data.norm(dim=0) # Column norms + expected_norms = torch.ones_like(column_norms) + torch.testing.assert_close( + column_norms, + expected_norms, + atol=0, + rtol=1e-5, + ) + + def test_oblique_sgd_convergence(self) -> None: + """Test that ObliqueSGD can train a simple MLP and maintain norms.""" + model = SimpleMLP(input_size=784, hidden_size=64, num_classes=10).to(self.device) + + # Train with ObliqueSGD + initial_loss, final_loss, final_accuracy = self._train_model( + model, ObliqueSGD, {"lr": 0.01, "momentum": 0.9, "dim": 0}, num_epochs=10 + ) + + # Check convergence + self.assertLess(final_loss, initial_loss, "Loss should decrease during training") + self.assertGreater(final_accuracy, 5.0, "Accuracy should be better than random (10%)") + + # Check norm preservation + self._verify_norms_preserved(model) + + def test_oblique_adam_convergence(self) -> None: + """Test that ObliqueAdam can train a simple MLP and maintain norms.""" + model = SimpleMLP(input_size=784, hidden_size=64, num_classes=10).to(self.device) + + # Train with ObliqueAdam + initial_loss, final_loss, final_accuracy = self._train_model( + model, ObliqueAdam, {"lr": 0.001, "betas": (0.9, 0.999), "dim": 0}, num_epochs=10 + ) + + # Check convergence + self.assertLess(final_loss, initial_loss, "Loss should decrease during training") + self.assertGreater(final_accuracy, 5.0, "Accuracy should be better than random (10%)") + + # Check norm preservation + self._verify_norms_preserved(model) + + @parameterized.named_parameters( + ("sgd_col", ObliqueSGD, {"lr": 0.1, "momentum": 0.75, "weight_decay": 0.1, "dim": 0}), + ("sgd_row", ObliqueSGD, {"lr": 0.1, "momentum": 0.75, "weight_decay": 0.1, "dim": 1}), + ("adam_col", ObliqueAdam, {"lr": 0.1, "betas": (0.9, 0.999), "weight_decay": 0.1, "dim": 0}), + ("adam_row", ObliqueAdam, {"lr": 0.1, "betas": (0.9, 0.999), "weight_decay": 0.1, "dim": 1}), + ) + def test_optimizer_modes_convergence(self, optimizer_class: torch.optim.Optimizer, optimizer_kwargs: dict) -> None: + """Test that both row and column modes work for both optimizers.""" + model = SimpleMLP(input_size=784, hidden_size=32, num_classes=10).to(self.device) + + # Re-initialize for row normalization + with torch.no_grad(): + for param in model.get_oblique_parameters(): + param.data /= param.data.norm(dim=optimizer_kwargs["dim"], keepdim=True).clamp(min=1e-8) + + # Train model + initial_loss, final_loss, final_accuracy = self._train_model( + model, optimizer_class, optimizer_kwargs, num_epochs=8 + ) + + # Basic convergence check + self.assertLess(final_loss, initial_loss * 1.01, "Loss should decrease or stay stable") + print(f"Final accuracy: {final_accuracy}") + self.assertGreater(final_accuracy, 50.0, "Should achieve reasonable accuracy") + + # Verify norm preservation based on mode + for param in model.get_oblique_parameters(): + norms = param.data.norm(dim=optimizer_kwargs["dim"]) + + expected_norms = torch.ones_like(norms) + torch.testing.assert_close( + norms, + expected_norms, + atol=0, + rtol=1e-5, + ) + + +if __name__ == "__main__": + absltest.main() diff --git a/tests/test_normalized_optimizer.py b/tests/test_normalized_optimizer.py new file mode 100644 index 0000000..5f999a3 --- /dev/null +++ b/tests/test_normalized_optimizer.py @@ -0,0 +1,273 @@ +# SPDX-FileCopyrightText: Copyright (c) 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# SPDX-License-Identifier: Apache-2.0 +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import torch +from absl import flags +from absl.testing import absltest, parameterized + +from emerging_optimizers.riemannian_optimizers.normalized_optimizer import ObliqueAdam, ObliqueSGD + + +# Define command line flags +flags.DEFINE_string("device", "cpu", "Device to run tests on: 'cpu' or 'cuda'") + +FLAGS = flags.FLAGS + + +class NormalizedOptimizerFunctionalTest(parameterized.TestCase): + """Tests for ObliqueSGD and ObliqueAdam optimizers that preserve row/column norms.""" + + def setUp(self): + """Set random seed before each test.""" + # Set seed for PyTorch + torch.manual_seed(1234) + # Set seed for CUDA if available + if torch.cuda.is_available(): + torch.cuda.manual_seed_all(1234) + self.device = FLAGS.device + + @parameterized.parameters( + (0), + (1), + ) + def test_oblique_sgd_preserves_norms(self, dim: int) -> None: + """Test that ObliqueSGD preserves row or column norms after one optimization step.""" + # Create a 4x6 matrix for testing + matrix_size = (4, 6) + + # Initialize with random values then normalize + param = torch.randn(matrix_size, dtype=torch.float32, device=self.device) + + # Normalize according to dim + torch.nn.functional.normalize(param, p=2.0, dim=dim, eps=1e-8, out=param) + + # Create optimizer + param = torch.nn.Parameter(param) + optimizer = ObliqueSGD([param], lr=0.1, momentum=0.9, dim=dim) + + # Generate random gradient + torch.manual_seed(1234) # For reproducible gradients + param.grad = torch.randn_like(param.data, device=self.device) + + # Perform one optimization step + optimizer.step() + + # Check that norms are preserved (should be 1.0 within tolerance) + final_norms = param.norm(dim=dim) + + # All norms should be approximately 1.0 (unit norm constraint) + expected_norms = torch.ones_like(final_norms) + torch.testing.assert_close( + final_norms, + expected_norms, + atol=0, + rtol=1e-6, + ) + + @parameterized.parameters( + (0), + (1), + ) + def test_oblique_adam_preserves_norms(self, dim: int) -> None: + """Test that ObliqueAdam preserves row or column norms after one optimization step.""" + # Create a 3x5 matrix for testing + matrix_size = (3, 5) + + # Initialize with random values then normalize + param = torch.randn(matrix_size, dtype=torch.float32, device=self.device) + + # Normalize + torch.nn.functional.normalize(param, p=2.0, dim=dim, eps=1e-8, out=param) + # Create optimizer + param = torch.nn.Parameter(param) + optimizer = ObliqueAdam([param], lr=0.01, betas=(0.9, 0.999), dim=dim) + + # Generate random gradient + torch.manual_seed(1234) # For reproducible gradients + param.grad = torch.randn_like(param.data, device=self.device) + + # Perform one optimization step + optimizer.step() + + # Check that norms are preserved (should be 1.0 within tolerance) + final_norms = param.norm(dim=dim) + + # All norms should be approximately 1.0 (unit norm constraint) + expected_norms = torch.ones_like(final_norms) + torch.testing.assert_close( + final_norms, + expected_norms, + atol=0, + rtol=1e-6, + ) + + def test_oblique_sgd_zero_gradient(self) -> None: + """Test that ObliqueSGD handles zero gradients correctly.""" + matrix_size = (2, 4) + param = torch.randn(matrix_size, dtype=torch.float32, device=self.device) + + # Normalize + torch.nn.functional.normalize(param, p=2.0, dim=0, eps=1e-8, out=param) + initial_param = param.clone() + + param = torch.nn.Parameter(param) + optimizer = ObliqueSGD([param], lr=0.1, dim=0) + + # Set zero gradient + param.grad = torch.zeros_like(param.data, device=self.device) + + # Perform optimization step + optimizer.step() + + # Parameter should remain unchanged with zero gradient + torch.testing.assert_close(param.data, initial_param, atol=0, rtol=1e-8) + + # Norms should still be 1.0 + final_norms = param.norm(dim=0) + expected_norms = torch.ones_like(final_norms) + torch.testing.assert_close(final_norms, expected_norms, atol=0, rtol=1e-6) + + def test_oblique_adam_zero_gradient(self) -> None: + """Test that ObliqueAdam handles zero gradients correctly.""" + matrix_size = (2, 3) + param = torch.randn(matrix_size, dtype=torch.float32, device=self.device) + + # Normalize + torch.nn.functional.normalize(param, p=2.0, dim=1, eps=1e-8, out=param) + initial_param = param.clone() + + # Keep as tensor, not parameter, but enable gradients + param.requires_grad_(True) + optimizer = ObliqueAdam([param], lr=0.01, dim=1) + + # Set zero gradient + param.grad = torch.zeros_like(param.data, device=self.device) + + # Perform optimization step + optimizer.step() + + # Parameter should remain unchanged with zero gradient + torch.testing.assert_close(param.data, initial_param, atol=0, rtol=1e-6) + + # Norms should still be 1.0 + final_norms = param.norm(dim=1) + expected_norms = torch.ones_like(final_norms) + torch.testing.assert_close(final_norms, expected_norms, atol=0, rtol=1e-6) + + def test_oblique_sgd_large_gradient(self) -> None: + """Test that ObliqueSGD handles large gradients correctly.""" + matrix_size = (3, 4) + param = torch.randn(matrix_size, dtype=torch.float32, device=self.device) + + # Normalize + param = param / param.norm(dim=0, keepdim=True).clamp(min=1e-8) + + param = torch.nn.Parameter(param) + optimizer = ObliqueSGD([param], lr=0.1, dim=0) + + # Set large gradient + param.grad = 100.0 * torch.randn_like(param.data, device=self.device) + + # Perform optimization step + optimizer.step() + + # Norms should still be preserved despite large gradient + final_norms = param.norm(dim=0) + expected_norms = torch.ones_like(final_norms) + torch.testing.assert_close(final_norms, expected_norms, atol=0, rtol=1e-6) + + def test_oblique_adam_large_gradient(self) -> None: + """Test that ObliqueAdam handles large gradients correctly.""" + matrix_size = (2, 5) + param = torch.randn(matrix_size, dtype=torch.float32, device=self.device) + + # Normalize rows + param = param / param.norm(dim=1, keepdim=True).clamp(min=1e-8) + + param = torch.nn.Parameter(param) + optimizer = ObliqueAdam([param], lr=0.01, dim=1) + + # Set large gradient + param.grad = 1000.0 * torch.randn_like(param.data, device=self.device) + + # Perform optimization step + optimizer.step() + + # Norms should still be preserved despite large gradient + final_norms = param.norm(dim=1) + expected_norms = torch.ones_like(final_norms) + torch.testing.assert_close( + final_norms, + expected_norms, + atol=0, + rtol=1e-6, + ) + + def test_multiple_optimization_steps_preserve_norms(self) -> None: + """Test that norms are preserved across multiple optimization steps.""" + matrix_size = (4, 4) + param = torch.randn(matrix_size, dtype=torch.float32, device=self.device) + + # Normalize + param = param / param.norm(dim=0, keepdim=True).clamp(min=1e-8) + + param = torch.nn.Parameter(param) + optimizer = ObliqueSGD([param], lr=0.05, momentum=0.8, dim=0) + + # Perform multiple optimization steps + for step in range(10): + param.grad = torch.randn_like(param.data, device=self.device) + optimizer.step() + + # Check norms after each step + final_norms = param.norm(dim=0) + expected_norms = torch.ones_like(final_norms) + torch.testing.assert_close( + final_norms, + expected_norms, + atol=0, + rtol=1e-6, + ) + + def test_weight_decay_with_norm_preservation(self) -> None: + """Test that weight decay doesn't break norm preservation.""" + matrix_size = (3, 3) + param = torch.randn(matrix_size, dtype=torch.float32, device=self.device) + + # Normalize + param = param / param.norm(dim=1, keepdim=True).clamp(min=1e-8) + + param = torch.nn.Parameter(param) + optimizer = ObliqueAdam([param], lr=0.01, weight_decay=0.01, dim=1) + + # Generate random gradient + param.grad = torch.randn_like(param.data, device=self.device) + + # Perform optimization step + optimizer.step() + + # Norms should still be preserved with weight decay + final_norms = param.norm(dim=1) + expected_norms = torch.ones_like(final_norms) + torch.testing.assert_close( + final_norms, + expected_norms, + atol=0, + rtol=1e-6, + ) + + +if __name__ == "__main__": + absltest.main()