@@ -26,6 +26,167 @@ def step(self, closure: Any) -> None:
2626 pass # Override NotImplementedError.
2727
2828
29+ class TestGetMultiplier (unittest .TestCase ):
30+ """Tests for the _get_multiplier function with TRANSFORMER policy."""
31+
32+ def test_transformer_warmup_at_step_one (self ) -> None :
33+ # Setup: Create TRANSFORMER warmup stage with warm_steps=4000
34+ stage = WarmupStage (
35+ policy = WarmupPolicy .TRANSFORMER ,
36+ max_iters = 4000 ,
37+ lr_scale = 1.0 ,
38+ warmup_steps = 4000 ,
39+ )
40+
41+ # Execute: Get multiplier at iteration 0 (step 1 internally)
42+ from torchrec .optim .warmup import _get_multiplier
43+
44+ multiplier = _get_multiplier (stage , iter = 0 )
45+
46+ # Assert: At step 1, multiplier should be min(1, 1/4000^1.5) ≈ 0.0000158
47+ # step^(-0.5) = 1^(-0.5) = 1.0
48+ # step * warm_steps^(-1.5) = 1 * 4000^(-1.5) ≈ 0.0000158
49+ expected = min (1.0 , 1 * (4000 ** (- 1.5 )))
50+ self .assertAlmostEqual (multiplier , expected , places = 8 )
51+ self .assertLess (multiplier , 0.00002 )
52+
53+ def test_transformer_warmup_at_warmup_steps (self ) -> None :
54+ # Setup: Create TRANSFORMER warmup stage with warm_steps=4000
55+ stage = WarmupStage (
56+ policy = WarmupPolicy .TRANSFORMER ,
57+ max_iters = 4000 ,
58+ lr_scale = 1.0 ,
59+ warmup_steps = 4000 ,
60+ )
61+
62+ # Execute: Get multiplier at iteration 3999 (step 4000 internally)
63+ from torchrec .optim .warmup import _get_multiplier
64+
65+ multiplier = _get_multiplier (stage , iter = 3999 )
66+
67+ # Assert: At step=warm_steps, both terms are equal
68+ # step^(-0.5) = 4000^(-0.5) ≈ 0.0158
69+ # step * warm_steps^(-1.5) = 4000 * 4000^(-1.5) ≈ 0.0158
70+ step = 4000
71+ expected = min (step ** (- 0.5 ), step * (4000 ** (- 1.5 )))
72+ self .assertAlmostEqual (multiplier , expected , places = 8 )
73+ self .assertAlmostEqual (multiplier , 0.0158114 , places = 6 )
74+
75+ def test_transformer_warmup_after_warmup_steps (self ) -> None :
76+ # Setup: Create TRANSFORMER warmup stage with warm_steps=4000
77+ stage = WarmupStage (
78+ policy = WarmupPolicy .TRANSFORMER ,
79+ max_iters = 4000 ,
80+ lr_scale = 1.0 ,
81+ warmup_steps = 4000 ,
82+ )
83+
84+ # Execute: Get multiplier at iteration 7999 (step 8000 internally)
85+ from torchrec .optim .warmup import _get_multiplier
86+
87+ multiplier = _get_multiplier (stage , iter = 7999 )
88+
89+ # Assert: After warmup, step^(-0.5) dominates (is smaller)
90+ # step^(-0.5) = 8000^(-0.5) ≈ 0.0112
91+ # step * warm_steps^(-1.5) = 8000 * 4000^(-1.5) ≈ 0.0316
92+ step = 8000
93+ inv_sqrt = step ** (- 0.5 )
94+ warmup_term = step * (4000 ** (- 1.5 ))
95+ self .assertAlmostEqual (multiplier , inv_sqrt , places = 8 )
96+ self .assertLess (inv_sqrt , warmup_term )
97+ self .assertAlmostEqual (multiplier , 0.0111803 , places = 6 )
98+
99+ def test_transformer_warmup_with_lr_scale (self ) -> None :
100+ # Setup: Create TRANSFORMER warmup stage with lr_scale=2.0
101+ stage = WarmupStage (
102+ policy = WarmupPolicy .TRANSFORMER ,
103+ max_iters = 4000 ,
104+ lr_scale = 2.0 ,
105+ warmup_steps = 4000 ,
106+ )
107+
108+ # Execute: Get multiplier at iteration 3999 (step 4000 internally)
109+ from torchrec .optim .warmup import _get_multiplier
110+
111+ multiplier = _get_multiplier (stage , iter = 3999 )
112+
113+ # Assert: lr_scale is applied as a multiplier
114+ step = 4000
115+ base_multiplier = min (step ** (- 0.5 ), step * (4000 ** (- 1.5 )))
116+ expected = base_multiplier * 2.0
117+ self .assertAlmostEqual (multiplier , expected , places = 8 )
118+
119+ def test_transformer_warmup_formula_correctness (self ) -> None :
120+ # Setup: Create TRANSFORMER warmup stage with warm_steps=1000
121+ stage = WarmupStage (
122+ policy = WarmupPolicy .TRANSFORMER ,
123+ max_iters = 1000 ,
124+ lr_scale = 1.0 ,
125+ warmup_steps = 1000 ,
126+ )
127+
128+ # Execute: Test multiple iterations to verify formula
129+ from torchrec .optim .warmup import _get_multiplier
130+
131+ test_iters = [0 , 99 , 499 , 999 , 1999 ] # steps 1, 100, 500, 1000, 2000
132+ for iter_val in test_iters :
133+ multiplier = _get_multiplier (stage , iter = iter_val )
134+ step = iter_val + 1
135+
136+ # Assert: Multiplier matches the Transformer formula
137+ expected = min (step ** (- 0.5 ), step * (1000 ** (- 1.5 )))
138+ self .assertAlmostEqual (
139+ multiplier ,
140+ expected ,
141+ places = 8 ,
142+ msg = f"Failed at iteration { iter_val } (step { step } )" ,
143+ )
144+
145+ def test_transformer_warmup_monotonic_increase_during_warmup (self ) -> None :
146+ # Setup: Create TRANSFORMER warmup stage with warm_steps=1000
147+ stage = WarmupStage (
148+ policy = WarmupPolicy .TRANSFORMER ,
149+ max_iters = 1000 ,
150+ lr_scale = 1.0 ,
151+ warmup_steps = 1000 ,
152+ )
153+
154+ # Execute: Get multipliers during warmup phase
155+ from torchrec .optim .warmup import _get_multiplier
156+
157+ multipliers = [_get_multiplier (stage , iter = i ) for i in range (0 , 1000 )]
158+
159+ # Assert: Multipliers should increase monotonically during warmup
160+ for idx in range (len (multipliers ) - 1 ):
161+ self .assertLess (
162+ multipliers [idx ],
163+ multipliers [idx + 1 ],
164+ msg = f"Multiplier should increase at iteration { idx } " ,
165+ )
166+
167+ def test_transformer_warmup_monotonic_decrease_after_warmup (self ) -> None :
168+ # Setup: Create TRANSFORMER warmup stage with warm_steps=1000
169+ stage = WarmupStage (
170+ policy = WarmupPolicy .TRANSFORMER ,
171+ max_iters = 1000 ,
172+ lr_scale = 1.0 ,
173+ warmup_steps = 1000 ,
174+ )
175+
176+ # Execute: Get multipliers after warmup phase
177+ from torchrec .optim .warmup import _get_multiplier
178+
179+ multipliers = [_get_multiplier (stage , iter = i ) for i in range (1000 , 2000 )]
180+
181+ # Assert: Multipliers should decrease monotonically after warmup
182+ for i in range (len (multipliers ) - 1 ):
183+ self .assertGreater (
184+ multipliers [i ],
185+ multipliers [i + 1 ],
186+ msg = f"Multiplier should decrease at iteration { i + 1000 } " ,
187+ )
188+
189+
29190class TestWarmupOptimizer (unittest .TestCase ):
30191 def test_load_state_dict (self ) -> None :
31192 def get_optimizer () -> WarmupOptimizer :
@@ -72,3 +233,157 @@ def get_optimizer() -> WarmupOptimizer:
72233 warmup_optimizer_1 .state_dict ()["state" ]["__warmup" ],
73234 warmup_optimizer_2 .state_dict ()["state" ]["__warmup" ],
74235 )
236+
237+ def test_transformer_warmup_integration (self ) -> None :
238+ # Setup: Create optimizer with TRANSFORMER warmup policy
239+ param = Variable (torch .tensor ([1.0 , 2.0 ]))
240+ keyed_optimizer = DummyKeyedOptimizer (
241+ {"param" : param }, defaultdict (dict ), [{"params" : [param ]}]
242+ )
243+
244+ base_lr = 0.001
245+ warm_steps = 100
246+
247+ warmup_optimizer = WarmupOptimizer (
248+ keyed_optimizer ,
249+ stages = [
250+ WarmupStage (
251+ policy = WarmupPolicy .TRANSFORMER ,
252+ max_iters = 100 , # Stage ends at iteration 100
253+ lr_scale = 1.0 ,
254+ warmup_steps = 100 ,
255+ ),
256+ ],
257+ lr = base_lr ,
258+ )
259+
260+ # Execute: Run optimizer through warmup steps
261+ learning_rates = []
262+ current_lr = 0.0
263+ for _ in range (100 ): # Only iterate through the TRANSFORMER stage
264+ for param_group in warmup_optimizer .param_groups :
265+ current_lr = param_group ["lr" ]
266+ learning_rates .append (current_lr )
267+ warmup_optimizer .step ()
268+
269+ # Assert: Verify learning rate follows Transformer schedule during warmup
270+ # At step 1 (iteration 0)
271+ step_1 = 1
272+ expected_lr_1 = base_lr * min (step_1 ** (- 0.5 ), step_1 * (warm_steps ** (- 1.5 )))
273+ self .assertAlmostEqual (learning_rates [0 ], expected_lr_1 , places = 10 )
274+
275+ # At step 50 (iteration 49) - mid-warmup
276+ step_50 = 50
277+ expected_lr_50 = base_lr * min (
278+ step_50 ** (- 0.5 ), step_50 * (warm_steps ** (- 1.5 ))
279+ )
280+ self .assertAlmostEqual (learning_rates [49 ], expected_lr_50 , places = 10 )
281+
282+ # At step 100 (iteration 99) - warmup completion
283+ step_100 = 100
284+ expected_lr_100 = base_lr * min (
285+ step_100 ** (- 0.5 ), step_100 * (warm_steps ** (- 1.5 ))
286+ )
287+ self .assertAlmostEqual (learning_rates [99 ], expected_lr_100 , places = 10 )
288+
289+ # Verify learning rate increases monotonically during warmup
290+ for idx in range (warm_steps - 1 ):
291+ self .assertLess (
292+ learning_rates [idx ],
293+ learning_rates [idx + 1 ],
294+ msg = f"LR should increase during warmup at step { idx + 1 } " ,
295+ )
296+ # Verify formula correctness at this step
297+ step = idx + 1
298+ expected_lr_at_idx = base_lr * min (
299+ step ** (- 0.5 ), step * (warm_steps ** (- 1.5 ))
300+ )
301+ self .assertAlmostEqual (
302+ learning_rates [idx ],
303+ expected_lr_at_idx ,
304+ places = 10 ,
305+ msg = f"LR mismatch at step { step } " ,
306+ )
307+
308+ def test_transformer_warmup_with_extended_stage (self ) -> None :
309+ # Setup: Create optimizer with TRANSFORMER stage to test warmup and decay
310+ param = Variable (torch .tensor ([1.0 , 2.0 ]))
311+ keyed_optimizer = DummyKeyedOptimizer (
312+ {"param" : param }, defaultdict (dict ), [{"params" : [param ]}]
313+ )
314+
315+ base_lr = 0.001
316+ # In the TRANSFORMER policy, max_iters acts as warm_steps in the formula
317+ max_iters = 8000 # Stage runs for 8000 iterations
318+
319+ warmup_optimizer = WarmupOptimizer (
320+ keyed_optimizer ,
321+ stages = [
322+ WarmupStage (
323+ policy = WarmupPolicy .TRANSFORMER ,
324+ max_iters = max_iters , # Stage runs for 8000 iterations
325+ lr_scale = 1.0 ,
326+ warmup_steps = max_iters ,
327+ ),
328+ ],
329+ lr = base_lr ,
330+ )
331+
332+ # Execute: Run optimizer through warmup and decay phases
333+ current_lr = 0.0
334+ learning_rates = []
335+ for _ in range (max_iters ):
336+ for param_group in warmup_optimizer .param_groups :
337+ current_lr = param_group ["lr" ]
338+ learning_rates .append (current_lr )
339+ warmup_optimizer .step ()
340+
341+ # Assert: Verify the formula uses max_iters as warm_steps
342+ # At step 1, verify the formula: min(step^(-0.5), step * max_iters^(-1.5))
343+ step_1 = 1
344+ expected_lr_1 = base_lr * min (step_1 ** (- 0.5 ), step_1 * (max_iters ** (- 1.5 )))
345+ self .assertAlmostEqual (
346+ learning_rates [0 ],
347+ expected_lr_1 ,
348+ places = 10 ,
349+ msg = f"LR at step 1 should match formula with warm_steps={ max_iters } " ,
350+ )
351+
352+ # At step 4000, verify with max_iters=8000
353+ step_4000 = 4000
354+ expected_lr_4000 = base_lr * min (
355+ step_4000 ** (- 0.5 ), step_4000 * (max_iters ** (- 1.5 ))
356+ )
357+ self .assertAlmostEqual (
358+ learning_rates [3999 ],
359+ expected_lr_4000 ,
360+ places = 10 ,
361+ msg = f"LR at step 4000 should match formula with warm_steps={ max_iters } " ,
362+ )
363+
364+ # At step max_iters (8000), both terms should be equal
365+ step_max = max_iters
366+ inv_sqrt = step_max ** (- 0.5 )
367+ warmup_term = step_max * (max_iters ** (- 1.5 ))
368+ self .assertAlmostEqual (
369+ inv_sqrt ,
370+ warmup_term ,
371+ places = 10 ,
372+ msg = f"At step={ max_iters } , both formula terms should be equal" ,
373+ )
374+
375+ expected_lr_max = base_lr * min (inv_sqrt , warmup_term )
376+ self .assertAlmostEqual (
377+ learning_rates [max_iters - 1 ],
378+ expected_lr_max ,
379+ places = 10 ,
380+ msg = f"LR at step { max_iters } should match formula" ,
381+ )
382+
383+ # Verify learning rate increases before max_iters
384+ for idx in range (max_iters - 1 ):
385+ self .assertLess (
386+ learning_rates [idx ],
387+ learning_rates [idx + 1 ],
388+ msg = f"LR should increase at step { idx + 1 } (before max_iters={ max_iters } )" ,
389+ )
0 commit comments