1313 SGDP ,
1414 AdaBelief ,
1515 AdaBound ,
16+ Adai ,
1617 AdamP ,
1718 Adan ,
1819 AdaPNM ,
3637 dummy_closure ,
3738 ids ,
3839 make_dataset ,
40+ names ,
3941 tensor_to_numpy ,
4042)
4143
5052 (AdaBound , {'lr' : 5e-1 , 'gamma' : 0.1 , 'weight_decay' : 1e-3 , 'fixed_decay' : True }, 100 ),
5153 (AdaBound , {'lr' : 5e-1 , 'gamma' : 0.1 , 'weight_decay' : 1e-3 , 'weight_decouple' : False }, 100 ),
5254 (AdaBound , {'lr' : 5e-1 , 'gamma' : 0.1 , 'weight_decay' : 1e-3 , 'amsbound' : True }, 100 ),
55+ (Adai , {'lr' : 1e-1 , 'weight_decay' : 0.0 }, 200 ),
56+ (Adai , {'lr' : 1e-1 , 'weight_decay' : 0.0 , 'dampening' : 0.9 }, 200 ),
57+ (Adai , {'lr' : 1e-1 , 'weight_decay' : 1e-4 , 'weight_decouple' : False }, 200 ),
58+ (Adai , {'lr' : 1e-1 , 'weight_decay' : 1e-4 , 'weight_decouple' : True }, 200 ),
5359 (AdamP , {'lr' : 5e-1 , 'weight_decay' : 1e-3 }, 100 ),
5460 (AdamP , {'lr' : 5e-1 , 'weight_decay' : 1e-3 , 'use_gc' : True }, 100 ),
5561 (AdamP , {'lr' : 5e-1 , 'weight_decay' : 1e-3 , 'nesterov' : True }, 100 ),
8490 (Adan , {'lr' : 1e-0 , 'weight_decay' : 1e-3 , 'use_gc' : True }, 100 ),
8591 (Adan , {'lr' : 1e-0 , 'weight_decay' : 1e-3 , 'use_gc' : True , 'weight_decouple' : True }, 100 ),
8692]
87-
8893ADAMD_SUPPORTED_OPTIMIZERS : List [Tuple [Any , Dict [str , Union [float , bool , int ]], int ]] = [
8994 (build_lookahead , {'lr' : 5e-1 , 'weight_decay' : 1e-3 , 'adamd_debias_term' : True }, 100 ),
9095 (AdaBelief , {'lr' : 5e-1 , 'weight_decay' : 1e-3 , 'adamd_debias_term' : True }, 100 ),
@@ -167,6 +172,7 @@ def test_safe_f16_optimizers(optimizer_fp16_config):
167172 or (optimizer_name == 'Nero' )
168173 or (optimizer_name == 'Adan' and 'weight_decay' not in config )
169174 or (optimizer_name == 'RAdam' )
175+ or (optimizer_name == 'Adai' )
170176 ):
171177 pytest .skip (f'skip { optimizer_name } ' )
172178
@@ -195,8 +201,10 @@ def test_sam_optimizers(adaptive, optimizer_sam_config):
195201 (x_data , y_data ), model , loss_fn = build_environment ()
196202
197203 optimizer_class , config , iterations = optimizer_sam_config
198- if optimizer_class .__name__ == 'Shampoo' :
199- pytest .skip (f'skip { optimizer_class .__name__ } ' )
204+
205+ optimizer_name : str = optimizer_class .__name__
206+ if (optimizer_name == 'Shampoo' ) or (optimizer_name == 'Adai' ):
207+ pytest .skip (f'skip { optimizer_name } ' )
200208
201209 optimizer = SAM (model .parameters (), optimizer_class , ** config , adaptive = adaptive )
202210
@@ -221,8 +229,10 @@ def test_sam_optimizers_with_closure(adaptive, optimizer_sam_config):
221229 (x_data , y_data ), model , loss_fn = build_environment ()
222230
223231 optimizer_class , config , iterations = optimizer_sam_config
224- if optimizer_class .__name__ == 'Shampoo' :
225- pytest .skip (f'skip { optimizer_class .__name__ } ' )
232+
233+ optimizer_name : str = optimizer_class .__name__
234+ if (optimizer_name == 'Shampoo' ) or (optimizer_name == 'Adai' ):
235+ pytest .skip (f'skip { optimizer_name } ' )
226236
227237 optimizer = SAM (model .parameters (), optimizer_class , ** config , adaptive = adaptive )
228238
@@ -335,26 +345,31 @@ def test_no_gradients(optimizer_config):
335345 assert tensor_to_numpy (init_loss ) >= tensor_to_numpy (loss )
336346
337347
338- @pytest .mark .parametrize ('optimizer_config ' , OPTIMIZERS , ids = ids )
339- def test_closure (optimizer_config ):
348+ @pytest .mark .parametrize ('optimizer ' , set ( config [ 0 ] for config in OPTIMIZERS ) , ids = names )
349+ def test_closure (optimizer ):
340350 _ , model , _ = build_environment ()
341351
342- optimizer_class , config , _ = optimizer_config
343- if optimizer_class .__name__ == 'Ranger21' :
344- pytest .skip (f'skip { optimizer_class .__name__ } ' )
345-
346- optimizer = optimizer_class (model .parameters (), ** config )
352+ if optimizer .__name__ == 'Ranger21' :
353+ optimizer = optimizer (model .parameters (), num_iterations = 1 )
354+ else :
355+ optimizer = optimizer (model .parameters ())
347356
348357 optimizer .zero_grad ()
349- optimizer .step (closure = dummy_closure )
358+
359+ try :
360+ optimizer .step (closure = dummy_closure )
361+ except ValueError : # in case of Ranger21, Adai optimizers
362+ pass
350363
351364
352365@pytest .mark .parametrize ('optimizer_config' , OPTIMIZERS , ids = ids )
353366def test_reset (optimizer_config ):
354367 _ , model , _ = build_environment ()
355368
356369 optimizer_class , config , _ = optimizer_config
357- optimizer = optimizer_class (model .parameters (), ** config )
370+ if optimizer_class .__name__ == 'Ranger21' :
371+ config .update ({'num_iterations' : 1 })
358372
373+ optimizer = optimizer_class (model .parameters (), ** config )
359374 optimizer .zero_grad ()
360375 optimizer .reset ()
0 commit comments