Skip to content

Add vLLM importance sampling ratio support for GRPO loss#1088

Merged
Tcc0403 merged 5 commits intolinkedin:mainfrom
yukiu00:feat/add-vllm-is-ratio-grpo-loss
Feb 9, 2026
Merged

Add vLLM importance sampling ratio support for GRPO loss#1088
Tcc0403 merged 5 commits intolinkedin:mainfrom
yukiu00:feat/add-vllm-is-ratio-grpo-loss

Conversation

@yukiu00
Copy link
Contributor

@yukiu00 yukiu00 commented Feb 9, 2026

Summary

Fixes the primary cause (item 1) of #1082LigerFusedLinearGRPOLoss produces ~100x larger grad_norm than TRL's non-Liger path when using vLLM.

Root cause: TRL's GRPOTrainer applies per_token_loss *= importance_sampling_ratio (source) to correct for distribution mismatch from vLLM's rejection/stratified sampling. Liger-Kernel had no mechanism to accept or apply this correction, so the IS ratio was silently ignored, resulting in uncorrected (and much larger) gradients.

This is a high-priority fix — any user running GRPOTrainer with use_vllm=True and use_liger_kernel=True is affected, and the resulting ~100x gradient mismatch can cause training instability or divergence.

Changes

  • Add optional vllm_is_ratio parameter ([B, T] tensor or None) to both code paths:
    • Chunked loss path: LigerFusedLinearGRPOLoss, LigerFusedLinearGRPOFunction, ppo_loss_fn, and the base class LigerFusedLinearPPOBase chunking pipeline
    • Triton kernel path: triton_grpo_loss, GrpoLossFunction, and the Triton fwd/bwd kernels (_grpo_loss_fwd_kernel, _grpo_loss_bwd_kernel)
  • The IS correction is applied after PPO clipped loss computation and before KL penalty, matching TRL's behavior exactly
  • vllm_is_ratio=None (default) preserves existing behavior — no breaking changes
  • Works with all loss types: grpo, dapo, bnpo, dr_grpo, cispo, sapo

Verification

With IS_RATIO=0.01, the grad_norm ratio matches exactly:

Chunked loss path:
  grad_norm WITHOUT vllm_is_ratio: 1.052219e-01
  grad_norm WITH    vllm_is_ratio: 1.052219e-03
  ratio: 0.010000 ✓

Triton path:
  grad_norm WITHOUT vllm_is_ratio: 1.461673e-02
  grad_norm WITH    vllm_is_ratio: 1.461673e-04
  ratio: 0.010000 ✓

Test plan

  • Extended existing test_correctness in test/chunked_loss/test_grpo_loss.py with use_vllm_is_ratio parametrize — covers all 6 loss types × 2 IS levels × 2 beta values × with/without vllm_is_ratio
  • Added test_grpo_loss_with_vllm_is_ratio in test/transformers/test_grpo_loss.py — compares Triton output against PyTorch reference with IS correction, plus vllm_is_ratio=None == vllm_is_ratio=ones identity check
  • All existing tests continue to pass (no regressions)
  • make checkstyle passes

Related

When TRL's GRPOTrainer uses vLLM for generation, it applies an importance
sampling correction (`per_token_loss *= importance_sampling_ratio`) to account
for distribution mismatch. Liger-Kernel's GRPO loss had no mechanism to accept
this correction, causing ~100x larger grad_norm vs the non-Liger path.

Add an optional `vllm_is_ratio` parameter ([B, T] tensor or None) to both
the chunked loss (`LigerFusedLinearGRPOLoss`) and Triton kernel
(`triton_grpo_loss`) paths. The correction is applied after PPO clipped loss
computation and before KL penalty, matching TRL's behavior.
@yukiu00 yukiu00 mentioned this pull request Feb 9, 2026
3 tasks
Use modulo indexing in Triton kernels so vllm_is_ratio can be either
(B, L) per-token or (B, 1) per-sequence. Add test verifying (B, 1)
produces the same result as (B, L) with uniform values.
@yukiu00
Copy link
Contributor Author

yukiu00 commented Feb 9, 2026

@Tcc0403 Could you please review this PR? This fixes a critical issue (#1082) where LigerFusedLinearGRPOLoss produces ~100x larger grad_norm than TRL's non-Liger path when using vLLM. Since many users rely on GRPOTrainer with vLLM + Liger in production, this bug can silently cause training instability. An early review and merge would help unblock affected users.

- Add shape validation assertions for vllm_is_ratio in both chunked
  and Triton paths, accepting (B, T), (B, 1), and (B,) shapes
- Unsqueeze 1D (B,) to (B, 1) in chunked path for correct broadcasting
- Extend Triton vllm_is_ratio tests to cover all loss types (cispo, sapo)
- Add 1D vs 2D equivalence tests for both paths
@kashif
Copy link
Contributor

kashif commented Feb 9, 2026

thanks!

Remove use_vllm_is_ratio from test_correctness parametrize (halving
test count from 1152 to 576) and add dedicated
test_correctness_with_vllm_is_ratio that verifies both torch reference
correctness and 1D/2D shape equivalence with focused parameters.
yukiu00 added a commit to yukiu00/trl that referenced this pull request Feb 9, 2026
When using `use_vllm=True` with `use_liger_kernel=True`, the vLLM
importance sampling correction (`importance_sampling_ratio`) was
computed and stored in the inputs dict but never passed to the Liger
GRPO loss. This caused ~100x larger grad_norm compared to the
non-Liger path, leading to training instability.

Pass `inputs.get("importance_sampling_ratio")` as `vllm_is_ratio` to
`self.liger_grpo_loss(...)` in `compute_liger_loss`, matching the
correction already applied in the standard `_compute_loss` path.

Requires linkedin/Liger-Kernel#1088 which adds `vllm_is_ratio`
parameter support to `LigerFusedLinearGRPOLoss`.
yukiu00 added a commit to yukiu00/trl that referenced this pull request Feb 9, 2026
…rnel

Use is_liger_kernel_available(min_version="0.6.6") instead of
inspecting LigerFusedLinearGRPOLoss.forward signature. The version
will be updated once linkedin/Liger-Kernel#1088 is released.
Copy link
Collaborator

@Tcc0403 Tcc0403 left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

LGTM, thanks for these high quality PRs!

@Tcc0403 Tcc0403 merged commit cb8e408 into linkedin:main Feb 9, 2026
3 of 7 checks passed
@yukiu00
Copy link
Contributor Author

yukiu00 commented Feb 9, 2026

@Tcc0403 Thanks for the quick review and merge!

Could you let me know which version this change will be released in? I have a downstream PR in TRL (huggingface/trl#5031) that depends on this, and I need to set the correct minimum version for the version check.

@Tcc0403
Copy link
Collaborator

Tcc0403 commented Feb 9, 2026

@vaibhavjindal @Mecoli1219 @momochen Can we have a release for these fixes?

yukiu00 added a commit to yukiu00/trl that referenced this pull request Feb 12, 2026
liger-kernel v0.7.0 has been released with vllm_is_ratio support
(linkedin/Liger-Kernel#1088), so pin to this version as minimum and
remove the provisional xfail marker on the vllm_is_ratio test.
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

3 participants