@@ -62,16 +62,6 @@ def test_pcgrad_parameters():
6262 PCGrad (opt , reduction = 'invalid' )
6363
6464
65- def test_sam_parameters ():
66- with pytest .raises (ValueError ):
67- SAM (None , load_optimizer ('adamp' ), rho = - 0.1 )
68-
69-
70- def test_wsam_parameters ():
71- with pytest .raises (ValueError ):
72- WSAM (None , None , load_optimizer ('adamp' ), rho = - 0.1 )
73-
74-
7565def test_lookahead_parameters ():
7666 optimizer_instance = load_optimizer ('adamp' )
7767 optimizer = optimizer_instance ([simple_parameter ()])
@@ -96,22 +86,16 @@ def test_lookahead_parameters():
9686 Lookahead (optimizer , pullback_momentum = 'invalid' )
9787
9888
99- def test_sam_methods ():
100- optimizer = SAM ([simple_parameter ()], load_optimizer ('adamp' ))
101- optimizer .init_group ()
102- optimizer .load_state_dict (optimizer .state_dict ())
103-
104-
105- def test_wsam_methods ():
106- optimizer = WSAM (None , [simple_parameter ()], load_optimizer ('adamp' ))
107- optimizer .init_group ()
108- optimizer .load_state_dict (optimizer .state_dict ())
89+ @pytest .mark .parametrize ('optimizer' , [SAM , WSAM , LookSAM ])
90+ def test_sam_family_methods (optimizer ):
91+ base_optimizer = load_optimizer ('lion' )
10992
93+ opt = optimizer (params = [simple_parameter ()], model = None , base_optimizer = base_optimizer )
94+ opt .init_group ({})
95+ opt .load_state_dict (opt .state_dict ())
11096
111- def test_looksam_methods ():
112- optimizer = LookSAM ([simple_parameter ()], load_optimizer ('adamp' ))
113- optimizer .init_group ()
114- optimizer .load_state_dict (optimizer .state_dict ())
97+ with pytest .raises (ValueError ):
98+ optimizer (model = None , params = None , base_optimizer = base_optimizer , rho = - 0.1 )
11599
116100
117101def test_safe_fp16_methods ():
0 commit comments