33import torch .nn .functional as F
44
55from liger_kernel .chunked_loss import LigerFusedLinearGRPOLoss
6- from liger_kernel .chunked_loss .grpo_loss import LigerFusedLinearGRPOFunction
76from liger_kernel .chunked_loss .functional import liger_fused_linear_grpo
7+ from liger_kernel .chunked_loss .grpo_loss import LigerFusedLinearGRPOFunction
88from liger_kernel .utils import infer_device
99from test .utils import assert_verbose_allclose
1010from test .utils import set_seed
1616# reset torch compiler cache
1717torch .compiler .reset ()
1818
19+
1920class TorchLMHeadGRPO (torch .nn .Module ):
2021 def __init__ (
2122 self ,
@@ -38,7 +39,7 @@ def __init__(
3839 self .epsilon_high = epsilon_high
3940 self .temperature = temperature
4041 self .use_ref_model = use_ref_model
41-
42+
4243 def forward (
4344 self ,
4445 x , # Shape: [batch_size, seq_len, hidden_size]
@@ -48,7 +49,7 @@ def forward(
4849 ref_input = None , # Shape: [batch_size, seq_len, hidden_size]
4950 old_per_token_logps = None ,
5051 ):
51- logits = ( x @ self .lin .weight .t () )
52+ logits = x @ self .lin .weight .t ()
5253 if self .lin .bias is not None :
5354 logits = logits + self .lin .bias
5455 if self .temperature != 1.0 :
@@ -81,9 +82,7 @@ def forward(
8182 per_token_loss = - torch .min (per_token_loss1 , per_token_loss2 )
8283 if self .beta != 0.0 :
8384 # Compute KL divergence between model and reference model
84- kl_div = (
85- torch .exp (ref_per_token_logps - per_token_logps ) - (ref_per_token_logps - per_token_logps ) - 1.0
86- )
85+ kl_div = torch .exp (ref_per_token_logps - per_token_logps ) - (ref_per_token_logps - per_token_logps ) - 1.0
8786 per_token_loss = per_token_loss + self .beta * kl_div
8887
8988 # Apply masking and normalize
@@ -171,9 +170,9 @@ def forward(
171170 "beta, epsilon_low, epsilon_high, temperature" ,
172171 [
173172 # Standard settings
174- (0.1 , 0.2 , 0.2 , 20.0 ), # set temperature to 20.0 for better numerical stability
173+ (0.1 , 0.2 , 0.2 , 20.0 ), # set temperature to 20.0 for better numerical stability
175174 (0.0 , 0.1 , 0.1 , 2.0 ),
176- ]
175+ ],
177176)
178177@pytest .mark .parametrize ("use_ref_model" , [True , False ])
179178@pytest .mark .parametrize ("old_per_token_logps" , [True , False ])
@@ -231,7 +230,9 @@ def test_correctness(
231230 V , H , device = device , dtype = dtype
232231 )
233232 if ref_bias :
234- torch_lm_head_grpo .ref_lin .bias .data = liger_lm_head_grpo .ref_lin .bias .data = torch .randn (V , device = device , dtype = dtype )
233+ torch_lm_head_grpo .ref_lin .bias .data = liger_lm_head_grpo .ref_lin .bias .data = torch .randn (
234+ V , device = device , dtype = dtype
235+ )
235236
236237 # Create inputs with shape [B, T, H]
237238 _input = torch .randn (B , T , H , device = device , dtype = dtype ) * scalar
@@ -260,15 +261,25 @@ def test_correctness(
260261
261262 # Forward pass with reference model
262263 loss1 , aux1 = torch_lm_head_grpo (
263- input1 , selected_token_ids , attention_mask , advantages , ref_input = ref_input , old_per_token_logps = old_per_token_logps
264+ input1 ,
265+ selected_token_ids ,
266+ attention_mask ,
267+ advantages ,
268+ ref_input = ref_input ,
269+ old_per_token_logps = old_per_token_logps ,
264270 )
265271 loss2 , aux2 = liger_lm_head_grpo (
266- input2 , selected_token_ids , attention_mask , advantages , ref_input = ref_input , old_per_token_logps = old_per_token_logps
272+ input2 ,
273+ selected_token_ids ,
274+ attention_mask ,
275+ advantages ,
276+ ref_input = ref_input ,
277+ old_per_token_logps = old_per_token_logps ,
267278 )
268279
269280 # Check losses match
270- assert loss1 != float (' nan' )
271- assert loss2 != float (' nan' )
281+ assert loss1 != float (" nan" )
282+ assert loss2 != float (" nan" )
272283 assert_verbose_allclose (loss1 , loss2 , atol = atol , rtol = rtol )
273284
274285 # Check metrics match
@@ -296,6 +307,7 @@ def test_correctness(
296307 rtol = rtol ,
297308 )
298309
310+
299311@pytest .mark .parametrize (
300312 "B, T, H, V" ,
301313 [
@@ -316,14 +328,29 @@ def test_correctness(
316328 "beta, epsilon_low, epsilon_high, temperature" ,
317329 [
318330 # Standard settings
319- (0.1 , 0.2 , 0.2 , 20.0 ), # set temperature to 20.0 for better numerical stability
331+ (0.1 , 0.2 , 0.2 , 20.0 ), # set temperature to 20.0 for better numerical stability
320332 (0.0 , 0.1 , 0.1 , 2.0 ),
321- ]
333+ ],
322334)
323335@pytest .mark .parametrize ("use_ref_model" , [True , False ])
324336@pytest .mark .parametrize ("old_per_token_logps" , [True , False ])
325337def test_functional_correctness (
326- B , T , H , V , scalar , dtype , atol , rtol , bias , ref_bias , beta , epsilon_low , epsilon_high , temperature , use_ref_model , old_per_token_logps
338+ B ,
339+ T ,
340+ H ,
341+ V ,
342+ scalar ,
343+ dtype ,
344+ atol ,
345+ rtol ,
346+ bias ,
347+ ref_bias ,
348+ beta ,
349+ epsilon_low ,
350+ epsilon_high ,
351+ temperature ,
352+ use_ref_model ,
353+ old_per_token_logps ,
327354):
328355 _input = torch .randn (B , T , H , device = device , dtype = dtype ) * scalar
329356 input1 = _input .detach ().clone ().requires_grad_ (True )
@@ -334,7 +361,7 @@ def test_functional_correctness(
334361 weight2 = _weight .detach ().clone ().requires_grad_ (True )
335362
336363 selected_token_ids = torch .randint (0 , V , (B , T ), device = device )
337-
364+
338365 attention_mask = torch .ones (B , T , device = device )
339366
340367 advantages = torch .rand (B , device = device , dtype = dtype )
@@ -348,7 +375,7 @@ def test_functional_correctness(
348375 bias2 = None
349376
350377 ref_input = torch .randn (B , T , H , device = device , dtype = dtype ) * scalar
351-
378+
352379 _ref_weight = torch .randn (V , H , device = device , dtype = dtype ) * scalar
353380 ref_weight1 = _ref_weight .detach ().clone ().requires_grad_ (True )
354381 ref_weight2 = _ref_weight .detach ().clone ().requires_grad_ (True )
@@ -367,15 +394,47 @@ def test_functional_correctness(
367394 old_per_token_logps = None
368395
369396 loss1 , aux1 = liger_fused_linear_grpo (
370- input1 , weight1 , selected_token_ids , attention_mask , advantages , bias1 , ref_input , ref_weight1 , ref_bias1 , old_per_token_logps , beta , epsilon_low , epsilon_high , temperature , True , use_ref_model , 1
397+ input1 ,
398+ weight1 ,
399+ selected_token_ids ,
400+ attention_mask ,
401+ advantages ,
402+ bias1 ,
403+ ref_input ,
404+ ref_weight1 ,
405+ ref_bias1 ,
406+ old_per_token_logps ,
407+ beta ,
408+ epsilon_low ,
409+ epsilon_high ,
410+ temperature ,
411+ True ,
412+ use_ref_model ,
413+ 1 ,
371414 )
372415
373416 loss2 , aux2 = LigerFusedLinearGRPOFunction .apply (
374- input2 , weight2 , selected_token_ids , attention_mask , advantages , bias2 , ref_input , ref_weight2 , ref_bias2 , old_per_token_logps , beta , epsilon_low , epsilon_high , temperature , True , use_ref_model , 1
417+ input2 ,
418+ weight2 ,
419+ selected_token_ids ,
420+ attention_mask ,
421+ advantages ,
422+ bias2 ,
423+ ref_input ,
424+ ref_weight2 ,
425+ ref_bias2 ,
426+ old_per_token_logps ,
427+ beta ,
428+ epsilon_low ,
429+ epsilon_high ,
430+ temperature ,
431+ True ,
432+ use_ref_model ,
433+ 1 ,
375434 )
376435
377- assert loss1 != float (' nan' )
378- assert loss2 != float (' nan' )
436+ assert loss1 != float (" nan" )
437+ assert loss2 != float (" nan" )
379438 assert_verbose_allclose (loss1 , loss2 , atol = atol , rtol = rtol )
380439
381440 # Check metrics match
0 commit comments