|
7 | 7 | # Licensed under the Apache License, Version 2.0 (the "License").
|
8 | 8 |
|
9 | 9 | """Optimizer schedules."""
|
| 10 | + |
10 | 11 | import math
|
11 | 12 | from typing import Callable, Optional, Union
|
12 | 13 |
|
@@ -45,22 +46,27 @@ def polynomial(
|
45 | 46 | Args:
|
46 | 47 | begin_step: The first step of polynomial schedule.
|
47 | 48 | begin_value: The begin value of polynomial schedule.
|
48 |
| - end_step: The end step of polynomial schedule. Must be > begin_step. |
| 49 | + end_step: The end step of polynomial schedule. Must be >= begin_step. |
| 50 | + If equal to begin_step, the schedule will always return `begin_value`. |
49 | 51 | end_value: The end value of polynomial schedule.
|
50 | 52 | power: The polynomial power.
|
51 | 53 |
|
52 | 54 | Returns:
|
53 | 55 | A ScheduleFn according to the spec.
|
54 | 56 |
|
55 | 57 | Raises:
|
56 |
| - ValueError: If begin_step >= end_step. |
| 58 | + ValueError: If begin_step > end_step. |
57 | 59 | """
|
58 |
| - if begin_step >= end_step: |
59 |
| - raise ValueError(f"begin_step {begin_step} must be < end_step {end_step}.") |
| 60 | + if begin_step > end_step: |
| 61 | + raise ValueError(f"begin_step ({begin_step}) must be <= end_step ({end_step}).") |
| 62 | + |
| 63 | + if begin_step == end_step: |
| 64 | + # For a zero-duration schedule, always return the starting value. |
| 65 | + return lambda step: jnp.array(begin_value, dtype=jnp.float32) |
60 | 66 |
|
61 | 67 | def fn(step: Tensor) -> Tensor:
|
62 | 68 | frac = (step - begin_step) / (end_step - begin_step)
|
63 |
| - frac = jnp.minimum(1.0, jnp.maximum(0.0, frac)) |
| 69 | + frac = jnp.minimum(1.0, jnp.maximum(0.0, frac)) # Clamp progress to [0, 1]. |
64 | 70 | return begin_value + (frac**power) * (end_value - begin_value)
|
65 | 71 |
|
66 | 72 | return fn
|
@@ -348,6 +354,61 @@ def cosine_with_linear_warmup(
|
348 | 354 | return segment_wise(segments=segments, segment_steps=segment_steps)
|
349 | 355 |
|
350 | 356 |
|
| 357 | +def warmup_stable_decay( |
| 358 | + peak_lr: float, |
| 359 | + *, |
| 360 | + max_step: int, |
| 361 | + decay_begin_step: int, |
| 362 | + warmup_steps: int = 500, |
| 363 | + begin_value: float = 0.0, |
| 364 | + alpha: float = 0.0, |
| 365 | +) -> ScheduleFn: |
| 366 | + """Warmup stable decay (WSD) learning rate schedule. Linear warmup + constant lr + linear decay. |
| 367 | +
|
| 368 | + Args: |
| 369 | + peak_lr: The peak learning rate corresponding to the stable part of the schedule. |
| 370 | + max_step: The total number of steps from warmup + stable + decay. |
| 371 | + decay_begin_step: The step to begin linear decay. The learning rate is kept constant |
| 372 | + in [warmup_steps, decay_begin_step). |
| 373 | + warmup_steps: The number of steps of the warm-up schedule. Skip warm-up if set to 0. |
| 374 | + begin_value: The begin value of the linear warm-up. |
| 375 | + alpha: The multiplier of peak_lr used to determine the final lr at the end of decay phase. |
| 376 | +
|
| 377 | + Returns: |
| 378 | + A composite schedule. |
| 379 | +
|
| 380 | + Raises: |
| 381 | + ValueError: If decay_begin_step < warmup_steps, or if max_step < decay_begin_step. |
| 382 | + """ |
| 383 | + if decay_begin_step < warmup_steps: |
| 384 | + raise ValueError( |
| 385 | + f"decay_begin_step ({decay_begin_step}) must be >= warmup_steps ({warmup_steps})." |
| 386 | + ) |
| 387 | + if max_step < decay_begin_step: |
| 388 | + raise ValueError(f"max_step ({max_step}) must be >= decay_begin_step ({decay_begin_step}).") |
| 389 | + |
| 390 | + return segment_wise( |
| 391 | + segments=[ |
| 392 | + config_for_function(polynomial).set( |
| 393 | + begin_step=0, |
| 394 | + begin_value=begin_value, |
| 395 | + end_step=warmup_steps, |
| 396 | + end_value=peak_lr, |
| 397 | + ), |
| 398 | + config_for_function(constant_schedule).set( |
| 399 | + value=peak_lr, |
| 400 | + ), |
| 401 | + config_for_function(polynomial).set( |
| 402 | + begin_step=0, |
| 403 | + begin_value=peak_lr, |
| 404 | + end_step=max_step - decay_begin_step, |
| 405 | + end_value=peak_lr * alpha, |
| 406 | + ), |
| 407 | + ], |
| 408 | + segment_steps=[warmup_steps, decay_begin_step - warmup_steps], |
| 409 | + ) |
| 410 | + |
| 411 | + |
351 | 412 | def constant_with_linear_warmup(
|
352 | 413 | peak_lr: float,
|
353 | 414 | *,
|
|
0 commit comments