Skip to content

Commit debbdac

Browse files
Element Panfacebook-github-bot
authored andcommitted
Add TRANSFORMER warmup policy for learning rate scheduling (#3548)
Summary: This diff implements the TRANSFORMER warmup policy from "Attention is All You Need" (Vaswani et al., 2017) for learning rate scheduling in torchrec, and updates a model configuration to use it. ## Implementation Added `WarmupPolicy.TRANSFORMER` to `fbcode/torchrec/optim/warmup.py` which implements the formula: ``` lr = base_lr * min(step^(-0.5), step * warm_steps^(-1.5)) * lr_scale ``` This schedule provides: - **Warmup phase**: LR increases from near-zero to peak at `warm_steps` - **Decay phase**: LR decreases via inverse square root after `warm_steps` The `max_iters` parameter serves as `warm_steps` in the formula. The schedule converges at step = warm_steps where both terms in the min() function become equal. ## Testing Added comprehensive unit tests in `fbcode/torchrec/optim/tests/test_warmup.py`: - Formula correctness at key milestones (step 1, warmup completion, post-warmup) - Monotonic increase during warmup phase - Monotonic decrease during decay phase - Proper application of lr_scale multiplier - Integration tests with WarmupOptimizer - Uses `none_throws()` from `pyre_extensions` for type-safe Optional handling Updated `fbcode/torchrec/optim/tests/BUCK` to include `pyre-extensions` dependency. ## Configuration Update Updated `fbcode/minimal_viable_ai/models/gysj/gysj_esr_roo/conf/model_roo_config.py` to use TRANSFORMER warmup: - Changed both sparse and dense optimizers from LINEAR warmup to TRANSFORMER - Set warm_steps to 80,000 for both optimizers - Updated optimizer hyperparameters (lr=0.001, eps=1e-6, beta values) to work with TRANSFORMER schedule - Sparse optimizer changed from ROWWISE_ADAGRAD to ADAM for consistency with dense optimizer Differential Revision: D87127589
1 parent 217889e commit debbdac

File tree

2 files changed

+327
-0
lines changed

2 files changed

+327
-0
lines changed

torchrec/optim/tests/test_warmup.py

Lines changed: 315 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -26,6 +26,167 @@ def step(self, closure: Any) -> None:
2626
pass # Override NotImplementedError.
2727

2828

29+
class TestGetMultiplier(unittest.TestCase):
30+
"""Tests for the _get_multiplier function with TRANSFORMER policy."""
31+
32+
def test_transformer_warmup_at_step_one(self) -> None:
33+
# Setup: Create TRANSFORMER warmup stage with warm_steps=4000
34+
stage = WarmupStage(
35+
policy=WarmupPolicy.TRANSFORMER,
36+
max_iters=4000,
37+
lr_scale=1.0,
38+
warmup_steps=4000,
39+
)
40+
41+
# Execute: Get multiplier at iteration 0 (step 1 internally)
42+
from torchrec.optim.warmup import _get_multiplier
43+
44+
multiplier = _get_multiplier(stage, iter=0)
45+
46+
# Assert: At step 1, multiplier should be min(1, 1/4000^1.5) ≈ 0.0000158
47+
# step^(-0.5) = 1^(-0.5) = 1.0
48+
# step * warm_steps^(-1.5) = 1 * 4000^(-1.5) ≈ 0.0000158
49+
expected = min(1.0, 1 * (4000 ** (-1.5)))
50+
self.assertAlmostEqual(multiplier, expected, places=8)
51+
self.assertLess(multiplier, 0.00002)
52+
53+
def test_transformer_warmup_at_warmup_steps(self) -> None:
54+
# Setup: Create TRANSFORMER warmup stage with warm_steps=4000
55+
stage = WarmupStage(
56+
policy=WarmupPolicy.TRANSFORMER,
57+
max_iters=4000,
58+
lr_scale=1.0,
59+
warmup_steps=4000,
60+
)
61+
62+
# Execute: Get multiplier at iteration 3999 (step 4000 internally)
63+
from torchrec.optim.warmup import _get_multiplier
64+
65+
multiplier = _get_multiplier(stage, iter=3999)
66+
67+
# Assert: At step=warm_steps, both terms are equal
68+
# step^(-0.5) = 4000^(-0.5) ≈ 0.0158
69+
# step * warm_steps^(-1.5) = 4000 * 4000^(-1.5) ≈ 0.0158
70+
step = 4000
71+
expected = min(step ** (-0.5), step * (4000 ** (-1.5)))
72+
self.assertAlmostEqual(multiplier, expected, places=8)
73+
self.assertAlmostEqual(multiplier, 0.0158114, places=6)
74+
75+
def test_transformer_warmup_after_warmup_steps(self) -> None:
76+
# Setup: Create TRANSFORMER warmup stage with warm_steps=4000
77+
stage = WarmupStage(
78+
policy=WarmupPolicy.TRANSFORMER,
79+
max_iters=4000,
80+
lr_scale=1.0,
81+
warmup_steps=4000,
82+
)
83+
84+
# Execute: Get multiplier at iteration 7999 (step 8000 internally)
85+
from torchrec.optim.warmup import _get_multiplier
86+
87+
multiplier = _get_multiplier(stage, iter=7999)
88+
89+
# Assert: After warmup, step^(-0.5) dominates (is smaller)
90+
# step^(-0.5) = 8000^(-0.5) ≈ 0.0112
91+
# step * warm_steps^(-1.5) = 8000 * 4000^(-1.5) ≈ 0.0316
92+
step = 8000
93+
inv_sqrt = step ** (-0.5)
94+
warmup_term = step * (4000 ** (-1.5))
95+
self.assertAlmostEqual(multiplier, inv_sqrt, places=8)
96+
self.assertLess(inv_sqrt, warmup_term)
97+
self.assertAlmostEqual(multiplier, 0.0111803, places=6)
98+
99+
def test_transformer_warmup_with_lr_scale(self) -> None:
100+
# Setup: Create TRANSFORMER warmup stage with lr_scale=2.0
101+
stage = WarmupStage(
102+
policy=WarmupPolicy.TRANSFORMER,
103+
max_iters=4000,
104+
lr_scale=2.0,
105+
warmup_steps=4000,
106+
)
107+
108+
# Execute: Get multiplier at iteration 3999 (step 4000 internally)
109+
from torchrec.optim.warmup import _get_multiplier
110+
111+
multiplier = _get_multiplier(stage, iter=3999)
112+
113+
# Assert: lr_scale is applied as a multiplier
114+
step = 4000
115+
base_multiplier = min(step ** (-0.5), step * (4000 ** (-1.5)))
116+
expected = base_multiplier * 2.0
117+
self.assertAlmostEqual(multiplier, expected, places=8)
118+
119+
def test_transformer_warmup_formula_correctness(self) -> None:
120+
# Setup: Create TRANSFORMER warmup stage with warm_steps=1000
121+
stage = WarmupStage(
122+
policy=WarmupPolicy.TRANSFORMER,
123+
max_iters=1000,
124+
lr_scale=1.0,
125+
warmup_steps=1000,
126+
)
127+
128+
# Execute: Test multiple iterations to verify formula
129+
from torchrec.optim.warmup import _get_multiplier
130+
131+
test_iters = [0, 99, 499, 999, 1999] # steps 1, 100, 500, 1000, 2000
132+
for iter_val in test_iters:
133+
multiplier = _get_multiplier(stage, iter=iter_val)
134+
step = iter_val + 1
135+
136+
# Assert: Multiplier matches the Transformer formula
137+
expected = min(step ** (-0.5), step * (1000 ** (-1.5)))
138+
self.assertAlmostEqual(
139+
multiplier,
140+
expected,
141+
places=8,
142+
msg=f"Failed at iteration {iter_val} (step {step})",
143+
)
144+
145+
def test_transformer_warmup_monotonic_increase_during_warmup(self) -> None:
146+
# Setup: Create TRANSFORMER warmup stage with warm_steps=1000
147+
stage = WarmupStage(
148+
policy=WarmupPolicy.TRANSFORMER,
149+
max_iters=1000,
150+
lr_scale=1.0,
151+
warmup_steps=1000,
152+
)
153+
154+
# Execute: Get multipliers during warmup phase
155+
from torchrec.optim.warmup import _get_multiplier
156+
157+
multipliers = [_get_multiplier(stage, iter=i) for i in range(0, 1000)]
158+
159+
# Assert: Multipliers should increase monotonically during warmup
160+
for idx in range(len(multipliers) - 1):
161+
self.assertLess(
162+
multipliers[idx],
163+
multipliers[idx + 1],
164+
msg=f"Multiplier should increase at iteration {idx}",
165+
)
166+
167+
def test_transformer_warmup_monotonic_decrease_after_warmup(self) -> None:
168+
# Setup: Create TRANSFORMER warmup stage with warm_steps=1000
169+
stage = WarmupStage(
170+
policy=WarmupPolicy.TRANSFORMER,
171+
max_iters=1000,
172+
lr_scale=1.0,
173+
warmup_steps=1000,
174+
)
175+
176+
# Execute: Get multipliers after warmup phase
177+
from torchrec.optim.warmup import _get_multiplier
178+
179+
multipliers = [_get_multiplier(stage, iter=i) for i in range(1000, 2000)]
180+
181+
# Assert: Multipliers should decrease monotonically after warmup
182+
for i in range(len(multipliers) - 1):
183+
self.assertGreater(
184+
multipliers[i],
185+
multipliers[i + 1],
186+
msg=f"Multiplier should decrease at iteration {i + 1000}",
187+
)
188+
189+
29190
class TestWarmupOptimizer(unittest.TestCase):
30191
def test_load_state_dict(self) -> None:
31192
def get_optimizer() -> WarmupOptimizer:
@@ -72,3 +233,157 @@ def get_optimizer() -> WarmupOptimizer:
72233
warmup_optimizer_1.state_dict()["state"]["__warmup"],
73234
warmup_optimizer_2.state_dict()["state"]["__warmup"],
74235
)
236+
237+
def test_transformer_warmup_integration(self) -> None:
238+
# Setup: Create optimizer with TRANSFORMER warmup policy
239+
param = Variable(torch.tensor([1.0, 2.0]))
240+
keyed_optimizer = DummyKeyedOptimizer(
241+
{"param": param}, defaultdict(dict), [{"params": [param]}]
242+
)
243+
244+
base_lr = 0.001
245+
warm_steps = 100
246+
247+
warmup_optimizer = WarmupOptimizer(
248+
keyed_optimizer,
249+
stages=[
250+
WarmupStage(
251+
policy=WarmupPolicy.TRANSFORMER,
252+
max_iters=100, # Stage ends at iteration 100
253+
lr_scale=1.0,
254+
warmup_steps=100,
255+
),
256+
],
257+
lr=base_lr,
258+
)
259+
260+
# Execute: Run optimizer through warmup steps
261+
learning_rates = []
262+
current_lr = 0.0
263+
for _ in range(100): # Only iterate through the TRANSFORMER stage
264+
for param_group in warmup_optimizer.param_groups:
265+
current_lr = param_group["lr"]
266+
learning_rates.append(current_lr)
267+
warmup_optimizer.step()
268+
269+
# Assert: Verify learning rate follows Transformer schedule during warmup
270+
# At step 1 (iteration 0)
271+
step_1 = 1
272+
expected_lr_1 = base_lr * min(step_1 ** (-0.5), step_1 * (warm_steps ** (-1.5)))
273+
self.assertAlmostEqual(learning_rates[0], expected_lr_1, places=10)
274+
275+
# At step 50 (iteration 49) - mid-warmup
276+
step_50 = 50
277+
expected_lr_50 = base_lr * min(
278+
step_50 ** (-0.5), step_50 * (warm_steps ** (-1.5))
279+
)
280+
self.assertAlmostEqual(learning_rates[49], expected_lr_50, places=10)
281+
282+
# At step 100 (iteration 99) - warmup completion
283+
step_100 = 100
284+
expected_lr_100 = base_lr * min(
285+
step_100 ** (-0.5), step_100 * (warm_steps ** (-1.5))
286+
)
287+
self.assertAlmostEqual(learning_rates[99], expected_lr_100, places=10)
288+
289+
# Verify learning rate increases monotonically during warmup
290+
for idx in range(warm_steps - 1):
291+
self.assertLess(
292+
learning_rates[idx],
293+
learning_rates[idx + 1],
294+
msg=f"LR should increase during warmup at step {idx + 1}",
295+
)
296+
# Verify formula correctness at this step
297+
step = idx + 1
298+
expected_lr_at_idx = base_lr * min(
299+
step ** (-0.5), step * (warm_steps ** (-1.5))
300+
)
301+
self.assertAlmostEqual(
302+
learning_rates[idx],
303+
expected_lr_at_idx,
304+
places=10,
305+
msg=f"LR mismatch at step {step}",
306+
)
307+
308+
def test_transformer_warmup_with_extended_stage(self) -> None:
309+
# Setup: Create optimizer with TRANSFORMER stage to test warmup and decay
310+
param = Variable(torch.tensor([1.0, 2.0]))
311+
keyed_optimizer = DummyKeyedOptimizer(
312+
{"param": param}, defaultdict(dict), [{"params": [param]}]
313+
)
314+
315+
base_lr = 0.001
316+
# In the TRANSFORMER policy, max_iters acts as warm_steps in the formula
317+
max_iters = 8000 # Stage runs for 8000 iterations
318+
319+
warmup_optimizer = WarmupOptimizer(
320+
keyed_optimizer,
321+
stages=[
322+
WarmupStage(
323+
policy=WarmupPolicy.TRANSFORMER,
324+
max_iters=max_iters, # Stage runs for 8000 iterations
325+
lr_scale=1.0,
326+
warmup_steps=max_iters,
327+
),
328+
],
329+
lr=base_lr,
330+
)
331+
332+
# Execute: Run optimizer through warmup and decay phases
333+
current_lr = 0.0
334+
learning_rates = []
335+
for _ in range(max_iters):
336+
for param_group in warmup_optimizer.param_groups:
337+
current_lr = param_group["lr"]
338+
learning_rates.append(current_lr)
339+
warmup_optimizer.step()
340+
341+
# Assert: Verify the formula uses max_iters as warm_steps
342+
# At step 1, verify the formula: min(step^(-0.5), step * max_iters^(-1.5))
343+
step_1 = 1
344+
expected_lr_1 = base_lr * min(step_1 ** (-0.5), step_1 * (max_iters ** (-1.5)))
345+
self.assertAlmostEqual(
346+
learning_rates[0],
347+
expected_lr_1,
348+
places=10,
349+
msg=f"LR at step 1 should match formula with warm_steps={max_iters}",
350+
)
351+
352+
# At step 4000, verify with max_iters=8000
353+
step_4000 = 4000
354+
expected_lr_4000 = base_lr * min(
355+
step_4000 ** (-0.5), step_4000 * (max_iters ** (-1.5))
356+
)
357+
self.assertAlmostEqual(
358+
learning_rates[3999],
359+
expected_lr_4000,
360+
places=10,
361+
msg=f"LR at step 4000 should match formula with warm_steps={max_iters}",
362+
)
363+
364+
# At step max_iters (8000), both terms should be equal
365+
step_max = max_iters
366+
inv_sqrt = step_max ** (-0.5)
367+
warmup_term = step_max * (max_iters ** (-1.5))
368+
self.assertAlmostEqual(
369+
inv_sqrt,
370+
warmup_term,
371+
places=10,
372+
msg=f"At step={max_iters}, both formula terms should be equal",
373+
)
374+
375+
expected_lr_max = base_lr * min(inv_sqrt, warmup_term)
376+
self.assertAlmostEqual(
377+
learning_rates[max_iters - 1],
378+
expected_lr_max,
379+
places=10,
380+
msg=f"LR at step {max_iters} should match formula",
381+
)
382+
383+
# Verify learning rate increases before max_iters
384+
for idx in range(max_iters - 1):
385+
self.assertLess(
386+
learning_rates[idx],
387+
learning_rates[idx + 1],
388+
msg=f"LR should increase at step {idx + 1} (before max_iters={max_iters})",
389+
)

torchrec/optim/warmup.py

Lines changed: 12 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -28,6 +28,9 @@ class WarmupPolicy(Enum):
2828
STEP = "step"
2929
INVSQRT = "inv_sqrt" # inverse square root
3030
COSINE_ANNEALING_WARM_RESTARTS = "cosine_annealing_warm_restarts"
31+
TRANSFORMER = (
32+
"transformer" # Transformer warmup: min(step^(-0.5), step * warm_steps^(-1.5))
33+
)
3134

3235

3336
@dataclass
@@ -42,6 +45,8 @@ class WarmupStage:
4245
# default to 1 if not set to value > 0
4346
decay_iters: int = -1
4447
sgdr_period: int = 1
48+
# used as warmup_steps in transformer decay
49+
warmup_steps: int = 1
4550

4651

4752
def _lr_stages(stages: List[WarmupStage]) -> List[WarmupStage]:
@@ -86,6 +91,13 @@ def _get_multiplier(stage: WarmupStage, iter: int) -> float:
8691
t_cur = iter % t_0
8792
cos_iter = 0.5 * (1 + math.cos(math.pi * t_cur / t_0))
8893
multiplier = eta_min + (1.0 - eta_min) * cos_iter
94+
elif stage.policy == WarmupPolicy.TRANSFORMER:
95+
# Transformer warmup from "Attention is All You Need" (Vaswani et al., 2017)
96+
# Formula: lr_scale = min(step^(-0.5), step * warm_steps^(-1.5))
97+
# where warm_steps = max_iters
98+
# Add 1 to iter to make it 1-indexed and avoid division by zero
99+
step = iter + 1
100+
multiplier = min(step ** (-0.5), step * stage.warmup_steps ** (-1.5))
89101
return multiplier * stage.lr_scale
90102

91103

0 commit comments

Comments
 (0)