Skip to content

Commit d321fac

Browse files
committed
refactor: test_cosine_annealing_warmup_restarts
1 parent 0e7ca8a commit d321fac

File tree

1 file changed

+58
-58
lines changed

1 file changed

+58
-58
lines changed

tests/test_lr_schedulers.py

Lines changed: 58 additions & 58 deletions
Original file line numberDiff line numberDiff line change
@@ -1,56 +1,21 @@
1-
from typing import List
2-
31
import numpy as np
2+
import pytest
43

54
from pytorch_optimizer import AdamP, get_chebyshev_schedule
65
from pytorch_optimizer.lr_scheduler.chebyshev import chebyshev_perm
76
from pytorch_optimizer.lr_scheduler.cosine_anealing import CosineAnnealingWarmupRestarts
87
from tests.utils import Example
98

10-
11-
def test_cosine_annealing_warmup_restarts():
12-
model = Example()
13-
optimizer = AdamP(model.parameters())
14-
15-
def test_scheduler(
16-
first_cycle_steps: int,
17-
cycle_mult: float,
18-
max_lr: float,
19-
min_lr: float,
20-
warmup_steps: int,
21-
gamma: float,
22-
max_epochs: int,
23-
expected_lrs: List[float],
24-
):
25-
lr_scheduler = CosineAnnealingWarmupRestarts(
26-
optimizer=optimizer,
27-
first_cycle_steps=first_cycle_steps,
28-
cycle_mult=cycle_mult,
29-
max_lr=max_lr,
30-
min_lr=min_lr,
31-
warmup_steps=warmup_steps,
32-
gamma=gamma,
33-
)
34-
35-
if warmup_steps > 0:
36-
np.testing.assert_almost_equal(min_lr, round(lr_scheduler.get_lr()[0], 6))
37-
38-
for epoch in range(max_epochs):
39-
lr_scheduler.step(epoch)
40-
41-
lr: float = round(lr_scheduler.get_lr()[0], 6)
42-
np.testing.assert_almost_equal(expected_lrs[epoch], lr)
43-
44-
# case 1
45-
test_scheduler(
46-
first_cycle_steps=10,
47-
cycle_mult=1.0,
48-
max_lr=1e-3,
49-
min_lr=1e-6,
50-
warmup_steps=5,
51-
gamma=1.0,
52-
max_epochs=20,
53-
expected_lrs=[
9+
CAWR_RECIPES = [
10+
(
11+
10,
12+
1.0,
13+
1e-3,
14+
1e-6,
15+
5,
16+
1.0,
17+
20,
18+
[
5419
1e-06,
5520
0.000201,
5621
0.000401,
@@ -72,18 +37,16 @@ def test_scheduler(
7237
0.000346,
7338
9.6e-05,
7439
],
75-
)
76-
77-
# case 2
78-
test_scheduler(
79-
first_cycle_steps=10,
80-
cycle_mult=0.9,
81-
max_lr=1e-3,
82-
min_lr=1e-6,
83-
warmup_steps=5,
84-
gamma=0.5,
85-
max_epochs=20,
86-
expected_lrs=[
40+
),
41+
(
42+
10,
43+
0.9,
44+
1e-3,
45+
1e-6,
46+
5,
47+
0.5,
48+
20,
49+
[
8750
1e-06,
8851
0.000201,
8952
0.000401,
@@ -105,8 +68,45 @@ def test_scheduler(
10568
7.4e-05,
10669
1e-06,
10770
],
71+
),
72+
]
73+
74+
75+
@pytest.mark.parametrize('cosine_annealing_warmup_restart_param', CAWR_RECIPES)
76+
def test_cosine_annealing_warmup_restarts(cosine_annealing_warmup_restart_param):
77+
model = Example()
78+
optimizer = AdamP(model.parameters())
79+
80+
(
81+
first_cycle_steps,
82+
cycle_mult,
83+
max_lr,
84+
min_lr,
85+
warmup_steps,
86+
gamma,
87+
max_epochs,
88+
expected_lrs,
89+
) = cosine_annealing_warmup_restart_param
90+
91+
lr_scheduler = CosineAnnealingWarmupRestarts(
92+
optimizer=optimizer,
93+
first_cycle_steps=first_cycle_steps,
94+
cycle_mult=cycle_mult,
95+
max_lr=max_lr,
96+
min_lr=min_lr,
97+
warmup_steps=warmup_steps,
98+
gamma=gamma,
10899
)
109100

101+
if warmup_steps > 0:
102+
np.testing.assert_almost_equal(min_lr, round(lr_scheduler.get_lr()[0], 6))
103+
104+
for epoch in range(max_epochs):
105+
lr_scheduler.step(epoch)
106+
107+
lr: float = round(lr_scheduler.get_lr()[0], 6)
108+
np.testing.assert_almost_equal(expected_lrs[epoch], lr)
109+
110110

111111
def test_get_chebyshev_schedule():
112112
np.testing.assert_almost_equal(get_chebyshev_schedule(3), 1.81818182, decimal=6)

0 commit comments

Comments
 (0)