2727)
2828
2929
30+ @pytest .fixture (scope = 'function' )
31+ def environment ():
32+ return build_environment ()
33+
34+
3035@pytest .mark .parametrize ('optimizer_fp32_config' , OPTIMIZERS , ids = ids )
31- def test_f32_optimizers (optimizer_fp32_config ):
36+ def test_f32_optimizers (optimizer_fp32_config , environment ):
3237 def closure (x ):
3338 def _closure () -> float :
3439 return x
3540
3641 return _closure
3742
38- (x_data , y_data ), model , loss_fn = build_environment ()
43+ (x_data , y_data ), model , loss_fn = environment
3944
4045 optimizer_class , config , iterations = optimizer_fp32_config
4146
@@ -62,17 +67,14 @@ def _closure() -> float:
6267
6368 loss .backward ()
6469
65- if optimizer_name == 'AliG' :
66- optimizer .step (closure (loss ))
67- else :
68- optimizer .step ()
70+ optimizer .step (closure (loss ) if optimizer_name == 'AliG' else None )
6971
7072 assert tensor_to_numpy (init_loss ) > 1.5 * tensor_to_numpy (loss )
7173
7274
7375@pytest .mark .parametrize ('pullback_momentum' , PULLBACK_MOMENTUM )
74- def test_lookahead (pullback_momentum ):
75- (x_data , y_data ), model , loss_fn = build_environment ()
76+ def test_lookahead (pullback_momentum , environment ):
77+ (x_data , y_data ), model , loss_fn = environment
7678
7779 optimizer = Lookahead (load_optimizer ('adamp' )(model .parameters (), lr = 5e-1 ), pullback_momentum = pullback_momentum )
7880
@@ -94,8 +96,8 @@ def test_lookahead(pullback_momentum):
9496
9597
9698@pytest .mark .parametrize ('adaptive' , ADAPTIVE_FLAGS )
97- def test_sam_optimizers (adaptive ):
98- (x_data , y_data ), model , loss_fn = build_environment ()
99+ def test_sam_optimizers (adaptive , environment ):
100+ (x_data , y_data ), model , loss_fn = environment
99101
100102 optimizer = SAM (model .parameters (), load_optimizer ('asgd' ), lr = 5e-1 , adaptive = adaptive )
101103
@@ -115,8 +117,8 @@ def test_sam_optimizers(adaptive):
115117
116118
117119@pytest .mark .parametrize ('adaptive' , ADAPTIVE_FLAGS )
118- def test_sam_optimizers_with_closure (adaptive ):
119- (x_data , y_data ), model , loss_fn = build_environment ()
120+ def test_sam_optimizers_with_closure (adaptive , environment ):
121+ (x_data , y_data ), model , loss_fn = environment
120122
121123 optimizer = SAM (model .parameters (), load_optimizer ('adamp' ), lr = 5e-1 , adaptive = adaptive )
122124
@@ -140,14 +142,10 @@ def closure():
140142
141143
142144@pytest .mark .parametrize ('adaptive' , ADAPTIVE_FLAGS )
143- def test_gsam_optimizers (adaptive ):
145+ def test_gsam_optimizers (adaptive , environment ):
144146 pytest .skip ('skip GSAM optimizer' )
145147
146- (x_data , y_data ), model , loss_fn = build_environment ()
147-
148- x_data = x_data .cuda ()
149- y_data = y_data .cuda ()
150- model .cuda ()
148+ (x_data , y_data ), model , loss_fn = environment
151149
152150 lr : float = 5e-1
153151 num_iterations : int = 25
@@ -174,8 +172,8 @@ def test_gsam_optimizers(adaptive):
174172
175173
176174@pytest .mark .parametrize ('optimizer_config' , ADANORM_SUPPORTED_OPTIMIZERS , ids = ids )
177- def test_adanorm_optimizers (optimizer_config ):
178- (x_data , y_data ), model , loss_fn = build_environment ()
175+ def test_adanorm_optimizers (optimizer_config , environment ):
176+ (x_data , y_data ), model , loss_fn = environment
179177
180178 optimizer_class , config , num_iterations = optimizer_config
181179 if optimizer_class .__name__ == 'Ranger21' :
@@ -215,8 +213,8 @@ def test_adanorm_condition(optimizer_config):
215213
216214
217215@pytest .mark .parametrize ('optimizer_config' , ADAMD_SUPPORTED_OPTIMIZERS , ids = ids )
218- def test_adamd_optimizers (optimizer_config ):
219- (x_data , y_data ), model , loss_fn = build_environment ()
216+ def test_adamd_optimizers (optimizer_config , environment ):
217+ (x_data , y_data ), model , loss_fn = environment
220218
221219 optimizer_class , config , num_iterations = optimizer_config
222220 if optimizer_class .__name__ == 'Ranger21' :
@@ -343,13 +341,14 @@ def test_d_adapt_reset(require_gradient, sparse_gradient, optimizer_name):
343341 param .grad = None
344342
345343 optimizer = load_optimizer (optimizer_name )([param ])
346- assert str (optimizer ) == optimizer_name
347344 optimizer .reset ()
348345
346+ assert str (optimizer ) == optimizer_name
347+
349348
350349@pytest .mark .parametrize ('pre_conditioner_type' , [0 , 1 , 2 ])
351- def test_scalable_shampoo_pre_conditioner_with_svd (pre_conditioner_type ):
352- (x_data , y_data ), _ , loss_fn = build_environment ()
350+ def test_scalable_shampoo_pre_conditioner_with_svd (pre_conditioner_type , environment ):
351+ (x_data , y_data ), _ , loss_fn = environment
353352
354353 model = nn .Sequential (
355354 nn .Linear (2 , 4096 ),
@@ -382,5 +381,6 @@ def test_sm3_make_sparse():
382381
383382def test_sm3_rank0 ():
384383 optimizer = load_optimizer ('sm3' )([simple_zero_rank_parameter (True )])
385- assert str (optimizer ) == 'SM3'
386384 optimizer .step ()
385+
386+ assert str (optimizer ) == 'SM3'
0 commit comments