Skip to content

Commit fc595b7

Browse files
authored
Fix LR scheduler when warmup exceeds cycle length (#3086)
## Summary - clamp warmup steps to cycle length before constructing per-cycle schedule phases - clamp decay steps to remaining cycle budget and fall back to constant schedule when no decay steps remain - add regression coverage for a long-warmup/short-run case to prevent end-of-cycle LR jumps ## Testing - `uv run --group test pytest tests/test_optimizer_config.py -q` (run in `lib/levanter`) - `uv run --with ruff ruff check src/levanter/optim/config.py tests/test_optimizer_config.py` (run in `lib/levanter`)
1 parent 988ca19 commit fc595b7

File tree

2 files changed

+52
-29
lines changed

2 files changed

+52
-29
lines changed

lib/levanter/src/levanter/optim/config.py

Lines changed: 35 additions & 29 deletions
Original file line numberDiff line numberDiff line change
@@ -263,50 +263,56 @@ def lr_scheduler(self, num_train_steps, override_lr=None):
263263
warmup_steps = _convert_frac_or_steps(self.warmup, cycle_steps)
264264
else:
265265
warmup_steps = _convert_frac_or_steps(self.rewarmup, cycle_steps)
266+
warmup_steps = min(warmup_steps, cycle_steps)
266267

267268
if warmup_steps != 0:
268269
warmup = optax.linear_schedule(previous_end, learning_rate, warmup_steps)
269270
schedules.append(warmup)
270271
boundaries.append(start + warmup_steps)
271272

272-
lr_decay_steps = (
273-
_convert_frac_or_steps(self.decay, cycle_steps)
274-
if self.decay is not None
275-
else cycle_steps - warmup_steps
273+
max_decay_steps = max(cycle_steps - warmup_steps, 0)
274+
requested_decay_steps = (
275+
_convert_frac_or_steps(self.decay, cycle_steps) if self.decay is not None else max_decay_steps
276276
)
277+
lr_decay_steps = min(max(requested_decay_steps, 0), max_decay_steps)
277278
stable_steps = cycle_steps - warmup_steps - lr_decay_steps
278279

279-
if stable_steps != 0:
280+
if stable_steps > 0:
280281
stable = optax.constant_schedule(learning_rate)
281282
schedules.append(stable)
282283
boundaries.append(start + warmup_steps + stable_steps)
283284

284-
if isinstance(self.lr_schedule, str):
285-
match self.lr_schedule:
286-
case "constant":
287-
schedule = optax.constant_schedule(learning_rate)
288-
case "cosine":
289-
schedule = optax.cosine_decay_schedule(learning_rate, lr_decay_steps, self.min_lr_ratio)
290-
case "linear":
291-
schedule = optax.linear_schedule(learning_rate, min_lr, lr_decay_steps)
292-
case "inv_sqrt":
293-
schedule = _inv_sqrt_decay_schedule(learning_rate, min_lr, warmup_steps, 10000)
294-
case "inv":
295-
schedule = _inv_decay_schedule(learning_rate, min_lr, lr_decay_steps)
296-
case _:
297-
raise ValueError(f"Unknown lr_schedule: {self.lr_schedule}")
298-
elif isinstance(self.lr_schedule, LrSchedule):
299-
schedule = self.lr_schedule.build(
300-
LrScheduleContext(
301-
warmup_steps=warmup_steps,
302-
decay_steps=lr_decay_steps,
303-
learning_rate=learning_rate,
304-
min_lr_ratio=self.min_lr_ratio,
305-
min_lr=min_lr,
285+
if lr_decay_steps > 0:
286+
if isinstance(self.lr_schedule, str):
287+
match self.lr_schedule:
288+
case "constant":
289+
schedule = optax.constant_schedule(learning_rate)
290+
case "cosine":
291+
schedule = optax.cosine_decay_schedule(learning_rate, lr_decay_steps, self.min_lr_ratio)
292+
case "linear":
293+
schedule = optax.linear_schedule(learning_rate, min_lr, lr_decay_steps)
294+
case "inv_sqrt":
295+
schedule = _inv_sqrt_decay_schedule(learning_rate, min_lr, warmup_steps, 10000)
296+
case "inv":
297+
schedule = _inv_decay_schedule(learning_rate, min_lr, lr_decay_steps)
298+
case _:
299+
raise ValueError(f"Unknown lr_schedule: {self.lr_schedule}")
300+
elif isinstance(self.lr_schedule, LrSchedule):
301+
schedule = self.lr_schedule.build(
302+
LrScheduleContext(
303+
warmup_steps=warmup_steps,
304+
decay_steps=lr_decay_steps,
305+
learning_rate=learning_rate,
306+
min_lr_ratio=self.min_lr_ratio,
307+
min_lr=min_lr,
308+
)
309+
)
310+
else:
311+
raise ValueError(
312+
f"lr_schedule must be a string or an instance of LrSchedule, got {self.lr_schedule}"
306313
)
307-
)
308314
else:
309-
raise ValueError(f"lr_schedule must be a string or an instance of LrSchedule, got {self.lr_schedule}")
315+
schedule = optax.constant_schedule(learning_rate)
310316

311317
previous_end = schedule(lr_decay_steps)
312318

lib/levanter/tests/test_optimizer_config.py

Lines changed: 17 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -246,3 +246,20 @@ def test_wsds_schedule_with_cycle_points():
246246
assert np.isclose(sched_fn(701), 1e-3)
247247
assert np.isclose(sched_fn(969), 1e-3)
248248
assert sched_fn(971) < 1e-3
249+
250+
251+
def test_warmup_longer_than_run_does_not_jump():
252+
optimizer = AdamConfig(
253+
learning_rate=3e-3,
254+
weight_decay=0.0,
255+
warmup=1000,
256+
decay=0.2,
257+
min_lr_ratio=0.1,
258+
lr_schedule="cosine",
259+
)
260+
261+
sched_fn = optimizer.lr_scheduler(200)
262+
263+
assert np.isclose(sched_fn(160), 0.0024, atol=1e-6)
264+
assert sched_fn(161) > sched_fn(160)
265+
assert np.isclose(sched_fn(200), 3e-3, atol=1e-6)

0 commit comments

Comments
 (0)