@@ -7,61 +7,57 @@ class LigerFusedLinearGRPOFunction(LigerFusedLinearRLHFBase):
77 @staticmethod
88 def rlhf_loss_fn (
99 log_probs ,
10+ selected_token_ids ,
1011 attention_mask ,
1112 advantages ,
1213 full_attention_mask ,
1314 ref_log_probs = None ,
1415 old_per_token_logps = None ,
1516 epsilon_low = 0.2 ,
1617 epsilon_high = 0.2 ,
17- beta = 0.1 ,
18+ beta = 0.04 ,
1819 ** kwargs ,
1920 ):
2021 """GRPO Loss Function matching GRPOTrainer implementation."""
21- # Get chosen token probabilities
22- chosen_tokens = log_probs .argmax (dim = - 1 ) # (batch_size, seq_len)
23- chosen_token_logprobs = log_probs .gather (dim = - 1 , index = chosen_tokens .unsqueeze (- 1 )).squeeze (
22+
23+ per_token_logps = log_probs .gather (dim = - 1 , index = selected_token_ids .unsqueeze (- 1 )).squeeze (
2424 - 1
2525 ) # (batch_size, seq_len)
2626
2727 # Get reference model probabilities
2828 if ref_log_probs is not None :
2929 with torch .no_grad ():
30- ref_token_logprobs = ref_log_probs .gather (dim = - 1 , index = chosen_tokens .unsqueeze (- 1 )).squeeze (- 1 )
30+ ref_per_token_logps = ref_log_probs .gather (dim = - 1 , index = selected_token_ids .unsqueeze (- 1 )).squeeze (- 1 )
3131 else :
32- ref_token_logprobs = chosen_token_logprobs .detach ()
32+ ref_per_token_logps = per_token_logps .detach ()
3333
3434 # Compute policy gradient loss with importance sampling ratio
35- old_per_token_logps = old_per_token_logps if old_per_token_logps is not None else chosen_token_logprobs .detach ()
36- coef_1 = torch .exp (chosen_token_logprobs - old_per_token_logps )
35+ old_per_token_logps = old_per_token_logps if old_per_token_logps is not None else per_token_logps .detach ()
36+ coef_1 = torch .exp (per_token_logps - old_per_token_logps )
3737 coef_2 = torch .clamp (coef_1 , 1 - epsilon_low , 1 + epsilon_high )
3838 per_token_loss1 = coef_1 * advantages .unsqueeze (1 )
3939 per_token_loss2 = coef_2 * advantages .unsqueeze (1 )
4040 per_token_loss = - torch .min (per_token_loss1 , per_token_loss2 )
4141 if beta != 0.0 :
4242 # Compute KL penalty
4343 kl_div = (
44- torch .exp (ref_token_logprobs - chosen_token_logprobs )
45- - (ref_token_logprobs - chosen_token_logprobs )
46- - 1.0
44+ torch .exp (ref_per_token_logps - per_token_logps ) - (ref_per_token_logps - per_token_logps ) - 1.0
4745 )
4846 # Combine losses
4947 per_token_loss = per_token_loss + beta * kl_div
50-
5148 # Apply mask and compute average loss
5249 loss = (per_token_loss * attention_mask ).sum () / torch .clamp (full_attention_mask .sum (), min = 1.0 )
5350
5451 # Calculate metrics
5552 full_batch_size , seq_len = full_attention_mask .shape
5653 vocab_size = log_probs .shape [2 ]
5754 metrics = [
58- chosen_token_logprobs .sum () / (full_batch_size * seq_len ), # mean log prob
55+ per_token_logps .sum () / (full_batch_size * seq_len ), # mean log prob
5956 log_probs .sum () / (full_batch_size * seq_len * vocab_size ), # mean all log probs
6057 ]
6158 if beta != 0.0 :
6259 metrics .append (
63- ((kl_div * attention_mask ).sum (dim = 1 ) / torch .clamp (attention_mask .sum (dim = 1 ), min = 1.0 )).sum ()
64- / full_batch_size
60+ ((kl_div * attention_mask ).sum (dim = 1 ) / torch .clamp (attention_mask .sum (dim = 1 ), min = 1.0 )).sum () / full_batch_size
6561 )
6662 return loss , metrics
6763
@@ -71,14 +67,15 @@ def forward(
7167 ctx ,
7268 _input ,
7369 weight ,
70+ selected_token_ids ,
7471 attention_mask ,
7572 advantages ,
7673 bias = None ,
7774 ref_input = None ,
7875 ref_weight = None ,
7976 ref_bias = None ,
8077 old_per_token_logps = None ,
81- beta = 0.1 ,
78+ beta = 0.04 ,
8279 epsilon_low = 0.2 ,
8380 epsilon_high = 0.2 ,
8481 temperature = 1.0 ,
@@ -91,6 +88,7 @@ def forward(
9188 Args:
9289 _input (torch.Tensor): Input tensor. Shape: (batch_size * seq_len, hidden_size)
9390 weight (torch.Tensor): Weight tensor. Shape: (vocab_size, hidden_size)
91+ selected_token_ids (torch.Tensor): Selected token ids tensor. Shape: (batch_size, seq_len)
9492 attention_mask (torch.Tensor): Attention mask tensor. Shape: (batch_size, seq_len)
9593 advantages (torch.Tensor): Advantages tensor. Shape: (batch_size,)
9694 bias (torch.Tensor, optional): Bias tensor. Shape: (vocab_size,)
@@ -110,6 +108,7 @@ def forward(
110108 ctx = ctx ,
111109 _input = _input ,
112110 weight = weight ,
111+ selected_token_ids = selected_token_ids ,
113112 attention_mask = attention_mask ,
114113 advantages = advantages ,
115114 bias = bias ,
@@ -136,7 +135,7 @@ def backward(ctx, grad_output, *grad_metrics):
136135 """
137136 grads = LigerFusedLinearRLHFBase .backward (ctx , grad_output )
138137 return (
139- * grads [:5 ], # grad_input, grad_weight, grad_attention_mask, grad_advantages, grad_bias
138+ * grads [:6 ], # grad_input, grad_weight, grad_selected_token_ids , grad_attention_mask, grad_advantages, grad_bias
140139 None , # grad_ref_input
141140 None , # grad_ref_weight
142141 None , # grad_ref_bias
@@ -156,7 +155,7 @@ class LigerFusedLinearGRPOLoss(torch.nn.Module):
156155
157156 def __init__ (
158157 self ,
159- beta : float = 0.1 ,
158+ beta : float = 0.04 ,
160159 compiled : bool = True ,
161160 use_ref_model : bool = True ,
162161 chunk_size : int = 1 ,
@@ -187,6 +186,7 @@ def forward(
187186 self ,
188187 _input ,
189188 lin_weight ,
189+ selected_token_ids ,
190190 attention_mask ,
191191 advantages ,
192192 bias = None ,
@@ -198,16 +198,17 @@ def forward(
198198 return LigerFusedLinearGRPOFunction .apply (
199199 _input ,
200200 lin_weight ,
201+ selected_token_ids ,
201202 attention_mask ,
202203 advantages ,
203204 bias ,
204205 ref_input ,
205206 ref_weight ,
206207 ref_bias ,
207208 old_per_token_logps ,
209+ self .beta ,
208210 self .epsilon_low ,
209211 self .epsilon_high ,
210- self .beta ,
211212 self .temperature ,
212213 self .compiled ,
213214 self .use_ref_model ,
0 commit comments