Skip to content

Commit 5b955f3

Browse files
bleykherchanglan
authored andcommitted
Add step_offset parameter to ema_schedule (#1516)
* Add step_offset parameter to ema_schedule * Fix formatting GitOrigin-RevId: 00eaebd
1 parent f8fd186 commit 5b955f3

File tree

2 files changed

+14
-9
lines changed

2 files changed

+14
-9
lines changed

axlearn/common/schedule.py

Lines changed: 6 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -417,14 +417,18 @@ def linear_schedule_with_warmup(
417417
)
418418

419419

420-
def ema_schedule(decay: float = 0.9999, *, warmup_steps: int = 1) -> ScheduleFn:
420+
def ema_schedule(
421+
decay: float = 0.9999, *, warmup_steps: int = 1, step_offset: int = 0
422+
) -> ScheduleFn:
421423
"""Ema decay schedule with warm-up.
422424
423425
The ema decay is 0, 1/2, 2/3, 3/4, 4/5, ... during warm-up, and then is constant at decay.
424426
425427
Args:
426428
decay: ema decay.
427429
warmup_steps: The number of steps of the warm-up schedule.
430+
step_offset: The initial step number.
431+
If `step` is less than or equal to `step_offset`, we clamp to `step - step_offset` to 0.
428432
429433
Returns:
430434
A ema decay schedule.
@@ -436,6 +440,7 @@ def ema_schedule(decay: float = 0.9999, *, warmup_steps: int = 1) -> ScheduleFn:
436440
raise ValueError("warmup_steps must be > 0.")
437441

438442
def fn(step):
443+
step = jnp.maximum(step - step_offset, 0)
439444
return step / (1.0 + step) * (step < warmup_steps) + decay * (step >= warmup_steps)
440445

441446
return fn

axlearn/common/schedule_test.py

Lines changed: 8 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -241,19 +241,19 @@ def test_adafactor_decay_rate(self):
241241
self.assertAlmostEqual(fn(200), 1 - (100) ** (-0.8))
242242

243243
def test_ema_schedule(self):
244-
warmup_steps = 5
245-
s = jax.jit(
246-
schedule.ema_schedule(
247-
warmup_steps=warmup_steps,
248-
)
249-
)
244+
warmup_steps, step_offset = 5, 3
245+
s = jax.jit(schedule.ema_schedule(warmup_steps=warmup_steps, step_offset=step_offset))
246+
expected_offset = 0.0
250247
expected_warmup = [0.0, 1.0 / 2, 2.0 / 3, 3.0 / 4, 4.0 / 5]
251248
expected_decay = 0.9999
252249
for step in range(10):
253250
value = s(step)
254-
if step < warmup_steps:
251+
if step < step_offset:
252+
# Test offset.
253+
self.assertAlmostEqual(expected_offset, value)
254+
elif step < warmup_steps + step_offset:
255255
# Test warmup.
256-
self.assertAlmostEqual(expected_warmup[step], value)
256+
self.assertAlmostEqual(expected_warmup[step - step_offset], value)
257257
else:
258258
# Test inverse sqrt schedule.
259259
self.assertAlmostEqual(expected_decay, value)

0 commit comments

Comments
 (0)