Skip to content

Commit 0fb8d94

Browse files
Tony Sunchanglan
authored andcommitted
feat: warmup-stable-decay (wsd) lr schedule
GitOrigin-RevId: f8131a40f89f38c236ad1f01c52b5c1a0c3e1c2b
1 parent 403e417 commit 0fb8d94

File tree

2 files changed

+133
-5
lines changed

2 files changed

+133
-5
lines changed

axlearn/common/schedule.py

Lines changed: 66 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -7,6 +7,7 @@
77
# Licensed under the Apache License, Version 2.0 (the "License").
88

99
"""Optimizer schedules."""
10+
1011
import math
1112
from typing import Callable, Optional, Union
1213

@@ -45,22 +46,27 @@ def polynomial(
4546
Args:
4647
begin_step: The first step of polynomial schedule.
4748
begin_value: The begin value of polynomial schedule.
48-
end_step: The end step of polynomial schedule. Must be > begin_step.
49+
end_step: The end step of polynomial schedule. Must be >= begin_step.
50+
If equal to begin_step, the schedule will always return `begin_value`.
4951
end_value: The end value of polynomial schedule.
5052
power: The polynomial power.
5153
5254
Returns:
5355
A ScheduleFn according to the spec.
5456
5557
Raises:
56-
ValueError: If begin_step >= end_step.
58+
ValueError: If begin_step > end_step.
5759
"""
58-
if begin_step >= end_step:
59-
raise ValueError(f"begin_step {begin_step} must be < end_step {end_step}.")
60+
if begin_step > end_step:
61+
raise ValueError(f"begin_step ({begin_step}) must be <= end_step ({end_step}).")
62+
63+
if begin_step == end_step:
64+
# For a zero-duration schedule, always return the starting value.
65+
return lambda step: jnp.array(begin_value, dtype=jnp.float32)
6066

6167
def fn(step: Tensor) -> Tensor:
6268
frac = (step - begin_step) / (end_step - begin_step)
63-
frac = jnp.minimum(1.0, jnp.maximum(0.0, frac))
69+
frac = jnp.minimum(1.0, jnp.maximum(0.0, frac)) # Clamp progress to [0, 1].
6470
return begin_value + (frac**power) * (end_value - begin_value)
6571

6672
return fn
@@ -348,6 +354,61 @@ def cosine_with_linear_warmup(
348354
return segment_wise(segments=segments, segment_steps=segment_steps)
349355

350356

357+
def warmup_stable_decay(
358+
peak_lr: float,
359+
*,
360+
max_step: int,
361+
decay_begin_step: int,
362+
warmup_steps: int = 500,
363+
begin_value: float = 0.0,
364+
alpha: float = 0.0,
365+
) -> ScheduleFn:
366+
"""Warmup stable decay (WSD) learning rate schedule. Linear warmup + constant lr + linear decay.
367+
368+
Args:
369+
peak_lr: The peak learning rate corresponding to the stable part of the schedule.
370+
max_step: The total number of steps from warmup + stable + decay.
371+
decay_begin_step: The step to begin linear decay. The learning rate is kept constant
372+
in [warmup_steps, decay_begin_step).
373+
warmup_steps: The number of steps of the warm-up schedule. Skip warm-up if set to 0.
374+
begin_value: The begin value of the linear warm-up.
375+
alpha: The multiplier of peak_lr used to determine the final lr at the end of decay phase.
376+
377+
Returns:
378+
A composite schedule.
379+
380+
Raises:
381+
ValueError: If decay_begin_step < warmup_steps, or if max_step < decay_begin_step.
382+
"""
383+
if decay_begin_step < warmup_steps:
384+
raise ValueError(
385+
f"decay_begin_step ({decay_begin_step}) must be >= warmup_steps ({warmup_steps})."
386+
)
387+
if max_step < decay_begin_step:
388+
raise ValueError(f"max_step ({max_step}) must be >= decay_begin_step ({decay_begin_step}).")
389+
390+
return segment_wise(
391+
segments=[
392+
config_for_function(polynomial).set(
393+
begin_step=0,
394+
begin_value=begin_value,
395+
end_step=warmup_steps,
396+
end_value=peak_lr,
397+
),
398+
config_for_function(constant_schedule).set(
399+
value=peak_lr,
400+
),
401+
config_for_function(polynomial).set(
402+
begin_step=0,
403+
begin_value=peak_lr,
404+
end_step=max_step - decay_begin_step,
405+
end_value=peak_lr * alpha,
406+
),
407+
],
408+
segment_steps=[warmup_steps, decay_begin_step - warmup_steps],
409+
)
410+
411+
351412
def constant_with_linear_warmup(
352413
peak_lr: float,
353414
*,

axlearn/common/schedule_test.py

Lines changed: 67 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,7 @@
11
# Copyright © 2023 Apple Inc.
22

33
"""Tests optimizer schedules."""
4+
45
import math
56

67
import jax
@@ -164,6 +165,72 @@ def test_cosine_with_linear_warmup(self, warmup_steps, decay_begin_step):
164165
)
165166
self.assertAlmostEqual(cosine_rate, value)
166167

168+
@parameterized.named_parameters(
169+
{
170+
"testcase_name": "full_schedule",
171+
"warmup_steps": 100,
172+
"decay_begin_step": 200,
173+
},
174+
{
175+
"testcase_name": "no_stable_phase",
176+
"warmup_steps": 100,
177+
"decay_begin_step": 100,
178+
},
179+
{
180+
"testcase_name": "no_warmup_phase",
181+
"warmup_steps": 0,
182+
"decay_begin_step": 200,
183+
},
184+
)
185+
def test_warmup_stable_decay(self, warmup_steps, decay_begin_step):
186+
peak_lr = 0.1
187+
max_step = 300
188+
alpha = 0.1
189+
begin_value = 0.0
190+
191+
s = jax.jit(
192+
schedule.warmup_stable_decay(
193+
peak_lr=peak_lr,
194+
max_step=max_step,
195+
warmup_steps=warmup_steps,
196+
begin_value=begin_value,
197+
decay_begin_step=decay_begin_step,
198+
alpha=alpha,
199+
)
200+
)
201+
202+
for step in range(1, max_step + 1, 25):
203+
lr = s(jnp.array(step, dtype=jnp.int32))
204+
205+
if warmup_steps > 0 and step <= warmup_steps: # Linear warmup.
206+
warmup_progress = step / warmup_steps
207+
expected_lr = begin_value + (peak_lr - begin_value) * warmup_progress
208+
self.assertAlmostEqual(expected_lr, lr, places=6)
209+
210+
elif warmup_steps < step <= decay_begin_step: # Stable at peak_lr.
211+
self.assertAlmostEqual(peak_lr, lr, places=6)
212+
213+
else: # Linear decay.
214+
num_decay_steps = max_step - decay_begin_step
215+
decay_progress = (step - decay_begin_step) / num_decay_steps
216+
end_lr = peak_lr * alpha
217+
expected_lr = peak_lr + (end_lr - peak_lr) * decay_progress
218+
self.assertAlmostEqual(expected_lr, lr, places=6)
219+
220+
def test_warmup_stable_decay_errors(self):
221+
"""Test error conditions for warmup_stable_decay."""
222+
# Test decay_begin_step < warmup_steps.
223+
with self.assertRaises(ValueError):
224+
schedule.warmup_stable_decay(
225+
peak_lr=0.1, warmup_steps=200, decay_begin_step=100, max_step=300
226+
)
227+
228+
# Test max_step < decay_begin_step.
229+
with self.assertRaises(ValueError):
230+
schedule.warmup_stable_decay(
231+
peak_lr=0.1, warmup_steps=100, decay_begin_step=300, max_step=200
232+
)
233+
167234
def test_constant_with_linear_warmup(self):
168235
peak_lr = 0.1
169236
warmup_steps = 100

0 commit comments

Comments
 (0)