File tree Expand file tree Collapse file tree 1 file changed +24
-0
lines changed Expand file tree Collapse file tree 1 file changed +24
-0
lines changed Original file line number Diff line number Diff line change 1010from pytorch_optimizer .lr_scheduler .experimental .deberta_v3_lr_scheduler import deberta_v3_large_lr_scheduler
1111from pytorch_optimizer .lr_scheduler .linear_warmup import CosineScheduler , LinearScheduler , PolyScheduler
1212from pytorch_optimizer .lr_scheduler .proportion import ProportionScheduler
13+ from pytorch_optimizer .lr_scheduler .rex import REXScheduler
1314from tests .utils import Example
1415
1516CAWR_RECIPES = [
@@ -263,6 +264,29 @@ def test_proportion_no_last_lr_scheduler():
263264 np .testing .assert_almost_equal (2.0 , rho_scheduler .get_lr (), 6 )
264265
265266
267+ def test_rex_lr_scheduler ():
268+ lrs = [
269+ 0.888888 ,
270+ 0.749999 ,
271+ 0.571428 ,
272+ 0.333333 ,
273+ 0.0 ,
274+ ]
275+
276+ base_optimizer = AdamP (Example ().parameters ())
277+
278+ lr_scheduler = REXScheduler (
279+ base_optimizer ,
280+ total_steps = 5 ,
281+ max_lr = 1.0 ,
282+ min_lr = 0.0 ,
283+ )
284+
285+ for expected_lr in lrs :
286+ lr : float = lr_scheduler .step ()
287+ np .testing .assert_almost_equal (expected_lr , lr , 6 )
288+
289+
266290def test_deberta_v3_large_lr_scheduler ():
267291 model = nn .Sequential (* [nn .Linear (1 , 1 , bias = False ) for _ in range (400 )])
268292 deberta_v3_large_lr_scheduler (model )
You can’t perform that action at this time.
0 commit comments