Skip to content

Commit 02fc0af

Browse files
authored
[Feature] Implment more cooldown types for WSD lr scheduler (#386)
* feature: lots of cooldown type * update: test case * docs: v3.6.1 changelog * build(deps): packages
1 parent 5aa4d13 commit 02fc0af

File tree

6 files changed

+179
-103
lines changed

6 files changed

+179
-103
lines changed

docs/changelogs/v3.6.1.md

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,9 @@
11
## Change Log
22

3+
## Feature
4+
5+
* Implement more cooldown types for WSD learning rate scheduler. (#382, #386)
6+
37
### Fix
48

59
* Fix to use `momentum buffer` instead of the gradient to calculate LMO. (#385)

poetry.lock

Lines changed: 93 additions & 89 deletions
Some generated files are not rendered by default. Learn more about customizing how changed files appear on GitHub.

pytorch_optimizer/lr_scheduler/wsd.py

Lines changed: 63 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -1,18 +1,66 @@
11
import math
22
from functools import partial
3+
from typing import Literal
34

45
from torch.optim import Optimizer
56
from torch.optim.lr_scheduler import LambdaLR, LRScheduler
67

8+
COOLDOWN_TYPE = Literal['cosine', '1-sqrt', 'linear', '1-square']
79

8-
def get_wsd_scheduler_lambda(
10+
11+
def get_cosine_cooldown_lr_ratio(
12+
current_step: int,
13+
num_warmup_steps: int,
14+
num_stable_steps: int,
15+
num_decay_steps: int,
16+
min_lr_ratio: float,
17+
num_cycles: float,
18+
) -> float:
19+
r"""Get Cosine cooldown learning rate ratio."""
20+
progress = float(current_step - num_warmup_steps - num_stable_steps) / float(max(1, num_decay_steps))
21+
value = max(0.0, 0.5 * (1.0 + math.cos(math.pi * float(num_cycles) * 2.0 * progress)))
22+
return (1.0 - min_lr_ratio) * value + min_lr_ratio
23+
24+
25+
def get_1sqrt_cooldown_lr_ratio(
26+
current_step: int,
27+
num_warmup_steps: int,
28+
num_stable_steps: int,
29+
num_decay_steps: int,
30+
) -> float:
31+
r"""Get 1-sqrt cooldown learning rate ratio."""
32+
return 1.0 - math.sqrt((current_step - num_warmup_steps - num_stable_steps) / num_decay_steps)
33+
34+
35+
def get_1square_cooldown_lr_ratio(
36+
current_step: int,
37+
num_warmup_steps: int,
38+
num_stable_steps: int,
39+
num_decay_steps: int,
40+
) -> float:
41+
r"""Get 1-square cooldown learning rate ratio."""
42+
return 1.0 - math.pow((current_step - num_warmup_steps - num_stable_steps) / num_decay_steps, 2)
43+
44+
45+
def get_linear_cooldown_lr_ratio(
46+
current_step: int,
47+
num_warmup_steps: int,
48+
num_stable_steps: int,
49+
num_decay_steps: int,
50+
) -> float:
51+
r"""Get linear cooldown learning rate ratio."""
52+
return 1.0 - (current_step - num_warmup_steps - num_stable_steps) / num_decay_steps
53+
54+
55+
def get_wsd_scheduler_lambda( # noqa: PLR0911
956
current_step: int,
1057
*,
1158
num_warmup_steps: int,
1259
num_stable_steps: int,
1360
num_decay_steps: int,
1461
min_lr_ratio: float,
1562
num_cycles: float,
63+
cooldown_type: COOLDOWN_TYPE,
1664
) -> float:
1765
r"""Get WSD learning rate.
1866
@@ -23,15 +71,23 @@ def get_wsd_scheduler_lambda(
2371
:param min_lr_ratio: float. the minimum learning rate as a ratio of the initial learning rate.
2472
:param num_cycles: float. the number of waves in the cosine schedule (the defaults is to just decrease from the max
2573
value to 0 following a half-cosine)
74+
:param cooldown_type: COOLDOWN_TYPE. cooldown type of the learning rate scheduler.
2675
"""
2776
if current_step < num_warmup_steps:
2877
return float(current_step) / float(max(1, num_warmup_steps))
2978
if current_step < num_warmup_steps + num_stable_steps:
3079
return 1.0
3180
if current_step < num_warmup_steps + num_stable_steps + num_decay_steps:
32-
progress = float(current_step - num_warmup_steps - num_stable_steps) / float(max(1, num_decay_steps))
33-
value = max(0.0, 0.5 * (1.0 + math.cos(math.pi * float(num_cycles) * 2.0 * progress)))
34-
return (1.0 - min_lr_ratio) * value + min_lr_ratio
81+
if cooldown_type == 'cosine':
82+
return get_cosine_cooldown_lr_ratio(
83+
current_step, num_warmup_steps, num_stable_steps, num_decay_steps, min_lr_ratio, num_cycles
84+
)
85+
if cooldown_type == '1-sqrt':
86+
return get_1sqrt_cooldown_lr_ratio(current_step, num_warmup_steps, num_stable_steps, num_decay_steps)
87+
if cooldown_type == '1-square':
88+
return get_1square_cooldown_lr_ratio(current_step, num_warmup_steps, num_stable_steps, num_decay_steps)
89+
if cooldown_type == 'linear':
90+
return get_linear_cooldown_lr_ratio(current_step, num_warmup_steps, num_stable_steps, num_decay_steps)
3591
return min_lr_ratio
3692

3793

@@ -42,6 +98,7 @@ def get_wsd_schedule(
4298
num_decay_steps: int,
4399
min_lr_ratio: float = 0.0,
44100
num_cycles: float = 0.5,
101+
cooldown_type: COOLDOWN_TYPE = '1-sqrt',
45102
last_epoch: int = -1,
46103
) -> LRScheduler:
47104
r"""Get Warmup-Stable-Decay learning rate scheduler.
@@ -53,6 +110,7 @@ def get_wsd_schedule(
53110
:param min_lr_ratio: float. the minimum learning rate as a ratio of the initial learning rate.
54111
:param num_cycles: float. the number of waves in the cosine schedule (the defaults is to just decrease from the max
55112
value to 0 following a half-cosine)
113+
:param cooldown_type: COOLDOWN_TYPE. cooldown type of the learning rate scheduler.
56114
:param last_epoch: int. the index of the last epoch when resuming training.
57115
"""
58116
lr_scheduler = partial(
@@ -62,6 +120,7 @@ def get_wsd_schedule(
62120
num_decay_steps=num_decay_steps,
63121
min_lr_ratio=min_lr_ratio,
64122
num_cycles=num_cycles,
123+
cooldown_type=cooldown_type,
65124
)
66125

67126
return LambdaLR(optimizer, lr_scheduler, last_epoch)

requirements-dev.txt

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -5,12 +5,12 @@ black==25.1.0 ; python_version >= "3.9"
55
click==8.1.8 ; python_version >= "3.8"
66
colorama==0.4.6 ; python_version >= "3.8" and (sys_platform == "win32" or platform_system == "Windows")
77
coverage[toml]==7.6.1 ; python_version == "3.8"
8-
coverage[toml]==7.8.0 ; python_version >= "3.9"
8+
coverage[toml]==7.8.2 ; python_version >= "3.9"
99
exceptiongroup==1.3.0 ; python_version < "3.11" and python_version >= "3.8"
1010
filelock==3.16.1 ; python_version == "3.8"
1111
filelock==3.18.0 ; python_version >= "3.9"
1212
fsspec==2025.3.0 ; python_version == "3.8"
13-
fsspec==2025.3.2 ; python_version >= "3.9"
13+
fsspec==2025.5.1 ; python_version >= "3.9"
1414
iniconfig==2.1.0 ; python_version >= "3.8"
1515
isort==5.13.2 ; python_version == "3.8"
1616
isort==6.0.1 ; python_version >= "3.9"
@@ -31,8 +31,8 @@ pluggy==1.5.0 ; python_version == "3.8"
3131
pluggy==1.6.0 ; python_version >= "3.9"
3232
pytest-cov==5.0.0 ; python_version >= "3.8"
3333
pytest==8.3.5 ; python_version >= "3.8"
34-
ruff==0.11.10 ; python_version >= "3.8"
35-
setuptools==80.7.1 ; python_version >= "3.12"
34+
ruff==0.11.12 ; python_version >= "3.8"
35+
setuptools==80.9.0 ; python_version >= "3.12"
3636
sympy==1.13.3 ; python_version == "3.8"
3737
sympy==1.14.0 ; python_version >= "3.9"
3838
tomli==2.2.1 ; python_full_version <= "3.11.0a6" and python_version >= "3.8"

requirements.txt

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -3,7 +3,7 @@
33
filelock==3.16.1 ; python_version == "3.8"
44
filelock==3.18.0 ; python_version >= "3.9"
55
fsspec==2025.3.0 ; python_version == "3.8"
6-
fsspec==2025.3.2 ; python_version >= "3.9"
6+
fsspec==2025.5.1 ; python_version >= "3.9"
77
jinja2==3.1.6 ; python_version >= "3.8"
88
markupsafe==2.1.5 ; python_version == "3.8"
99
markupsafe==3.0.2 ; python_version >= "3.9"
@@ -12,7 +12,7 @@ networkx==3.1 ; python_version == "3.8"
1212
networkx==3.2.1 ; python_version >= "3.9"
1313
numpy==1.24.4 ; python_version == "3.8"
1414
numpy==2.0.2 ; python_version >= "3.9"
15-
setuptools==80.7.1 ; python_version >= "3.12"
15+
setuptools==80.9.0 ; python_version >= "3.12"
1616
sympy==1.13.3 ; python_version == "3.8"
1717
sympy==1.14.0 ; python_version >= "3.9"
1818
torch==2.4.1+cpu ; python_version == "3.8"

tests/test_lr_schedulers.py

Lines changed: 13 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -307,17 +307,26 @@ def test_rex_lr_scheduler():
307307
np.testing.assert_almost_equal(expected_lr, lr_scheduler.get_lr(), 6)
308308

309309

310-
def test_wsd_lr_scheduler():
310+
@pytest.mark.parametrize(
311+
'recipe',
312+
[
313+
('cosine', [0.0005, 0.001, 0.001, 0.001, 0.000775, 0.000325, 0.0001, 0.0001, 0.0001]),
314+
('1-sqrt', [0.0005, 0.001, 0.001, 0.001, 0.0004226, 0.0001835, 0.0001, 0.0001, 0.0001]),
315+
('1-square', [0.0005, 0.001, 0.001, 0.001, 0.0008888, 0.0005555, 0.0001, 0.0001, 0.0001]),
316+
('linear', [0.0005, 0.001, 0.001, 0.001, 0.0006666, 0.0003333, 0.0001, 0.0001, 0.0001]),
317+
],
318+
)
319+
def test_wsd_lr_scheduler(recipe):
311320
optimizer = AdamW(Example().parameters())
312321
optimizer.step()
313322

314-
lr_scheduler = get_wsd_schedule(optimizer, 2, 2, 3, min_lr_ratio=0.1)
323+
cooldown_type, expected_lrs = recipe
315324

316-
expected_lrs = [0.0005, 0.001, 0.001, 0.001, 0.000775, 0.000325, 0.0001, 0.0001, 0.0001]
325+
lr_scheduler = get_wsd_schedule(optimizer, 2, 2, 3, min_lr_ratio=0.1, cooldown_type=cooldown_type)
317326

318327
for expected_lr in expected_lrs:
319328
lr_scheduler.step()
320-
np.testing.assert_almost_equal(expected_lr, lr_scheduler.get_last_lr(), 6)
329+
np.testing.assert_almost_equal(expected_lr, lr_scheduler.get_last_lr()[0], 7)
321330

322331

323332
def test_deberta_v3_large_lr_scheduler():

0 commit comments

Comments
 (0)