diff --git a/tests/ci/L0_Tests_CPU.sh b/tests/ci/L0_Tests_CPU.sh index afd9f1a..594773f 100644 --- a/tests/ci/L0_Tests_CPU.sh +++ b/tests/ci/L0_Tests_CPU.sh @@ -15,3 +15,5 @@ export TORCH_COMPILE_DISABLE=1 set -o pipefail torchrun --nproc_per_node=8 --no-python coverage run -p tests/test_distributed_muon_utils_cpu.py torchrun --nproc_per_node=4 --no-python coverage run -p tests/test_distributed_muon_utils_cpu.py +coverage run -p --source=emerging_optimizers tests/test_scalar_optimizers.py --device=cpu + diff --git a/tests/ci/L0_Tests_GPU.sh b/tests/ci/L0_Tests_GPU.sh index bd1a818..9d80bcf 100644 --- a/tests/ci/L0_Tests_GPU.sh +++ b/tests/ci/L0_Tests_GPU.sh @@ -20,5 +20,5 @@ coverage run -p --source=emerging_optimizers tests/test_soap_functions.py coverage run -p --source=emerging_optimizers tests/test_soap_utils.py coverage run -p --source=emerging_optimizers tests/soap_smoke_test.py coverage run -p --source=emerging_optimizers tests/soap_mnist_test.py -coverage run -p --source=emerging_optimizers tests/test_scalar_optimizers.py +coverage run -p --source=emerging_optimizers tests/test_scalar_optimizers.py --device=cuda coverage run -p --source=emerging_optimizers tests/test_spectral_clipping_utils.py \ No newline at end of file diff --git a/tests/ci/L1_Tests_GPU.sh b/tests/ci/L1_Tests_GPU.sh index 7af079e..add9c89 100644 --- a/tests/ci/L1_Tests_GPU.sh +++ b/tests/ci/L1_Tests_GPU.sh @@ -18,5 +18,5 @@ python tests/test_orthogonalized_optimizer.py python tests/test_soap_functions.py python tests/test_soap_utils.py python tests/soap_smoke_test.py -python tests/test_scalar_optimizers.py -python tests/test_spectral_clipping_utils.py \ No newline at end of file +python tests/test_scalar_optimizers.py --device=cuda +python tests/test_spectral_clipping_utils.py diff --git a/tests/test_scalar_optimizers.py b/tests/test_scalar_optimizers.py index 07a6f88..8e8244b 100644 --- a/tests/test_scalar_optimizers.py +++ b/tests/test_scalar_optimizers.py @@ -13,6 +13,7 @@ # See the License for the specific language governing permissions and # limitations under the License. import torch +from absl import flags from absl.testing import absltest, parameterized from emerging_optimizers.scalar_optimizers import ( @@ -23,22 +24,35 @@ ) -# Base class for tests requiring seeding for determinism -class BaseTestCase(parameterized.TestCase): +# Define command line flags +flags.DEFINE_string("device", "cpu", "Device to run tests on: 'cpu' or 'cuda'") +flags.DEFINE_integer("seed", 42, "Random seed for reproducible tests") + +FLAGS = flags.FLAGS + + +class ScalarOptimizerTest(parameterized.TestCase): def setUp(self): - """Set random seed before each test.""" - # Set seed for PyTorch - torch.manual_seed(42) + """Set random seed and device before each test.""" + # Set seed for PyTorch (using seed from flags) + torch.manual_seed(FLAGS.seed) # Set seed for CUDA if available if torch.cuda.is_available(): - torch.cuda.manual_seed_all(42) + torch.cuda.manual_seed_all(FLAGS.seed) + # Set up device based on flags + self.device = FLAGS.device -class ScalarOptimizerTest(BaseTestCase): def test_calculate_adam_update_simple(self) -> None: - exp_avg_initial = torch.tensor([[1.0]]) - exp_avg_sq_initial = torch.tensor([[2.0]]) - grad = torch.tensor([[0.5]]) + exp_avg_initial = torch.tensor([[1.0]], device=self.device) + exp_avg_sq_initial = torch.tensor([[2.0]], device=self.device) + grad = torch.tensor([[0.5]], device=self.device) + + # Move tensors to the test device + exp_avg_initial = exp_avg_initial.to(self.device) + exp_avg_sq_initial = exp_avg_sq_initial.to(self.device) + grad = grad.to(self.device) + betas = (0.9, 0.99) eps = 1e-8 step = 10 @@ -59,7 +73,7 @@ def test_calculate_adam_update_simple(self) -> None: eps=eps, ) - initial_param_val_tensor = torch.tensor([[10.0]]) + initial_param_val_tensor = torch.tensor([[10.0]]).to(self.device) param = torch.nn.Parameter(initial_param_val_tensor.clone()) param.grad = grad.clone() @@ -73,7 +87,7 @@ def test_calculate_adam_update_simple(self) -> None: ) # Manually set Adam's internal state to match conditions *before* the current update - adam_optimizer.state[param]["step"] = torch.tensor(float(step - 1)) + adam_optimizer.state[param]["step"] = torch.tensor(float(step - 1), device=self.device) adam_optimizer.state[param]["exp_avg"] = exp_avg_initial.clone() adam_optimizer.state[param]["exp_avg_sq"] = exp_avg_sq_initial.clone() @@ -96,9 +110,9 @@ def test_calculate_adam_update_simple(self) -> None: def test_calculate_laprop_update_with_zero_momentum_equals_rmsprop(self) -> None: # LaProp with momentum (beta1) = 0 should be equivalent to RMSProp. - exp_avg_initial = torch.tensor([[0.0]]) # Momentum is 0, so exp_avg starts at 0 - exp_avg_sq_initial = torch.tensor([[2.0]]) - grad = torch.tensor([[0.5]]) + exp_avg_initial = torch.tensor([[0.0]], device=self.device) # Momentum is 0, so exp_avg starts at 0 + exp_avg_sq_initial = torch.tensor([[2.0]], device=self.device) + grad = torch.tensor([[0.5]], device=self.device) betas = (0.0, 0.99) # beta1=0 for momentum eps = 1e-8 step = 10 @@ -119,7 +133,7 @@ def test_calculate_laprop_update_with_zero_momentum_equals_rmsprop(self) -> None ) # Manually verify with RMSProp logic - initial_param_val_tensor = torch.tensor([[10.0]]) + initial_param_val_tensor = torch.tensor([[10.0]], device=self.device) param = torch.nn.Parameter(initial_param_val_tensor.clone()) param.grad = grad.clone() @@ -134,7 +148,7 @@ def test_calculate_laprop_update_with_zero_momentum_equals_rmsprop(self) -> None ) # Manually set RMSProp's internal state - rmsprop_optimizer.state[param]["step"] = torch.tensor(float(step)) + rmsprop_optimizer.state[param]["step"] = torch.tensor(float(step), device=self.device) rmsprop_optimizer.state[param]["square_avg"] = exp_avg_sq_initial.clone() rmsprop_optimizer.state[param]["momentum_buffer"] = exp_avg_initial.clone() @@ -150,10 +164,10 @@ def test_calculate_laprop_update_with_zero_momentum_equals_rmsprop(self) -> None def test_calculate_ademamix_update_with_alpha_zero_equals_adam(self) -> None: # AdEMAMix with alpha=0 and no beta scheduling should be equivalent to Adam. - exp_avg_fast_initial = torch.tensor([[1.0]]) - exp_avg_slow_initial = torch.tensor([[1.0]]) - exp_avg_sq_initial = torch.tensor([[2.0]]) - grad = torch.tensor([[0.5]]) + exp_avg_fast_initial = torch.tensor([[1.0]], device=self.device) + exp_avg_slow_initial = torch.tensor([[1.0]], device=self.device) + exp_avg_sq_initial = torch.tensor([[2.0]], device=self.device) + grad = torch.tensor([[0.5]], device=self.device) betas = (0.9, 0.99, 0.999) eps = 1e-8 step = 10 @@ -195,9 +209,9 @@ def test_calculate_ademamix_update_with_alpha_zero_equals_adam(self) -> None: def test_calculate_sim_ademamix_update_with_zero_momentum_and_alpha_equals_rmsprop(self) -> None: # sim_ademamix with momentum (beta_fast) = 0 and alpha = 0 should be equivalent to RMSProp. - exp_avg_initial = torch.tensor([[0.0]]) # Momentum is 0, so exp_avg starts at 0 - exp_avg_sq_initial = torch.tensor([[2.0]]) - grad = torch.tensor([[0.5]]) + exp_avg_initial = torch.tensor([[0.0]], device=self.device) # Momentum is 0, so exp_avg starts at 0 + exp_avg_sq_initial = torch.tensor([[2.0]], device=self.device) + grad = torch.tensor([[0.5]], device=self.device) betas = (0.0, 0.99) # beta1=0 for momentum eps = 1e-8 step = 10 @@ -221,7 +235,7 @@ def test_calculate_sim_ademamix_update_with_zero_momentum_and_alpha_equals_rmspr ) # Manually verify with RMSProp logic - initial_param_val_tensor = torch.tensor([[10.0]]) + initial_param_val_tensor = torch.tensor([[10.0]], device=self.device) param = torch.nn.Parameter(initial_param_val_tensor.clone()) param.grad = grad.clone() @@ -236,7 +250,7 @@ def test_calculate_sim_ademamix_update_with_zero_momentum_and_alpha_equals_rmspr ) # Manually set RMSProp's internal state - rmsprop_optimizer.state[param]["step"] = torch.tensor(float(step)) + rmsprop_optimizer.state[param]["step"] = torch.tensor(float(step), device=self.device) rmsprop_optimizer.state[param]["square_avg"] = exp_avg_sq_initial.clone() rmsprop_optimizer.step()