Skip to content

Commit 10078c2

Browse files
committed
add selected token ids and functional tests
2 parents c061ed7 + 8731c54 commit 10078c2

File tree

3 files changed

+148
-39
lines changed

3 files changed

+148
-39
lines changed

src/liger_kernel/chunked_loss/functional.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -4,10 +4,12 @@
44
from liger_kernel.chunked_loss.kto_loss import LigerFusedLinearKTOFunction
55
from liger_kernel.chunked_loss.orpo_loss import LigerFusedLinearORPOFunction
66
from liger_kernel.chunked_loss.simpo_loss import LigerFusedLinearSimPOFunction
7+
from liger_kernel.chunked_loss.grpo_loss import LigerFusedLinearGRPOFunction
78

89
liger_fused_linear_orpo = LigerFusedLinearORPOFunction.apply
910
liger_fused_linear_dpo = LigerFusedLinearDPOFunction.apply
1011
liger_fused_linear_jsd = LigerFusedLinearJSDFunction.apply
1112
liger_fused_linear_cpo = LigerFusedLinearCPOFunction.apply
1213
liger_fused_linear_simpo = LigerFusedLinearSimPOFunction.apply
1314
liger_fused_linear_kto = LigerFusedLinearKTOFunction.apply
15+
liger_fused_linear_grpo = LigerFusedLinearGRPOFunction.apply

src/liger_kernel/chunked_loss/fused_linear_rlhf.py

Lines changed: 22 additions & 15 deletions
Original file line numberDiff line numberDiff line change
@@ -20,6 +20,7 @@ def forward(
2020
ctx,
2121
_input,
2222
weight,
23+
selected_token_ids,
2324
attention_mask,
2425
advantages,
2526
bias=None,
@@ -29,7 +30,7 @@ def forward(
2930
old_per_token_logps=None,
3031
epsilon_low=0.2,
3132
epsilon_high=0.2,
32-
beta=0.1,
33+
beta=0.04,
3334
temperature=1.0,
3435
compiled=True,
3536
use_ref_model=False,
@@ -42,6 +43,7 @@ def forward(
4243
ctx: Context for backward
4344
_input: Input tensor
4445
weight: Weight tensor
46+
selected_token_ids: Selected token ids tensor
4547
attention_mask: Attention mask tensor
4648
advantages: Advantages tensor
4749
bias: Bias tensor
@@ -78,22 +80,23 @@ def forward(
7880
rlhf_loss_fn=cls.rlhf_loss_fn,
7981
)
8082

81-
def fused_fwd_bwd(input_chunk, attention_mask_chunk, advantages_chunk, ref_input_chunk, old_per_token_logps_chunk):
83+
def fused_fwd_bwd(input_chunk, selected_token_ids_chunk, attention_mask_chunk, advantages_chunk, ref_input_chunk, old_per_token_logps_chunk):
8284
"""Fused forward and backward for a chunk."""
83-
argnums = (0, 1, 4) if bias is not None else (0, 1)
85+
argnums = (0, 1, 5) if bias is not None else (0, 1)
8486
return torch.func.grad_and_value(compute_loss, argnums=argnums, has_aux=True)(
8587
input_chunk, # arg 0
8688
weight, # arg 1
87-
attention_mask_chunk, # arg 2
88-
advantages_chunk, # arg 3
89-
bias, # arg 4
90-
ref_input_chunk=ref_input_chunk, # arg 5
91-
old_per_token_logps_chunk=old_per_token_logps_chunk, # arg 6
89+
selected_token_ids_chunk, # arg 2
90+
attention_mask_chunk, # arg 3
91+
advantages_chunk, # arg 4
92+
bias, # arg 5
93+
ref_input_chunk=ref_input_chunk, # arg 6
94+
old_per_token_logps_chunk=old_per_token_logps_chunk, # arg 7
9295
)
9396

94-
def accumulate_chunk(input_chunk, attention_mask_chunk, advantages_chunk, ref_input_chunk=None, old_per_token_logps_chunk=None):
97+
def accumulate_chunk(input_chunk, selected_token_ids_chunk, attention_mask_chunk, advantages_chunk, ref_input_chunk=None, old_per_token_logps_chunk=None):
9598
(chunk_grad_input, chunk_grad_weight, *chunk_grad_bias), (chunk_loss, chunk_metrics) = fused_fwd_bwd(
96-
input_chunk, attention_mask_chunk, advantages_chunk, ref_input_chunk, old_per_token_logps_chunk
99+
input_chunk, selected_token_ids_chunk, attention_mask_chunk, advantages_chunk, ref_input_chunk, old_per_token_logps_chunk
97100
)
98101
if bias is not None:
99102
grad_bias.add_(chunk_grad_bias[0])
@@ -102,7 +105,6 @@ def accumulate_chunk(input_chunk, attention_mask_chunk, advantages_chunk, ref_in
102105
grad_weight.add_(chunk_grad_weight)
103106
grad_inputs.append(chunk_grad_input)
104107
loss_acc.add_(chunk_loss)
105-
106108
# Initialize storage for metrics on first chunk
107109
if len(aggregated_metrics) == 0:
108110
for metric in chunk_metrics:
@@ -126,16 +128,18 @@ def accumulate_chunk(input_chunk, attention_mask_chunk, advantages_chunk, ref_in
126128
# Process input in chunks based on chunk_size
127129
chunks = max(1, _input.shape[0] // chunk_size)
128130
_input_chunks = torch.chunk(_input, chunks=chunks, dim=0)
131+
_selected_token_ids_chunks = torch.chunk(selected_token_ids, chunks=chunks, dim=0)
129132
_attention_mask_chunks = torch.chunk(attention_mask, chunks=chunks, dim=0)
130133
_advantages_chunks = torch.chunk(advantages, chunks=chunks, dim=0)
131134
_ref_input_chunks = torch.chunk(ref_input, chunks=chunks, dim=0) if use_ref_model else [None] * chunks
132135
_old_per_token_logps_chunks = torch.chunk(old_per_token_logps, chunks=chunks, dim=0) if old_per_token_logps is not None else [None] * chunks
133136

134-
for input_chunk, attention_mask_chunk, advantages_chunk, ref_input_chunk, old_per_token_logps_chunk in zip(
135-
_input_chunks, _attention_mask_chunks, _advantages_chunks, _ref_input_chunks, _old_per_token_logps_chunks
137+
for input_chunk, selected_token_ids_chunk, attention_mask_chunk, advantages_chunk, ref_input_chunk, old_per_token_logps_chunk in zip(
138+
_input_chunks, _selected_token_ids_chunks, _attention_mask_chunks, _advantages_chunks, _ref_input_chunks, _old_per_token_logps_chunks
136139
):
137140
# Mark dynamic dimensions
138141
torch._dynamo.mark_dynamic(input_chunk, 1)
142+
torch._dynamo.mark_dynamic(selected_token_ids_chunk, 1)
139143
torch._dynamo.mark_dynamic(attention_mask_chunk, 1)
140144
if use_ref_model:
141145
torch._dynamo.mark_dynamic(ref_input_chunk, 1)
@@ -146,7 +150,7 @@ def accumulate_chunk(input_chunk, attention_mask_chunk, advantages_chunk, ref_in
146150
else:
147151
old_per_token_logps_chunk = None
148152

149-
accumulate_chunk(input_chunk, attention_mask_chunk, advantages_chunk, ref_input_chunk, old_per_token_logps_chunk)
153+
accumulate_chunk(input_chunk, selected_token_ids_chunk, attention_mask_chunk, advantages_chunk, ref_input_chunk, old_per_token_logps_chunk)
150154

151155
# Combine gradients
152156
grad_input = torch.cat(grad_inputs, dim=0)
@@ -168,6 +172,7 @@ def accumulate_chunk(input_chunk, attention_mask_chunk, advantages_chunk, ref_in
168172
def _compute_chunk_loss(
169173
input_chunk,
170174
weight,
175+
selected_token_ids_chunk,
171176
attention_mask_chunk,
172177
advantages_chunk,
173178
bias=None,
@@ -178,7 +183,7 @@ def _compute_chunk_loss(
178183
full_attention_mask=None,
179184
epsilon_low=0.2,
180185
epsilon_high=0.2,
181-
beta=0.1,
186+
beta=0.04,
182187
temperature=1.0,
183188
use_ref_model=False,
184189
rlhf_loss_fn=None,
@@ -196,6 +201,7 @@ def _compute_chunk_loss(
196201
# Compute chunk loss and metrics using the provided loss function
197202
chunk_loss, chunk_metrics = rlhf_loss_fn(
198203
log_probs=log_probs,
204+
selected_token_ids=selected_token_ids_chunk,
199205
attention_mask=attention_mask_chunk,
200206
advantages=advantages_chunk,
201207
full_attention_mask=full_attention_mask,
@@ -236,6 +242,7 @@ def backward(ctx, grad_output, *grad_metrics):
236242
return (
237243
grad_input,
238244
grad_weight,
245+
None, # grad_selected_token_ids
239246
None, # grad_attention_mask
240247
None, # grad_advantages
241248
grad_bias,

test/chunked_loss/test_grpo_loss.py

Lines changed: 124 additions & 24 deletions
Original file line numberDiff line numberDiff line change
@@ -2,7 +2,9 @@
22
import torch
33
import torch.nn.functional as F
44

5+
from liger_kernel.chunked_loss import LigerFusedLinearGRPOLoss
56
from liger_kernel.chunked_loss.grpo_loss import LigerFusedLinearGRPOFunction
7+
from liger_kernel.chunked_loss.functional import liger_fused_linear_grpo
68
from liger_kernel.utils import infer_device
79
from test.utils import assert_verbose_allclose
810
from test.utils import set_seed
@@ -40,6 +42,7 @@ def __init__(
4042
def forward(
4143
self,
4244
x, # Shape: [batch_size, seq_len, hidden_size]
45+
selected_token_ids, # Shape: [batch_size, seq_len]
4346
attention_mask, # Shape: [batch_size, seq_len]
4447
advantages, # Shape: [batch_size,]
4548
ref_input=None, # Shape: [batch_size, seq_len, hidden_size]
@@ -54,8 +57,7 @@ def forward(
5457
log_probs = F.log_softmax(logits, dim=-1)
5558

5659
# Get chosen token probabilities
57-
chosen_tokens = log_probs.argmax(dim=-1)
58-
chosen_token_logprobs = log_probs.gather(dim=-1, index=chosen_tokens.unsqueeze(-1)).squeeze(-1)
60+
per_token_logps = log_probs.gather(dim=-1, index=selected_token_ids.unsqueeze(-1)).squeeze(-1)
5961

6062
# Get reference model probabilities
6163
if self.use_ref_model:
@@ -66,22 +68,21 @@ def forward(
6668
if self.temperature != 1.0:
6769
ref_logits = ref_logits / self.temperature
6870
ref_log_probs = F.log_softmax(ref_logits, dim=-1)
69-
ref_token_logprobs = ref_log_probs.gather(dim=-1, index=chosen_tokens.unsqueeze(-1)).squeeze(-1)
71+
ref_per_token_logps = ref_log_probs.gather(dim=-1, index=selected_token_ids.unsqueeze(-1)).squeeze(-1)
7072
else:
71-
ref_token_logprobs = chosen_token_logprobs.detach()
72-
73+
ref_per_token_logps = per_token_logps.detach()
7374

7475
# Compute policy gradient loss with importance sampling ratio
75-
old_per_token_logps = old_per_token_logps if old_per_token_logps is not None else chosen_token_logprobs.detach()
76-
coef_1 = torch.exp(chosen_token_logprobs - old_per_token_logps)
76+
old_per_token_logps = old_per_token_logps if old_per_token_logps is not None else per_token_logps.detach()
77+
coef_1 = torch.exp(per_token_logps - old_per_token_logps)
7778
coef_2 = torch.clamp(coef_1, 1 - self.epsilon_low, 1 + self.epsilon_high)
7879
per_token_loss1 = coef_1 * advantages.unsqueeze(1)
7980
per_token_loss2 = coef_2 * advantages.unsqueeze(1)
8081
per_token_loss = -torch.min(per_token_loss1, per_token_loss2)
8182
if self.beta != 0.0:
8283
# Compute KL divergence between model and reference model
8384
kl_div = (
84-
torch.exp(ref_token_logprobs - chosen_token_logprobs) - (ref_token_logprobs - chosen_token_logprobs) - 1.0
85+
torch.exp(ref_per_token_logps - per_token_logps) - (ref_per_token_logps - per_token_logps) - 1.0
8586
)
8687
per_token_loss = per_token_loss + self.beta * kl_div
8788

@@ -90,7 +91,7 @@ def forward(
9091

9192
# Compute metrics
9293
metrics = [
93-
chosen_token_logprobs.mean(),
94+
per_token_logps.mean(),
9495
log_probs.mean(),
9596
]
9697
if self.beta != 0.0:
@@ -118,16 +119,18 @@ def __init__(
118119
super().__init__()
119120
self.lin = torch.nn.Linear(in_features=H, out_features=V, bias=bias, dtype=dtype)
120121
self.ref_lin = torch.nn.Linear(in_features=H, out_features=V, bias=ref_bias, dtype=dtype)
121-
self.grpo_loss = LigerFusedLinearGRPOFunction.apply
122-
self.beta = beta
123-
self.epsilon_low = epsilon_low
124-
self.epsilon_high = epsilon_high
125-
self.temperature = temperature
126-
self.use_ref_model = use_ref_model
122+
self.grpo_loss = LigerFusedLinearGRPOLoss(
123+
beta=beta,
124+
epsilon_low=epsilon_low,
125+
epsilon_high=epsilon_high,
126+
temperature=temperature,
127+
use_ref_model=use_ref_model,
128+
)
127129

128130
def forward(
129131
self,
130132
x,
133+
selected_token_ids,
131134
attention_mask,
132135
advantages,
133136
ref_input=None,
@@ -137,19 +140,14 @@ def forward(
137140
return self.grpo_loss(
138141
x, # _input
139142
self.lin.weight, # weight
143+
selected_token_ids, # selected_token_ids
140144
attention_mask, # attention_mask
141145
advantages, # advantages
142146
self.lin.bias, # bias
143147
ref_input, # ref_input
144148
self.ref_lin.weight, # ref_weight
145149
self.ref_lin.bias, # ref_bias
146150
old_per_token_logps, # old_per_token_logps
147-
self.beta, # beta
148-
self.epsilon_low, # epsilon_low
149-
self.epsilon_high, # epsilon_high
150-
self.temperature, # temperature
151-
True, # compiled
152-
self.use_ref_model, # use_ref_model
153151
)
154152

155153

@@ -173,7 +171,7 @@ def forward(
173171
"beta, epsilon_low, epsilon_high, temperature",
174172
[
175173
# Standard settings
176-
(0.1, 0.2, 0.2, 1.0),
174+
(0.1, 0.2, 0.2, 20.0), # set temperature to 20.0 for better numerical stability
177175
(0.0, 0.1, 0.1, 2.0),
178176
]
179177
)
@@ -240,6 +238,9 @@ def test_correctness(
240238
input1 = _input.detach().clone().requires_grad_(True)
241239
input2 = _input.detach().clone().requires_grad_(True)
242240

241+
# Create selected token ids with shape [B, T]
242+
selected_token_ids = torch.randint(0, V, (B, T), device=device)
243+
243244
# Create attention mask with random padding [B, T]
244245
attention_mask = torch.ones(B, T, device=device)
245246
num_elements_to_mask = torch.randint(1, B * T // 2, (1,)).item()
@@ -259,13 +260,15 @@ def test_correctness(
259260

260261
# Forward pass with reference model
261262
loss1, aux1 = torch_lm_head_grpo(
262-
input1, attention_mask, advantages, ref_input=ref_input, old_per_token_logps=old_per_token_logps
263+
input1, selected_token_ids, attention_mask, advantages, ref_input=ref_input, old_per_token_logps=old_per_token_logps
263264
)
264265
loss2, aux2 = liger_lm_head_grpo(
265-
input2, attention_mask, advantages, ref_input=ref_input, old_per_token_logps=old_per_token_logps
266+
input2, selected_token_ids, attention_mask, advantages, ref_input=ref_input, old_per_token_logps=old_per_token_logps
266267
)
267268

268269
# Check losses match
270+
assert loss1 != float('nan')
271+
assert loss2 != float('nan')
269272
assert_verbose_allclose(loss1, loss2, atol=atol, rtol=rtol)
270273

271274
# Check metrics match
@@ -292,3 +295,100 @@ def test_correctness(
292295
atol=atol,
293296
rtol=rtol,
294297
)
298+
299+
@pytest.mark.parametrize(
300+
"B, T, H, V",
301+
[
302+
(8, 128, 1024, 4096),
303+
(3, 47, 31, 123), # random shape
304+
],
305+
)
306+
@pytest.mark.parametrize(
307+
"scalar, dtype, atol, rtol",
308+
[
309+
(1.0, torch.bfloat16, 5e-2, 5e-2),
310+
(1.0, torch.float32, 1e-4, 5e-3),
311+
],
312+
)
313+
@pytest.mark.parametrize("bias", [True, False])
314+
@pytest.mark.parametrize("ref_bias", [True, False])
315+
@pytest.mark.parametrize(
316+
"beta, epsilon_low, epsilon_high, temperature",
317+
[
318+
# Standard settings
319+
(0.1, 0.2, 0.2, 20.0), # set temperature to 20.0 for better numerical stability
320+
(0.0, 0.1, 0.1, 2.0),
321+
]
322+
)
323+
@pytest.mark.parametrize("use_ref_model", [True, False])
324+
@pytest.mark.parametrize("old_per_token_logps", [True, False])
325+
def test_functional_correctness(
326+
B, T, H, V, scalar, dtype, atol, rtol, bias, ref_bias, beta, epsilon_low, epsilon_high, temperature, use_ref_model, old_per_token_logps
327+
):
328+
_input = torch.randn(B, T, H, device=device, dtype=dtype) * scalar
329+
input1 = _input.detach().clone().requires_grad_(True)
330+
input2 = _input.detach().clone().requires_grad_(True)
331+
332+
_weight = torch.randn(V, H, device=device, dtype=dtype) * scalar
333+
weight1 = _weight.detach().clone().requires_grad_(True)
334+
weight2 = _weight.detach().clone().requires_grad_(True)
335+
336+
selected_token_ids = torch.randint(0, V, (B, T), device=device)
337+
338+
attention_mask = torch.ones(B, T, device=device)
339+
340+
advantages = torch.rand(B, device=device, dtype=dtype)
341+
342+
if bias:
343+
_bias = torch.randn(V, device=device, dtype=dtype) * scalar
344+
bias1 = _bias.detach().clone().requires_grad_(True)
345+
bias2 = _bias.detach().clone().requires_grad_(True)
346+
else:
347+
bias1 = None
348+
bias2 = None
349+
350+
ref_input = torch.randn(B, T, H, device=device, dtype=dtype) * scalar
351+
352+
_ref_weight = torch.randn(V, H, device=device, dtype=dtype) * scalar
353+
ref_weight1 = _ref_weight.detach().clone().requires_grad_(True)
354+
ref_weight2 = _ref_weight.detach().clone().requires_grad_(True)
355+
356+
if ref_bias:
357+
_ref_bias = torch.randn(V, device=device, dtype=dtype) * scalar
358+
ref_bias1 = _ref_bias.detach().clone().requires_grad_(True)
359+
ref_bias2 = _ref_bias.detach().clone().requires_grad_(True)
360+
else:
361+
ref_bias1 = None
362+
ref_bias2 = None
363+
364+
if old_per_token_logps:
365+
old_per_token_logps = torch.randn(B, T, device=device, dtype=dtype) * scalar
366+
else:
367+
old_per_token_logps = None
368+
369+
loss1, aux1 = liger_fused_linear_grpo(
370+
input1, weight1, selected_token_ids, attention_mask, advantages, bias1, ref_input, ref_weight1, ref_bias1, old_per_token_logps, beta, epsilon_low, epsilon_high, temperature, True, use_ref_model, 1
371+
)
372+
373+
loss2, aux2 = LigerFusedLinearGRPOFunction.apply(
374+
input2, weight2, selected_token_ids, attention_mask, advantages, bias2, ref_input, ref_weight2, ref_bias2, old_per_token_logps, beta, epsilon_low, epsilon_high, temperature, True, use_ref_model, 1
375+
)
376+
377+
assert loss1 != float('nan')
378+
assert loss2 != float('nan')
379+
assert_verbose_allclose(loss1, loss2, atol=atol, rtol=rtol)
380+
381+
# Check metrics match
382+
assert len(aux1) == len(aux2)
383+
for metric1, metric2 in zip(aux1, aux2):
384+
assert_verbose_allclose(metric1, metric2, atol=atol, rtol=rtol)
385+
386+
# Backward pass
387+
loss1.backward()
388+
loss2.backward()
389+
390+
# Check gradients match
391+
assert_verbose_allclose(input1.grad, input2.grad, atol=atol, rtol=rtol)
392+
assert_verbose_allclose(weight1.grad, weight2.grad, atol=atol, rtol=rtol)
393+
if bias:
394+
assert_verbose_allclose(bias1.grad, bias2.grad, atol=atol, rtol=rtol)

0 commit comments

Comments
 (0)