Skip to content

Commit dd2d3a4

Browse files
kashifclaudevaibhavjindal
authored
[GRPO] add grpo loss types (#993)
## Summary Add various GRPO loss types in chunked and triton grpo loss. <!--- ## Details This is an optional section; is there anything specific that reviewers should be aware of? ---> ## Testing Done <!--- This is a required section; please describe how this change was tested. ---> <!-- Replace BLANK with your device type. For example, A100-80G-PCIe Complete the following tasks before sending your PR, and replace `[ ]` with `[x]` to indicate you have done them. --> - Hardware Type: <BLANK> - [x] run `make test` to ensure correctness - [x] run `make checkstyle` to ensure code style - [ ] run `make test-convergence` to ensure convergence --------- Co-authored-by: Claude Opus 4.6 <noreply@anthropic.com> Co-authored-by: Vaibhav Jindal <vaibhav.jndl@gmail.com>
1 parent adb2238 commit dd2d3a4

File tree

6 files changed

+1484
-116
lines changed

6 files changed

+1484
-116
lines changed

src/liger_kernel/chunked_loss/fused_linear_ppo.py

Lines changed: 10 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -42,6 +42,8 @@ def forward(
4242
sapo_temperature_pos=1.0,
4343
sapo_temperature_neg=1.05,
4444
vllm_is_ratio=None,
45+
delta=None,
46+
use_bias_correction_kl=False,
4547
):
4648
# TODO: check torch compile matmul
4749
"""Chunked forward pass for PPO loss computation.
@@ -121,6 +123,8 @@ def forward(
121123
ppo_loss_fn=cls.ppo_loss_fn,
122124
sapo_temperature_pos=sapo_temperature_pos,
123125
sapo_temperature_neg=sapo_temperature_neg,
126+
delta=delta,
127+
use_bias_correction_kl=use_bias_correction_kl,
124128
)
125129

126130
def fused_fwd_bwd(
@@ -321,6 +325,8 @@ def _compute_chunk_loss(
321325
ppo_loss_fn=None,
322326
sapo_temperature_pos=1.0,
323327
sapo_temperature_neg=1.05,
328+
delta=None,
329+
use_bias_correction_kl=False,
324330
):
325331
"""Compute loss for a single chunk."""
326332
# Get policy log probabilities using chunk_forward
@@ -353,6 +359,8 @@ def _compute_chunk_loss(
353359
sapo_temperature_pos=sapo_temperature_pos,
354360
sapo_temperature_neg=sapo_temperature_neg,
355361
vllm_is_ratio=vllm_is_ratio_chunk,
362+
delta=delta,
363+
use_bias_correction_kl=use_bias_correction_kl,
356364
)
357365

358366
return chunk_loss, chunk_metrics
@@ -408,4 +416,6 @@ def backward(ctx, grad_output, *grad_metrics):
408416
None, # grad_sapo_temperature_pos
409417
None, # grad_sapo_temperature_neg
410418
None, # grad_vllm_is_ratio
419+
None, # grad_delta
420+
None, # grad_use_bias_correction_kl
411421
)

src/liger_kernel/chunked_loss/grpo_loss.py

Lines changed: 55 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -70,15 +70,24 @@ def ppo_loss_fn(
7070
epsilon_low=0.2,
7171
epsilon_high=0.2,
7272
beta=0.04,
73-
loss_type="dapo", # ["grpo", "bnpo", "dr_grpo", "dapo", "cispo", "sapo"]
73+
loss_type="dapo", # ["grpo", "bnpo", "dr_grpo", "dapo", "cispo", "sapo", "luspo"]
7474
max_completion_length=None, # Required for dr_grpo
7575
importance_sampling_level="token", # ["token", "sequence"] - new parameter for GSPO
7676
sapo_temperature_pos=1.0, # Temperature for positive advantages in SAPO
7777
sapo_temperature_neg=1.05, # Temperature for negative advantages in SAPO
7878
vllm_is_ratio=None, # vLLM importance sampling ratio (chunk_size, seq_len) or (chunk_size, 1) or None
79+
delta=None, # Upper clamp for two-sided clipping (INTELLECT-2)
80+
use_bias_correction_kl=False, # Importance-sampling-corrected KL (DeepSeek-V3.2)
7981
**kwargs,
8082
):
8183
"""GRPO Loss Function matching GRPOTrainer implementation."""
84+
# Validate sequence-level + loss_type combinations
85+
if importance_sampling_level == "sequence" and loss_type in ("cispo", "sapo"):
86+
raise ValueError(
87+
f"Sequence-level importance sampling is not supported for loss_type='{loss_type}'. "
88+
f"Use importance_sampling_level='token' instead."
89+
)
90+
8291
per_token_logps = log_probs.gather(dim=-1, index=selected_token_ids.unsqueeze(-1)).squeeze(
8392
-1
8493
) # (batch_size, seq_len)
@@ -135,6 +144,9 @@ def ppo_loss_fn(
135144
)
136145
per_token_loss = -per_token_loss * advantages_expanded
137146
else:
147+
# Apply delta (two-sided clipping from INTELLECT-2) to coef_1
148+
if delta is not None:
149+
coef_1 = torch.clamp(coef_1, max=delta)
138150
per_token_loss1 = coef_1 * advantages.unsqueeze(1)
139151
per_token_loss2 = coef_2 * advantages.unsqueeze(1)
140152
per_token_loss = -torch.min(per_token_loss1, per_token_loss2)
@@ -146,6 +158,10 @@ def ppo_loss_fn(
146158
if beta != 0.0:
147159
# Compute KL penalty (approximates KL[per_token_logps, ref_per_token_logps])
148160
kl_div = k3_loss_fn(ref_per_token_logps, per_token_logps)
161+
if use_bias_correction_kl:
162+
# Importance-sampling-corrected KL (DeepSeek-V3.2): kl *= token-level coef_1
163+
token_coef_1 = torch.exp(per_token_logps - old_per_token_logps)
164+
kl_div = kl_div * token_coef_1
149165
# Combine losses
150166
per_token_loss = per_token_loss + beta * kl_div
151167

@@ -169,6 +185,19 @@ def ppo_loss_fn(
169185
elif loss_type == "dapo" or loss_type == "cispo":
170186
loss_normalizer = LigerFusedLinearPPOBase._compute_dapo_normalizer(full_attention_mask)
171187
loss = (per_token_loss * attention_mask).sum() / loss_normalizer
188+
elif loss_type == "luspo":
189+
# LUSPO: loss = (per_token_loss * mask.sum(1, keepdim=True)).mean()
190+
# Reformulated as: sum_i(sum_j(per_token_loss_ij) * seq_len_i) / numel
191+
# to avoid (B,T) * (B,1) broadcast which amplifies torch.compile differences.
192+
seq_lens = attention_mask.sum(-1) # (chunk_B,)
193+
per_seq_sum = per_token_loss.sum(-1) # (chunk_B,)
194+
weighted = per_seq_sum * seq_lens # (chunk_B,)
195+
if importance_sampling_level == "sequence" and beta == 0.0:
196+
# per_token_loss stays (B, 1), so .mean() divides by B
197+
loss = weighted.sum() / full_attention_mask.shape[0]
198+
else:
199+
# per_token_loss is (B, T), .mean() divides by B*T
200+
loss = weighted.sum() / (full_attention_mask.shape[0] * full_attention_mask.shape[1])
172201
else:
173202
raise ValueError(f"Unknown loss type: {loss_type}")
174203

@@ -220,6 +249,8 @@ def forward(
220249
use_ref_model=True,
221250
chunk_size=1,
222251
vllm_is_ratio=None,
252+
delta=None,
253+
use_bias_correction_kl=False,
223254
):
224255
"""
225256
Fused linear layer with GRPO loss.
@@ -235,7 +266,7 @@ def forward(
235266
ref_weight (torch.Tensor, optional): Reference model weight tensor. Shape: (vocab_size, hidden_size)
236267
ref_bias (torch.Tensor, optional): Reference model bias tensor. Shape: (vocab_size,)
237268
beta (float): Weight for the KL penalty
238-
loss_type (str): Type of loss calculation ("grpo", "bnpo", "dr_grpo", "dapo", "cispo", "sapo").
269+
loss_type (str): Type of loss calculation ("grpo", "bnpo", "dr_grpo", "dapo", "cispo", "sapo", "luspo").
239270
Defaults to "dapo".
240271
max_completion_length (int, optional): Maximum completion length, required for "dr_grpo". Defaults to None.
241272
importance_sampling_level (str): Level of importance sampling ("token" or "sequence"). Defaults to "token".
@@ -250,6 +281,13 @@ def forward(
250281
Returns:
251282
torch.Tensor: Computed loss
252283
"""
284+
# Validate before entering torch.compile boundary
285+
if importance_sampling_level == "sequence" and loss_type in ("cispo", "sapo"):
286+
raise ValueError(
287+
f"Sequence-level importance sampling is not supported for loss_type='{loss_type}'. "
288+
f"Use importance_sampling_level='token' instead."
289+
)
290+
253291
return super().forward(
254292
cls=cls,
255293
ctx=ctx,
@@ -277,6 +315,8 @@ def forward(
277315
sapo_temperature_pos=sapo_temperature_pos,
278316
sapo_temperature_neg=sapo_temperature_neg,
279317
vllm_is_ratio=vllm_is_ratio,
318+
delta=delta,
319+
use_bias_correction_kl=use_bias_correction_kl,
280320
)
281321

282322
@staticmethod
@@ -310,6 +350,8 @@ def backward(ctx, grad_output, *grad_metrics):
310350
None, # grad_use_ref_model
311351
None, # grad_chunk_size
312352
None, # grad_vllm_is_ratio
353+
None, # grad_delta
354+
None, # grad_use_bias_correction_kl
313355
)
314356

315357

@@ -330,6 +372,8 @@ def __init__(
330372
sapo_temperature_pos: float = 1.0,
331373
sapo_temperature_neg: float = 1.05,
332374
temperature: float = 1.0,
375+
delta: Optional[float] = None,
376+
use_bias_correction_kl: bool = False,
333377
):
334378
"""
335379
Args:
@@ -339,21 +383,25 @@ def __init__(
339383
chunk_size (int): Size of chunks for processing.
340384
epsilon_low (float): Lower bound for the importance sampling ratio.
341385
epsilon_high (float): Upper bound for the importance sampling ratio.
342-
loss_type (str): Type of loss calculation ("grpo", "bnpo", "dr_grpo", "dapo", "cispo", "sapo").
386+
loss_type (str): Type of loss calculation ("grpo", "bnpo", "dr_grpo", "dapo", "cispo", "sapo", "luspo").
343387
Defaults to "dapo". For "cispo", epsilon_high is typically larger (e.g. 5.0) and
344388
epsilon_low is unused. For "sapo", uses soft gating instead of hard clipping.
345389
max_completion_length (int, optional): Maximum completion length, required for "dr_grpo". Defaults to None.
346390
importance_sampling_level (str): Level of importance sampling ("token" or "sequence"). Defaults to "token".
347391
sapo_temperature_pos (float): Temperature for positive advantages in SAPO. Defaults to 1.0.
348392
sapo_temperature_neg (float): Temperature for negative advantages in SAPO. Defaults to 1.05.
349393
temperature (float): Temperature for the logits.
394+
delta (float, optional): Upper clamp for two-sided clipping (INTELLECT-2). None means disabled.
395+
use_bias_correction_kl (bool): If True, multiply KL by importance sampling ratio (DeepSeek-V3.2).
350396
"""
351397
super().__init__()
352398
# Validate SAPO temperatures to prevent division by zero or numerical instability
353399
if sapo_temperature_pos <= 0:
354400
raise ValueError(f"sapo_temperature_pos must be positive, got {sapo_temperature_pos}")
355401
if sapo_temperature_neg <= 0:
356402
raise ValueError(f"sapo_temperature_neg must be positive, got {sapo_temperature_neg}")
403+
if delta is not None and delta <= 0:
404+
raise ValueError(f"delta must be positive, got {delta}")
357405
self.beta = beta
358406
self.compiled = compiled
359407
self.use_ref_model = use_ref_model
@@ -366,6 +414,8 @@ def __init__(
366414
self.sapo_temperature_pos = sapo_temperature_pos
367415
self.sapo_temperature_neg = sapo_temperature_neg
368416
self.temperature = temperature
417+
self.delta = delta
418+
self.use_bias_correction_kl = use_bias_correction_kl
369419

370420
def forward(
371421
self,
@@ -407,4 +457,6 @@ def forward(
407457
self.use_ref_model,
408458
self.chunk_size,
409459
vllm_is_ratio,
460+
self.delta,
461+
self.use_bias_correction_kl,
410462
)

0 commit comments

Comments
 (0)