Add vLLM importance sampling ratio support for GRPO loss#1088
Add vLLM importance sampling ratio support for GRPO loss#1088Tcc0403 merged 5 commits intolinkedin:mainfrom
Conversation
When TRL's GRPOTrainer uses vLLM for generation, it applies an importance sampling correction (`per_token_loss *= importance_sampling_ratio`) to account for distribution mismatch. Liger-Kernel's GRPO loss had no mechanism to accept this correction, causing ~100x larger grad_norm vs the non-Liger path. Add an optional `vllm_is_ratio` parameter ([B, T] tensor or None) to both the chunked loss (`LigerFusedLinearGRPOLoss`) and Triton kernel (`triton_grpo_loss`) paths. The correction is applied after PPO clipped loss computation and before KL penalty, matching TRL's behavior.
Use modulo indexing in Triton kernels so vllm_is_ratio can be either (B, L) per-token or (B, 1) per-sequence. Add test verifying (B, 1) produces the same result as (B, L) with uniform values.
|
@Tcc0403 Could you please review this PR? This fixes a critical issue (#1082) where |
- Add shape validation assertions for vllm_is_ratio in both chunked and Triton paths, accepting (B, T), (B, 1), and (B,) shapes - Unsqueeze 1D (B,) to (B, 1) in chunked path for correct broadcasting - Extend Triton vllm_is_ratio tests to cover all loss types (cispo, sapo) - Add 1D vs 2D equivalence tests for both paths
|
thanks! |
Remove use_vllm_is_ratio from test_correctness parametrize (halving test count from 1152 to 576) and add dedicated test_correctness_with_vllm_is_ratio that verifies both torch reference correctness and 1D/2D shape equivalence with focused parameters.
When using `use_vllm=True` with `use_liger_kernel=True`, the vLLM
importance sampling correction (`importance_sampling_ratio`) was
computed and stored in the inputs dict but never passed to the Liger
GRPO loss. This caused ~100x larger grad_norm compared to the
non-Liger path, leading to training instability.
Pass `inputs.get("importance_sampling_ratio")` as `vllm_is_ratio` to
`self.liger_grpo_loss(...)` in `compute_liger_loss`, matching the
correction already applied in the standard `_compute_loss` path.
Requires linkedin/Liger-Kernel#1088 which adds `vllm_is_ratio`
parameter support to `LigerFusedLinearGRPOLoss`.
…rnel Use is_liger_kernel_available(min_version="0.6.6") instead of inspecting LigerFusedLinearGRPOLoss.forward signature. The version will be updated once linkedin/Liger-Kernel#1088 is released.
Tcc0403
left a comment
There was a problem hiding this comment.
LGTM, thanks for these high quality PRs!
|
@Tcc0403 Thanks for the quick review and merge! Could you let me know which version this change will be released in? I have a downstream PR in TRL (huggingface/trl#5031) that depends on this, and I need to set the correct minimum version for the version check. |
|
@vaibhavjindal @Mecoli1219 @momochen Can we have a release for these fixes? |
liger-kernel v0.7.0 has been released with vllm_is_ratio support (linkedin/Liger-Kernel#1088), so pin to this version as minimum and remove the provisional xfail marker on the vllm_is_ratio test.
Summary
Fixes the primary cause (item 1) of #1082 —
LigerFusedLinearGRPOLossproduces ~100x largergrad_normthan TRL's non-Liger path when using vLLM.Root cause: TRL's
GRPOTrainerappliesper_token_loss *= importance_sampling_ratio(source) to correct for distribution mismatch from vLLM's rejection/stratified sampling. Liger-Kernel had no mechanism to accept or apply this correction, so the IS ratio was silently ignored, resulting in uncorrected (and much larger) gradients.This is a high-priority fix — any user running
GRPOTrainerwithuse_vllm=Trueanduse_liger_kernel=Trueis affected, and the resulting ~100x gradient mismatch can cause training instability or divergence.Changes
vllm_is_ratioparameter ([B, T]tensor orNone) to both code paths:LigerFusedLinearGRPOLoss,LigerFusedLinearGRPOFunction,ppo_loss_fn, and the base classLigerFusedLinearPPOBasechunking pipelinetriton_grpo_loss,GrpoLossFunction, and the Triton fwd/bwd kernels (_grpo_loss_fwd_kernel,_grpo_loss_bwd_kernel)vllm_is_ratio=None(default) preserves existing behavior — no breaking changesgrpo,dapo,bnpo,dr_grpo,cispo,sapoVerification
With
IS_RATIO=0.01, thegrad_normratio matches exactly:Test plan
test_correctnessintest/chunked_loss/test_grpo_loss.pywithuse_vllm_is_ratioparametrize — covers all 6 loss types × 2 IS levels × 2 beta values × with/without vllm_is_ratiotest_grpo_loss_with_vllm_is_ratiointest/transformers/test_grpo_loss.py— compares Triton output against PyTorch reference with IS correction, plusvllm_is_ratio=None==vllm_is_ratio=onesidentity checkmake checkstylepassesRelated