Skip to content

Commit 06fc893

Browse files
committed
cleaned up test cases
Signed-off-by: mikail <mkhona@nvidia.com>
1 parent 3c9bec5 commit 06fc893

File tree

2 files changed

+22
-26
lines changed

2 files changed

+22
-26
lines changed

tests/normalized_optimizer_convergence_test.py

Lines changed: 11 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -60,8 +60,9 @@ def get_regular_parameters(self):
6060
return [self.fc3.weight, self.fc3.bias]
6161

6262

63-
# Base class for tests requiring seeding for determinism
64-
class BaseTestCase(parameterized.TestCase):
63+
class NormalizedOptimizerConvergenceTest(parameterized.TestCase):
64+
"""Convergence tests for normalized optimizers on a simple MLP task."""
65+
6566
def setUp(self):
6667
"""Set random seed before each test."""
6768
# Set seed for PyTorch
@@ -71,19 +72,17 @@ def setUp(self):
7172
torch.cuda.manual_seed_all(1234)
7273
self.device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
7374

74-
75-
class NormalizedOptimizerConvergenceTest(BaseTestCase):
76-
"""Convergence tests for normalized optimizers on a simple MLP task."""
77-
78-
def _create_synthetic_mnist_data(self, num_samples=1000):
75+
def _create_synthetic_mnist_data(self, num_samples: int = 1000) -> TensorDataset:
7976
"""Create synthetic MNIST-like data for testing."""
8077
torch.manual_seed(1234)
8178
X = torch.randn(num_samples, 784, device=self.device)
8279
# Create somewhat realistic targets with class imbalance
8380
y = torch.randint(0, 10, (num_samples,))
8481
return TensorDataset(X, y)
8582

86-
def _train_model(self, model, optimizer_class, optimizer_kwargs, num_epochs=5):
83+
def _train_model(
84+
self, model: SimpleMLP, optimizer_class: torch.optim.Optimizer, optimizer_kwargs: dict, num_epochs: int = 5
85+
) -> tuple[float, float, float]:
8786
"""Train model with given optimizer and return final loss and accuracy."""
8887
# Create data
8988
dataset = self._create_synthetic_mnist_data(num_samples=500)
@@ -140,7 +139,7 @@ def _train_model(self, model, optimizer_class, optimizer_kwargs, num_epochs=5):
140139

141140
return initial_loss, final_loss, final_accuracy
142141

143-
def _verify_norms_preserved(self, model):
142+
def _verify_norms_preserved(self, model: SimpleMLP) -> None:
144143
"""Verify that oblique parameters maintain unit column norms."""
145144
for param in model.get_oblique_parameters():
146145
column_norms = param.data.norm(dim=0) # Column norms
@@ -152,7 +151,7 @@ def _verify_norms_preserved(self, model):
152151
rtol=1e-5,
153152
)
154153

155-
def test_oblique_sgd_convergence(self):
154+
def test_oblique_sgd_convergence(self) -> None:
156155
"""Test that ObliqueSGD can train a simple MLP and maintain norms."""
157156
model = SimpleMLP(input_size=784, hidden_size=64, num_classes=10)
158157

@@ -168,7 +167,7 @@ def test_oblique_sgd_convergence(self):
168167
# Check norm preservation
169168
self._verify_norms_preserved(model)
170169

171-
def test_oblique_adam_convergence(self):
170+
def test_oblique_adam_convergence(self) -> None:
172171
"""Test that ObliqueAdam can train a simple MLP and maintain norms."""
173172
model = SimpleMLP(input_size=784, hidden_size=64, num_classes=10)
174173

@@ -190,7 +189,7 @@ def test_oblique_adam_convergence(self):
190189
("adam_col", ObliqueAdam, {"lr": 0.1, "betas": (0.9, 0.999), "weight_decay": 0.1, "dim": 0}),
191190
("adam_row", ObliqueAdam, {"lr": 0.1, "betas": (0.9, 0.999), "weight_decay": 0.1, "dim": 1}),
192191
)
193-
def test_optimizer_modes_convergence(self, optimizer_class, optimizer_kwargs):
192+
def test_optimizer_modes_convergence(self, optimizer_class: torch.optim.Optimizer, optimizer_kwargs: dict) -> None:
194193
"""Test that both row and column modes work for both optimizers."""
195194
model = SimpleMLP(input_size=784, hidden_size=32, num_classes=10)
196195

tests/test_normalized_optimizer.py

Lines changed: 11 additions & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -19,8 +19,9 @@
1919
from emerging_optimizers.riemannian_optimizers.normalized_optimizer import ObliqueAdam, ObliqueSGD
2020

2121

22-
# Base class for tests requiring seeding for determinism
23-
class BaseTestCase(parameterized.TestCase):
22+
class NormalizedOptimizerFunctionalTest(parameterized.TestCase):
23+
"""Tests for ObliqueSGD and ObliqueAdam optimizers that preserve row/column norms."""
24+
2425
def setUp(self):
2526
"""Set random seed before each test."""
2627
# Set seed for PyTorch
@@ -30,15 +31,11 @@ def setUp(self):
3031
torch.cuda.manual_seed_all(1234)
3132
self.device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
3233

33-
34-
class NormalizedOptimizerFunctionalTest(BaseTestCase):
35-
"""Tests for ObliqueSGD and ObliqueAdam optimizers that preserve row/column norms."""
36-
3734
@parameterized.parameters(
3835
(0),
3936
(1),
4037
)
41-
def test_oblique_sgd_preserves_norms(self, dim):
38+
def test_oblique_sgd_preserves_norms(self, dim: int) -> None:
4239
"""Test that ObliqueSGD preserves row or column norms after one optimization step."""
4340
# Create a 4x6 matrix for testing
4441
matrix_size = (4, 6)
@@ -76,7 +73,7 @@ def test_oblique_sgd_preserves_norms(self, dim):
7673
(0),
7774
(1),
7875
)
79-
def test_oblique_adam_preserves_norms(self, dim):
76+
def test_oblique_adam_preserves_norms(self, dim: int) -> None:
8077
"""Test that ObliqueAdam preserves row or column norms after one optimization step."""
8178
# Create a 3x5 matrix for testing
8279
matrix_size = (3, 5)
@@ -109,7 +106,7 @@ def test_oblique_adam_preserves_norms(self, dim):
109106
rtol=1e-6,
110107
)
111108

112-
def test_oblique_sgd_zero_gradient(self):
109+
def test_oblique_sgd_zero_gradient(self) -> None:
113110
"""Test that ObliqueSGD handles zero gradients correctly."""
114111
matrix_size = (2, 4)
115112
param = torch.randn(matrix_size, dtype=torch.float32, device=self.device)
@@ -135,7 +132,7 @@ def test_oblique_sgd_zero_gradient(self):
135132
expected_norms = torch.ones_like(final_norms)
136133
torch.testing.assert_close(final_norms, expected_norms, atol=0, rtol=1e-6)
137134

138-
def test_oblique_adam_zero_gradient(self):
135+
def test_oblique_adam_zero_gradient(self) -> None:
139136
"""Test that ObliqueAdam handles zero gradients correctly."""
140137
matrix_size = (2, 3)
141138
param = torch.randn(matrix_size, dtype=torch.float32, device=self.device)
@@ -161,7 +158,7 @@ def test_oblique_adam_zero_gradient(self):
161158
expected_norms = torch.ones_like(final_norms)
162159
torch.testing.assert_close(final_norms, expected_norms, atol=0, rtol=1e-6)
163160

164-
def test_oblique_sgd_large_gradient(self):
161+
def test_oblique_sgd_large_gradient(self) -> None:
165162
"""Test that ObliqueSGD handles large gradients correctly."""
166163
matrix_size = (3, 4)
167164
param = torch.randn(matrix_size, dtype=torch.float32, device=self.device)
@@ -183,7 +180,7 @@ def test_oblique_sgd_large_gradient(self):
183180
expected_norms = torch.ones_like(final_norms)
184181
torch.testing.assert_close(final_norms, expected_norms, atol=0, rtol=1e-6)
185182

186-
def test_oblique_adam_large_gradient(self):
183+
def test_oblique_adam_large_gradient(self) -> None:
187184
"""Test that ObliqueAdam handles large gradients correctly."""
188185
matrix_size = (2, 5)
189186
param = torch.randn(matrix_size, dtype=torch.float32, device=self.device)
@@ -210,7 +207,7 @@ def test_oblique_adam_large_gradient(self):
210207
rtol=1e-6,
211208
)
212209

213-
def test_multiple_optimization_steps_preserve_norms(self):
210+
def test_multiple_optimization_steps_preserve_norms(self) -> None:
214211
"""Test that norms are preserved across multiple optimization steps."""
215212
matrix_size = (4, 4)
216213
param = torch.randn(matrix_size, dtype=torch.float32, device=self.device)
@@ -237,7 +234,7 @@ def test_multiple_optimization_steps_preserve_norms(self):
237234
rtol=1e-6,
238235
)
239236

240-
def test_weight_decay_with_norm_preservation(self):
237+
def test_weight_decay_with_norm_preservation(self) -> None:
241238
"""Test that weight decay doesn't break norm preservation."""
242239
matrix_size = (3, 3)
243240
param = torch.randn(matrix_size, dtype=torch.float32, device=self.device)

0 commit comments

Comments
 (0)