diff --git a/src/liger_kernel/chunked_loss/fused_linear_ppo.py b/src/liger_kernel/chunked_loss/fused_linear_ppo.py index 84070624d..13ac46793 100644 --- a/src/liger_kernel/chunked_loss/fused_linear_ppo.py +++ b/src/liger_kernel/chunked_loss/fused_linear_ppo.py @@ -39,6 +39,7 @@ def forward( compiled=True, use_ref_model=False, chunk_size=1, + vllm_is_ratio=None, ): # TODO: check torch compile matmul """Chunked forward pass for PPO loss computation. @@ -107,6 +108,7 @@ def fused_fwd_bwd( ref_per_token_logps_chunk, old_per_token_logps_chunk, ref_input_chunk, + vllm_is_ratio_chunk, ): """Fused forward and backward for a chunk.""" argnums = (0, 1, 5) if bias is not None else (0, 1) @@ -120,6 +122,7 @@ def fused_fwd_bwd( ref_per_token_logps_chunk=ref_per_token_logps_chunk, # arg 6 old_per_token_logps_chunk=old_per_token_logps_chunk, # arg 7 ref_input_chunk=ref_input_chunk, # arg 8 + vllm_is_ratio_chunk=vllm_is_ratio_chunk, # arg 9 ) def accumulate_chunk( @@ -130,6 +133,7 @@ def accumulate_chunk( ref_per_token_logps_chunk=None, old_per_token_logps_chunk=None, ref_input_chunk=None, + vllm_is_ratio_chunk=None, ): (chunk_grad_input, chunk_grad_weight, *chunk_grad_bias), (chunk_loss, chunk_metrics) = fused_fwd_bwd( input_chunk, @@ -139,6 +143,7 @@ def accumulate_chunk( ref_per_token_logps_chunk, old_per_token_logps_chunk, ref_input_chunk, + vllm_is_ratio_chunk, ) if bias is not None: grad_bias.add_(chunk_grad_bias[0]) @@ -189,6 +194,9 @@ def accumulate_chunk( if use_ref_model and ref_per_token_logps is None else [None] * chunks ) + _vllm_is_ratio_chunks = ( + torch.chunk(vllm_is_ratio, chunks=chunks, dim=0) if vllm_is_ratio is not None else [None] * chunks + ) for ( input_chunk, @@ -198,6 +206,7 @@ def accumulate_chunk( ref_per_token_logps_chunk, old_per_token_logps_chunk, ref_input_chunk, + vllm_is_ratio_chunk, ) in zip( _input_chunks, _selected_token_ids_chunks, @@ -206,6 +215,7 @@ def accumulate_chunk( _ref_per_token_logps_chunks, _old_per_token_logps_chunks, _ref_input_chunks, + _vllm_is_ratio_chunks, ): # Mark dynamic dimensions torch._dynamo.mark_dynamic(input_chunk, 1) @@ -217,6 +227,8 @@ def accumulate_chunk( torch._dynamo.mark_dynamic(ref_input_chunk, 1) if old_per_token_logps_chunk is not None: torch._dynamo.mark_dynamic(old_per_token_logps_chunk, 1) + if vllm_is_ratio_chunk is not None: + torch._dynamo.mark_dynamic(vllm_is_ratio_chunk, 1) accumulate_chunk( input_chunk, @@ -226,6 +238,7 @@ def accumulate_chunk( ref_per_token_logps_chunk, old_per_token_logps_chunk, ref_input_chunk, + vllm_is_ratio_chunk, ) # Combine gradients @@ -270,6 +283,7 @@ def _compute_chunk_loss( ref_per_token_logps_chunk=None, old_per_token_logps_chunk=None, ref_input_chunk=None, + vllm_is_ratio_chunk=None, ref_weight=None, ref_bias=None, full_attention_mask=None, @@ -311,6 +325,7 @@ def _compute_chunk_loss( loss_type=loss_type, max_completion_length=max_completion_length, importance_sampling_level=importance_sampling_level, + vllm_is_ratio=vllm_is_ratio_chunk, ) return chunk_loss, chunk_metrics @@ -363,4 +378,5 @@ def backward(ctx, grad_output, *grad_metrics): None, # grad_compiled None, # grad_use_ref_model None, # grad_chunk_size + None, # grad_vllm_is_ratio ) diff --git a/src/liger_kernel/chunked_loss/grpo_loss.py b/src/liger_kernel/chunked_loss/grpo_loss.py index 9c985c695..5581f642c 100644 --- a/src/liger_kernel/chunked_loss/grpo_loss.py +++ b/src/liger_kernel/chunked_loss/grpo_loss.py @@ -30,8 +30,9 @@ def ppo_loss_fn( epsilon_high=0.2, beta=0.04, loss_type="dapo", # ["grpo", "bnpo", "dr_grpo", "dapo"] - max_completion_length=None, # Required for dr_grpo + max_completion_length=None, # Optional for dr_grpo (defaults to sequence length) importance_sampling_level="token", # ["token", "sequence"] - new parameter for GSPO + vllm_is_ratio=None, # vLLM importance sampling ratio (chunk_size, seq_len) or None **kwargs, ): """GRPO Loss Function matching GRPOTrainer implementation.""" @@ -71,6 +72,11 @@ def ppo_loss_fn( per_token_loss1 = coef_1 * advantages.unsqueeze(1) per_token_loss2 = coef_2 * advantages.unsqueeze(1) per_token_loss = -torch.min(per_token_loss1, per_token_loss2) + + # Apply vLLM importance sampling correction BEFORE adding KL penalty + if vllm_is_ratio is not None: + per_token_loss = per_token_loss * vllm_is_ratio + if beta != 0.0: # Compute KL penalty (approximates KL[per_token_logps, ref_per_token_logps]) kl_div = k3_loss_fn(ref_per_token_logps, per_token_logps) @@ -91,9 +97,8 @@ def ppo_loss_fn( loss = (per_token_loss * attention_mask).sum() / torch.clamp(full_attention_mask.sum(), min=1.0) elif loss_type == "dr_grpo": # Dimension-Reduced GRPO (normalize by batch_size * max_completion_length) - if max_completion_length is None: - raise ValueError("max_completion_length must be provided for loss_type 'dr_grpo'") - loss = (per_token_loss * attention_mask).sum() / (full_attention_mask.shape[0] * max_completion_length) + max_len = max_completion_length if max_completion_length is not None else attention_mask.shape[1] + loss = (per_token_loss * attention_mask).sum() / (full_attention_mask.shape[0] * max_len) elif loss_type == "dapo": loss_normalizer = LigerFusedLinearPPOBase._compute_dapo_normalizer(full_attention_mask) loss = (per_token_loss * attention_mask).sum() / loss_normalizer @@ -145,6 +150,7 @@ def forward( compiled=True, use_ref_model=True, chunk_size=1, + vllm_is_ratio=None, ): """ Fused linear layer with GRPO loss. @@ -161,12 +167,14 @@ def forward( ref_bias (torch.Tensor, optional): Reference model bias tensor. Shape: (vocab_size,) beta (float): Weight for the KL penalty loss_type (str): Type of loss calculation ("grpo", "bnpo", "dr_grpo", "dapo"). Defaults to "dapo". - max_completion_length (int, optional): Maximum completion length, required for "dr_grpo". Defaults to None. + max_completion_length (int, optional): Maximum completion length; if None, defaults to sequence length. importance_sampling_level (str): Level of importance sampling ("token" or "sequence"). Defaults to "token". temperature (float): Temperature for the logits compiled (bool): Whether to use torch compile use_ref_model (bool): Whether to use a reference model chunk_size (int): Size of chunks for processing. + vllm_is_ratio (torch.Tensor, optional): vLLM importance sampling ratio (batch_size, seq_len) or None. + Used to correct for distribution mismatch when using vLLM for generation. Returns: torch.Tensor: Computed loss """ @@ -194,6 +202,7 @@ def forward( use_ref_model=use_ref_model, chunk_size=chunk_size, importance_sampling_level=importance_sampling_level, + vllm_is_ratio=vllm_is_ratio, ) @staticmethod @@ -224,6 +233,7 @@ def backward(ctx, grad_output, *grad_metrics): None, # grad_compiled None, # grad_use_ref_model None, # grad_chunk_size + None, # grad_vllm_is_ratio ) @@ -252,7 +262,7 @@ def __init__( epsilon_low (float): Lower bound for the importance sampling ratio. epsilon_high (float): Upper bound for the importance sampling ratio. loss_type (str): Type of loss calculation ("grpo", "bnpo", "dr_grpo", "dapo"). Defaults to "dapo". - max_completion_length (int, optional): Maximum completion length, required for "dr_grpo". Defaults to None. + max_completion_length (int, optional): Maximum completion length; if None, defaults to sequence length. importance_sampling_level (str): Level of importance sampling ("token" or "sequence"). Defaults to "token". temperature (float): Temperature for the logits. """ @@ -281,6 +291,7 @@ def forward( ref_input=None, ref_weight=None, ref_bias=None, + vllm_is_ratio=None, ): return LigerFusedLinearGRPOFunction.apply( _input, @@ -304,4 +315,5 @@ def forward( self.compiled, self.use_ref_model, self.chunk_size, + vllm_is_ratio, ) diff --git a/src/liger_kernel/ops/grpo_loss.py b/src/liger_kernel/ops/grpo_loss.py index cfc06e0d6..00d4e6424 100644 --- a/src/liger_kernel/ops/grpo_loss.py +++ b/src/liger_kernel/ops/grpo_loss.py @@ -75,6 +75,8 @@ def _grpo_loss_fwd_kernel( INPUT_IDS, COMPLETION_MASK, ADVANTAGES, + VLLM_IS_RATIO, # vLLM importance sampling ratio (B, L) or None + VLLM_IS_RATIO_STRIDE, # stride for VLLM_IS_RATIO (L for per-token, 1 for per-sequence) LOSS, LSE, KL, @@ -132,6 +134,14 @@ def _grpo_loss_fwd_kernel( is_high_clipped = (coef_1 > 1 + EPS_HIGH) & (advantage > 0) is_clipped = is_low_clipped | is_high_clipped + # Apply vLLM importance sampling correction BEFORE adding KL + if VLLM_IS_RATIO is not None: + # VLLM_IS_RATIO_STRIDE is L for per-token, 1 for per-sequence + vllm_is_ratio = tl.load(VLLM_IS_RATIO + off_b * VLLM_IS_RATIO_STRIDE + off_l % VLLM_IS_RATIO_STRIDE).to( + tl.float32 + ) + per_token_loss = per_token_loss * vllm_is_ratio + if BETA != 0.0: REF_LOGP += off_b * L + off_l KL += off_b * L + off_l @@ -145,11 +155,175 @@ def _grpo_loss_fwd_kernel( tl.store(IS_CLIPPED, is_clipped) -# @triton.autotune([triton.Config({"BLOCK_N":BLOCK_N}, num_stages=ns, num_warps=nw) -# for BLOCK_N in [2048, 4096, 8192] -# for ns in [1, 2, 4] -# for nw in [1, 2, 4, 8, 16]], -# key=['N']) +# Sequence-level forward kernel: uses pre-computed coef_1 per sequence +@triton.jit +def _grpo_loss_fwd_kernel_seq( + LOGITS, + REF_LOGP, + INPUT_IDS, + COMPLETION_MASK, + ADVANTAGES, + COEF_1, # Pre-computed sequence-level importance weight (B,) + COEF_2, # Pre-computed clipped coef (B,) + IS_CLIPPED_SEQ, # Pre-computed clipping indicator (B,) + VLLM_IS_RATIO, # vLLM importance sampling ratio (B, L) or (B, 1) or None + VLLM_IS_RATIO_STRIDE, # stride for VLLM_IS_RATIO (L for per-token, 1 for per-sequence) + LOSS, + LSE, + KL, + IS_CLIPPED, + TEMPERATURE, + BETA: tl.constexpr, + L: tl.constexpr, + N: tl.constexpr, + BLOCK_N: tl.constexpr = 4096, +): + off_b = tl.program_id(0).cast(tl.int64) + off_l = tl.program_id(1).cast(tl.int64) + + if COMPLETION_MASK is not None: + COMPLETION_MASK += off_b * L + off_l + not_skip = tl.load(COMPLETION_MASK) + if not_skip == 0: + return + + LOGITS += off_b * (L + 1) * N + off_l * N + INPUT_IDS += off_b * L + off_l + ADVANTAGES += off_b + COEF_1 += off_b + COEF_2 += off_b + IS_CLIPPED_SEQ += off_b + LOSS += off_b * L + off_l + LSE += off_b * L + off_l + IS_CLIPPED += off_b * L + off_l + + # Compute log softmax + m_i = float("-inf") + l_i = 0.0 + for start in range(0, N, BLOCK_N): + cols = start + tl.arange(0, BLOCK_N) + logits = tl.load(LOGITS + cols, mask=cols < N, other=float("-inf")).to(tl.float32) / TEMPERATURE + new_m_i = tl.maximum(m_i, tl.max(logits)) + alpha = tl.exp(m_i - new_m_i) + l_i = l_i * alpha + tl.sum(tl.exp(logits - new_m_i)) + m_i = new_m_i + lse = m_i + tl.log(l_i) + + idx = tl.load(INPUT_IDS) + x = tl.load(LOGITS + idx).to(tl.float32) / TEMPERATURE + logp = x - lse + + # Load pre-computed sequence-level coefficients + coef_1 = tl.load(COEF_1).to(tl.float32) + coef_2 = tl.load(COEF_2).to(tl.float32) + is_clipped_seq = tl.load(IS_CLIPPED_SEQ) + + advantage = tl.load(ADVANTAGES).to(tl.float32) + per_token_loss1 = coef_1 * advantage + per_token_loss2 = coef_2 * advantage + per_token_loss = -tl.minimum(per_token_loss1, per_token_loss2) + + # Apply vLLM importance sampling correction BEFORE adding KL + if VLLM_IS_RATIO is not None: + vllm_is_ratio = tl.load(VLLM_IS_RATIO + off_b * VLLM_IS_RATIO_STRIDE + off_l % VLLM_IS_RATIO_STRIDE).to( + tl.float32 + ) + per_token_loss = per_token_loss * vllm_is_ratio + + if BETA != 0.0: + REF_LOGP += off_b * L + off_l + KL += off_b * L + off_l + ref_logp = tl.load(REF_LOGP).to(tl.float32) + kl = tl.exp(ref_logp - logp) - (ref_logp - logp) - 1 + per_token_loss += BETA * kl + tl.store(KL, kl) + + tl.store(LOSS, per_token_loss) + tl.store(LSE, lse) + tl.store(IS_CLIPPED, is_clipped_seq) # Same for all tokens in sequence + + +# Sequence-level backward kernel +@triton.jit +def _grpo_loss_bwd_kernel_seq( + DLOSS, + DLOSS_SUM, + DLOGITS, + LOGITS, + REF_LOGP, + INPUT_IDS, + ADVANTAGES, + COMPLETION_MASK, + LSE, + COEF_1, # Pre-computed sequence-level importance weight (B,) + SEQ_LEN, # Number of valid tokens per sequence (B,) + TEMPERATURE, + BETA: tl.constexpr, + EPS_LOW, + EPS_HIGH, + loss_stride0, + loss_stride1, + L: tl.constexpr, + N: tl.constexpr, + BLOCK_N: tl.constexpr = 4096, +): + off_b = tl.program_id(0).cast(tl.int64) + off_l = tl.program_id(1).cast(tl.int64) + + DLOGITS += off_b * (L + 1) * N + off_l * N + if COMPLETION_MASK is not None: + COMPLETION_MASK += off_b * L + off_l + not_skip = tl.load(COMPLETION_MASK) + if not_skip == 0: + for start in range(0, N, BLOCK_N): + cols = tl.arange(0, BLOCK_N) + start + tl.store(DLOGITS + cols, 0.0, mask=cols < N) + return + + LOGITS += off_b * (L + 1) * N + off_l * N + DLOSS += off_b * loss_stride0 + off_l * loss_stride1 + DLOSS_SUM += off_b + INPUT_IDS += off_b * L + off_l + ADVANTAGES += off_b + LSE += off_b * L + off_l + COEF_1 += off_b + SEQ_LEN += off_b + + dloss = tl.load(DLOSS).to(tl.float32) + dloss_sum = tl.load(DLOSS_SUM).to(tl.float32) + lse = tl.load(LSE).to(tl.float32) + coef_1 = tl.load(COEF_1).to(tl.float32) + seq_len = tl.load(SEQ_LEN).to(tl.float32) + + idx = tl.load(INPUT_IDS) + x = tl.load(LOGITS + idx).to(tl.float32) / TEMPERATURE + logp = x - lse + + advantage = tl.load(ADVANTAGES).to(tl.float32) + coef_2 = tl.clamp(coef_1, 1 - EPS_LOW, 1 + EPS_HIGH) + per_token_loss1 = coef_1 * advantage + per_token_loss2 = coef_2 * advantage + is_unclipped = per_token_loss2 >= per_token_loss1 + + # For sequence-level: gradient flows through mean, so scale by coef_1/seq_len + # d(loss)/d(logp) = -advantage * coef_1 / seq_len (when unclipped) + dlogp = -per_token_loss1 / seq_len * is_unclipped * dloss_sum + + if BETA != 0.0: + REF_LOGP += off_b * L + off_l + ref_logp = tl.load(REF_LOGP).to(tl.float32) + dlogp += BETA * (1 - tl.exp(ref_logp - logp)) * dloss + + dlogp = dlogp / TEMPERATURE + tl.debug_barrier() + for start_n in tl.range(0, N, BLOCK_N): + cols = start_n + tl.arange(0, BLOCK_N) + logits = tl.load(LOGITS + cols, mask=cols < N, other=-float("inf")).to(tl.float32) / TEMPERATURE + probs = tl.exp(logits - lse) + dlogits = tl.where(cols == idx, 1 - probs, -probs) * dlogp + tl.store(DLOGITS + cols, dlogits, mask=cols < N) + + @triton.jit def _grpo_loss_bwd_kernel( DLOSS, @@ -161,6 +335,8 @@ def _grpo_loss_bwd_kernel( ADVANTAGES, COMPLETION_MASK, LSE, + VLLM_IS_RATIO, + VLLM_IS_RATIO_STRIDE, TEMPERATURE, BETA: tl.constexpr, EPS_LOW, @@ -209,6 +385,14 @@ def _grpo_loss_bwd_kernel( mask = per_token_loss2 >= per_token_loss1 dlogp = -per_token_loss1 * mask + + # Apply vLLM IS ratio to PPO gradient (before KL gradient) + if VLLM_IS_RATIO is not None: + vllm_is_ratio = tl.load(VLLM_IS_RATIO + off_b * VLLM_IS_RATIO_STRIDE + off_l % VLLM_IS_RATIO_STRIDE).to( + tl.float32 + ) + dlogp = dlogp * vllm_is_ratio + if BETA != 0.0: REF_LOGP += off_b * L + off_l ref_logp = tl.load(REF_LOGP).to(tl.float32) @@ -224,6 +408,32 @@ def _grpo_loss_bwd_kernel( tl.store(DLOGITS + cols, dlogits, mask=cols < N) +def _compute_dapo_normalizer(completion_mask): + """Global active tokens averaged per process (for distributed DAPO loss).""" + normalizer = completion_mask.to(torch.float32).sum() + world_size = 1 + if torch.distributed.is_available() and torch.distributed.is_initialized(): + normalizer = normalizer.clone() + torch.distributed.all_reduce(normalizer, op=torch.distributed.ReduceOp.SUM) + world_size = torch.distributed.get_world_size() + normalizer = normalizer / world_size + return torch.clamp(normalizer, min=1.0) + + +def _reduce_loss(per_token_loss, mask, loss_type, max_completion_length, B, L): + """Apply loss reduction based on loss_type.""" + if loss_type == "grpo": + return ((per_token_loss * mask).sum(-1) / mask.sum(-1).clamp(min=1.0)).mean() + elif loss_type == "bnpo": + return (per_token_loss * mask).sum() / mask.sum().clamp(min=1.0) + elif loss_type == "dr_grpo": + max_len = max_completion_length if max_completion_length is not None else L + return (per_token_loss * mask).sum() / (B * max_len) + elif loss_type == "dapo": + return (per_token_loss * mask).sum() / _compute_dapo_normalizer(mask) + raise ValueError(f"Unknown loss_type: {loss_type}. Expected one of: grpo, bnpo, dr_grpo, dapo") + + class GrpoLossFunction(torch.autograd.Function): @staticmethod def forward( @@ -239,10 +449,18 @@ def forward( eps_low, eps_high, inplace, + loss_type="grpo", + max_completion_length=None, + reduce=True, + importance_sampling_level="token", + vllm_is_ratio=None, # vLLM importance sampling ratio (B, L) or (B, 1) or None ): assert logits.is_contiguous() and completion_ids.is_contiguous() assert old_logp is None or old_logp.is_contiguous() assert (ref_logp is not None and ref_logp.is_contiguous()) if beta != 0.0 else True + assert importance_sampling_level in ("token", "sequence"), ( + f"importance_sampling_level must be 'token' or 'sequence', got {importance_sampling_level}" + ) B, L_ADD_1, N = logits.shape L = L_ADD_1 - 1 @@ -250,63 +468,252 @@ def forward( if completion_mask is not None: assert completion_mask.is_contiguous() + mask = completion_mask.float() if completion_mask is not None else torch.ones(B, L, device=logits.device) + + # Handle vLLM IS ratio + vllm_is_ratio_ptr = None + vllm_is_ratio_stride = L # default to per-token + if vllm_is_ratio is not None: + vllm_is_ratio = vllm_is_ratio.contiguous() + vllm_is_ratio_ptr = vllm_is_ratio + # Determine stride: L for per-token (B, L), 1 for per-sequence (B, 1) + vllm_is_ratio_stride = vllm_is_ratio.shape[1] if vllm_is_ratio.dim() > 1 else 1 + + # Allocate outputs loss = torch.zeros(B, L, device=logits.device, dtype=torch.float32) lse = torch.zeros_like(loss) is_clipped = torch.zeros_like(loss) kl = torch.zeros_like(loss) if beta != 0.0 else None - kwargs = {"BLOCK_N": 2048, "num_stages": 2, "num_warps": 1} - _grpo_loss_fwd_kernel[(B, L)]( - logits, - old_logp, - ref_logp, - completion_ids, - completion_mask, - advantages, - loss, - lse, - kl, - is_clipped, + + if importance_sampling_level == "sequence": + # Sequence-level: pre-compute sequence importance weights, then use Triton kernel + # Step 1: Get per-token log probs using existing Triton kernel + per_token_logps = fused_selective_log_softmax(logits, completion_ids, temperature, completion_mask) + + # Step 2: Compute sequence-level importance weights + if old_logp is None: + log_ratio = torch.zeros_like(per_token_logps) + else: + log_ratio = per_token_logps - old_logp + + seq_lens = mask.sum(-1).clamp(min=1.0) # (B,) + seq_log_importance = (log_ratio * mask).sum(-1) / seq_lens # (B,) + coef_1 = torch.exp(seq_log_importance) # (B,) + coef_2 = torch.clamp(coef_1, 1 - eps_low, 1 + eps_high) # (B,) + + # Compute is_clipped at sequence level + is_clipped_seq = ((coef_1 < 1 - eps_low) & (advantages < 0)) | ((coef_1 > 1 + eps_high) & (advantages > 0)) + is_clipped_seq = is_clipped_seq.float() # (B,) + + # Step 3: Run Triton kernel with pre-computed coefficients + kwargs = {"BLOCK_N": 2048, "num_stages": 2, "num_warps": 1} + _grpo_loss_fwd_kernel_seq[(B, L)]( + logits, + ref_logp, + completion_ids, + completion_mask, + advantages, + coef_1.contiguous(), + coef_2.contiguous(), + is_clipped_seq.contiguous(), + vllm_is_ratio_ptr, + vllm_is_ratio_stride, + loss, + lse, + kl, + is_clipped, + temperature, + beta, + L, + N, + **kwargs, + ) + + # Save extra tensors for backward + ctx.save_for_backward( + logits, + old_logp, + ref_logp, + completion_ids, + advantages, + completion_mask, + lse, + mask, + coef_1, + seq_lens, + vllm_is_ratio_ptr, + ) + else: + # Token-level: use optimized Triton kernel + kwargs = {"BLOCK_N": 2048, "num_stages": 2, "num_warps": 1} + _grpo_loss_fwd_kernel[(B, L)]( + logits, + old_logp, + ref_logp, + completion_ids, + completion_mask, + advantages, + vllm_is_ratio_ptr, + vllm_is_ratio_stride, + loss, + lse, + kl, + is_clipped, + temperature, + beta, + eps_low, + eps_high, + L, + N, + **kwargs, + ) + ctx.save_for_backward( + logits, old_logp, ref_logp, completion_ids, advantages, completion_mask, lse, mask, vllm_is_ratio_ptr + ) + + ctx.infos = ( temperature, beta, eps_low, eps_high, + inplace, + loss_type, + max_completion_length, + B, L, - N, - **kwargs, + importance_sampling_level, + vllm_is_ratio_stride, + reduce, ) - ctx.save_for_backward(logits, old_logp, ref_logp, completion_ids, advantages, completion_mask, lse) - ctx.infos = (temperature, beta, eps_low, eps_high, inplace) - # return loss - return loss, kl, is_clipped + + # Compute metrics before reduction + mask_sum = mask.sum().clamp(min=1.0) + kl_mean = (kl * mask).sum() / mask_sum if kl is not None else None + clip_ratio = (is_clipped.float() * mask).sum() / mask_sum + + if not reduce: + loss_out = loss * mask + kl_out = kl * mask if kl is not None else None + is_clipped_out = is_clipped * mask + return loss_out, kl_out, is_clipped_out + + reduced_loss = _reduce_loss(loss, mask, loss_type, max_completion_length, B, L) + return reduced_loss, kl_mean, clip_ratio @staticmethod def backward(ctx, *args): - dloss = args[0] - # print(dloss.shape) - logits, old_logp, ref_logp, completion_ids, advantages, completion_mask, lse = ctx.saved_tensors - temperature, beta, eps_low, eps_high, inplace = ctx.infos - B, L_ADD_1, N = logits.shape - L = L_ADD_1 - 1 - dlogits = logits.data if inplace else torch.empty_like(logits) - kwargs = {"BLOCK_N": 4096, "num_stages": 1, "num_warps": 16} - _grpo_loss_bwd_kernel[(B, L)]( - dloss, - dlogits, - logits, - old_logp, - ref_logp, - completion_ids, - advantages, - completion_mask, - lse, + dloss_input = args[0] + saved_tensors = ctx.saved_tensors + ( temperature, beta, eps_low, eps_high, - *dloss.stride(), + inplace, + loss_type, + max_completion_length, + B, L, - N, - **kwargs, - ) + importance_sampling_level, + vllm_is_ratio_stride, + reduce, + ) = ctx.infos + + if importance_sampling_level == "sequence": + ( + logits, + old_logp, + ref_logp, + completion_ids, + advantages, + completion_mask, + lse, + mask, + coef_1, + seq_lens, + vllm_is_ratio, + ) = saved_tensors + else: + (logits, old_logp, ref_logp, completion_ids, advantages, completion_mask, lse, mask, vllm_is_ratio) = ( + saved_tensors + ) + + _, L_ADD_1, N = logits.shape + + # Compute per-token gradient scaling based on loss_type + if not reduce: + dloss = dloss_input + elif loss_type == "grpo": + seq_lens_bwd = mask.sum(-1, keepdim=True).clamp(min=1.0) + dloss = dloss_input * mask / (seq_lens_bwd * B) + elif loss_type == "bnpo": + dloss = dloss_input * mask / mask.sum().clamp(min=1.0) + elif loss_type == "dr_grpo": + max_len = max_completion_length if max_completion_length is not None else L + dloss = dloss_input * mask / (B * max_len) + elif loss_type == "dapo": + dloss = dloss_input * mask / _compute_dapo_normalizer(mask) + else: + raise ValueError(f"Unknown loss_type: {loss_type}") + + dlogits = logits.data if inplace else torch.empty_like(logits) + kwargs = {"BLOCK_N": 4096, "num_stages": 1, "num_warps": 16} + + if importance_sampling_level == "sequence": + if vllm_is_ratio is None: + dloss_sum = dloss.sum(-1).contiguous() + else: + if vllm_is_ratio.dim() == 1: + ratio = vllm_is_ratio.unsqueeze(-1) + else: + ratio = vllm_is_ratio + dloss_sum = (dloss * ratio).sum(-1).contiguous() + # Sequence-level backward kernel + _grpo_loss_bwd_kernel_seq[(B, L)]( + dloss, + dloss_sum, + dlogits, + logits, + ref_logp, + completion_ids, + advantages, + completion_mask, + lse, + coef_1, + seq_lens, + temperature, + beta, + eps_low, + eps_high, + *dloss.stride(), + L, + N, + **kwargs, + ) + else: + # Token-level backward kernel + _grpo_loss_bwd_kernel[(B, L)]( + dloss, + dlogits, + logits, + old_logp, + ref_logp, + completion_ids, + advantages, + completion_mask, + lse, + vllm_is_ratio, + vllm_is_ratio_stride, + temperature, + beta, + eps_low, + eps_high, + *dloss.stride(), + L, + N, + **kwargs, + ) + dlogits[:, -1, :] = 0 - return dlogits, None, None, None, None, None, None, None, None, None, None + # Return gradients for all forward inputs: dlogits + 15 None for non-differentiable params + return dlogits, None, None, None, None, None, None, None, None, None, None, None, None, None, None, None diff --git a/src/liger_kernel/transformers/grpo_loss.py b/src/liger_kernel/transformers/grpo_loss.py index 7c37acbb0..a5244b6b6 100644 --- a/src/liger_kernel/transformers/grpo_loss.py +++ b/src/liger_kernel/transformers/grpo_loss.py @@ -20,16 +20,43 @@ def triton_grpo_loss( max_completion_length=None, importance_sampling_level="token", reduce=False, + vllm_is_ratio=None, ): + """ + Triton-optimized GRPO loss function. + + Args: + logits: Model logits (B, L+1, V) + old_logp: Old policy log probabilities (B, L) or None + ref_logp: Reference model log probabilities (B, L) or None (required if beta != 0) + completion_ids: Token IDs for completions (B, L) + advantages: Per-sequence advantages (B,) + completion_mask: Mask for valid tokens (B, L) or None + temperature: Temperature for log softmax + beta: KL penalty coefficient + eps_low: Lower clipping bound for importance ratio + eps_high: Upper clipping bound for importance ratio + inplace: Whether to modify logits in-place during backward + loss_type: Loss reduction type ("grpo", "bnpo", "dr_grpo", "dapo") + max_completion_length: Max completion length for dr_grpo loss type; defaults to sequence length if None + importance_sampling_level: "token" or "sequence" importance sampling + reduce: If True, return reduced loss; if False, return per-token loss + vllm_is_ratio: vLLM importance sampling ratio (B, L) or (B, 1) or None. + Used to correct for distribution mismatch when using vLLM for generation. + Applied to PPO loss BEFORE adding KL penalty. + + Returns: + If reduce=True: (loss, metrics) where metrics = [kl_mean, clip_ratio] or [clip_ratio] + If reduce=False: (per_token_loss, per_token_kl, is_clipped) + """ assert logits is not None and completion_ids is not None and advantages is not None, ( - "must provide logits、completion_ids and advantages" + "must provide logits, completion_ids and advantages" + ) + assert importance_sampling_level in ("token", "sequence"), ( + f"importance_sampling_level must be 'token' or 'sequence', got {importance_sampling_level}" ) - if importance_sampling_level != "token": - raise ValueError( - f"Triton GRPO loss only supports token-level importance sampling. Got {importance_sampling_level}." - ) - per_token_loss, per_token_kl, is_clipped = GrpoLossFunction.apply( + result = GrpoLossFunction.apply( logits, old_logp, ref_logp, @@ -41,22 +68,24 @@ def triton_grpo_loss( eps_low, eps_high, inplace, + loss_type, + max_completion_length, + reduce, + importance_sampling_level, + vllm_is_ratio, ) - if not reduce: - return per_token_loss, per_token_kl, is_clipped - loss = _reduce_grpo_loss( - per_token_loss, - completion_mask, - loss_type=loss_type, - max_completion_length=max_completion_length, - ) + if not reduce: + # Returns (per_token_loss, per_token_kl, is_clipped) - all (B, L) tensors + return result + # reduce=True: Returns (reduced_loss, kl_mean, clip_ratio) - all scalars + reduced_loss, kl_mean, clip_ratio = result metrics = [] - if beta != 0.0 and per_token_kl is not None: - metrics.append(_masked_mean(per_token_kl, completion_mask)) - metrics.append(_masked_mean(is_clipped.float(), completion_mask)) - return loss, metrics + if beta != 0.0 and kl_mean is not None: + metrics.append(kl_mean) + metrics.append(clip_ratio) + return reduced_loss, metrics def _reduce_grpo_loss(per_token_loss, completion_mask, loss_type, max_completion_length): @@ -71,10 +100,9 @@ def _reduce_grpo_loss(per_token_loss, completion_mask, loss_type, max_completion if loss_type == "bnpo": return (per_token_loss * mask).sum() / mask.sum().clamp(min=1.0) if loss_type == "dr_grpo": - if max_completion_length is None: - raise ValueError("max_completion_length must be provided when using loss_type='dr_grpo'") batch = per_token_loss.shape[0] - return (per_token_loss * mask).sum() / (batch * max_completion_length) + max_len = max_completion_length if max_completion_length is not None else per_token_loss.shape[1] + return (per_token_loss * mask).sum() / (batch * max_len) if loss_type == "dapo": normalizer = LigerFusedLinearPPOBase._compute_dapo_normalizer(mask) return (per_token_loss * mask).sum() / normalizer @@ -88,11 +116,12 @@ def _masked_mean(values, mask): return (values * mask).sum() / mask.sum().clamp(min=1.0) -# This is a demo how to use grpo_loss in GRPOTrainer. The Trl version must be 0.16 +# This is a demo how to use grpo_loss in GRPOTrainer. The Trl version must be 0.26.2+ """ import torch import trl -assert trl.__version__.startswith("0.16"), "please pip install trl==0.16" +from packaging.version import Version +assert Version(trl.__version__) >= Version("0.26.2"), "please pip install trl>=0.26.2" from trl.extras.profiling import profiling_decorator @profiling_decorator @@ -117,18 +146,24 @@ def compute_loss(self, model, inputs, return_outputs=False, num_items_in_batch=N ref_per_token_logps = inputs["ref_per_token_logps"] advantages = inputs["advantages"] old_per_token_logps = inputs["old_per_token_logps"] - - - per_token_loss, per_token_kl, is_clipped = triton_grpo_loss(logits, - old_per_token_logps, - ref_per_token_logps, - completion_ids, - advantages, - completion_mask, - self.temperature, - self.beta, - self.epsilon_low, - self.epsilon_high,) + + # Get vLLM importance sampling ratio if using vLLM with importance sampling correction + vllm_is_ratio = inputs.get("importance_sampling_ratio", None) + + per_token_loss, per_token_kl, is_clipped = triton_grpo_loss( + logits, + old_per_token_logps, + ref_per_token_logps, + completion_ids, + advantages, + completion_mask, + temperature=self.temperature, + beta=self.beta, + eps_low=self.epsilon_low, + eps_high=self.epsilon_high, + importance_sampling_level=self.importance_sampling_level, # "token" or "sequence" + vllm_is_ratio=vllm_is_ratio, # vLLM distribution correction + ) loss = (per_token_loss * completion_mask).sum() / completion_mask.sum() # Log the metrics diff --git a/test/chunked_loss/test_grpo_loss.py b/test/chunked_loss/test_grpo_loss.py index 8259a22db..f272d1cb8 100644 --- a/test/chunked_loss/test_grpo_loss.py +++ b/test/chunked_loss/test_grpo_loss.py @@ -4,6 +4,7 @@ from liger_kernel.chunked_loss import LigerFusedLinearGRPOLoss from liger_kernel.chunked_loss.functional import liger_fused_linear_grpo +from liger_kernel.chunked_loss.fused_linear_ppo import LigerFusedLinearPPOBase from liger_kernel.chunked_loss.grpo_loss import LigerFusedLinearGRPOFunction from liger_kernel.transformers.grpo_loss import _reduce_grpo_loss from liger_kernel.transformers.grpo_loss import triton_grpo_loss @@ -158,7 +159,7 @@ def forward( elif self.loss_type == "dr_grpo": loss = (per_token_loss * attention_mask).sum() / (per_token_loss.size(0) * self.max_completion_length) elif self.loss_type == "dapo": - normalizer = attention_mask.sum().clamp(min=1.0) + normalizer = LigerFusedLinearPPOBase._compute_dapo_normalizer(attention_mask) loss = (per_token_loss * attention_mask).sum() / normalizer else: raise ValueError(f"Unknown loss type: {self.loss_type}") @@ -574,8 +575,9 @@ def test_reduce_grpo_loss_matches_reference(loss_type): def test_reduce_grpo_loss_requires_max_completion_length(): per_token_loss = torch.randn(2, 3) mask = torch.ones_like(per_token_loss, dtype=torch.long) - with pytest.raises(ValueError): - _reduce_grpo_loss(per_token_loss, mask, "dr_grpo", max_completion_length=None) + reduced = _reduce_grpo_loss(per_token_loss, mask, "dr_grpo", max_completion_length=None) + expected = (per_token_loss * mask).sum() / (per_token_loss.size(0) * per_token_loss.size(1)) + assert_verbose_allclose(reduced, expected) @pytest.mark.parametrize("loss_type,beta", [("bnpo", 0.0), ("dapo", 0.04)]) diff --git a/test/transformers/test_grpo_loss.py b/test/transformers/test_grpo_loss.py index ae9d265ab..c0ca18bf4 100644 --- a/test/transformers/test_grpo_loss.py +++ b/test/transformers/test_grpo_loss.py @@ -188,3 +188,379 @@ def test_grpo_loss(B, T, V, temperature, num_iteration, beta, eps_low, eps_high, compare(loss2, loss3, "per_token_loss: triton-bf16 vs torch-fp32, ") compare(kl2, kl3, "per_token_kl: triton-bf16 vs torch-fp32, ") compare(logits2.grad, logits3.grad, "logits.grad: triton-bf16 vs torch-fp32, ") + + +def trl_reference_grpo_loss( + logits, + old_logp, + ref_logp, + completion_ids, + advantages, + completion_mask, + temperature, + beta, + eps_low, + eps_high, + loss_type, + importance_sampling_level, +): + """TRL reference implementation from grpo_trainer.py""" + B, L_ADD_1, V = logits.shape + L = L_ADD_1 - 1 + + logits_scaled = logits[:, :-1, :] / temperature + log_probs = torch.log_softmax(logits_scaled.float(), dim=-1) + per_token_logps = log_probs.gather(-1, completion_ids.unsqueeze(-1)).squeeze(-1) + + if old_logp is None: + old_logp = per_token_logps.detach() + + log_ratio = per_token_logps - old_logp + + if importance_sampling_level == "token": + log_importance_weights = log_ratio + else: # sequence + log_importance_weights = (log_ratio * completion_mask).sum(-1) / completion_mask.sum(-1).clamp(min=1.0) + log_importance_weights = log_importance_weights.unsqueeze(-1) + + coef_1 = torch.exp(log_importance_weights) + coef_2 = torch.clamp(coef_1, 1 - eps_low, 1 + eps_high) + + per_token_loss1 = coef_1 * advantages.unsqueeze(-1) + per_token_loss2 = coef_2 * advantages.unsqueeze(-1) + per_token_loss = -torch.min(per_token_loss1, per_token_loss2) + + if importance_sampling_level == "sequence": + per_token_loss = per_token_loss.expand(B, L) + + if beta != 0.0: + kl = torch.exp(ref_logp - per_token_logps) - (ref_logp - per_token_logps) - 1.0 + per_token_loss = per_token_loss + beta * kl + + # Loss reduction + if loss_type == "grpo": + loss = ((per_token_loss * completion_mask).sum(-1) / completion_mask.sum(-1).clamp(min=1.0)).mean() + elif loss_type == "bnpo": + loss = (per_token_loss * completion_mask).sum() / completion_mask.sum().clamp(min=1.0) + elif loss_type == "dr_grpo": + loss = (per_token_loss * completion_mask).sum() / (B * L) + elif loss_type == "dapo": + loss = (per_token_loss * completion_mask).sum() / completion_mask.sum().clamp(min=1.0) + + return loss + + +@pytest.mark.parametrize("importance_sampling_level", ["token", "sequence"]) +@pytest.mark.parametrize("loss_type", ["grpo", "bnpo", "dr_grpo", "dapo"]) +@pytest.mark.parametrize("beta", [0.0, 0.04]) +@pytest.mark.parametrize( + "B, T, V", + [ + (2, 128, 1000), + ], +) +def test_grpo_loss_vs_trl(B, T, V, beta, loss_type, importance_sampling_level): + """Test that triton_grpo_loss matches TRL's exact implementation.""" + torch.manual_seed(42) + + logits = torch.randn(B, T + 1, V, device=device, dtype=torch.float32) + completion_ids = torch.randint(0, V, (B, T), device=device) + completion_mask = torch.ones(B, T, device=device, dtype=torch.float32) + advantages = torch.randn(B, device=device, dtype=torch.float32) + + # Compute realistic old_logp and ref_logp + with torch.no_grad(): + log_probs = torch.log_softmax(logits[:, :-1, :] / 0.9, dim=-1) + current_logp = log_probs.gather(-1, completion_ids.unsqueeze(-1)).squeeze(-1) + old_logp = current_logp + torch.randn_like(current_logp) * 0.3 + ref_logp = current_logp + torch.randn_like(current_logp) * 0.2 if beta != 0.0 else None + + temperature = 0.9 + eps_low, eps_high = 0.2, 0.4 + + # TRL reference + trl_loss = trl_reference_grpo_loss( + logits.clone(), + old_logp, + ref_logp, + completion_ids, + advantages, + completion_mask, + temperature, + beta, + eps_low, + eps_high, + loss_type, + importance_sampling_level, + ) + + # Triton implementation + logits_triton = logits.clone().requires_grad_(True) + triton_loss, _ = triton_grpo_loss( + logits_triton, + old_logp, + ref_logp, + completion_ids, + advantages, + completion_mask, + temperature=temperature, + beta=beta, + eps_low=eps_low, + eps_high=eps_high, + importance_sampling_level=importance_sampling_level, + loss_type=loss_type, + max_completion_length=T, + reduce=True, + ) + + # Verify forward match + torch.testing.assert_close(triton_loss, trl_loss, rtol=1e-4, atol=1e-4) + + # Verify backward works + triton_loss.backward() + assert logits_triton.grad is not None + assert not torch.isnan(logits_triton.grad).any() + + +def trl_reference_grpo_loss_with_vllm_is( + logits, + old_logp, + ref_logp, + completion_ids, + advantages, + completion_mask, + temperature, + beta, + eps_low, + eps_high, + loss_type, + importance_sampling_level, + vllm_is_ratio, +): + """TRL reference implementation with vLLM IS ratio correction.""" + B, L_ADD_1, V = logits.shape + L = L_ADD_1 - 1 + + logits_scaled = logits[:, :-1, :] / temperature + log_probs = torch.log_softmax(logits_scaled.float(), dim=-1) + per_token_logps = log_probs.gather(-1, completion_ids.unsqueeze(-1)).squeeze(-1) + + if old_logp is None: + old_logp = per_token_logps.detach() + + log_ratio = per_token_logps - old_logp + + if importance_sampling_level == "token": + log_importance_weights = log_ratio + else: # sequence + log_importance_weights = (log_ratio * completion_mask).sum(-1) / completion_mask.sum(-1).clamp(min=1.0) + log_importance_weights = log_importance_weights.unsqueeze(-1) + + coef_1 = torch.exp(log_importance_weights) + coef_2 = torch.clamp(coef_1, 1 - eps_low, 1 + eps_high) + + per_token_loss1 = coef_1 * advantages.unsqueeze(-1) + per_token_loss2 = coef_2 * advantages.unsqueeze(-1) + per_token_loss = -torch.min(per_token_loss1, per_token_loss2) + + if importance_sampling_level == "sequence": + per_token_loss = per_token_loss.expand(B, L) + + # Apply vLLM IS ratio BEFORE KL penalty (matches TRL) + if vllm_is_ratio is not None: + per_token_loss = per_token_loss * vllm_is_ratio + + if beta != 0.0: + kl = torch.exp(ref_logp - per_token_logps) - (ref_logp - per_token_logps) - 1.0 + per_token_loss = per_token_loss + beta * kl + + # Loss reduction + if loss_type == "grpo": + loss = ((per_token_loss * completion_mask).sum(-1) / completion_mask.sum(-1).clamp(min=1.0)).mean() + elif loss_type == "bnpo": + loss = (per_token_loss * completion_mask).sum() / completion_mask.sum().clamp(min=1.0) + elif loss_type == "dr_grpo": + loss = (per_token_loss * completion_mask).sum() / (B * L) + elif loss_type == "dapo": + loss = (per_token_loss * completion_mask).sum() / completion_mask.sum().clamp(min=1.0) + + return loss + + +@pytest.mark.parametrize("importance_sampling_level", ["token", "sequence"]) +@pytest.mark.parametrize("loss_type", ["grpo", "dapo"]) +@pytest.mark.parametrize("beta", [0.0, 0.04]) +@pytest.mark.parametrize( + "B, T, V", + [ + (2, 128, 1000), + ], +) +def test_grpo_loss_with_vllm_is_ratio(B, T, V, beta, loss_type, importance_sampling_level): + """Test that triton_grpo_loss with vllm_is_ratio matches TRL's behavior.""" + torch.manual_seed(42) + + logits = torch.randn(B, T + 1, V, device=device, dtype=torch.float32) + completion_ids = torch.randint(0, V, (B, T), device=device) + completion_mask = torch.ones(B, T, device=device, dtype=torch.float32) + advantages = torch.randn(B, device=device, dtype=torch.float32) + + # Compute realistic old_logp and ref_logp + with torch.no_grad(): + log_probs = torch.log_softmax(logits[:, :-1, :] / 0.9, dim=-1) + current_logp = log_probs.gather(-1, completion_ids.unsqueeze(-1)).squeeze(-1) + old_logp = current_logp + torch.randn_like(current_logp) * 0.3 + ref_logp = current_logp + torch.randn_like(current_logp) * 0.2 if beta != 0.0 else None + + # Create vLLM IS ratio (random values between 0.5 and 1.5) + vllm_is_ratio = torch.rand(B, T, device=device, dtype=torch.float32) + 0.5 + + temperature = 0.9 + eps_low, eps_high = 0.2, 0.4 + + # TRL reference with vLLM IS ratio + trl_loss = trl_reference_grpo_loss_with_vllm_is( + logits.clone(), + old_logp, + ref_logp, + completion_ids, + advantages, + completion_mask, + temperature, + beta, + eps_low, + eps_high, + loss_type, + importance_sampling_level, + vllm_is_ratio, + ) + + # Triton implementation with vLLM IS ratio + logits_triton = logits.clone().requires_grad_(True) + triton_loss, _ = triton_grpo_loss( + logits_triton, + old_logp, + ref_logp, + completion_ids, + advantages, + completion_mask, + temperature=temperature, + beta=beta, + eps_low=eps_low, + eps_high=eps_high, + importance_sampling_level=importance_sampling_level, + loss_type=loss_type, + max_completion_length=T, + reduce=True, + vllm_is_ratio=vllm_is_ratio, + ) + + # Verify forward match + torch.testing.assert_close(triton_loss, trl_loss, rtol=1e-4, atol=1e-4) + + # Verify backward works + triton_loss.backward() + assert logits_triton.grad is not None + assert not torch.isnan(logits_triton.grad).any() + + # Also verify that vllm_is_ratio=None gives same result as vllm_is_ratio=1 + logits_no_ratio = logits.clone().requires_grad_(True) + loss_no_ratio, _ = triton_grpo_loss( + logits_no_ratio, + old_logp, + ref_logp, + completion_ids, + advantages, + completion_mask, + temperature=temperature, + beta=beta, + eps_low=eps_low, + eps_high=eps_high, + importance_sampling_level=importance_sampling_level, + loss_type=loss_type, + max_completion_length=T, + reduce=True, + vllm_is_ratio=None, + ) + + logits_ones_ratio = logits.clone().requires_grad_(True) + loss_ones_ratio, _ = triton_grpo_loss( + logits_ones_ratio, + old_logp, + ref_logp, + completion_ids, + advantages, + completion_mask, + temperature=temperature, + beta=beta, + eps_low=eps_low, + eps_high=eps_high, + importance_sampling_level=importance_sampling_level, + loss_type=loss_type, + max_completion_length=T, + reduce=True, + vllm_is_ratio=torch.ones(B, T, device=device), + ) + + torch.testing.assert_close(loss_no_ratio, loss_ones_ratio, rtol=1e-5, atol=1e-5) + + +@pytest.mark.parametrize("beta", [0.0, 0.04]) +def test_grpo_loss_sequence_backward_matches_reference(beta): + """Sequence-level importance sampling should match reference gradients.""" + pytest.importorskip("triton") + torch.manual_seed(0) + + B, T, V = 2, 8, 32 + logits = torch.randn(B, T + 1, V, device=device, dtype=torch.float32) + completion_ids = torch.randint(0, V, (B, T), device=device) + completion_mask = torch.ones(B, T, device=device, dtype=torch.float32) + advantages = torch.randn(B, device=device, dtype=torch.float32) + + with torch.no_grad(): + log_probs = torch.log_softmax(logits[:, :-1, :] / 1.1, dim=-1) + current_logp = log_probs.gather(-1, completion_ids.unsqueeze(-1)).squeeze(-1) + old_logp = current_logp + torch.randn_like(current_logp) * 0.2 + ref_logp = current_logp + torch.randn_like(current_logp) * 0.1 if beta != 0.0 else None + + temperature = 1.1 + eps_low, eps_high = 0.2, 0.4 + + logits_triton = logits.clone().requires_grad_(True) + triton_loss, _ = triton_grpo_loss( + logits_triton, + old_logp, + ref_logp, + completion_ids, + advantages, + completion_mask, + temperature=temperature, + beta=beta, + eps_low=eps_low, + eps_high=eps_high, + importance_sampling_level="sequence", + loss_type="grpo", + max_completion_length=T, + reduce=True, + ) + triton_loss.backward() + + logits_ref = logits.clone().requires_grad_(True) + reference_loss = trl_reference_grpo_loss( + logits_ref, + old_logp, + ref_logp, + completion_ids, + advantages, + completion_mask, + temperature, + beta, + eps_low, + eps_high, + loss_type="grpo", + importance_sampling_level="sequence", + ) + reference_loss.backward() + + torch.testing.assert_close(triton_loss, reference_loss, rtol=1e-5, atol=1e-5) + torch.testing.assert_close(logits_triton.grad, logits_ref.grad, rtol=1e-4, atol=1e-4)