@@ -200,17 +200,19 @@ def test_get_chebyshev_lr():
200200 optimizer .step ()
201201
202202 lr_scheduler = get_chebyshev_schedule (optimizer , num_epochs = 16 , is_warmup = True )
203- lr_scheduler .step (0 )
203+ lr_scheduler .last_epoch = 0
204+ lr_scheduler .step ()
204205
205206 np .testing .assert_almost_equal (lr_scheduler .get_last_lr (), 1e-3 )
206207
207208 optimizer = AdamW (Example ().parameters ())
208209 optimizer .step ()
209210
210211 lr_scheduler = get_chebyshev_schedule (optimizer , num_epochs = 16 , is_warmup = False )
212+ lr_scheduler .last_epoch = 0
211213
212- for i , expected_lr in enumerate ( recipes , start = 1 ) :
213- lr_scheduler .step (i )
214+ for expected_lr in recipes :
215+ lr_scheduler .step ()
214216 np .testing .assert_almost_equal (lr_scheduler .get_last_lr (), expected_lr )
215217
216218
@@ -311,10 +313,10 @@ def test_wsd_lr_scheduler():
311313
312314 lr_scheduler = get_wsd_schedule (optimizer , 2 , 2 , 3 , min_lr_ratio = 0.1 )
313315
314- expected_lrs = [0.0 , 0. 0005 , 0.001 , 0.001 , 0.001 , 0.000775 , 0.000325 , 0.0001 , 0.0001 , 0.0001 ]
316+ expected_lrs = [0.0005 , 0.001 , 0.001 , 0.001 , 0.000775 , 0.000325 , 0.0001 , 0.0001 , 0.0001 ]
315317
316- for step , expected_lr in enumerate ( expected_lrs ) :
317- lr_scheduler .step (step )
318+ for expected_lr in expected_lrs :
319+ lr_scheduler .step ()
318320 np .testing .assert_almost_equal (expected_lr , lr_scheduler .get_last_lr (), 6 )
319321
320322
0 commit comments