1313# See the License for the specific language governing permissions and
1414# limitations under the License.
1515import torch
16+ from absl import flags
1617from absl .testing import absltest , parameterized
1718
1819from emerging_optimizers .scalar_optimizers import (
2324)
2425
2526
26- # Base class for tests requiring seeding for determinism
27- class BaseTestCase (parameterized .TestCase ):
27+ # Define command line flags
28+ flags .DEFINE_string ("device" , "cpu" , "Device to run tests on: 'cpu' or 'cuda'" )
29+ flags .DEFINE_integer ("seed" , 42 , "Random seed for reproducible tests" )
30+
31+ FLAGS = flags .FLAGS
32+
33+
34+ class ScalarOptimizerTest (parameterized .TestCase ):
2835 def setUp (self ):
29- """Set random seed before each test."""
30- # Set seed for PyTorch
31- torch .manual_seed (42 )
36+ """Set random seed and device before each test."""
37+ # Set seed for PyTorch (using seed from flags)
38+ torch .manual_seed (FLAGS . seed )
3239 # Set seed for CUDA if available
3340 if torch .cuda .is_available ():
34- torch .cuda .manual_seed_all (42 )
41+ torch .cuda .manual_seed_all (FLAGS . seed )
3542
43+ # Set up device based on flags
44+ self .device = FLAGS .device
3645
37- class ScalarOptimizerTest (BaseTestCase ):
3846 def test_calculate_adam_update_simple (self ) -> None :
39- exp_avg_initial = torch .tensor ([[1.0 ]])
40- exp_avg_sq_initial = torch .tensor ([[2.0 ]])
41- grad = torch .tensor ([[0.5 ]])
47+ exp_avg_initial = torch .tensor ([[1.0 ]], device = self .device )
48+ exp_avg_sq_initial = torch .tensor ([[2.0 ]], device = self .device )
49+ grad = torch .tensor ([[0.5 ]], device = self .device )
50+
51+ # Move tensors to the test device
52+ exp_avg_initial = exp_avg_initial .to (self .device )
53+ exp_avg_sq_initial = exp_avg_sq_initial .to (self .device )
54+ grad = grad .to (self .device )
55+
4256 betas = (0.9 , 0.99 )
4357 eps = 1e-8
4458 step = 10
@@ -59,7 +73,7 @@ def test_calculate_adam_update_simple(self) -> None:
5973 eps = eps ,
6074 )
6175
62- initial_param_val_tensor = torch .tensor ([[10.0 ]])
76+ initial_param_val_tensor = torch .tensor ([[10.0 ]]). to ( self . device )
6377 param = torch .nn .Parameter (initial_param_val_tensor .clone ())
6478 param .grad = grad .clone ()
6579
@@ -73,7 +87,7 @@ def test_calculate_adam_update_simple(self) -> None:
7387 )
7488
7589 # Manually set Adam's internal state to match conditions *before* the current update
76- adam_optimizer .state [param ]["step" ] = torch .tensor (float (step - 1 ))
90+ adam_optimizer .state [param ]["step" ] = torch .tensor (float (step - 1 ), device = self . device )
7791 adam_optimizer .state [param ]["exp_avg" ] = exp_avg_initial .clone ()
7892 adam_optimizer .state [param ]["exp_avg_sq" ] = exp_avg_sq_initial .clone ()
7993
@@ -96,9 +110,9 @@ def test_calculate_adam_update_simple(self) -> None:
96110
97111 def test_calculate_laprop_update_with_zero_momentum_equals_rmsprop (self ) -> None :
98112 # LaProp with momentum (beta1) = 0 should be equivalent to RMSProp.
99- exp_avg_initial = torch .tensor ([[0.0 ]]) # Momentum is 0, so exp_avg starts at 0
100- exp_avg_sq_initial = torch .tensor ([[2.0 ]])
101- grad = torch .tensor ([[0.5 ]])
113+ exp_avg_initial = torch .tensor ([[0.0 ]], device = self . device ) # Momentum is 0, so exp_avg starts at 0
114+ exp_avg_sq_initial = torch .tensor ([[2.0 ]], device = self . device )
115+ grad = torch .tensor ([[0.5 ]], device = self . device )
102116 betas = (0.0 , 0.99 ) # beta1=0 for momentum
103117 eps = 1e-8
104118 step = 10
@@ -119,7 +133,7 @@ def test_calculate_laprop_update_with_zero_momentum_equals_rmsprop(self) -> None
119133 )
120134
121135 # Manually verify with RMSProp logic
122- initial_param_val_tensor = torch .tensor ([[10.0 ]])
136+ initial_param_val_tensor = torch .tensor ([[10.0 ]], device = self . device )
123137 param = torch .nn .Parameter (initial_param_val_tensor .clone ())
124138 param .grad = grad .clone ()
125139
@@ -134,7 +148,7 @@ def test_calculate_laprop_update_with_zero_momentum_equals_rmsprop(self) -> None
134148 )
135149
136150 # Manually set RMSProp's internal state
137- rmsprop_optimizer .state [param ]["step" ] = torch .tensor (float (step ))
151+ rmsprop_optimizer .state [param ]["step" ] = torch .tensor (float (step ), device = self . device )
138152 rmsprop_optimizer .state [param ]["square_avg" ] = exp_avg_sq_initial .clone ()
139153 rmsprop_optimizer .state [param ]["momentum_buffer" ] = exp_avg_initial .clone ()
140154
@@ -150,10 +164,10 @@ def test_calculate_laprop_update_with_zero_momentum_equals_rmsprop(self) -> None
150164
151165 def test_calculate_ademamix_update_with_alpha_zero_equals_adam (self ) -> None :
152166 # AdEMAMix with alpha=0 and no beta scheduling should be equivalent to Adam.
153- exp_avg_fast_initial = torch .tensor ([[1.0 ]])
154- exp_avg_slow_initial = torch .tensor ([[1.0 ]])
155- exp_avg_sq_initial = torch .tensor ([[2.0 ]])
156- grad = torch .tensor ([[0.5 ]])
167+ exp_avg_fast_initial = torch .tensor ([[1.0 ]], device = self . device )
168+ exp_avg_slow_initial = torch .tensor ([[1.0 ]], device = self . device )
169+ exp_avg_sq_initial = torch .tensor ([[2.0 ]], device = self . device )
170+ grad = torch .tensor ([[0.5 ]], device = self . device )
157171 betas = (0.9 , 0.99 , 0.999 )
158172 eps = 1e-8
159173 step = 10
@@ -195,9 +209,9 @@ def test_calculate_ademamix_update_with_alpha_zero_equals_adam(self) -> None:
195209
196210 def test_calculate_sim_ademamix_update_with_zero_momentum_and_alpha_equals_rmsprop (self ) -> None :
197211 # sim_ademamix with momentum (beta_fast) = 0 and alpha = 0 should be equivalent to RMSProp.
198- exp_avg_initial = torch .tensor ([[0.0 ]]) # Momentum is 0, so exp_avg starts at 0
199- exp_avg_sq_initial = torch .tensor ([[2.0 ]])
200- grad = torch .tensor ([[0.5 ]])
212+ exp_avg_initial = torch .tensor ([[0.0 ]], device = self . device ) # Momentum is 0, so exp_avg starts at 0
213+ exp_avg_sq_initial = torch .tensor ([[2.0 ]], device = self . device )
214+ grad = torch .tensor ([[0.5 ]], device = self . device )
201215 betas = (0.0 , 0.99 ) # beta1=0 for momentum
202216 eps = 1e-8
203217 step = 10
@@ -221,7 +235,7 @@ def test_calculate_sim_ademamix_update_with_zero_momentum_and_alpha_equals_rmspr
221235 )
222236
223237 # Manually verify with RMSProp logic
224- initial_param_val_tensor = torch .tensor ([[10.0 ]])
238+ initial_param_val_tensor = torch .tensor ([[10.0 ]], device = self . device )
225239 param = torch .nn .Parameter (initial_param_val_tensor .clone ())
226240 param .grad = grad .clone ()
227241
@@ -236,7 +250,7 @@ def test_calculate_sim_ademamix_update_with_zero_momentum_and_alpha_equals_rmspr
236250 )
237251
238252 # Manually set RMSProp's internal state
239- rmsprop_optimizer .state [param ]["step" ] = torch .tensor (float (step ))
253+ rmsprop_optimizer .state [param ]["step" ] = torch .tensor (float (step ), device = self . device )
240254 rmsprop_optimizer .state [param ]["square_avg" ] = exp_avg_sq_initial .clone ()
241255
242256 rmsprop_optimizer .step ()
0 commit comments