Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 2 additions & 0 deletions tests/ci/L0_Tests_CPU.sh
Original file line number Diff line number Diff line change
Expand Up @@ -15,3 +15,5 @@ export TORCH_COMPILE_DISABLE=1
set -o pipefail
torchrun --nproc_per_node=8 --no-python coverage run -p tests/test_distributed_muon_utils_cpu.py
torchrun --nproc_per_node=4 --no-python coverage run -p tests/test_distributed_muon_utils_cpu.py
coverage run -p --source=emerging_optimizers tests/test_scalar_optimizers.py --device=cpu

2 changes: 1 addition & 1 deletion tests/ci/L0_Tests_GPU.sh
Original file line number Diff line number Diff line change
Expand Up @@ -20,5 +20,5 @@ coverage run -p --source=emerging_optimizers tests/test_soap_functions.py
coverage run -p --source=emerging_optimizers tests/test_soap_utils.py
coverage run -p --source=emerging_optimizers tests/soap_smoke_test.py
coverage run -p --source=emerging_optimizers tests/soap_mnist_test.py
coverage run -p --source=emerging_optimizers tests/test_scalar_optimizers.py
coverage run -p --source=emerging_optimizers tests/test_scalar_optimizers.py --device=cuda
coverage run -p --source=emerging_optimizers tests/test_spectral_clipping_utils.py
4 changes: 2 additions & 2 deletions tests/ci/L1_Tests_GPU.sh
Original file line number Diff line number Diff line change
Expand Up @@ -18,5 +18,5 @@ python tests/test_orthogonalized_optimizer.py
python tests/test_soap_functions.py
python tests/test_soap_utils.py
python tests/soap_smoke_test.py
python tests/test_scalar_optimizers.py
python tests/test_spectral_clipping_utils.py
python tests/test_scalar_optimizers.py --device=cuda
python tests/test_spectral_clipping_utils.py
66 changes: 40 additions & 26 deletions tests/test_scalar_optimizers.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,6 +13,7 @@
# See the License for the specific language governing permissions and
# limitations under the License.
import torch
from absl import flags
from absl.testing import absltest, parameterized

from emerging_optimizers.scalar_optimizers import (
Expand All @@ -23,22 +24,35 @@
)


# Base class for tests requiring seeding for determinism
class BaseTestCase(parameterized.TestCase):
# Define command line flags
flags.DEFINE_string("device", "cpu", "Device to run tests on: 'cpu' or 'cuda'")
flags.DEFINE_integer("seed", 42, "Random seed for reproducible tests")

FLAGS = flags.FLAGS


class ScalarOptimizerTest(parameterized.TestCase):
def setUp(self):
"""Set random seed before each test."""
# Set seed for PyTorch
torch.manual_seed(42)
"""Set random seed and device before each test."""
# Set seed for PyTorch (using seed from flags)
torch.manual_seed(FLAGS.seed)
# Set seed for CUDA if available
if torch.cuda.is_available():
torch.cuda.manual_seed_all(42)
torch.cuda.manual_seed_all(FLAGS.seed)

# Set up device based on flags
self.device = FLAGS.device

class ScalarOptimizerTest(BaseTestCase):
def test_calculate_adam_update_simple(self) -> None:
exp_avg_initial = torch.tensor([[1.0]])
exp_avg_sq_initial = torch.tensor([[2.0]])
grad = torch.tensor([[0.5]])
exp_avg_initial = torch.tensor([[1.0]], device=self.device)
exp_avg_sq_initial = torch.tensor([[2.0]], device=self.device)
grad = torch.tensor([[0.5]], device=self.device)

# Move tensors to the test device
exp_avg_initial = exp_avg_initial.to(self.device)
exp_avg_sq_initial = exp_avg_sq_initial.to(self.device)
grad = grad.to(self.device)

betas = (0.9, 0.99)
eps = 1e-8
step = 10
Expand All @@ -59,7 +73,7 @@ def test_calculate_adam_update_simple(self) -> None:
eps=eps,
)

initial_param_val_tensor = torch.tensor([[10.0]])
initial_param_val_tensor = torch.tensor([[10.0]]).to(self.device)
param = torch.nn.Parameter(initial_param_val_tensor.clone())
param.grad = grad.clone()

Expand All @@ -73,7 +87,7 @@ def test_calculate_adam_update_simple(self) -> None:
)

# Manually set Adam's internal state to match conditions *before* the current update
adam_optimizer.state[param]["step"] = torch.tensor(float(step - 1))
adam_optimizer.state[param]["step"] = torch.tensor(float(step - 1), device=self.device)
adam_optimizer.state[param]["exp_avg"] = exp_avg_initial.clone()
adam_optimizer.state[param]["exp_avg_sq"] = exp_avg_sq_initial.clone()

Expand All @@ -96,9 +110,9 @@ def test_calculate_adam_update_simple(self) -> None:

def test_calculate_laprop_update_with_zero_momentum_equals_rmsprop(self) -> None:
# LaProp with momentum (beta1) = 0 should be equivalent to RMSProp.
exp_avg_initial = torch.tensor([[0.0]]) # Momentum is 0, so exp_avg starts at 0
exp_avg_sq_initial = torch.tensor([[2.0]])
grad = torch.tensor([[0.5]])
exp_avg_initial = torch.tensor([[0.0]], device=self.device) # Momentum is 0, so exp_avg starts at 0
exp_avg_sq_initial = torch.tensor([[2.0]], device=self.device)
grad = torch.tensor([[0.5]], device=self.device)
betas = (0.0, 0.99) # beta1=0 for momentum
eps = 1e-8
step = 10
Expand All @@ -119,7 +133,7 @@ def test_calculate_laprop_update_with_zero_momentum_equals_rmsprop(self) -> None
)

# Manually verify with RMSProp logic
initial_param_val_tensor = torch.tensor([[10.0]])
initial_param_val_tensor = torch.tensor([[10.0]], device=self.device)
param = torch.nn.Parameter(initial_param_val_tensor.clone())
param.grad = grad.clone()

Expand All @@ -134,7 +148,7 @@ def test_calculate_laprop_update_with_zero_momentum_equals_rmsprop(self) -> None
)

# Manually set RMSProp's internal state
rmsprop_optimizer.state[param]["step"] = torch.tensor(float(step))
rmsprop_optimizer.state[param]["step"] = torch.tensor(float(step), device=self.device)
rmsprop_optimizer.state[param]["square_avg"] = exp_avg_sq_initial.clone()
rmsprop_optimizer.state[param]["momentum_buffer"] = exp_avg_initial.clone()

Expand All @@ -150,10 +164,10 @@ def test_calculate_laprop_update_with_zero_momentum_equals_rmsprop(self) -> None

def test_calculate_ademamix_update_with_alpha_zero_equals_adam(self) -> None:
# AdEMAMix with alpha=0 and no beta scheduling should be equivalent to Adam.
exp_avg_fast_initial = torch.tensor([[1.0]])
exp_avg_slow_initial = torch.tensor([[1.0]])
exp_avg_sq_initial = torch.tensor([[2.0]])
grad = torch.tensor([[0.5]])
exp_avg_fast_initial = torch.tensor([[1.0]], device=self.device)
exp_avg_slow_initial = torch.tensor([[1.0]], device=self.device)
exp_avg_sq_initial = torch.tensor([[2.0]], device=self.device)
grad = torch.tensor([[0.5]], device=self.device)
betas = (0.9, 0.99, 0.999)
eps = 1e-8
step = 10
Expand Down Expand Up @@ -195,9 +209,9 @@ def test_calculate_ademamix_update_with_alpha_zero_equals_adam(self) -> None:

def test_calculate_sim_ademamix_update_with_zero_momentum_and_alpha_equals_rmsprop(self) -> None:
# sim_ademamix with momentum (beta_fast) = 0 and alpha = 0 should be equivalent to RMSProp.
exp_avg_initial = torch.tensor([[0.0]]) # Momentum is 0, so exp_avg starts at 0
exp_avg_sq_initial = torch.tensor([[2.0]])
grad = torch.tensor([[0.5]])
exp_avg_initial = torch.tensor([[0.0]], device=self.device) # Momentum is 0, so exp_avg starts at 0
exp_avg_sq_initial = torch.tensor([[2.0]], device=self.device)
grad = torch.tensor([[0.5]], device=self.device)
betas = (0.0, 0.99) # beta1=0 for momentum
eps = 1e-8
step = 10
Expand All @@ -221,7 +235,7 @@ def test_calculate_sim_ademamix_update_with_zero_momentum_and_alpha_equals_rmspr
)

# Manually verify with RMSProp logic
initial_param_val_tensor = torch.tensor([[10.0]])
initial_param_val_tensor = torch.tensor([[10.0]], device=self.device)
param = torch.nn.Parameter(initial_param_val_tensor.clone())
param.grad = grad.clone()

Expand All @@ -236,7 +250,7 @@ def test_calculate_sim_ademamix_update_with_zero_momentum_and_alpha_equals_rmspr
)

# Manually set RMSProp's internal state
rmsprop_optimizer.state[param]["step"] = torch.tensor(float(step))
rmsprop_optimizer.state[param]["step"] = torch.tensor(float(step), device=self.device)
rmsprop_optimizer.state[param]["square_avg"] = exp_avg_sq_initial.clone()

rmsprop_optimizer.step()
Expand Down