Commit cb8e408
authored
Add vLLM importance sampling ratio support for GRPO loss (#1088)
## Summary
Fixes the **primary cause** (item 1) of #1082 —
`LigerFusedLinearGRPOLoss` produces ~100x larger `grad_norm` than TRL's
non-Liger path when using vLLM.
**Root cause:** TRL's `GRPOTrainer` applies `per_token_loss *=
importance_sampling_ratio`
([source](https://github.com/huggingface/trl/blob/v0.27.2/trl/trainer/grpo_trainer.py#L2351-L2352))
to correct for distribution mismatch from vLLM's rejection/stratified
sampling. Liger-Kernel had no mechanism to accept or apply this
correction, so the IS ratio was silently ignored, resulting in
uncorrected (and much larger) gradients.
**This is a high-priority fix** — any user running `GRPOTrainer` with
`use_vllm=True` and `use_liger_kernel=True` is affected, and the
resulting ~100x gradient mismatch can cause training instability or
divergence.
### Changes
- Add optional `vllm_is_ratio` parameter (`[B, T]` tensor or `None`) to
both code paths:
- **Chunked loss path**: `LigerFusedLinearGRPOLoss`,
`LigerFusedLinearGRPOFunction`, `ppo_loss_fn`, and the base class
`LigerFusedLinearPPOBase` chunking pipeline
- **Triton kernel path**: `triton_grpo_loss`, `GrpoLossFunction`, and
the Triton fwd/bwd kernels (`_grpo_loss_fwd_kernel`,
`_grpo_loss_bwd_kernel`)
- The IS correction is applied **after** PPO clipped loss computation
and **before** KL penalty, matching TRL's behavior exactly
- `vllm_is_ratio=None` (default) preserves existing behavior — no
breaking changes
- Works with all loss types: `grpo`, `dapo`, `bnpo`, `dr_grpo`, `cispo`,
`sapo`
### Verification
With `IS_RATIO=0.01`, the `grad_norm` ratio matches exactly:
```
Chunked loss path:
grad_norm WITHOUT vllm_is_ratio: 1.052219e-01
grad_norm WITH vllm_is_ratio: 1.052219e-03
ratio: 0.010000 ✓
Triton path:
grad_norm WITHOUT vllm_is_ratio: 1.461673e-02
grad_norm WITH vllm_is_ratio: 1.461673e-04
ratio: 0.010000 ✓
```
## Test plan
- [x] Extended existing `test_correctness` in
`test/chunked_loss/test_grpo_loss.py` with `use_vllm_is_ratio`
parametrize — covers all 6 loss types × 2 IS levels × 2 beta values ×
with/without vllm_is_ratio
- [x] Added `test_grpo_loss_with_vllm_is_ratio` in
`test/transformers/test_grpo_loss.py` — compares Triton output against
PyTorch reference with IS correction, plus `vllm_is_ratio=None` ==
`vllm_is_ratio=ones` identity check
- [x] All existing tests continue to pass (no regressions)
- [x] `make checkstyle` passes
## Related
- Reference implementation: #993
- Issue: #10821 parent cc14537 commit cb8e408
File tree
6 files changed
+436
-3
lines changed- src/liger_kernel
- chunked_loss
- ops
- transformers
- test
- chunked_loss
- transformers
6 files changed
+436
-3
lines changed| Original file line number | Diff line number | Diff line change | |
|---|---|---|---|
| |||
41 | 41 | | |
42 | 42 | | |
43 | 43 | | |
| 44 | + | |
44 | 45 | | |
45 | 46 | | |
46 | 47 | | |
| |||
71 | 72 | | |
72 | 73 | | |
73 | 74 | | |
| 75 | + | |
| 76 | + | |
74 | 77 | | |
75 | 78 | | |
76 | 79 | | |
| |||
80 | 83 | | |
81 | 84 | | |
82 | 85 | | |
| 86 | + | |
| 87 | + | |
| 88 | + | |
| 89 | + | |
| 90 | + | |
| 91 | + | |
| 92 | + | |
| 93 | + | |
| 94 | + | |
| 95 | + | |
| 96 | + | |
| 97 | + | |
| 98 | + | |
| 99 | + | |
83 | 100 | | |
84 | 101 | | |
85 | 102 | | |
| |||
114 | 131 | | |
115 | 132 | | |
116 | 133 | | |
| 134 | + | |
117 | 135 | | |
118 | 136 | | |
119 | 137 | | |
| |||
127 | 145 | | |
128 | 146 | | |
129 | 147 | | |
| 148 | + | |
130 | 149 | | |
131 | 150 | | |
132 | 151 | | |
| |||
137 | 156 | | |
138 | 157 | | |
139 | 158 | | |
| 159 | + | |
140 | 160 | | |
141 | 161 | | |
142 | 162 | | |
| |||
146 | 166 | | |
147 | 167 | | |
148 | 168 | | |
| 169 | + | |
149 | 170 | | |
150 | 171 | | |
151 | 172 | | |
| |||
196 | 217 | | |
197 | 218 | | |
198 | 219 | | |
| 220 | + | |
| 221 | + | |
| 222 | + | |
199 | 223 | | |
200 | 224 | | |
201 | 225 | | |
| |||
205 | 229 | | |
206 | 230 | | |
207 | 231 | | |
| 232 | + | |
208 | 233 | | |
209 | 234 | | |
210 | 235 | | |
| |||
213 | 238 | | |
214 | 239 | | |
215 | 240 | | |
| 241 | + | |
216 | 242 | | |
217 | 243 | | |
218 | 244 | | |
| |||
224 | 250 | | |
225 | 251 | | |
226 | 252 | | |
| 253 | + | |
| 254 | + | |
227 | 255 | | |
228 | 256 | | |
229 | 257 | | |
| |||
233 | 261 | | |
234 | 262 | | |
235 | 263 | | |
| 264 | + | |
236 | 265 | | |
237 | 266 | | |
238 | 267 | | |
| |||
277 | 306 | | |
278 | 307 | | |
279 | 308 | | |
| 309 | + | |
280 | 310 | | |
281 | 311 | | |
282 | 312 | | |
| |||
322 | 352 | | |
323 | 353 | | |
324 | 354 | | |
| 355 | + | |
325 | 356 | | |
326 | 357 | | |
327 | 358 | | |
| |||
376 | 407 | | |
377 | 408 | | |
378 | 409 | | |
| 410 | + | |
379 | 411 | | |
| Original file line number | Diff line number | Diff line change | |
|---|---|---|---|
| |||
75 | 75 | | |
76 | 76 | | |
77 | 77 | | |
| 78 | + | |
78 | 79 | | |
79 | 80 | | |
80 | 81 | | |
| |||
138 | 139 | | |
139 | 140 | | |
140 | 141 | | |
| 142 | + | |
| 143 | + | |
| 144 | + | |
| 145 | + | |
141 | 146 | | |
142 | 147 | | |
143 | 148 | | |
| |||
214 | 219 | | |
215 | 220 | | |
216 | 221 | | |
| 222 | + | |
217 | 223 | | |
218 | 224 | | |
219 | 225 | | |
| |||
239 | 245 | | |
240 | 246 | | |
241 | 247 | | |
| 248 | + | |
| 249 | + | |
242 | 250 | | |
243 | 251 | | |
244 | 252 | | |
| |||
268 | 276 | | |
269 | 277 | | |
270 | 278 | | |
| 279 | + | |
271 | 280 | | |
272 | 281 | | |
273 | 282 | | |
| |||
300 | 309 | | |
301 | 310 | | |
302 | 311 | | |
| 312 | + | |
303 | 313 | | |
304 | 314 | | |
305 | 315 | | |
| |||
370 | 380 | | |
371 | 381 | | |
372 | 382 | | |
| 383 | + | |
373 | 384 | | |
374 | 385 | | |
375 | 386 | | |
| |||
395 | 406 | | |
396 | 407 | | |
397 | 408 | | |
| 409 | + | |
398 | 410 | | |
| Original file line number | Diff line number | Diff line change | |
|---|---|---|---|
| |||
90 | 90 | | |
91 | 91 | | |
92 | 92 | | |
| 93 | + | |
| 94 | + | |
93 | 95 | | |
94 | 96 | | |
95 | 97 | | |
| |||
169 | 171 | | |
170 | 172 | | |
171 | 173 | | |
| 174 | + | |
| 175 | + | |
| 176 | + | |
| 177 | + | |
| 178 | + | |
| 179 | + | |
| 180 | + | |
| 181 | + | |
172 | 182 | | |
173 | 183 | | |
174 | 184 | | |
| |||
198 | 208 | | |
199 | 209 | | |
200 | 210 | | |
| 211 | + | |
| 212 | + | |
201 | 213 | | |
202 | 214 | | |
203 | 215 | | |
| |||
271 | 283 | | |
272 | 284 | | |
273 | 285 | | |
| 286 | + | |
| 287 | + | |
| 288 | + | |
| 289 | + | |
| 290 | + | |
| 291 | + | |
| 292 | + | |
| 293 | + | |
274 | 294 | | |
275 | 295 | | |
276 | 296 | | |
| |||
304 | 324 | | |
305 | 325 | | |
306 | 326 | | |
| 327 | + | |
307 | 328 | | |
308 | 329 | | |
309 | 330 | | |
| |||
329 | 350 | | |
330 | 351 | | |
331 | 352 | | |
| 353 | + | |
| 354 | + | |
| 355 | + | |
| 356 | + | |
| 357 | + | |
| 358 | + | |
| 359 | + | |
| 360 | + | |
| 361 | + | |
| 362 | + | |
| 363 | + | |
| 364 | + | |
| 365 | + | |
| 366 | + | |
| 367 | + | |
| 368 | + | |
| 369 | + | |
| 370 | + | |
| 371 | + | |
332 | 372 | | |
333 | 373 | | |
334 | 374 | | |
| |||
341 | 381 | | |
342 | 382 | | |
343 | 383 | | |
| 384 | + | |
| 385 | + | |
344 | 386 | | |
345 | 387 | | |
346 | 388 | | |
| |||
357 | 399 | | |
358 | 400 | | |
359 | 401 | | |
| 402 | + | |
| 403 | + | |
360 | 404 | | |
361 | 405 | | |
362 | 406 | | |
| |||
376 | 420 | | |
377 | 421 | | |
378 | 422 | | |
| 423 | + | |
| 424 | + | |
379 | 425 | | |
380 | 426 | | |
381 | 427 | | |
| |||
390 | 436 | | |
391 | 437 | | |
392 | 438 | | |
| 439 | + | |
| 440 | + | |
393 | 441 | | |
394 | 442 | | |
395 | 443 | | |
| |||
404 | 452 | | |
405 | 453 | | |
406 | 454 | | |
407 | | - | |
408 | | - | |
| 455 | + | |
| 456 | + | |
| 457 | + | |
| Original file line number | Diff line number | Diff line change | |
|---|---|---|---|
| |||
22 | 22 | | |
23 | 23 | | |
24 | 24 | | |
| 25 | + | |
25 | 26 | | |
26 | 27 | | |
27 | 28 | | |
| |||
46 | 47 | | |
47 | 48 | | |
48 | 49 | | |
| 50 | + | |
49 | 51 | | |
50 | 52 | | |
51 | 53 | | |
| |||
0 commit comments