@@ -70,15 +70,24 @@ def ppo_loss_fn(
7070 epsilon_low = 0.2 ,
7171 epsilon_high = 0.2 ,
7272 beta = 0.04 ,
73- loss_type = "dapo" , # ["grpo", "bnpo", "dr_grpo", "dapo", "cispo", "sapo"]
73+ loss_type = "dapo" , # ["grpo", "bnpo", "dr_grpo", "dapo", "cispo", "sapo", "luspo" ]
7474 max_completion_length = None , # Required for dr_grpo
7575 importance_sampling_level = "token" , # ["token", "sequence"] - new parameter for GSPO
7676 sapo_temperature_pos = 1.0 , # Temperature for positive advantages in SAPO
7777 sapo_temperature_neg = 1.05 , # Temperature for negative advantages in SAPO
7878 vllm_is_ratio = None , # vLLM importance sampling ratio (chunk_size, seq_len) or (chunk_size, 1) or None
79+ delta = None , # Upper clamp for two-sided clipping (INTELLECT-2)
80+ use_bias_correction_kl = False , # Importance-sampling-corrected KL (DeepSeek-V3.2)
7981 ** kwargs ,
8082 ):
8183 """GRPO Loss Function matching GRPOTrainer implementation."""
84+ # Validate sequence-level + loss_type combinations
85+ if importance_sampling_level == "sequence" and loss_type in ("cispo" , "sapo" ):
86+ raise ValueError (
87+ f"Sequence-level importance sampling is not supported for loss_type='{ loss_type } '. "
88+ f"Use importance_sampling_level='token' instead."
89+ )
90+
8291 per_token_logps = log_probs .gather (dim = - 1 , index = selected_token_ids .unsqueeze (- 1 )).squeeze (
8392 - 1
8493 ) # (batch_size, seq_len)
@@ -135,6 +144,9 @@ def ppo_loss_fn(
135144 )
136145 per_token_loss = - per_token_loss * advantages_expanded
137146 else :
147+ # Apply delta (two-sided clipping from INTELLECT-2) to coef_1
148+ if delta is not None :
149+ coef_1 = torch .clamp (coef_1 , max = delta )
138150 per_token_loss1 = coef_1 * advantages .unsqueeze (1 )
139151 per_token_loss2 = coef_2 * advantages .unsqueeze (1 )
140152 per_token_loss = - torch .min (per_token_loss1 , per_token_loss2 )
@@ -146,6 +158,10 @@ def ppo_loss_fn(
146158 if beta != 0.0 :
147159 # Compute KL penalty (approximates KL[per_token_logps, ref_per_token_logps])
148160 kl_div = k3_loss_fn (ref_per_token_logps , per_token_logps )
161+ if use_bias_correction_kl :
162+ # Importance-sampling-corrected KL (DeepSeek-V3.2): kl *= token-level coef_1
163+ token_coef_1 = torch .exp (per_token_logps - old_per_token_logps )
164+ kl_div = kl_div * token_coef_1
149165 # Combine losses
150166 per_token_loss = per_token_loss + beta * kl_div
151167
@@ -169,6 +185,19 @@ def ppo_loss_fn(
169185 elif loss_type == "dapo" or loss_type == "cispo" :
170186 loss_normalizer = LigerFusedLinearPPOBase ._compute_dapo_normalizer (full_attention_mask )
171187 loss = (per_token_loss * attention_mask ).sum () / loss_normalizer
188+ elif loss_type == "luspo" :
189+ # LUSPO: loss = (per_token_loss * mask.sum(1, keepdim=True)).mean()
190+ # Reformulated as: sum_i(sum_j(per_token_loss_ij) * seq_len_i) / numel
191+ # to avoid (B,T) * (B,1) broadcast which amplifies torch.compile differences.
192+ seq_lens = attention_mask .sum (- 1 ) # (chunk_B,)
193+ per_seq_sum = per_token_loss .sum (- 1 ) # (chunk_B,)
194+ weighted = per_seq_sum * seq_lens # (chunk_B,)
195+ if importance_sampling_level == "sequence" and beta == 0.0 :
196+ # per_token_loss stays (B, 1), so .mean() divides by B
197+ loss = weighted .sum () / full_attention_mask .shape [0 ]
198+ else :
199+ # per_token_loss is (B, T), .mean() divides by B*T
200+ loss = weighted .sum () / (full_attention_mask .shape [0 ] * full_attention_mask .shape [1 ])
172201 else :
173202 raise ValueError (f"Unknown loss type: { loss_type } " )
174203
@@ -220,6 +249,8 @@ def forward(
220249 use_ref_model = True ,
221250 chunk_size = 1 ,
222251 vllm_is_ratio = None ,
252+ delta = None ,
253+ use_bias_correction_kl = False ,
223254 ):
224255 """
225256 Fused linear layer with GRPO loss.
@@ -235,7 +266,7 @@ def forward(
235266 ref_weight (torch.Tensor, optional): Reference model weight tensor. Shape: (vocab_size, hidden_size)
236267 ref_bias (torch.Tensor, optional): Reference model bias tensor. Shape: (vocab_size,)
237268 beta (float): Weight for the KL penalty
238- loss_type (str): Type of loss calculation ("grpo", "bnpo", "dr_grpo", "dapo", "cispo", "sapo").
269+ loss_type (str): Type of loss calculation ("grpo", "bnpo", "dr_grpo", "dapo", "cispo", "sapo", "luspo" ).
239270 Defaults to "dapo".
240271 max_completion_length (int, optional): Maximum completion length, required for "dr_grpo". Defaults to None.
241272 importance_sampling_level (str): Level of importance sampling ("token" or "sequence"). Defaults to "token".
@@ -250,6 +281,13 @@ def forward(
250281 Returns:
251282 torch.Tensor: Computed loss
252283 """
284+ # Validate before entering torch.compile boundary
285+ if importance_sampling_level == "sequence" and loss_type in ("cispo" , "sapo" ):
286+ raise ValueError (
287+ f"Sequence-level importance sampling is not supported for loss_type='{ loss_type } '. "
288+ f"Use importance_sampling_level='token' instead."
289+ )
290+
253291 return super ().forward (
254292 cls = cls ,
255293 ctx = ctx ,
@@ -277,6 +315,8 @@ def forward(
277315 sapo_temperature_pos = sapo_temperature_pos ,
278316 sapo_temperature_neg = sapo_temperature_neg ,
279317 vllm_is_ratio = vllm_is_ratio ,
318+ delta = delta ,
319+ use_bias_correction_kl = use_bias_correction_kl ,
280320 )
281321
282322 @staticmethod
@@ -310,6 +350,8 @@ def backward(ctx, grad_output, *grad_metrics):
310350 None , # grad_use_ref_model
311351 None , # grad_chunk_size
312352 None , # grad_vllm_is_ratio
353+ None , # grad_delta
354+ None , # grad_use_bias_correction_kl
313355 )
314356
315357
@@ -330,6 +372,8 @@ def __init__(
330372 sapo_temperature_pos : float = 1.0 ,
331373 sapo_temperature_neg : float = 1.05 ,
332374 temperature : float = 1.0 ,
375+ delta : Optional [float ] = None ,
376+ use_bias_correction_kl : bool = False ,
333377 ):
334378 """
335379 Args:
@@ -339,21 +383,25 @@ def __init__(
339383 chunk_size (int): Size of chunks for processing.
340384 epsilon_low (float): Lower bound for the importance sampling ratio.
341385 epsilon_high (float): Upper bound for the importance sampling ratio.
342- loss_type (str): Type of loss calculation ("grpo", "bnpo", "dr_grpo", "dapo", "cispo", "sapo").
386+ loss_type (str): Type of loss calculation ("grpo", "bnpo", "dr_grpo", "dapo", "cispo", "sapo", "luspo" ).
343387 Defaults to "dapo". For "cispo", epsilon_high is typically larger (e.g. 5.0) and
344388 epsilon_low is unused. For "sapo", uses soft gating instead of hard clipping.
345389 max_completion_length (int, optional): Maximum completion length, required for "dr_grpo". Defaults to None.
346390 importance_sampling_level (str): Level of importance sampling ("token" or "sequence"). Defaults to "token".
347391 sapo_temperature_pos (float): Temperature for positive advantages in SAPO. Defaults to 1.0.
348392 sapo_temperature_neg (float): Temperature for negative advantages in SAPO. Defaults to 1.05.
349393 temperature (float): Temperature for the logits.
394+ delta (float, optional): Upper clamp for two-sided clipping (INTELLECT-2). None means disabled.
395+ use_bias_correction_kl (bool): If True, multiply KL by importance sampling ratio (DeepSeek-V3.2).
350396 """
351397 super ().__init__ ()
352398 # Validate SAPO temperatures to prevent division by zero or numerical instability
353399 if sapo_temperature_pos <= 0 :
354400 raise ValueError (f"sapo_temperature_pos must be positive, got { sapo_temperature_pos } " )
355401 if sapo_temperature_neg <= 0 :
356402 raise ValueError (f"sapo_temperature_neg must be positive, got { sapo_temperature_neg } " )
403+ if delta is not None and delta <= 0 :
404+ raise ValueError (f"delta must be positive, got { delta } " )
357405 self .beta = beta
358406 self .compiled = compiled
359407 self .use_ref_model = use_ref_model
@@ -366,6 +414,8 @@ def __init__(
366414 self .sapo_temperature_pos = sapo_temperature_pos
367415 self .sapo_temperature_neg = sapo_temperature_neg
368416 self .temperature = temperature
417+ self .delta = delta
418+ self .use_bias_correction_kl = use_bias_correction_kl
369419
370420 def forward (
371421 self ,
@@ -407,4 +457,6 @@ def forward(
407457 self .use_ref_model ,
408458 self .chunk_size ,
409459 vllm_is_ratio ,
460+ self .delta ,
461+ self .use_bias_correction_kl ,
410462 )
0 commit comments