11import math
22from functools import partial
3+ from typing import Literal
34
45from torch .optim import Optimizer
56from torch .optim .lr_scheduler import LambdaLR , LRScheduler
67
8+ COOLDOWN_TYPE = Literal ['cosine' , '1-sqrt' , 'linear' , '1-square' ]
79
8- def get_wsd_scheduler_lambda (
10+
11+ def get_cosine_cooldown_lr_ratio (
12+ current_step : int ,
13+ num_warmup_steps : int ,
14+ num_stable_steps : int ,
15+ num_decay_steps : int ,
16+ min_lr_ratio : float ,
17+ num_cycles : float ,
18+ ) -> float :
19+ r"""Get Cosine cooldown learning rate ratio."""
20+ progress = float (current_step - num_warmup_steps - num_stable_steps ) / float (max (1 , num_decay_steps ))
21+ value = max (0.0 , 0.5 * (1.0 + math .cos (math .pi * float (num_cycles ) * 2.0 * progress )))
22+ return (1.0 - min_lr_ratio ) * value + min_lr_ratio
23+
24+
25+ def get_1sqrt_cooldown_lr_ratio (
26+ current_step : int ,
27+ num_warmup_steps : int ,
28+ num_stable_steps : int ,
29+ num_decay_steps : int ,
30+ ) -> float :
31+ r"""Get 1-sqrt cooldown learning rate ratio."""
32+ return 1.0 - math .sqrt ((current_step - num_warmup_steps - num_stable_steps ) / num_decay_steps )
33+
34+
35+ def get_1square_cooldown_lr_ratio (
36+ current_step : int ,
37+ num_warmup_steps : int ,
38+ num_stable_steps : int ,
39+ num_decay_steps : int ,
40+ ) -> float :
41+ r"""Get 1-square cooldown learning rate ratio."""
42+ return 1.0 - math .pow ((current_step - num_warmup_steps - num_stable_steps ) / num_decay_steps , 2 )
43+
44+
45+ def get_linear_cooldown_lr_ratio (
46+ current_step : int ,
47+ num_warmup_steps : int ,
48+ num_stable_steps : int ,
49+ num_decay_steps : int ,
50+ ) -> float :
51+ r"""Get linear cooldown learning rate ratio."""
52+ return 1.0 - (current_step - num_warmup_steps - num_stable_steps ) / num_decay_steps
53+
54+
55+ def get_wsd_scheduler_lambda ( # noqa: PLR0911
956 current_step : int ,
1057 * ,
1158 num_warmup_steps : int ,
1259 num_stable_steps : int ,
1360 num_decay_steps : int ,
1461 min_lr_ratio : float ,
1562 num_cycles : float ,
63+ cooldown_type : COOLDOWN_TYPE ,
1664) -> float :
1765 r"""Get WSD learning rate.
1866
@@ -23,15 +71,23 @@ def get_wsd_scheduler_lambda(
2371 :param min_lr_ratio: float. the minimum learning rate as a ratio of the initial learning rate.
2472 :param num_cycles: float. the number of waves in the cosine schedule (the defaults is to just decrease from the max
2573 value to 0 following a half-cosine)
74+ :param cooldown_type: COOLDOWN_TYPE. cooldown type of the learning rate scheduler.
2675 """
2776 if current_step < num_warmup_steps :
2877 return float (current_step ) / float (max (1 , num_warmup_steps ))
2978 if current_step < num_warmup_steps + num_stable_steps :
3079 return 1.0
3180 if current_step < num_warmup_steps + num_stable_steps + num_decay_steps :
32- progress = float (current_step - num_warmup_steps - num_stable_steps ) / float (max (1 , num_decay_steps ))
33- value = max (0.0 , 0.5 * (1.0 + math .cos (math .pi * float (num_cycles ) * 2.0 * progress )))
34- return (1.0 - min_lr_ratio ) * value + min_lr_ratio
81+ if cooldown_type == 'cosine' :
82+ return get_cosine_cooldown_lr_ratio (
83+ current_step , num_warmup_steps , num_stable_steps , num_decay_steps , min_lr_ratio , num_cycles
84+ )
85+ if cooldown_type == '1-sqrt' :
86+ return get_1sqrt_cooldown_lr_ratio (current_step , num_warmup_steps , num_stable_steps , num_decay_steps )
87+ if cooldown_type == '1-square' :
88+ return get_1square_cooldown_lr_ratio (current_step , num_warmup_steps , num_stable_steps , num_decay_steps )
89+ if cooldown_type == 'linear' :
90+ return get_linear_cooldown_lr_ratio (current_step , num_warmup_steps , num_stable_steps , num_decay_steps )
3591 return min_lr_ratio
3692
3793
@@ -42,6 +98,7 @@ def get_wsd_schedule(
4298 num_decay_steps : int ,
4399 min_lr_ratio : float = 0.0 ,
44100 num_cycles : float = 0.5 ,
101+ cooldown_type : COOLDOWN_TYPE = '1-sqrt' ,
45102 last_epoch : int = - 1 ,
46103) -> LRScheduler :
47104 r"""Get Warmup-Stable-Decay learning rate scheduler.
@@ -53,6 +110,7 @@ def get_wsd_schedule(
53110 :param min_lr_ratio: float. the minimum learning rate as a ratio of the initial learning rate.
54111 :param num_cycles: float. the number of waves in the cosine schedule (the defaults is to just decrease from the max
55112 value to 0 following a half-cosine)
113+ :param cooldown_type: COOLDOWN_TYPE. cooldown type of the learning rate scheduler.
56114 :param last_epoch: int. the index of the last epoch when resuming training.
57115 """
58116 lr_scheduler = partial (
@@ -62,6 +120,7 @@ def get_wsd_schedule(
62120 num_decay_steps = num_decay_steps ,
63121 min_lr_ratio = min_lr_ratio ,
64122 num_cycles = num_cycles ,
123+ cooldown_type = cooldown_type ,
65124 )
66125
67126 return LambdaLR (optimizer , lr_scheduler , last_epoch )
0 commit comments