@@ -60,8 +60,9 @@ def get_regular_parameters(self):
6060 return [self .fc3 .weight , self .fc3 .bias ]
6161
6262
63- # Base class for tests requiring seeding for determinism
64- class BaseTestCase (parameterized .TestCase ):
63+ class NormalizedOptimizerConvergenceTest (parameterized .TestCase ):
64+ """Convergence tests for normalized optimizers on a simple MLP task."""
65+
6566 def setUp (self ):
6667 """Set random seed before each test."""
6768 # Set seed for PyTorch
@@ -71,19 +72,17 @@ def setUp(self):
7172 torch .cuda .manual_seed_all (1234 )
7273 self .device = torch .device ("cuda" if torch .cuda .is_available () else "cpu" )
7374
74-
75- class NormalizedOptimizerConvergenceTest (BaseTestCase ):
76- """Convergence tests for normalized optimizers on a simple MLP task."""
77-
78- def _create_synthetic_mnist_data (self , num_samples = 1000 ):
75+ def _create_synthetic_mnist_data (self , num_samples : int = 1000 ) -> TensorDataset :
7976 """Create synthetic MNIST-like data for testing."""
8077 torch .manual_seed (1234 )
8178 X = torch .randn (num_samples , 784 , device = self .device )
8279 # Create somewhat realistic targets with class imbalance
8380 y = torch .randint (0 , 10 , (num_samples ,))
8481 return TensorDataset (X , y )
8582
86- def _train_model (self , model , optimizer_class , optimizer_kwargs , num_epochs = 5 ):
83+ def _train_model (
84+ self , model : SimpleMLP , optimizer_class : torch .optim .Optimizer , optimizer_kwargs : dict , num_epochs : int = 5
85+ ) -> tuple [float , float , float ]:
8786 """Train model with given optimizer and return final loss and accuracy."""
8887 # Create data
8988 dataset = self ._create_synthetic_mnist_data (num_samples = 500 )
@@ -140,7 +139,7 @@ def _train_model(self, model, optimizer_class, optimizer_kwargs, num_epochs=5):
140139
141140 return initial_loss , final_loss , final_accuracy
142141
143- def _verify_norms_preserved (self , model ) :
142+ def _verify_norms_preserved (self , model : SimpleMLP ) -> None :
144143 """Verify that oblique parameters maintain unit column norms."""
145144 for param in model .get_oblique_parameters ():
146145 column_norms = param .data .norm (dim = 0 ) # Column norms
@@ -152,7 +151,7 @@ def _verify_norms_preserved(self, model):
152151 rtol = 1e-5 ,
153152 )
154153
155- def test_oblique_sgd_convergence (self ):
154+ def test_oblique_sgd_convergence (self ) -> None :
156155 """Test that ObliqueSGD can train a simple MLP and maintain norms."""
157156 model = SimpleMLP (input_size = 784 , hidden_size = 64 , num_classes = 10 )
158157
@@ -168,7 +167,7 @@ def test_oblique_sgd_convergence(self):
168167 # Check norm preservation
169168 self ._verify_norms_preserved (model )
170169
171- def test_oblique_adam_convergence (self ):
170+ def test_oblique_adam_convergence (self ) -> None :
172171 """Test that ObliqueAdam can train a simple MLP and maintain norms."""
173172 model = SimpleMLP (input_size = 784 , hidden_size = 64 , num_classes = 10 )
174173
@@ -190,7 +189,7 @@ def test_oblique_adam_convergence(self):
190189 ("adam_col" , ObliqueAdam , {"lr" : 0.1 , "betas" : (0.9 , 0.999 ), "weight_decay" : 0.1 , "dim" : 0 }),
191190 ("adam_row" , ObliqueAdam , {"lr" : 0.1 , "betas" : (0.9 , 0.999 ), "weight_decay" : 0.1 , "dim" : 1 }),
192191 )
193- def test_optimizer_modes_convergence (self , optimizer_class , optimizer_kwargs ) :
192+ def test_optimizer_modes_convergence (self , optimizer_class : torch . optim . Optimizer , optimizer_kwargs : dict ) -> None :
194193 """Test that both row and column modes work for both optimizers."""
195194 model = SimpleMLP (input_size = 784 , hidden_size = 32 , num_classes = 10 )
196195
0 commit comments