Skip to content

Commit e1a8def

Browse files
Improve tests for scalar optimizers (#39)
* test on both CPU and GPU Signed-off-by: mikail <[email protected]>
1 parent e7bc96e commit e1a8def

File tree

4 files changed

+45
-29
lines changed

4 files changed

+45
-29
lines changed

tests/ci/L0_Tests_CPU.sh

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -15,3 +15,5 @@ export TORCH_COMPILE_DISABLE=1
1515
set -o pipefail
1616
torchrun --nproc_per_node=8 --no-python coverage run -p tests/test_distributed_muon_utils_cpu.py
1717
torchrun --nproc_per_node=4 --no-python coverage run -p tests/test_distributed_muon_utils_cpu.py
18+
coverage run -p --source=emerging_optimizers tests/test_scalar_optimizers.py --device=cpu
19+

tests/ci/L0_Tests_GPU.sh

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -20,5 +20,5 @@ coverage run -p --source=emerging_optimizers tests/test_soap_functions.py
2020
coverage run -p --source=emerging_optimizers tests/test_soap_utils.py
2121
coverage run -p --source=emerging_optimizers tests/soap_smoke_test.py
2222
coverage run -p --source=emerging_optimizers tests/soap_mnist_test.py
23-
coverage run -p --source=emerging_optimizers tests/test_scalar_optimizers.py
23+
coverage run -p --source=emerging_optimizers tests/test_scalar_optimizers.py --device=cuda
2424
coverage run -p --source=emerging_optimizers tests/test_spectral_clipping_utils.py

tests/ci/L1_Tests_GPU.sh

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -18,5 +18,5 @@ python tests/test_orthogonalized_optimizer.py
1818
python tests/test_soap_functions.py
1919
python tests/test_soap_utils.py
2020
python tests/soap_smoke_test.py
21-
python tests/test_scalar_optimizers.py
22-
python tests/test_spectral_clipping_utils.py
21+
python tests/test_scalar_optimizers.py --device=cuda
22+
python tests/test_spectral_clipping_utils.py

tests/test_scalar_optimizers.py

Lines changed: 40 additions & 26 deletions
Original file line numberDiff line numberDiff line change
@@ -13,6 +13,7 @@
1313
# See the License for the specific language governing permissions and
1414
# limitations under the License.
1515
import torch
16+
from absl import flags
1617
from absl.testing import absltest, parameterized
1718

1819
from emerging_optimizers.scalar_optimizers import (
@@ -23,22 +24,35 @@
2324
)
2425

2526

26-
# Base class for tests requiring seeding for determinism
27-
class BaseTestCase(parameterized.TestCase):
27+
# Define command line flags
28+
flags.DEFINE_string("device", "cpu", "Device to run tests on: 'cpu' or 'cuda'")
29+
flags.DEFINE_integer("seed", 42, "Random seed for reproducible tests")
30+
31+
FLAGS = flags.FLAGS
32+
33+
34+
class ScalarOptimizerTest(parameterized.TestCase):
2835
def setUp(self):
29-
"""Set random seed before each test."""
30-
# Set seed for PyTorch
31-
torch.manual_seed(42)
36+
"""Set random seed and device before each test."""
37+
# Set seed for PyTorch (using seed from flags)
38+
torch.manual_seed(FLAGS.seed)
3239
# Set seed for CUDA if available
3340
if torch.cuda.is_available():
34-
torch.cuda.manual_seed_all(42)
41+
torch.cuda.manual_seed_all(FLAGS.seed)
3542

43+
# Set up device based on flags
44+
self.device = FLAGS.device
3645

37-
class ScalarOptimizerTest(BaseTestCase):
3846
def test_calculate_adam_update_simple(self) -> None:
39-
exp_avg_initial = torch.tensor([[1.0]])
40-
exp_avg_sq_initial = torch.tensor([[2.0]])
41-
grad = torch.tensor([[0.5]])
47+
exp_avg_initial = torch.tensor([[1.0]], device=self.device)
48+
exp_avg_sq_initial = torch.tensor([[2.0]], device=self.device)
49+
grad = torch.tensor([[0.5]], device=self.device)
50+
51+
# Move tensors to the test device
52+
exp_avg_initial = exp_avg_initial.to(self.device)
53+
exp_avg_sq_initial = exp_avg_sq_initial.to(self.device)
54+
grad = grad.to(self.device)
55+
4256
betas = (0.9, 0.99)
4357
eps = 1e-8
4458
step = 10
@@ -59,7 +73,7 @@ def test_calculate_adam_update_simple(self) -> None:
5973
eps=eps,
6074
)
6175

62-
initial_param_val_tensor = torch.tensor([[10.0]])
76+
initial_param_val_tensor = torch.tensor([[10.0]]).to(self.device)
6377
param = torch.nn.Parameter(initial_param_val_tensor.clone())
6478
param.grad = grad.clone()
6579

@@ -73,7 +87,7 @@ def test_calculate_adam_update_simple(self) -> None:
7387
)
7488

7589
# Manually set Adam's internal state to match conditions *before* the current update
76-
adam_optimizer.state[param]["step"] = torch.tensor(float(step - 1))
90+
adam_optimizer.state[param]["step"] = torch.tensor(float(step - 1), device=self.device)
7791
adam_optimizer.state[param]["exp_avg"] = exp_avg_initial.clone()
7892
adam_optimizer.state[param]["exp_avg_sq"] = exp_avg_sq_initial.clone()
7993

@@ -96,9 +110,9 @@ def test_calculate_adam_update_simple(self) -> None:
96110

97111
def test_calculate_laprop_update_with_zero_momentum_equals_rmsprop(self) -> None:
98112
# LaProp with momentum (beta1) = 0 should be equivalent to RMSProp.
99-
exp_avg_initial = torch.tensor([[0.0]]) # Momentum is 0, so exp_avg starts at 0
100-
exp_avg_sq_initial = torch.tensor([[2.0]])
101-
grad = torch.tensor([[0.5]])
113+
exp_avg_initial = torch.tensor([[0.0]], device=self.device) # Momentum is 0, so exp_avg starts at 0
114+
exp_avg_sq_initial = torch.tensor([[2.0]], device=self.device)
115+
grad = torch.tensor([[0.5]], device=self.device)
102116
betas = (0.0, 0.99) # beta1=0 for momentum
103117
eps = 1e-8
104118
step = 10
@@ -119,7 +133,7 @@ def test_calculate_laprop_update_with_zero_momentum_equals_rmsprop(self) -> None
119133
)
120134

121135
# Manually verify with RMSProp logic
122-
initial_param_val_tensor = torch.tensor([[10.0]])
136+
initial_param_val_tensor = torch.tensor([[10.0]], device=self.device)
123137
param = torch.nn.Parameter(initial_param_val_tensor.clone())
124138
param.grad = grad.clone()
125139

@@ -134,7 +148,7 @@ def test_calculate_laprop_update_with_zero_momentum_equals_rmsprop(self) -> None
134148
)
135149

136150
# Manually set RMSProp's internal state
137-
rmsprop_optimizer.state[param]["step"] = torch.tensor(float(step))
151+
rmsprop_optimizer.state[param]["step"] = torch.tensor(float(step), device=self.device)
138152
rmsprop_optimizer.state[param]["square_avg"] = exp_avg_sq_initial.clone()
139153
rmsprop_optimizer.state[param]["momentum_buffer"] = exp_avg_initial.clone()
140154

@@ -150,10 +164,10 @@ def test_calculate_laprop_update_with_zero_momentum_equals_rmsprop(self) -> None
150164

151165
def test_calculate_ademamix_update_with_alpha_zero_equals_adam(self) -> None:
152166
# AdEMAMix with alpha=0 and no beta scheduling should be equivalent to Adam.
153-
exp_avg_fast_initial = torch.tensor([[1.0]])
154-
exp_avg_slow_initial = torch.tensor([[1.0]])
155-
exp_avg_sq_initial = torch.tensor([[2.0]])
156-
grad = torch.tensor([[0.5]])
167+
exp_avg_fast_initial = torch.tensor([[1.0]], device=self.device)
168+
exp_avg_slow_initial = torch.tensor([[1.0]], device=self.device)
169+
exp_avg_sq_initial = torch.tensor([[2.0]], device=self.device)
170+
grad = torch.tensor([[0.5]], device=self.device)
157171
betas = (0.9, 0.99, 0.999)
158172
eps = 1e-8
159173
step = 10
@@ -195,9 +209,9 @@ def test_calculate_ademamix_update_with_alpha_zero_equals_adam(self) -> None:
195209

196210
def test_calculate_sim_ademamix_update_with_zero_momentum_and_alpha_equals_rmsprop(self) -> None:
197211
# sim_ademamix with momentum (beta_fast) = 0 and alpha = 0 should be equivalent to RMSProp.
198-
exp_avg_initial = torch.tensor([[0.0]]) # Momentum is 0, so exp_avg starts at 0
199-
exp_avg_sq_initial = torch.tensor([[2.0]])
200-
grad = torch.tensor([[0.5]])
212+
exp_avg_initial = torch.tensor([[0.0]], device=self.device) # Momentum is 0, so exp_avg starts at 0
213+
exp_avg_sq_initial = torch.tensor([[2.0]], device=self.device)
214+
grad = torch.tensor([[0.5]], device=self.device)
201215
betas = (0.0, 0.99) # beta1=0 for momentum
202216
eps = 1e-8
203217
step = 10
@@ -221,7 +235,7 @@ def test_calculate_sim_ademamix_update_with_zero_momentum_and_alpha_equals_rmspr
221235
)
222236

223237
# Manually verify with RMSProp logic
224-
initial_param_val_tensor = torch.tensor([[10.0]])
238+
initial_param_val_tensor = torch.tensor([[10.0]], device=self.device)
225239
param = torch.nn.Parameter(initial_param_val_tensor.clone())
226240
param.grad = grad.clone()
227241

@@ -236,7 +250,7 @@ def test_calculate_sim_ademamix_update_with_zero_momentum_and_alpha_equals_rmspr
236250
)
237251

238252
# Manually set RMSProp's internal state
239-
rmsprop_optimizer.state[param]["step"] = torch.tensor(float(step))
253+
rmsprop_optimizer.state[param]["step"] = torch.tensor(float(step), device=self.device)
240254
rmsprop_optimizer.state[param]["square_avg"] = exp_avg_sq_initial.clone()
241255

242256
rmsprop_optimizer.step()

0 commit comments

Comments
 (0)