Skip to content

Commit cb8e408

Browse files
authored
Add vLLM importance sampling ratio support for GRPO loss (#1088)
## Summary Fixes the **primary cause** (item 1) of #1082 — `LigerFusedLinearGRPOLoss` produces ~100x larger `grad_norm` than TRL's non-Liger path when using vLLM. **Root cause:** TRL's `GRPOTrainer` applies `per_token_loss *= importance_sampling_ratio` ([source](https://github.com/huggingface/trl/blob/v0.27.2/trl/trainer/grpo_trainer.py#L2351-L2352)) to correct for distribution mismatch from vLLM's rejection/stratified sampling. Liger-Kernel had no mechanism to accept or apply this correction, so the IS ratio was silently ignored, resulting in uncorrected (and much larger) gradients. **This is a high-priority fix** — any user running `GRPOTrainer` with `use_vllm=True` and `use_liger_kernel=True` is affected, and the resulting ~100x gradient mismatch can cause training instability or divergence. ### Changes - Add optional `vllm_is_ratio` parameter (`[B, T]` tensor or `None`) to both code paths: - **Chunked loss path**: `LigerFusedLinearGRPOLoss`, `LigerFusedLinearGRPOFunction`, `ppo_loss_fn`, and the base class `LigerFusedLinearPPOBase` chunking pipeline - **Triton kernel path**: `triton_grpo_loss`, `GrpoLossFunction`, and the Triton fwd/bwd kernels (`_grpo_loss_fwd_kernel`, `_grpo_loss_bwd_kernel`) - The IS correction is applied **after** PPO clipped loss computation and **before** KL penalty, matching TRL's behavior exactly - `vllm_is_ratio=None` (default) preserves existing behavior — no breaking changes - Works with all loss types: `grpo`, `dapo`, `bnpo`, `dr_grpo`, `cispo`, `sapo` ### Verification With `IS_RATIO=0.01`, the `grad_norm` ratio matches exactly: ``` Chunked loss path: grad_norm WITHOUT vllm_is_ratio: 1.052219e-01 grad_norm WITH vllm_is_ratio: 1.052219e-03 ratio: 0.010000 ✓ Triton path: grad_norm WITHOUT vllm_is_ratio: 1.461673e-02 grad_norm WITH vllm_is_ratio: 1.461673e-04 ratio: 0.010000 ✓ ``` ## Test plan - [x] Extended existing `test_correctness` in `test/chunked_loss/test_grpo_loss.py` with `use_vllm_is_ratio` parametrize — covers all 6 loss types × 2 IS levels × 2 beta values × with/without vllm_is_ratio - [x] Added `test_grpo_loss_with_vllm_is_ratio` in `test/transformers/test_grpo_loss.py` — compares Triton output against PyTorch reference with IS correction, plus `vllm_is_ratio=None` == `vllm_is_ratio=ones` identity check - [x] All existing tests continue to pass (no regressions) - [x] `make checkstyle` passes ## Related - Reference implementation: #993 - Issue: #1082
1 parent cc14537 commit cb8e408

File tree

6 files changed

+436
-3
lines changed

6 files changed

+436
-3
lines changed

src/liger_kernel/chunked_loss/fused_linear_ppo.py

Lines changed: 32 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -41,6 +41,7 @@ def forward(
4141
chunk_size=1,
4242
sapo_temperature_pos=1.0,
4343
sapo_temperature_neg=1.05,
44+
vllm_is_ratio=None,
4445
):
4546
# TODO: check torch compile matmul
4647
"""Chunked forward pass for PPO loss computation.
@@ -71,6 +72,8 @@ def forward(
7172
chunk_size: Size of chunks for processing in other loss modules
7273
sapo_temperature_pos: Temperature for positive advantages in SAPO
7374
sapo_temperature_neg: Temperature for negative advantages in SAPO
75+
vllm_is_ratio: vLLM importance sampling ratio tensor (batch_size, seq_len) or (batch_size, 1) or None.
76+
Used to correct for distribution mismatch when using vLLM for generation.
7477
"""
7578
if use_ref_model:
7679
assert ref_per_token_logps is not None or ref_input is not None, (
@@ -80,6 +83,20 @@ def forward(
8083
raise Warning("Both ref_per_token_logps and ref_input are provided. Using ref_per_token_logps.")
8184
if loss_type == "dr_grpo":
8285
assert max_completion_length is not None, "max_completion_length must be provided for loss_type 'dr_grpo'"
86+
if vllm_is_ratio is not None:
87+
B, T = attention_mask.shape
88+
assert vllm_is_ratio.dim() in (1, 2), (
89+
f"vllm_is_ratio must be 1D (B,) or 2D (B, T) / (B, 1), got {vllm_is_ratio.dim()}D"
90+
)
91+
if vllm_is_ratio.dim() == 2:
92+
assert vllm_is_ratio.shape[0] == B and vllm_is_ratio.shape[1] in (1, T), (
93+
f"vllm_is_ratio shape must be ({B}, 1) or ({B}, {T}), got {tuple(vllm_is_ratio.shape)}"
94+
)
95+
else:
96+
assert vllm_is_ratio.shape[0] == B, (
97+
f"vllm_is_ratio shape must be ({B},), got {tuple(vllm_is_ratio.shape)}"
98+
)
99+
vllm_is_ratio = vllm_is_ratio.unsqueeze(-1) # (B,) -> (B, 1) for broadcasting
83100
# Initialize accumulators
84101
loss_acc = torch.zeros((), device=_input.device, dtype=torch.float32)
85102
grad_weight = torch.zeros_like(weight) # [V, H]
@@ -114,6 +131,7 @@ def fused_fwd_bwd(
114131
ref_per_token_logps_chunk,
115132
old_per_token_logps_chunk,
116133
ref_input_chunk,
134+
vllm_is_ratio_chunk,
117135
):
118136
"""Fused forward and backward for a chunk."""
119137
argnums = (0, 1, 5) if bias is not None else (0, 1)
@@ -127,6 +145,7 @@ def fused_fwd_bwd(
127145
ref_per_token_logps_chunk=ref_per_token_logps_chunk, # arg 6
128146
old_per_token_logps_chunk=old_per_token_logps_chunk, # arg 7
129147
ref_input_chunk=ref_input_chunk, # arg 8
148+
vllm_is_ratio_chunk=vllm_is_ratio_chunk, # arg 9
130149
)
131150

132151
def accumulate_chunk(
@@ -137,6 +156,7 @@ def accumulate_chunk(
137156
ref_per_token_logps_chunk=None,
138157
old_per_token_logps_chunk=None,
139158
ref_input_chunk=None,
159+
vllm_is_ratio_chunk=None,
140160
):
141161
(chunk_grad_input, chunk_grad_weight, *chunk_grad_bias), (chunk_loss, chunk_metrics) = fused_fwd_bwd(
142162
input_chunk,
@@ -146,6 +166,7 @@ def accumulate_chunk(
146166
ref_per_token_logps_chunk,
147167
old_per_token_logps_chunk,
148168
ref_input_chunk,
169+
vllm_is_ratio_chunk,
149170
)
150171
if bias is not None:
151172
grad_bias.add_(chunk_grad_bias[0])
@@ -196,6 +217,9 @@ def accumulate_chunk(
196217
if use_ref_model and ref_per_token_logps is None
197218
else [None] * chunks
198219
)
220+
_vllm_is_ratio_chunks = (
221+
torch.chunk(vllm_is_ratio, chunks=chunks, dim=0) if vllm_is_ratio is not None else [None] * chunks
222+
)
199223

200224
for (
201225
input_chunk,
@@ -205,6 +229,7 @@ def accumulate_chunk(
205229
ref_per_token_logps_chunk,
206230
old_per_token_logps_chunk,
207231
ref_input_chunk,
232+
vllm_is_ratio_chunk,
208233
) in zip(
209234
_input_chunks,
210235
_selected_token_ids_chunks,
@@ -213,6 +238,7 @@ def accumulate_chunk(
213238
_ref_per_token_logps_chunks,
214239
_old_per_token_logps_chunks,
215240
_ref_input_chunks,
241+
_vllm_is_ratio_chunks,
216242
):
217243
# Mark dynamic dimensions
218244
torch._dynamo.mark_dynamic(input_chunk, 1)
@@ -224,6 +250,8 @@ def accumulate_chunk(
224250
torch._dynamo.mark_dynamic(ref_input_chunk, 1)
225251
if old_per_token_logps_chunk is not None:
226252
torch._dynamo.mark_dynamic(old_per_token_logps_chunk, 1)
253+
if vllm_is_ratio_chunk is not None:
254+
torch._dynamo.mark_dynamic(vllm_is_ratio_chunk, 1)
227255

228256
accumulate_chunk(
229257
input_chunk,
@@ -233,6 +261,7 @@ def accumulate_chunk(
233261
ref_per_token_logps_chunk,
234262
old_per_token_logps_chunk,
235263
ref_input_chunk,
264+
vllm_is_ratio_chunk,
236265
)
237266

238267
# Combine gradients
@@ -277,6 +306,7 @@ def _compute_chunk_loss(
277306
ref_per_token_logps_chunk=None,
278307
old_per_token_logps_chunk=None,
279308
ref_input_chunk=None,
309+
vllm_is_ratio_chunk=None,
280310
ref_weight=None,
281311
ref_bias=None,
282312
full_attention_mask=None,
@@ -322,6 +352,7 @@ def _compute_chunk_loss(
322352
importance_sampling_level=importance_sampling_level,
323353
sapo_temperature_pos=sapo_temperature_pos,
324354
sapo_temperature_neg=sapo_temperature_neg,
355+
vllm_is_ratio=vllm_is_ratio_chunk,
325356
)
326357

327358
return chunk_loss, chunk_metrics
@@ -376,4 +407,5 @@ def backward(ctx, grad_output, *grad_metrics):
376407
None, # grad_chunk_size
377408
None, # grad_sapo_temperature_pos
378409
None, # grad_sapo_temperature_neg
410+
None, # grad_vllm_is_ratio
379411
)

src/liger_kernel/chunked_loss/grpo_loss.py

Lines changed: 12 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -75,6 +75,7 @@ def ppo_loss_fn(
7575
importance_sampling_level="token", # ["token", "sequence"] - new parameter for GSPO
7676
sapo_temperature_pos=1.0, # Temperature for positive advantages in SAPO
7777
sapo_temperature_neg=1.05, # Temperature for negative advantages in SAPO
78+
vllm_is_ratio=None, # vLLM importance sampling ratio (chunk_size, seq_len) or (chunk_size, 1) or None
7879
**kwargs,
7980
):
8081
"""GRPO Loss Function matching GRPOTrainer implementation."""
@@ -138,6 +139,10 @@ def ppo_loss_fn(
138139
per_token_loss2 = coef_2 * advantages.unsqueeze(1)
139140
per_token_loss = -torch.min(per_token_loss1, per_token_loss2)
140141

142+
# Apply vLLM importance sampling correction BEFORE adding KL penalty
143+
if vllm_is_ratio is not None:
144+
per_token_loss = per_token_loss * vllm_is_ratio
145+
141146
if beta != 0.0:
142147
# Compute KL penalty (approximates KL[per_token_logps, ref_per_token_logps])
143148
kl_div = k3_loss_fn(ref_per_token_logps, per_token_logps)
@@ -214,6 +219,7 @@ def forward(
214219
compiled=True,
215220
use_ref_model=True,
216221
chunk_size=1,
222+
vllm_is_ratio=None,
217223
):
218224
"""
219225
Fused linear layer with GRPO loss.
@@ -239,6 +245,8 @@ def forward(
239245
compiled (bool): Whether to use torch compile
240246
use_ref_model (bool): Whether to use a reference model
241247
chunk_size (int): Size of chunks for processing.
248+
vllm_is_ratio (torch.Tensor, optional): vLLM importance sampling ratio (batch_size, seq_len) or (batch_size, 1) or None.
249+
Used to correct for distribution mismatch when using vLLM for generation.
242250
Returns:
243251
torch.Tensor: Computed loss
244252
"""
@@ -268,6 +276,7 @@ def forward(
268276
importance_sampling_level=importance_sampling_level,
269277
sapo_temperature_pos=sapo_temperature_pos,
270278
sapo_temperature_neg=sapo_temperature_neg,
279+
vllm_is_ratio=vllm_is_ratio,
271280
)
272281

273282
@staticmethod
@@ -300,6 +309,7 @@ def backward(ctx, grad_output, *grad_metrics):
300309
None, # grad_compiled
301310
None, # grad_use_ref_model
302311
None, # grad_chunk_size
312+
None, # grad_vllm_is_ratio
303313
)
304314

305315

@@ -370,6 +380,7 @@ def forward(
370380
ref_input=None,
371381
ref_weight=None,
372382
ref_bias=None,
383+
vllm_is_ratio=None,
373384
):
374385
return LigerFusedLinearGRPOFunction.apply(
375386
_input,
@@ -395,4 +406,5 @@ def forward(
395406
self.compiled,
396407
self.use_ref_model,
397408
self.chunk_size,
409+
vllm_is_ratio,
398410
)

src/liger_kernel/ops/grpo_loss.py

Lines changed: 51 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -90,6 +90,8 @@ def _grpo_loss_fwd_kernel(
9090
INPUT_IDS,
9191
COMPLETION_MASK,
9292
ADVANTAGES,
93+
VLLM_IS_RATIO,
94+
VLLM_IS_RATIO_STRIDE,
9395
LOSS,
9496
LSE,
9597
KL,
@@ -169,6 +171,14 @@ def _grpo_loss_fwd_kernel(
169171
per_token_loss = -sapo_coef * advantage
170172
is_clipped = 0.0 # SAPO has no clipping concept
171173

174+
# Apply vLLM importance sampling correction BEFORE adding KL penalty
175+
if VLLM_IS_RATIO is not None:
176+
# Use modulo to support both (B, L) per-token and (B, 1) per-sequence shapes
177+
vllm_is_ratio = tl.load(VLLM_IS_RATIO + off_b * VLLM_IS_RATIO_STRIDE + off_l % VLLM_IS_RATIO_STRIDE).to(
178+
tl.float32
179+
)
180+
per_token_loss = per_token_loss * vllm_is_ratio
181+
172182
if BETA != 0.0:
173183
REF_LOGP += off_b * L + off_l
174184
KL += off_b * L + off_l
@@ -198,6 +208,8 @@ def _grpo_loss_bwd_kernel(
198208
ADVANTAGES,
199209
COMPLETION_MASK,
200210
LSE,
211+
VLLM_IS_RATIO,
212+
VLLM_IS_RATIO_STRIDE,
201213
TEMPERATURE,
202214
BETA: tl.constexpr,
203215
EPS_LOW,
@@ -271,6 +283,14 @@ def _grpo_loss_bwd_kernel(
271283
d_sapo_d_coef1 = 4.0 * sigmoid_val * (1.0 - sigmoid_val)
272284
dlogp = -advantage * d_sapo_d_coef1 * coef_1
273285

286+
# Apply vLLM IS ratio to PPO gradient (before KL gradient)
287+
if VLLM_IS_RATIO is not None:
288+
# Use modulo to support both (B, L) per-token and (B, 1) per-sequence shapes
289+
vllm_is_ratio = tl.load(VLLM_IS_RATIO + off_b * VLLM_IS_RATIO_STRIDE + off_l % VLLM_IS_RATIO_STRIDE).to(
290+
tl.float32
291+
)
292+
dlogp = dlogp * vllm_is_ratio
293+
274294
if BETA != 0.0:
275295
REF_LOGP += off_b * L + off_l
276296
ref_logp = tl.load(REF_LOGP).to(tl.float32)
@@ -304,6 +324,7 @@ def forward(
304324
loss_type="grpo",
305325
sapo_temperature_pos=1.0,
306326
sapo_temperature_neg=1.05,
327+
vllm_is_ratio=None,
307328
):
308329
assert logits.is_contiguous() and completion_ids.is_contiguous()
309330
assert old_logp is None or old_logp.is_contiguous()
@@ -329,6 +350,25 @@ def forward(
329350
if completion_mask is not None:
330351
assert completion_mask.is_contiguous()
331352

353+
# Handle vLLM IS ratio
354+
vllm_is_ratio_ptr = None
355+
vllm_is_ratio_stride = L # default to per-token (unused when ptr is None)
356+
if vllm_is_ratio is not None:
357+
assert vllm_is_ratio.dim() in (1, 2), (
358+
f"vllm_is_ratio must be 1D (B,) or 2D (B, L) / (B, 1), got {vllm_is_ratio.dim()}D"
359+
)
360+
if vllm_is_ratio.dim() == 2:
361+
assert vllm_is_ratio.shape[0] == B and vllm_is_ratio.shape[1] in (1, L), (
362+
f"vllm_is_ratio shape must be ({B}, 1) or ({B}, {L}), got {tuple(vllm_is_ratio.shape)}"
363+
)
364+
else:
365+
assert vllm_is_ratio.shape[0] == B, (
366+
f"vllm_is_ratio shape must be ({B},), got {tuple(vllm_is_ratio.shape)}"
367+
)
368+
vllm_is_ratio = vllm_is_ratio.contiguous()
369+
vllm_is_ratio_ptr = vllm_is_ratio
370+
vllm_is_ratio_stride = vllm_is_ratio.shape[1] if vllm_is_ratio.dim() > 1 else 1
371+
332372
loss = torch.zeros(B, L, device=logits.device, dtype=torch.float32)
333373
lse = torch.zeros_like(loss)
334374
is_clipped = torch.zeros_like(loss)
@@ -341,6 +381,8 @@ def forward(
341381
completion_ids,
342382
completion_mask,
343383
advantages,
384+
vllm_is_ratio_ptr,
385+
vllm_is_ratio_stride,
344386
loss,
345387
lse,
346388
kl,
@@ -357,6 +399,8 @@ def forward(
357399
**kwargs,
358400
)
359401
ctx.save_for_backward(logits, old_logp, ref_logp, completion_ids, advantages, completion_mask, lse)
402+
ctx.vllm_is_ratio = vllm_is_ratio_ptr
403+
ctx.vllm_is_ratio_stride = vllm_is_ratio_stride
360404
ctx.infos = (
361405
temperature,
362406
beta,
@@ -376,6 +420,8 @@ def backward(ctx, *args):
376420
temperature, beta, eps_low, eps_high, inplace, loss_type_int, sapo_temperature_pos, sapo_temperature_neg = (
377421
ctx.infos
378422
)
423+
vllm_is_ratio = ctx.vllm_is_ratio
424+
vllm_is_ratio_stride = ctx.vllm_is_ratio_stride
379425
B, L_ADD_1, N = logits.shape
380426
L = L_ADD_1 - 1
381427
dlogits = logits.data if inplace else torch.empty_like(logits)
@@ -390,6 +436,8 @@ def backward(ctx, *args):
390436
advantages,
391437
completion_mask,
392438
lse,
439+
vllm_is_ratio,
440+
vllm_is_ratio_stride,
393441
temperature,
394442
beta,
395443
eps_low,
@@ -404,5 +452,6 @@ def backward(ctx, *args):
404452
)
405453
dlogits[:, -1, :] = 0
406454
# Return None for: old_logp, ref_logp, completion_ids, advantages, completion_mask,
407-
# temperature, beta, eps_low, eps_high, inplace, loss_type, sapo_temperature_pos, sapo_temperature_neg
408-
return dlogits, None, None, None, None, None, None, None, None, None, None, None, None, None
455+
# temperature, beta, eps_low, eps_high, inplace, loss_type, sapo_temperature_pos, sapo_temperature_neg,
456+
# vllm_is_ratio
457+
return dlogits, None, None, None, None, None, None, None, None, None, None, None, None, None, None

src/liger_kernel/transformers/grpo_loss.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -22,6 +22,7 @@ def triton_grpo_loss(
2222
reduce=False,
2323
sapo_temperature_pos=1.0,
2424
sapo_temperature_neg=1.05,
25+
vllm_is_ratio=None,
2526
):
2627
assert logits is not None and completion_ids is not None and advantages is not None, (
2728
"must provide logits, completion_ids and advantages"
@@ -46,6 +47,7 @@ def triton_grpo_loss(
4647
loss_type,
4748
sapo_temperature_pos,
4849
sapo_temperature_neg,
50+
vllm_is_ratio,
4951
)
5052
if not reduce:
5153
return per_token_loss, per_token_kl, is_clipped

0 commit comments

Comments
 (0)