1- from typing import List
2-
31import numpy as np
2+ import pytest
43
54from pytorch_optimizer import AdamP , get_chebyshev_schedule
65from pytorch_optimizer .lr_scheduler .chebyshev import chebyshev_perm
76from pytorch_optimizer .lr_scheduler .cosine_anealing import CosineAnnealingWarmupRestarts
87from tests .utils import Example
98
10-
11- def test_cosine_annealing_warmup_restarts ():
12- model = Example ()
13- optimizer = AdamP (model .parameters ())
14-
15- def test_scheduler (
16- first_cycle_steps : int ,
17- cycle_mult : float ,
18- max_lr : float ,
19- min_lr : float ,
20- warmup_steps : int ,
21- gamma : float ,
22- max_epochs : int ,
23- expected_lrs : List [float ],
24- ):
25- lr_scheduler = CosineAnnealingWarmupRestarts (
26- optimizer = optimizer ,
27- first_cycle_steps = first_cycle_steps ,
28- cycle_mult = cycle_mult ,
29- max_lr = max_lr ,
30- min_lr = min_lr ,
31- warmup_steps = warmup_steps ,
32- gamma = gamma ,
33- )
34-
35- if warmup_steps > 0 :
36- np .testing .assert_almost_equal (min_lr , round (lr_scheduler .get_lr ()[0 ], 6 ))
37-
38- for epoch in range (max_epochs ):
39- lr_scheduler .step (epoch )
40-
41- lr : float = round (lr_scheduler .get_lr ()[0 ], 6 )
42- np .testing .assert_almost_equal (expected_lrs [epoch ], lr )
43-
44- # case 1
45- test_scheduler (
46- first_cycle_steps = 10 ,
47- cycle_mult = 1.0 ,
48- max_lr = 1e-3 ,
49- min_lr = 1e-6 ,
50- warmup_steps = 5 ,
51- gamma = 1.0 ,
52- max_epochs = 20 ,
53- expected_lrs = [
9+ CAWR_RECIPES = [
10+ (
11+ 10 ,
12+ 1.0 ,
13+ 1e-3 ,
14+ 1e-6 ,
15+ 5 ,
16+ 1.0 ,
17+ 20 ,
18+ [
5419 1e-06 ,
5520 0.000201 ,
5621 0.000401 ,
@@ -72,18 +37,16 @@ def test_scheduler(
7237 0.000346 ,
7338 9.6e-05 ,
7439 ],
75- )
76-
77- # case 2
78- test_scheduler (
79- first_cycle_steps = 10 ,
80- cycle_mult = 0.9 ,
81- max_lr = 1e-3 ,
82- min_lr = 1e-6 ,
83- warmup_steps = 5 ,
84- gamma = 0.5 ,
85- max_epochs = 20 ,
86- expected_lrs = [
40+ ),
41+ (
42+ 10 ,
43+ 0.9 ,
44+ 1e-3 ,
45+ 1e-6 ,
46+ 5 ,
47+ 0.5 ,
48+ 20 ,
49+ [
8750 1e-06 ,
8851 0.000201 ,
8952 0.000401 ,
@@ -105,8 +68,45 @@ def test_scheduler(
10568 7.4e-05 ,
10669 1e-06 ,
10770 ],
71+ ),
72+ ]
73+
74+
75+ @pytest .mark .parametrize ('cosine_annealing_warmup_restart_param' , CAWR_RECIPES )
76+ def test_cosine_annealing_warmup_restarts (cosine_annealing_warmup_restart_param ):
77+ model = Example ()
78+ optimizer = AdamP (model .parameters ())
79+
80+ (
81+ first_cycle_steps ,
82+ cycle_mult ,
83+ max_lr ,
84+ min_lr ,
85+ warmup_steps ,
86+ gamma ,
87+ max_epochs ,
88+ expected_lrs ,
89+ ) = cosine_annealing_warmup_restart_param
90+
91+ lr_scheduler = CosineAnnealingWarmupRestarts (
92+ optimizer = optimizer ,
93+ first_cycle_steps = first_cycle_steps ,
94+ cycle_mult = cycle_mult ,
95+ max_lr = max_lr ,
96+ min_lr = min_lr ,
97+ warmup_steps = warmup_steps ,
98+ gamma = gamma ,
10899 )
109100
101+ if warmup_steps > 0 :
102+ np .testing .assert_almost_equal (min_lr , round (lr_scheduler .get_lr ()[0 ], 6 ))
103+
104+ for epoch in range (max_epochs ):
105+ lr_scheduler .step (epoch )
106+
107+ lr : float = round (lr_scheduler .get_lr ()[0 ], 6 )
108+ np .testing .assert_almost_equal (expected_lrs [epoch ], lr )
109+
110110
111111def test_get_chebyshev_schedule ():
112112 np .testing .assert_almost_equal (get_chebyshev_schedule (3 ), 1.81818182 , decimal = 6 )
0 commit comments