@@ -107,18 +107,18 @@ def test_lookahead_k(optimizer_name):
107107 optimizer (None , k = - 1 )
108108
109109
110- @pytest .mark .parametrize ('optimizer_name' , ['ranger21' ])
111- def test_beta0 (optimizer_name ):
112- optimizer = load_optimizer (optimizer_name )
113- with pytest .raises (ValueError ):
114- optimizer (None , num_iterations = 200 , beta0 = - 0.1 )
115-
116-
117- @pytest .mark .parametrize ('optimizer_name' , ['nero' , 'apollo' , 'sm3' , 'msvag' ])
110+ @pytest .mark .parametrize ('optimizer_name' , ['nero' , 'apollo' , 'sm3' , 'msvag' , 'ranger21' ])
118111def test_beta (optimizer_name ):
119112 optimizer = load_optimizer (optimizer_name )
120- with pytest .raises (ValueError ):
121- optimizer (None , beta = - 0.1 )
113+
114+ if optimizer_name == 'ranger21' :
115+ # test beta0
116+ with pytest .raises (ValueError ):
117+ optimizer (None , num_iterations = 200 , beta0 = - 0.1 )
118+ else :
119+ # test beta
120+ with pytest .raises (ValueError ):
121+ optimizer (None , beta = - 0.1 )
122122
123123
124124@pytest .mark .parametrize ('optimizer_name' , BETA_OPTIMIZER_NAMES )
@@ -137,6 +137,9 @@ def test_betas(optimizer_name):
137137
138138 with pytest .raises (ValueError ):
139139 optimizer (None , ** config2 )
140+ elif optimizer_name == 'prodigy' :
141+ with pytest .raises (ValueError ):
142+ optimizer (None , beta3 = - 0.1 )
140143 else :
141144 with pytest .raises (ValueError ):
142145 optimizer (None , betas = (0.1 , 0.1 , - 0.1 ))
0 commit comments