@@ -149,7 +149,7 @@ def _verify_norms_preserved(self, model: SimpleMLP) -> None:
149149
150150 def test_oblique_sgd_convergence (self ) -> None :
151151 """Test that ObliqueSGD can train a simple MLP and maintain norms."""
152- model = SimpleMLP (input_size = 784 , hidden_size = 64 , num_classes = 10 )
152+ model = SimpleMLP (input_size = 784 , hidden_size = 64 , num_classes = 10 ). to ( self . device )
153153
154154 # Train with ObliqueSGD
155155 initial_loss , final_loss , final_accuracy = self ._train_model (
@@ -165,7 +165,7 @@ def test_oblique_sgd_convergence(self) -> None:
165165
166166 def test_oblique_adam_convergence (self ) -> None :
167167 """Test that ObliqueAdam can train a simple MLP and maintain norms."""
168- model = SimpleMLP (input_size = 784 , hidden_size = 64 , num_classes = 10 )
168+ model = SimpleMLP (input_size = 784 , hidden_size = 64 , num_classes = 10 ). to ( self . device )
169169
170170 # Train with ObliqueAdam
171171 initial_loss , final_loss , final_accuracy = self ._train_model (
@@ -187,7 +187,7 @@ def test_oblique_adam_convergence(self) -> None:
187187 )
188188 def test_optimizer_modes_convergence (self , optimizer_class : torch .optim .Optimizer , optimizer_kwargs : dict ) -> None :
189189 """Test that both row and column modes work for both optimizers."""
190- model = SimpleMLP (input_size = 784 , hidden_size = 32 , num_classes = 10 )
190+ model = SimpleMLP (input_size = 784 , hidden_size = 32 , num_classes = 10 ). to ( self . device )
191191
192192 # Re-initialize for row normalization
193193 with torch .no_grad ():
0 commit comments