Skip to content

Commit 6eab498

Browse files
committed
moved model to device appropriately
Signed-off-by: mikail <mkhona@nvidia.com>
1 parent 344689f commit 6eab498

File tree

1 file changed

+3
-3
lines changed

1 file changed

+3
-3
lines changed

tests/normalized_optimizer_convergence_test.py

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -149,7 +149,7 @@ def _verify_norms_preserved(self, model: SimpleMLP) -> None:
149149

150150
def test_oblique_sgd_convergence(self) -> None:
151151
"""Test that ObliqueSGD can train a simple MLP and maintain norms."""
152-
model = SimpleMLP(input_size=784, hidden_size=64, num_classes=10)
152+
model = SimpleMLP(input_size=784, hidden_size=64, num_classes=10).to(self.device)
153153

154154
# Train with ObliqueSGD
155155
initial_loss, final_loss, final_accuracy = self._train_model(
@@ -165,7 +165,7 @@ def test_oblique_sgd_convergence(self) -> None:
165165

166166
def test_oblique_adam_convergence(self) -> None:
167167
"""Test that ObliqueAdam can train a simple MLP and maintain norms."""
168-
model = SimpleMLP(input_size=784, hidden_size=64, num_classes=10)
168+
model = SimpleMLP(input_size=784, hidden_size=64, num_classes=10).to(self.device)
169169

170170
# Train with ObliqueAdam
171171
initial_loss, final_loss, final_accuracy = self._train_model(
@@ -187,7 +187,7 @@ def test_oblique_adam_convergence(self) -> None:
187187
)
188188
def test_optimizer_modes_convergence(self, optimizer_class: torch.optim.Optimizer, optimizer_kwargs: dict) -> None:
189189
"""Test that both row and column modes work for both optimizers."""
190-
model = SimpleMLP(input_size=784, hidden_size=32, num_classes=10)
190+
model = SimpleMLP(input_size=784, hidden_size=32, num_classes=10).to(self.device)
191191

192192
# Re-initialize for row normalization
193193
with torch.no_grad():

0 commit comments

Comments
 (0)