Skip to content

Commit 148efed

Browse files
committed
add selected token ids
1 parent 9ef2cda commit 148efed

File tree

1 file changed

+20
-19
lines changed

1 file changed

+20
-19
lines changed

src/liger_kernel/chunked_loss/grpo_loss.py

Lines changed: 20 additions & 19 deletions
Original file line numberDiff line numberDiff line change
@@ -7,61 +7,57 @@ class LigerFusedLinearGRPOFunction(LigerFusedLinearRLHFBase):
77
@staticmethod
88
def rlhf_loss_fn(
99
log_probs,
10+
selected_token_ids,
1011
attention_mask,
1112
advantages,
1213
full_attention_mask,
1314
ref_log_probs=None,
1415
old_per_token_logps=None,
1516
epsilon_low=0.2,
1617
epsilon_high=0.2,
17-
beta=0.1,
18+
beta=0.04,
1819
**kwargs,
1920
):
2021
"""GRPO Loss Function matching GRPOTrainer implementation."""
21-
# Get chosen token probabilities
22-
chosen_tokens = log_probs.argmax(dim=-1) # (batch_size, seq_len)
23-
chosen_token_logprobs = log_probs.gather(dim=-1, index=chosen_tokens.unsqueeze(-1)).squeeze(
22+
23+
per_token_logps = log_probs.gather(dim=-1, index=selected_token_ids.unsqueeze(-1)).squeeze(
2424
-1
2525
) # (batch_size, seq_len)
2626

2727
# Get reference model probabilities
2828
if ref_log_probs is not None:
2929
with torch.no_grad():
30-
ref_token_logprobs = ref_log_probs.gather(dim=-1, index=chosen_tokens.unsqueeze(-1)).squeeze(-1)
30+
ref_per_token_logps = ref_log_probs.gather(dim=-1, index=selected_token_ids.unsqueeze(-1)).squeeze(-1)
3131
else:
32-
ref_token_logprobs = chosen_token_logprobs.detach()
32+
ref_per_token_logps = per_token_logps.detach()
3333

3434
# Compute policy gradient loss with importance sampling ratio
35-
old_per_token_logps = old_per_token_logps if old_per_token_logps is not None else chosen_token_logprobs.detach()
36-
coef_1 = torch.exp(chosen_token_logprobs - old_per_token_logps)
35+
old_per_token_logps = old_per_token_logps if old_per_token_logps is not None else per_token_logps.detach()
36+
coef_1 = torch.exp(per_token_logps - old_per_token_logps)
3737
coef_2 = torch.clamp(coef_1, 1 - epsilon_low, 1 + epsilon_high)
3838
per_token_loss1 = coef_1 * advantages.unsqueeze(1)
3939
per_token_loss2 = coef_2 * advantages.unsqueeze(1)
4040
per_token_loss = -torch.min(per_token_loss1, per_token_loss2)
4141
if beta != 0.0:
4242
# Compute KL penalty
4343
kl_div = (
44-
torch.exp(ref_token_logprobs - chosen_token_logprobs)
45-
- (ref_token_logprobs - chosen_token_logprobs)
46-
- 1.0
44+
torch.exp(ref_per_token_logps - per_token_logps) - (ref_per_token_logps - per_token_logps) - 1.0
4745
)
4846
# Combine losses
4947
per_token_loss = per_token_loss + beta * kl_div
50-
5148
# Apply mask and compute average loss
5249
loss = (per_token_loss * attention_mask).sum() / torch.clamp(full_attention_mask.sum(), min=1.0)
5350

5451
# Calculate metrics
5552
full_batch_size, seq_len = full_attention_mask.shape
5653
vocab_size = log_probs.shape[2]
5754
metrics = [
58-
chosen_token_logprobs.sum() / (full_batch_size * seq_len), # mean log prob
55+
per_token_logps.sum() / (full_batch_size * seq_len), # mean log prob
5956
log_probs.sum() / (full_batch_size * seq_len * vocab_size), # mean all log probs
6057
]
6158
if beta != 0.0:
6259
metrics.append(
63-
((kl_div * attention_mask).sum(dim=1) / torch.clamp(attention_mask.sum(dim=1), min=1.0)).sum()
64-
/ full_batch_size
60+
((kl_div * attention_mask).sum(dim=1) / torch.clamp(attention_mask.sum(dim=1), min=1.0)).sum() / full_batch_size
6561
)
6662
return loss, metrics
6763

@@ -71,14 +67,15 @@ def forward(
7167
ctx,
7268
_input,
7369
weight,
70+
selected_token_ids,
7471
attention_mask,
7572
advantages,
7673
bias=None,
7774
ref_input=None,
7875
ref_weight=None,
7976
ref_bias=None,
8077
old_per_token_logps=None,
81-
beta=0.1,
78+
beta=0.04,
8279
epsilon_low=0.2,
8380
epsilon_high=0.2,
8481
temperature=1.0,
@@ -91,6 +88,7 @@ def forward(
9188
Args:
9289
_input (torch.Tensor): Input tensor. Shape: (batch_size * seq_len, hidden_size)
9390
weight (torch.Tensor): Weight tensor. Shape: (vocab_size, hidden_size)
91+
selected_token_ids (torch.Tensor): Selected token ids tensor. Shape: (batch_size, seq_len)
9492
attention_mask (torch.Tensor): Attention mask tensor. Shape: (batch_size, seq_len)
9593
advantages (torch.Tensor): Advantages tensor. Shape: (batch_size,)
9694
bias (torch.Tensor, optional): Bias tensor. Shape: (vocab_size,)
@@ -110,6 +108,7 @@ def forward(
110108
ctx=ctx,
111109
_input=_input,
112110
weight=weight,
111+
selected_token_ids=selected_token_ids,
113112
attention_mask=attention_mask,
114113
advantages=advantages,
115114
bias=bias,
@@ -136,7 +135,7 @@ def backward(ctx, grad_output, *grad_metrics):
136135
"""
137136
grads = LigerFusedLinearRLHFBase.backward(ctx, grad_output)
138137
return (
139-
*grads[:5], # grad_input, grad_weight, grad_attention_mask, grad_advantages, grad_bias
138+
*grads[:6], # grad_input, grad_weight, grad_selected_token_ids, grad_attention_mask, grad_advantages, grad_bias
140139
None, # grad_ref_input
141140
None, # grad_ref_weight
142141
None, # grad_ref_bias
@@ -156,7 +155,7 @@ class LigerFusedLinearGRPOLoss(torch.nn.Module):
156155

157156
def __init__(
158157
self,
159-
beta: float = 0.1,
158+
beta: float = 0.04,
160159
compiled: bool = True,
161160
use_ref_model: bool = True,
162161
chunk_size: int = 1,
@@ -187,6 +186,7 @@ def forward(
187186
self,
188187
_input,
189188
lin_weight,
189+
selected_token_ids,
190190
attention_mask,
191191
advantages,
192192
bias=None,
@@ -198,16 +198,17 @@ def forward(
198198
return LigerFusedLinearGRPOFunction.apply(
199199
_input,
200200
lin_weight,
201+
selected_token_ids,
201202
attention_mask,
202203
advantages,
203204
bias,
204205
ref_input,
205206
ref_weight,
206207
ref_bias,
207208
old_per_token_logps,
209+
self.beta,
208210
self.epsilon_low,
209211
self.epsilon_high,
210-
self.beta,
211212
self.temperature,
212213
self.compiled,
213214
self.use_ref_model,

0 commit comments

Comments
 (0)