@@ -29,86 +29,251 @@ def __new__(cls: type, *args: Any, **kwargs: Any) -> Any:
2929 return super ().__new__ (cls )
3030
3131 def __init__ (
32- self , start_lr : float , stop_lr : float , stop_steps : int , ** kwargs : Any
32+ self ,
33+ start_lr : float ,
34+ stop_lr : float | None = None ,
35+ stop_ratio : float | None = None ,
36+ stop_steps : int = 100000 ,
37+ warmup_steps : int = 0 ,
38+ warmup_ratio : float | None = None ,
39+ warmup_start_factor : float = 0.0 ,
40+ ** kwargs : Any ,
3341 ) -> None :
3442 """
35- Base class for learning rate schedules.
43+ Base class for learning rate schedules with warmup support .
3644
3745 Parameters
3846 ----------
39- start_lr
40- The initial learning rate.
41- stop_lr
42- The final learning rate.
43- stop_steps
44- The total training steps for learning rate scheduler.
47+ start_lr : float
48+ The learning rate at the start of the training (after warmup).
49+ stop_lr : float, optional
50+ The final learning rate at the end of the training.
51+ Mutually exclusive with stop_ratio.
52+ stop_ratio : float, optional
53+ The ratio of stop_lr to start_lr. stop_lr = start_lr * stop_ratio.
54+ Mutually exclusive with stop_lr.
55+ One of stop_lr or stop_ratio must be provided.
56+ stop_steps : int
57+ The total training steps (including warmup).
58+ warmup_steps : int, optional
59+ The number of steps for learning rate warmup.
60+ Mutually exclusive with warmup_ratio. Default is 0 (no warmup).
61+ warmup_ratio : float, optional
62+ The ratio of warmup steps to total training steps.
63+ warmup_steps = int(warmup_ratio * stop_steps).
64+ Mutually exclusive with warmup_steps.
65+ warmup_start_factor : float, optional
66+ The factor of start_lr for the initial warmup learning rate.
67+ The warmup learning rate starts from warmup_start_factor * start_lr.
68+ Default is 0.0.
4569 """
70+ # === Step 1. Compute stop_lr from stop_ratio if needed ===
71+ # Mutual exclusion validated in argcheck.py
72+ if stop_ratio is not None :
73+ self .stop_lr = start_lr * stop_ratio
74+ else :
75+ self .stop_lr = stop_lr # type: ignore[assignment]
76+
77+ # === Step 2. Compute warmup_steps from warmup_ratio if needed ===
78+ # Mutual exclusion validated in argcheck.py
79+ if warmup_ratio is not None :
80+ self .warmup_steps = int (warmup_ratio * stop_steps )
81+ else :
82+ self .warmup_steps = warmup_steps
83+
84+ # === Step 3. Validate step ranges (runtime check) ===
85+ if stop_steps <= 0 :
86+ raise ValueError ("stop_steps must be positive" )
87+ if self .warmup_steps < 0 :
88+ raise ValueError ("warmup_steps must be non-negative" )
89+ if self .warmup_steps >= stop_steps :
90+ raise ValueError ("warmup_steps must be smaller than stop_steps" )
91+
92+ # === Step 4. Compute warmup_start_lr ===
93+ self .warmup_start_lr = warmup_start_factor * start_lr
94+
95+ # === Step 5. Store core parameters ===
4696 self .start_lr = start_lr
47- self .stop_lr = stop_lr
4897 self .stop_steps = stop_steps
98+ # Decay phase covers (stop_steps - warmup_steps) steps
99+ self .decay_stop_steps = stop_steps - self .warmup_steps
49100
50101 @abstractmethod
51- def value (self , step : int | Array ) -> Array :
52- """Get the learning rate at the given step."""
53- # in optax, step will be a jnp.ndarray passed in JIT mode
102+ def _decay_value (self , step : int | Array ) -> Array :
103+ """
104+ Get the decayed learning rate at the given step (after warmup).
105+
106+ This method should implement the actual decay logic (exp, cosine, etc.)
107+ without considering warmup.
108+
109+ Parameters
110+ ----------
111+ step : int or Array
112+ The step index relative to the end of warmup.
113+ For example, if warmup_steps=100 and total_step=150, this method
114+ will be called with step=50.
115+
116+ Returns
117+ -------
118+ Array
119+ The decayed learning rate (absolute value, not factor).
120+ """
54121 pass
55122
123+ def value (self , step : int | Array ) -> Array | float :
124+ """
125+ Get the learning rate at the given step, including warmup.
126+
127+ Parameters
128+ ----------
129+ step : int or Array
130+ The absolute step index from the start of training.
131+
132+ Returns
133+ -------
134+ Array
135+ The learning rate at the given step.
136+ """
137+ is_scalar = isinstance (step , (int , float ))
138+ if not array_api_compat .is_array_api_obj (step ):
139+ step = np .asarray (step )
140+ xp = array_api_compat .array_namespace (step )
141+
142+ # === Step 1. Handle no-warmup case directly ===
143+ if self .warmup_steps == 0 :
144+ lr = self ._decay_value (xp .astype (step , xp .float64 ))
145+ else :
146+ # === Step 2. Warmup phase ===
147+ # Linear warmup from warmup_start_lr to start_lr
148+ warmup_progress = xp .astype (step , xp .float64 ) / self .warmup_steps
149+ warmup_lr = (
150+ self .warmup_start_lr
151+ + (self .start_lr - self .warmup_start_lr ) * warmup_progress
152+ )
153+
154+ # === Step 3. Decay phase ===
155+ # Call subclass decay logic for steps after warmup
156+ decay_step = xp .maximum (
157+ xp .astype (step , xp .float64 ) - self .warmup_steps , 0.0
158+ )
159+ decay_lr = self ._decay_value (decay_step )
160+
161+ # === Step 4. Select warmup or decay based on step ===
162+ lr = xp .where (step < self .warmup_steps , warmup_lr , decay_lr )
163+
164+ if is_scalar :
165+ return float (lr )
166+ return lr
167+
56168
57169@BaseLR .register ("exp" )
58170class LearningRateExp (BaseLR ):
59171 def __init__ (
60172 self ,
61173 start_lr : float ,
62- stop_lr : float ,
63- decay_steps : int ,
64- stop_steps : int ,
174+ stop_lr : float | None = None ,
175+ stop_ratio : float | None = None ,
176+ decay_steps : int = 5000 ,
177+ stop_steps : int = 100000 ,
65178 decay_rate : float | None = None ,
179+ warmup_steps : int = 0 ,
180+ warmup_ratio : float | None = None ,
181+ warmup_start_factor : float = 0.0 ,
66182 ** kwargs : Any ,
67183 ) -> None :
68184 """
69- Construct an exponential-decayed learning rate.
185+ Construct an exponential-decayed learning rate with optional warmup .
70186
71187 Parameters
72188 ----------
73- start_lr
74- The learning rate at the start of the training.
75- stop_lr
189+ start_lr : float
190+ The learning rate at the start of the training (after warmup) .
191+ stop_lr : float, optional
76192 The desired learning rate at the end of the training.
77193 When decay_rate is explicitly set, this value will serve as
78- the minimum learning rate during training. In other words,
79- if the learning rate decays below stop_lr, stop_lr will be applied instead.
80- decay_steps
194+ the minimum learning rate during training.
195+ Mutually exclusive with stop_ratio.
196+ stop_ratio : float, optional
197+ The ratio of stop_lr to start_lr.
198+ Mutually exclusive with stop_lr.
199+ decay_steps : int
81200 The learning rate is decaying every this number of training steps.
82- stop_steps
83- The total training steps for learning rate scheduler.
84- decay_rate
201+ Default is 5000.
202+ stop_steps : int
203+ The total training steps (including warmup).
204+ decay_rate : float, optional
85205 The decay rate for the learning rate.
86206 If provided, the decay rate will be set instead of
87207 calculating it through interpolation between start_lr and stop_lr.
208+ warmup_steps : int, optional
209+ The number of steps for learning rate warmup.
210+ Mutually exclusive with warmup_ratio. Default is 0.
211+ warmup_ratio : float, optional
212+ The ratio of warmup steps to total training steps.
213+ Mutually exclusive with warmup_steps.
214+ warmup_start_factor : float, optional
215+ The factor of start_lr for the initial warmup learning rate.
216+ Default is 0.0.
217+
218+ Raises
219+ ------
220+ ValueError
221+ If both stop_lr and stop_ratio are provided, or neither is provided.
222+ If both warmup_steps and warmup_ratio are provided.
223+ If decay_steps is larger than the decay phase total steps.
88224 """
89- super ().__init__ (start_lr , stop_lr , stop_steps , ** kwargs )
90- default_ds = 100 if stop_steps // 10 > 100 else stop_steps // 100 + 1
225+ super ().__init__ (
226+ start_lr = start_lr ,
227+ stop_lr = stop_lr ,
228+ stop_ratio = stop_ratio ,
229+ stop_steps = stop_steps ,
230+ warmup_steps = warmup_steps ,
231+ warmup_ratio = warmup_ratio ,
232+ warmup_start_factor = warmup_start_factor ,
233+ ** kwargs ,
234+ )
235+ # === Step 5. Compute decay_rate for exp scheduler ===
236+ # Use decay_stop_steps (stop_steps - warmup_steps) for decay calculation
237+ decay_total = self .decay_stop_steps
91238 self .decay_steps = decay_steps
92- if self .decay_steps >= stop_steps :
93- self .decay_steps = default_ds
239+
240+ if self .decay_steps > decay_total :
241+ raise ValueError (
242+ f"decay_steps ({ self .decay_steps } ) must not exceed decay phase steps ({ decay_total } )."
243+ )
244+
245+ # Avoid log(0) issues by clamping stop_lr for computation
246+ clamped_stop_lr = max (self .stop_lr , 1e-10 )
247+ self .min_lr = self .stop_lr
248+
94249 self .decay_rate = np .exp (
95- np .log (stop_lr / self .start_lr ) / (stop_steps / self .decay_steps )
250+ np .log (clamped_stop_lr / self .start_lr ) / (decay_total / self .decay_steps )
96251 ).item ()
97252 if decay_rate is not None :
98253 self .decay_rate = decay_rate
99- self .min_lr = self .stop_lr
100254
101- def value (self , step : int | Array ) -> Array :
102- """Get the learning rate at the given step."""
255+ def _decay_value (self , step : int | Array ) -> Array :
256+ """
257+ Get the exponential-decayed learning rate factor at the given step.
258+
259+ Parameters
260+ ----------
261+ step : int or Array
262+ The step index relative to the end of warmup.
263+
264+ Returns
265+ -------
266+ Array
267+ The decayed learning rate (absolute value).
268+ """
103269 if not array_api_compat .is_array_api_obj (step ):
104270 step = np .asarray (step )
105271 xp = array_api_compat .array_namespace (step )
106272 step_lr = self .start_lr * xp .pow (
107273 xp .asarray (self .decay_rate , device = array_api_compat .device (step )),
108274 xp .astype (step // self .decay_steps , xp .float64 ),
109275 )
110- # the original implementation `if step_lr < self.min_lr:`
111- # will cause a dynamic graph which is unsupported in JAX JIT
276+ # Clip to min_lr for numerical stability in JIT
112277 step_lr = xp .clip (step_lr , self .min_lr , None )
113278 return step_lr
114279
@@ -118,29 +283,74 @@ class LearningRateCosine(BaseLR):
118283 def __init__ (
119284 self ,
120285 start_lr : float ,
121- stop_lr : float ,
122- stop_steps : int ,
286+ stop_lr : float | None = None ,
287+ stop_ratio : float | None = None ,
288+ stop_steps : int = 100000 ,
289+ warmup_steps : int = 0 ,
290+ warmup_ratio : float | None = None ,
291+ warmup_start_factor : float = 0.0 ,
123292 ** kwargs : Any ,
124293 ) -> None :
125294 """
126- Defines a cosine annealing learning rate schedule.
127- The learning rate starts at `start_lr` and gradually decreases to `stop_lr`
128- following a cosine curve over the training steps.
295+ Defines a cosine annealing learning rate schedule with optional warmup.
296+
297+ The learning rate starts at `start_lr` (after warmup) and gradually
298+ decreases to `stop_lr` following a cosine curve over the training steps.
129299
130300 Parameters
131301 ----------
132- start_lr
133- The initial learning rate at the beginning of training.
134- stop_lr
302+ start_lr : float
303+ The learning rate at the start of the training (after warmup) .
304+ stop_lr : float, optional
135305 The final learning rate at the end of training.
136- stop_steps
137- The total number of training steps over which the learning rate
138- will be annealed from start_lr to stop_lr.
306+ Mutually exclusive with stop_ratio.
307+ stop_ratio : float, optional
308+ The ratio of stop_lr to start_lr.
309+ Mutually exclusive with stop_lr.
310+ stop_steps : int
311+ The total training steps (including warmup).
312+ warmup_steps : int, optional
313+ The number of steps for learning rate warmup.
314+ Mutually exclusive with warmup_ratio. Default is 0.
315+ warmup_ratio : float, optional
316+ The ratio of warmup steps to total training steps.
317+ Mutually exclusive with warmup_steps.
318+ warmup_start_factor : float, optional
319+ The factor of start_lr for the initial warmup learning rate.
320+ Default is 0.0.
321+
322+ Raises
323+ ------
324+ ValueError
325+ If both stop_lr and stop_ratio are provided, or neither is provided.
326+ If both warmup_steps and warmup_ratio are provided.
139327 """
140- super ().__init__ (start_lr , stop_lr , stop_steps , ** kwargs )
141- self .lr_min_factor = stop_lr / start_lr
328+ super ().__init__ (
329+ start_lr = start_lr ,
330+ stop_lr = stop_lr ,
331+ stop_ratio = stop_ratio ,
332+ stop_steps = stop_steps ,
333+ warmup_steps = warmup_steps ,
334+ warmup_ratio = warmup_ratio ,
335+ warmup_start_factor = warmup_start_factor ,
336+ ** kwargs ,
337+ )
338+ self .lr_min_factor = self .stop_lr / self .start_lr
142339
143- def value (self , step : int | Array ) -> Array :
340+ def _decay_value (self , step : int | Array ) -> Array :
341+ """
342+ Get the cosine-annealed learning rate at the given step.
343+
344+ Parameters
345+ ----------
346+ step : int or Array
347+ The step index relative to the end of warmup.
348+
349+ Returns
350+ -------
351+ Array
352+ The annealed learning rate (absolute value).
353+ """
144354 if not array_api_compat .is_array_api_obj (step ):
145355 step = np .asarray (step )
146356 xp = array_api_compat .array_namespace (step )
@@ -153,11 +363,12 @@ def value(self, step: int | Array) -> Array:
153363 1
154364 + xp .cos (
155365 xp .asarray (
156- xp .pi * (xp .astype (step , xp .float64 ) / self .stop_steps ),
366+ xp .pi * (xp .astype (step , xp .float64 ) / self .decay_stop_steps ),
157367 device = array_api_compat .device (step ),
158368 )
159369 )
160370 )
161371 )
162- step_lr = xp .where (step >= self .stop_steps , min_lr , step_lr )
372+ # Clip to min_lr for steps beyond decay_stop_steps
373+ step_lr = xp .where (step >= self .decay_stop_steps , min_lr , step_lr )
163374 return step_lr
0 commit comments