File tree Expand file tree Collapse file tree 2 files changed +9
-2
lines changed
Expand file tree Collapse file tree 2 files changed +9
-2
lines changed Original file line number Diff line number Diff line change @@ -22,5 +22,5 @@ coverage run -p --source=emerging_optimizers tests/soap_smoke_test.py
2222coverage run -p --source=emerging_optimizers tests/soap_mnist_test.py
2323coverage run -p --source=emerging_optimizers tests/test_scalar_optimizers.py --device=cuda
2424coverage run -p --source=emerging_optimizers tests/test_spectral_clipping_utils.py
25- coverage run -p --source=emerging_optimizers tests/test_normalized_optimizer.py
25+ coverage run -p --source=emerging_optimizers tests/test_normalized_optimizer.py --device=cuda
2626coverage run -p --source=emerging_optimizers tests/normalized_optimizer_convergence_test.py --device=cuda
Original file line number Diff line number Diff line change 1414# limitations under the License.
1515
1616import torch
17+ from absl import flags
1718from absl .testing import absltest , parameterized
1819
1920from emerging_optimizers .riemannian_optimizers .normalized_optimizer import ObliqueAdam , ObliqueSGD
2021
2122
23+ # Define command line flags
24+ flags .DEFINE_string ("device" , "cpu" , "Device to run tests on: 'cpu' or 'cuda'" )
25+
26+ FLAGS = flags .FLAGS
27+
28+
2229class NormalizedOptimizerFunctionalTest (parameterized .TestCase ):
2330 """Tests for ObliqueSGD and ObliqueAdam optimizers that preserve row/column norms."""
2431
@@ -29,7 +36,7 @@ def setUp(self):
2936 # Set seed for CUDA if available
3037 if torch .cuda .is_available ():
3138 torch .cuda .manual_seed_all (1234 )
32- self .device = torch .device ( "cuda" if torch . cuda . is_available () else "cpu" )
39+ self .device = FLAGS .device
3340
3441 @parameterized .parameters (
3542 (0 ),
You can’t perform that action at this time.
0 commit comments