22import torch
33import torch .nn .functional as F
44
5+ from liger_kernel .chunked_loss import LigerFusedLinearGRPOLoss
56from liger_kernel .chunked_loss .grpo_loss import LigerFusedLinearGRPOFunction
7+ from liger_kernel .chunked_loss .functional import liger_fused_linear_grpo
68from liger_kernel .utils import infer_device
79from test .utils import assert_verbose_allclose
810from test .utils import set_seed
@@ -40,6 +42,7 @@ def __init__(
4042 def forward (
4143 self ,
4244 x , # Shape: [batch_size, seq_len, hidden_size]
45+ selected_token_ids , # Shape: [batch_size, seq_len]
4346 attention_mask , # Shape: [batch_size, seq_len]
4447 advantages , # Shape: [batch_size,]
4548 ref_input = None , # Shape: [batch_size, seq_len, hidden_size]
@@ -54,8 +57,7 @@ def forward(
5457 log_probs = F .log_softmax (logits , dim = - 1 )
5558
5659 # Get chosen token probabilities
57- chosen_tokens = log_probs .argmax (dim = - 1 )
58- chosen_token_logprobs = log_probs .gather (dim = - 1 , index = chosen_tokens .unsqueeze (- 1 )).squeeze (- 1 )
60+ per_token_logps = log_probs .gather (dim = - 1 , index = selected_token_ids .unsqueeze (- 1 )).squeeze (- 1 )
5961
6062 # Get reference model probabilities
6163 if self .use_ref_model :
@@ -66,22 +68,21 @@ def forward(
6668 if self .temperature != 1.0 :
6769 ref_logits = ref_logits / self .temperature
6870 ref_log_probs = F .log_softmax (ref_logits , dim = - 1 )
69- ref_token_logprobs = ref_log_probs .gather (dim = - 1 , index = chosen_tokens .unsqueeze (- 1 )).squeeze (- 1 )
71+ ref_per_token_logps = ref_log_probs .gather (dim = - 1 , index = selected_token_ids .unsqueeze (- 1 )).squeeze (- 1 )
7072 else :
71- ref_token_logprobs = chosen_token_logprobs .detach ()
72-
73+ ref_per_token_logps = per_token_logps .detach ()
7374
7475 # Compute policy gradient loss with importance sampling ratio
75- old_per_token_logps = old_per_token_logps if old_per_token_logps is not None else chosen_token_logprobs .detach ()
76- coef_1 = torch .exp (chosen_token_logprobs - old_per_token_logps )
76+ old_per_token_logps = old_per_token_logps if old_per_token_logps is not None else per_token_logps .detach ()
77+ coef_1 = torch .exp (per_token_logps - old_per_token_logps )
7778 coef_2 = torch .clamp (coef_1 , 1 - self .epsilon_low , 1 + self .epsilon_high )
7879 per_token_loss1 = coef_1 * advantages .unsqueeze (1 )
7980 per_token_loss2 = coef_2 * advantages .unsqueeze (1 )
8081 per_token_loss = - torch .min (per_token_loss1 , per_token_loss2 )
8182 if self .beta != 0.0 :
8283 # Compute KL divergence between model and reference model
8384 kl_div = (
84- torch .exp (ref_token_logprobs - chosen_token_logprobs ) - (ref_token_logprobs - chosen_token_logprobs ) - 1.0
85+ torch .exp (ref_per_token_logps - per_token_logps ) - (ref_per_token_logps - per_token_logps ) - 1.0
8586 )
8687 per_token_loss = per_token_loss + self .beta * kl_div
8788
@@ -90,7 +91,7 @@ def forward(
9091
9192 # Compute metrics
9293 metrics = [
93- chosen_token_logprobs .mean (),
94+ per_token_logps .mean (),
9495 log_probs .mean (),
9596 ]
9697 if self .beta != 0.0 :
@@ -118,16 +119,18 @@ def __init__(
118119 super ().__init__ ()
119120 self .lin = torch .nn .Linear (in_features = H , out_features = V , bias = bias , dtype = dtype )
120121 self .ref_lin = torch .nn .Linear (in_features = H , out_features = V , bias = ref_bias , dtype = dtype )
121- self .grpo_loss = LigerFusedLinearGRPOFunction .apply
122- self .beta = beta
123- self .epsilon_low = epsilon_low
124- self .epsilon_high = epsilon_high
125- self .temperature = temperature
126- self .use_ref_model = use_ref_model
122+ self .grpo_loss = LigerFusedLinearGRPOLoss (
123+ beta = beta ,
124+ epsilon_low = epsilon_low ,
125+ epsilon_high = epsilon_high ,
126+ temperature = temperature ,
127+ use_ref_model = use_ref_model ,
128+ )
127129
128130 def forward (
129131 self ,
130132 x ,
133+ selected_token_ids ,
131134 attention_mask ,
132135 advantages ,
133136 ref_input = None ,
@@ -137,19 +140,14 @@ def forward(
137140 return self .grpo_loss (
138141 x , # _input
139142 self .lin .weight , # weight
143+ selected_token_ids , # selected_token_ids
140144 attention_mask , # attention_mask
141145 advantages , # advantages
142146 self .lin .bias , # bias
143147 ref_input , # ref_input
144148 self .ref_lin .weight , # ref_weight
145149 self .ref_lin .bias , # ref_bias
146150 old_per_token_logps , # old_per_token_logps
147- self .beta , # beta
148- self .epsilon_low , # epsilon_low
149- self .epsilon_high , # epsilon_high
150- self .temperature , # temperature
151- True , # compiled
152- self .use_ref_model , # use_ref_model
153151 )
154152
155153
@@ -173,7 +171,7 @@ def forward(
173171 "beta, epsilon_low, epsilon_high, temperature" ,
174172 [
175173 # Standard settings
176- (0.1 , 0.2 , 0.2 , 1 .0 ),
174+ (0.1 , 0.2 , 0.2 , 20 .0 ), # set temperature to 20.0 for better numerical stability
177175 (0.0 , 0.1 , 0.1 , 2.0 ),
178176 ]
179177)
@@ -240,6 +238,9 @@ def test_correctness(
240238 input1 = _input .detach ().clone ().requires_grad_ (True )
241239 input2 = _input .detach ().clone ().requires_grad_ (True )
242240
241+ # Create selected token ids with shape [B, T]
242+ selected_token_ids = torch .randint (0 , V , (B , T ), device = device )
243+
243244 # Create attention mask with random padding [B, T]
244245 attention_mask = torch .ones (B , T , device = device )
245246 num_elements_to_mask = torch .randint (1 , B * T // 2 , (1 ,)).item ()
@@ -259,13 +260,15 @@ def test_correctness(
259260
260261 # Forward pass with reference model
261262 loss1 , aux1 = torch_lm_head_grpo (
262- input1 , attention_mask , advantages , ref_input = ref_input , old_per_token_logps = old_per_token_logps
263+ input1 , selected_token_ids , attention_mask , advantages , ref_input = ref_input , old_per_token_logps = old_per_token_logps
263264 )
264265 loss2 , aux2 = liger_lm_head_grpo (
265- input2 , attention_mask , advantages , ref_input = ref_input , old_per_token_logps = old_per_token_logps
266+ input2 , selected_token_ids , attention_mask , advantages , ref_input = ref_input , old_per_token_logps = old_per_token_logps
266267 )
267268
268269 # Check losses match
270+ assert loss1 != float ('nan' )
271+ assert loss2 != float ('nan' )
269272 assert_verbose_allclose (loss1 , loss2 , atol = atol , rtol = rtol )
270273
271274 # Check metrics match
@@ -292,3 +295,100 @@ def test_correctness(
292295 atol = atol ,
293296 rtol = rtol ,
294297 )
298+
299+ @pytest .mark .parametrize (
300+ "B, T, H, V" ,
301+ [
302+ (8 , 128 , 1024 , 4096 ),
303+ (3 , 47 , 31 , 123 ), # random shape
304+ ],
305+ )
306+ @pytest .mark .parametrize (
307+ "scalar, dtype, atol, rtol" ,
308+ [
309+ (1.0 , torch .bfloat16 , 5e-2 , 5e-2 ),
310+ (1.0 , torch .float32 , 1e-4 , 5e-3 ),
311+ ],
312+ )
313+ @pytest .mark .parametrize ("bias" , [True , False ])
314+ @pytest .mark .parametrize ("ref_bias" , [True , False ])
315+ @pytest .mark .parametrize (
316+ "beta, epsilon_low, epsilon_high, temperature" ,
317+ [
318+ # Standard settings
319+ (0.1 , 0.2 , 0.2 , 20.0 ), # set temperature to 20.0 for better numerical stability
320+ (0.0 , 0.1 , 0.1 , 2.0 ),
321+ ]
322+ )
323+ @pytest .mark .parametrize ("use_ref_model" , [True , False ])
324+ @pytest .mark .parametrize ("old_per_token_logps" , [True , False ])
325+ def test_functional_correctness (
326+ B , T , H , V , scalar , dtype , atol , rtol , bias , ref_bias , beta , epsilon_low , epsilon_high , temperature , use_ref_model , old_per_token_logps
327+ ):
328+ _input = torch .randn (B , T , H , device = device , dtype = dtype ) * scalar
329+ input1 = _input .detach ().clone ().requires_grad_ (True )
330+ input2 = _input .detach ().clone ().requires_grad_ (True )
331+
332+ _weight = torch .randn (V , H , device = device , dtype = dtype ) * scalar
333+ weight1 = _weight .detach ().clone ().requires_grad_ (True )
334+ weight2 = _weight .detach ().clone ().requires_grad_ (True )
335+
336+ selected_token_ids = torch .randint (0 , V , (B , T ), device = device )
337+
338+ attention_mask = torch .ones (B , T , device = device )
339+
340+ advantages = torch .rand (B , device = device , dtype = dtype )
341+
342+ if bias :
343+ _bias = torch .randn (V , device = device , dtype = dtype ) * scalar
344+ bias1 = _bias .detach ().clone ().requires_grad_ (True )
345+ bias2 = _bias .detach ().clone ().requires_grad_ (True )
346+ else :
347+ bias1 = None
348+ bias2 = None
349+
350+ ref_input = torch .randn (B , T , H , device = device , dtype = dtype ) * scalar
351+
352+ _ref_weight = torch .randn (V , H , device = device , dtype = dtype ) * scalar
353+ ref_weight1 = _ref_weight .detach ().clone ().requires_grad_ (True )
354+ ref_weight2 = _ref_weight .detach ().clone ().requires_grad_ (True )
355+
356+ if ref_bias :
357+ _ref_bias = torch .randn (V , device = device , dtype = dtype ) * scalar
358+ ref_bias1 = _ref_bias .detach ().clone ().requires_grad_ (True )
359+ ref_bias2 = _ref_bias .detach ().clone ().requires_grad_ (True )
360+ else :
361+ ref_bias1 = None
362+ ref_bias2 = None
363+
364+ if old_per_token_logps :
365+ old_per_token_logps = torch .randn (B , T , device = device , dtype = dtype ) * scalar
366+ else :
367+ old_per_token_logps = None
368+
369+ loss1 , aux1 = liger_fused_linear_grpo (
370+ input1 , weight1 , selected_token_ids , attention_mask , advantages , bias1 , ref_input , ref_weight1 , ref_bias1 , old_per_token_logps , beta , epsilon_low , epsilon_high , temperature , True , use_ref_model , 1
371+ )
372+
373+ loss2 , aux2 = LigerFusedLinearGRPOFunction .apply (
374+ input2 , weight2 , selected_token_ids , attention_mask , advantages , bias2 , ref_input , ref_weight2 , ref_bias2 , old_per_token_logps , beta , epsilon_low , epsilon_high , temperature , True , use_ref_model , 1
375+ )
376+
377+ assert loss1 != float ('nan' )
378+ assert loss2 != float ('nan' )
379+ assert_verbose_allclose (loss1 , loss2 , atol = atol , rtol = rtol )
380+
381+ # Check metrics match
382+ assert len (aux1 ) == len (aux2 )
383+ for metric1 , metric2 in zip (aux1 , aux2 ):
384+ assert_verbose_allclose (metric1 , metric2 , atol = atol , rtol = rtol )
385+
386+ # Backward pass
387+ loss1 .backward ()
388+ loss2 .backward ()
389+
390+ # Check gradients match
391+ assert_verbose_allclose (input1 .grad , input2 .grad , atol = atol , rtol = rtol )
392+ assert_verbose_allclose (weight1 .grad , weight2 .grad , atol = atol , rtol = rtol )
393+ if bias :
394+ assert_verbose_allclose (bias1 .grad , bias2 .grad , atol = atol , rtol = rtol )
0 commit comments