diff --git a/src/liger_kernel/chunked_loss/fused_linear_ppo.py b/src/liger_kernel/chunked_loss/fused_linear_ppo.py index 53f706b3c..925871d8d 100644 --- a/src/liger_kernel/chunked_loss/fused_linear_ppo.py +++ b/src/liger_kernel/chunked_loss/fused_linear_ppo.py @@ -41,6 +41,7 @@ def forward( chunk_size=1, sapo_temperature_pos=1.0, sapo_temperature_neg=1.05, + vllm_is_ratio=None, ): # TODO: check torch compile matmul """Chunked forward pass for PPO loss computation. @@ -71,6 +72,8 @@ def forward( chunk_size: Size of chunks for processing in other loss modules sapo_temperature_pos: Temperature for positive advantages in SAPO sapo_temperature_neg: Temperature for negative advantages in SAPO + vllm_is_ratio: vLLM importance sampling ratio tensor (batch_size, seq_len) or (batch_size, 1) or None. + Used to correct for distribution mismatch when using vLLM for generation. """ if use_ref_model: assert ref_per_token_logps is not None or ref_input is not None, ( @@ -80,6 +83,20 @@ def forward( raise Warning("Both ref_per_token_logps and ref_input are provided. Using ref_per_token_logps.") if loss_type == "dr_grpo": assert max_completion_length is not None, "max_completion_length must be provided for loss_type 'dr_grpo'" + if vllm_is_ratio is not None: + B, T = attention_mask.shape + assert vllm_is_ratio.dim() in (1, 2), ( + f"vllm_is_ratio must be 1D (B,) or 2D (B, T) / (B, 1), got {vllm_is_ratio.dim()}D" + ) + if vllm_is_ratio.dim() == 2: + assert vllm_is_ratio.shape[0] == B and vllm_is_ratio.shape[1] in (1, T), ( + f"vllm_is_ratio shape must be ({B}, 1) or ({B}, {T}), got {tuple(vllm_is_ratio.shape)}" + ) + else: + assert vllm_is_ratio.shape[0] == B, ( + f"vllm_is_ratio shape must be ({B},), got {tuple(vllm_is_ratio.shape)}" + ) + vllm_is_ratio = vllm_is_ratio.unsqueeze(-1) # (B,) -> (B, 1) for broadcasting # Initialize accumulators loss_acc = torch.zeros((), device=_input.device, dtype=torch.float32) grad_weight = torch.zeros_like(weight) # [V, H] @@ -114,6 +131,7 @@ def fused_fwd_bwd( ref_per_token_logps_chunk, old_per_token_logps_chunk, ref_input_chunk, + vllm_is_ratio_chunk, ): """Fused forward and backward for a chunk.""" argnums = (0, 1, 5) if bias is not None else (0, 1) @@ -127,6 +145,7 @@ def fused_fwd_bwd( ref_per_token_logps_chunk=ref_per_token_logps_chunk, # arg 6 old_per_token_logps_chunk=old_per_token_logps_chunk, # arg 7 ref_input_chunk=ref_input_chunk, # arg 8 + vllm_is_ratio_chunk=vllm_is_ratio_chunk, # arg 9 ) def accumulate_chunk( @@ -137,6 +156,7 @@ def accumulate_chunk( ref_per_token_logps_chunk=None, old_per_token_logps_chunk=None, ref_input_chunk=None, + vllm_is_ratio_chunk=None, ): (chunk_grad_input, chunk_grad_weight, *chunk_grad_bias), (chunk_loss, chunk_metrics) = fused_fwd_bwd( input_chunk, @@ -146,6 +166,7 @@ def accumulate_chunk( ref_per_token_logps_chunk, old_per_token_logps_chunk, ref_input_chunk, + vllm_is_ratio_chunk, ) if bias is not None: grad_bias.add_(chunk_grad_bias[0]) @@ -196,6 +217,9 @@ def accumulate_chunk( if use_ref_model and ref_per_token_logps is None else [None] * chunks ) + _vllm_is_ratio_chunks = ( + torch.chunk(vllm_is_ratio, chunks=chunks, dim=0) if vllm_is_ratio is not None else [None] * chunks + ) for ( input_chunk, @@ -205,6 +229,7 @@ def accumulate_chunk( ref_per_token_logps_chunk, old_per_token_logps_chunk, ref_input_chunk, + vllm_is_ratio_chunk, ) in zip( _input_chunks, _selected_token_ids_chunks, @@ -213,6 +238,7 @@ def accumulate_chunk( _ref_per_token_logps_chunks, _old_per_token_logps_chunks, _ref_input_chunks, + _vllm_is_ratio_chunks, ): # Mark dynamic dimensions torch._dynamo.mark_dynamic(input_chunk, 1) @@ -224,6 +250,8 @@ def accumulate_chunk( torch._dynamo.mark_dynamic(ref_input_chunk, 1) if old_per_token_logps_chunk is not None: torch._dynamo.mark_dynamic(old_per_token_logps_chunk, 1) + if vllm_is_ratio_chunk is not None: + torch._dynamo.mark_dynamic(vllm_is_ratio_chunk, 1) accumulate_chunk( input_chunk, @@ -233,6 +261,7 @@ def accumulate_chunk( ref_per_token_logps_chunk, old_per_token_logps_chunk, ref_input_chunk, + vllm_is_ratio_chunk, ) # Combine gradients @@ -277,6 +306,7 @@ def _compute_chunk_loss( ref_per_token_logps_chunk=None, old_per_token_logps_chunk=None, ref_input_chunk=None, + vllm_is_ratio_chunk=None, ref_weight=None, ref_bias=None, full_attention_mask=None, @@ -322,6 +352,7 @@ def _compute_chunk_loss( importance_sampling_level=importance_sampling_level, sapo_temperature_pos=sapo_temperature_pos, sapo_temperature_neg=sapo_temperature_neg, + vllm_is_ratio=vllm_is_ratio_chunk, ) return chunk_loss, chunk_metrics @@ -376,4 +407,5 @@ def backward(ctx, grad_output, *grad_metrics): None, # grad_chunk_size None, # grad_sapo_temperature_pos None, # grad_sapo_temperature_neg + None, # grad_vllm_is_ratio ) diff --git a/src/liger_kernel/chunked_loss/grpo_loss.py b/src/liger_kernel/chunked_loss/grpo_loss.py index ae3a57a30..f6cd3be81 100644 --- a/src/liger_kernel/chunked_loss/grpo_loss.py +++ b/src/liger_kernel/chunked_loss/grpo_loss.py @@ -75,6 +75,7 @@ def ppo_loss_fn( importance_sampling_level="token", # ["token", "sequence"] - new parameter for GSPO sapo_temperature_pos=1.0, # Temperature for positive advantages in SAPO sapo_temperature_neg=1.05, # Temperature for negative advantages in SAPO + vllm_is_ratio=None, # vLLM importance sampling ratio (chunk_size, seq_len) or (chunk_size, 1) or None **kwargs, ): """GRPO Loss Function matching GRPOTrainer implementation.""" @@ -138,6 +139,10 @@ def ppo_loss_fn( per_token_loss2 = coef_2 * advantages.unsqueeze(1) per_token_loss = -torch.min(per_token_loss1, per_token_loss2) + # Apply vLLM importance sampling correction BEFORE adding KL penalty + if vllm_is_ratio is not None: + per_token_loss = per_token_loss * vllm_is_ratio + if beta != 0.0: # Compute KL penalty (approximates KL[per_token_logps, ref_per_token_logps]) kl_div = k3_loss_fn(ref_per_token_logps, per_token_logps) @@ -214,6 +219,7 @@ def forward( compiled=True, use_ref_model=True, chunk_size=1, + vllm_is_ratio=None, ): """ Fused linear layer with GRPO loss. @@ -239,6 +245,8 @@ def forward( compiled (bool): Whether to use torch compile use_ref_model (bool): Whether to use a reference model chunk_size (int): Size of chunks for processing. + vllm_is_ratio (torch.Tensor, optional): vLLM importance sampling ratio (batch_size, seq_len) or (batch_size, 1) or None. + Used to correct for distribution mismatch when using vLLM for generation. Returns: torch.Tensor: Computed loss """ @@ -268,6 +276,7 @@ def forward( importance_sampling_level=importance_sampling_level, sapo_temperature_pos=sapo_temperature_pos, sapo_temperature_neg=sapo_temperature_neg, + vllm_is_ratio=vllm_is_ratio, ) @staticmethod @@ -300,6 +309,7 @@ def backward(ctx, grad_output, *grad_metrics): None, # grad_compiled None, # grad_use_ref_model None, # grad_chunk_size + None, # grad_vllm_is_ratio ) @@ -370,6 +380,7 @@ def forward( ref_input=None, ref_weight=None, ref_bias=None, + vllm_is_ratio=None, ): return LigerFusedLinearGRPOFunction.apply( _input, @@ -395,4 +406,5 @@ def forward( self.compiled, self.use_ref_model, self.chunk_size, + vllm_is_ratio, ) diff --git a/src/liger_kernel/ops/grpo_loss.py b/src/liger_kernel/ops/grpo_loss.py index 3f8460d8c..595f4d2cc 100644 --- a/src/liger_kernel/ops/grpo_loss.py +++ b/src/liger_kernel/ops/grpo_loss.py @@ -90,6 +90,8 @@ def _grpo_loss_fwd_kernel( INPUT_IDS, COMPLETION_MASK, ADVANTAGES, + VLLM_IS_RATIO, + VLLM_IS_RATIO_STRIDE, LOSS, LSE, KL, @@ -169,6 +171,14 @@ def _grpo_loss_fwd_kernel( per_token_loss = -sapo_coef * advantage is_clipped = 0.0 # SAPO has no clipping concept + # Apply vLLM importance sampling correction BEFORE adding KL penalty + if VLLM_IS_RATIO is not None: + # Use modulo to support both (B, L) per-token and (B, 1) per-sequence shapes + vllm_is_ratio = tl.load(VLLM_IS_RATIO + off_b * VLLM_IS_RATIO_STRIDE + off_l % VLLM_IS_RATIO_STRIDE).to( + tl.float32 + ) + per_token_loss = per_token_loss * vllm_is_ratio + if BETA != 0.0: REF_LOGP += off_b * L + off_l KL += off_b * L + off_l @@ -198,6 +208,8 @@ def _grpo_loss_bwd_kernel( ADVANTAGES, COMPLETION_MASK, LSE, + VLLM_IS_RATIO, + VLLM_IS_RATIO_STRIDE, TEMPERATURE, BETA: tl.constexpr, EPS_LOW, @@ -271,6 +283,14 @@ def _grpo_loss_bwd_kernel( d_sapo_d_coef1 = 4.0 * sigmoid_val * (1.0 - sigmoid_val) dlogp = -advantage * d_sapo_d_coef1 * coef_1 + # Apply vLLM IS ratio to PPO gradient (before KL gradient) + if VLLM_IS_RATIO is not None: + # Use modulo to support both (B, L) per-token and (B, 1) per-sequence shapes + vllm_is_ratio = tl.load(VLLM_IS_RATIO + off_b * VLLM_IS_RATIO_STRIDE + off_l % VLLM_IS_RATIO_STRIDE).to( + tl.float32 + ) + dlogp = dlogp * vllm_is_ratio + if BETA != 0.0: REF_LOGP += off_b * L + off_l ref_logp = tl.load(REF_LOGP).to(tl.float32) @@ -304,6 +324,7 @@ def forward( loss_type="grpo", sapo_temperature_pos=1.0, sapo_temperature_neg=1.05, + vllm_is_ratio=None, ): assert logits.is_contiguous() and completion_ids.is_contiguous() assert old_logp is None or old_logp.is_contiguous() @@ -329,6 +350,25 @@ def forward( if completion_mask is not None: assert completion_mask.is_contiguous() + # Handle vLLM IS ratio + vllm_is_ratio_ptr = None + vllm_is_ratio_stride = L # default to per-token (unused when ptr is None) + if vllm_is_ratio is not None: + assert vllm_is_ratio.dim() in (1, 2), ( + f"vllm_is_ratio must be 1D (B,) or 2D (B, L) / (B, 1), got {vllm_is_ratio.dim()}D" + ) + if vllm_is_ratio.dim() == 2: + assert vllm_is_ratio.shape[0] == B and vllm_is_ratio.shape[1] in (1, L), ( + f"vllm_is_ratio shape must be ({B}, 1) or ({B}, {L}), got {tuple(vllm_is_ratio.shape)}" + ) + else: + assert vllm_is_ratio.shape[0] == B, ( + f"vllm_is_ratio shape must be ({B},), got {tuple(vllm_is_ratio.shape)}" + ) + vllm_is_ratio = vllm_is_ratio.contiguous() + vllm_is_ratio_ptr = vllm_is_ratio + vllm_is_ratio_stride = vllm_is_ratio.shape[1] if vllm_is_ratio.dim() > 1 else 1 + loss = torch.zeros(B, L, device=logits.device, dtype=torch.float32) lse = torch.zeros_like(loss) is_clipped = torch.zeros_like(loss) @@ -341,6 +381,8 @@ def forward( completion_ids, completion_mask, advantages, + vllm_is_ratio_ptr, + vllm_is_ratio_stride, loss, lse, kl, @@ -357,6 +399,8 @@ def forward( **kwargs, ) ctx.save_for_backward(logits, old_logp, ref_logp, completion_ids, advantages, completion_mask, lse) + ctx.vllm_is_ratio = vllm_is_ratio_ptr + ctx.vllm_is_ratio_stride = vllm_is_ratio_stride ctx.infos = ( temperature, beta, @@ -376,6 +420,8 @@ def backward(ctx, *args): temperature, beta, eps_low, eps_high, inplace, loss_type_int, sapo_temperature_pos, sapo_temperature_neg = ( ctx.infos ) + vllm_is_ratio = ctx.vllm_is_ratio + vllm_is_ratio_stride = ctx.vllm_is_ratio_stride B, L_ADD_1, N = logits.shape L = L_ADD_1 - 1 dlogits = logits.data if inplace else torch.empty_like(logits) @@ -390,6 +436,8 @@ def backward(ctx, *args): advantages, completion_mask, lse, + vllm_is_ratio, + vllm_is_ratio_stride, temperature, beta, eps_low, @@ -404,5 +452,6 @@ def backward(ctx, *args): ) dlogits[:, -1, :] = 0 # Return None for: old_logp, ref_logp, completion_ids, advantages, completion_mask, - # temperature, beta, eps_low, eps_high, inplace, loss_type, sapo_temperature_pos, sapo_temperature_neg - return dlogits, None, None, None, None, None, None, None, None, None, None, None, None, None + # temperature, beta, eps_low, eps_high, inplace, loss_type, sapo_temperature_pos, sapo_temperature_neg, + # vllm_is_ratio + return dlogits, None, None, None, None, None, None, None, None, None, None, None, None, None, None diff --git a/src/liger_kernel/transformers/grpo_loss.py b/src/liger_kernel/transformers/grpo_loss.py index 7dfcebf6f..350176c9d 100644 --- a/src/liger_kernel/transformers/grpo_loss.py +++ b/src/liger_kernel/transformers/grpo_loss.py @@ -22,6 +22,7 @@ def triton_grpo_loss( reduce=False, sapo_temperature_pos=1.0, sapo_temperature_neg=1.05, + vllm_is_ratio=None, ): assert logits is not None and completion_ids is not None and advantages is not None, ( "must provide logits, completion_ids and advantages" @@ -46,6 +47,7 @@ def triton_grpo_loss( loss_type, sapo_temperature_pos, sapo_temperature_neg, + vllm_is_ratio, ) if not reduce: return per_token_loss, per_token_kl, is_clipped diff --git a/test/chunked_loss/test_grpo_loss.py b/test/chunked_loss/test_grpo_loss.py index 00240c90f..33bedbfd7 100644 --- a/test/chunked_loss/test_grpo_loss.py +++ b/test/chunked_loss/test_grpo_loss.py @@ -78,6 +78,7 @@ def compute_per_token_components( loss_type: str = "grpo", sapo_temperature_pos: float = 1.0, sapo_temperature_neg: float = 1.05, + vllm_is_ratio=None, ): attention_mask = attention_mask.to(per_token_logps.dtype) old_per_token_logps = ( @@ -136,6 +137,11 @@ def compute_per_token_components( per_token_loss1 = coef_1 * expanded_advantages per_token_loss2 = coef_2 * expanded_advantages per_token_loss = -torch.min(per_token_loss1, per_token_loss2) + + # Apply vLLM importance sampling correction BEFORE KL penalty + if vllm_is_ratio is not None: + per_token_loss = per_token_loss * vllm_is_ratio + kl_div = None if beta != 0.0: ref_per_token_logps = ref_per_token_logps.float() @@ -160,6 +166,7 @@ def forward( ref_per_token_logps=None, # Shape: [batch_size, seq_len] old_per_token_logps=None, ref_input=None, # Shape: [batch_size, seq_len, hidden_size] + vllm_is_ratio=None, # Shape: [batch_size, seq_len] or None ): logits = x @ self.lin.weight.t() if self.lin.bias is not None: @@ -201,6 +208,7 @@ def forward( self.loss_type, self.sapo_temperature_pos, self.sapo_temperature_neg, + vllm_is_ratio=vllm_is_ratio, ) # Apply masking and calculate loss based on loss_type @@ -272,8 +280,8 @@ def forward( ref_per_token_logps=None, old_per_token_logps=None, ref_input=None, + vllm_is_ratio=None, ): - # Pass only the arguments defined in LigerFusedLinearGRPOFunction.forward() return self.grpo_loss( x, # _input self.lin.weight, # weight @@ -286,6 +294,7 @@ def forward( ref_input, # ref_input self.ref_lin.weight, # ref_weight self.ref_lin.bias, # ref_bias + vllm_is_ratio=vllm_is_ratio, ) @@ -485,6 +494,74 @@ def test_correctness( ) +@pytest.mark.parametrize("loss_type", ["bnpo", "grpo", "dapo", "cispo", "sapo"]) +@pytest.mark.parametrize("beta", [0.0, 0.1]) +def test_correctness_with_vllm_is_ratio(loss_type, beta): + """Test vllm_is_ratio correctness against torch reference, and 1D/2D shape equivalence.""" + torch.compiler.reset() + B, T, H, V = 4, 32, 64, 128 + dtype = torch.float32 + atol, rtol = 1e-5, 5e-4 + + _weight = torch.randn(V, H, device=device, dtype=dtype) + _input = torch.randn(B, T, H, device=device, dtype=dtype) + input1 = _input.detach().clone().requires_grad_(True) + input2 = _input.detach().clone().requires_grad_(True) + + selected_token_ids = torch.randint(0, V, (B, T), device=device) + attention_mask = torch.ones(B, T, device=device) + attention_mask[:, -5:] = 0 + advantages = torch.randn(B, device=device, dtype=dtype) + advantages[0] = -advantages[0].abs() # ensure mixed signs for SAPO + + vllm_is_ratio = torch.rand(B, T, device=device, dtype=torch.float32) * 0.999 + 0.001 + + torch_lm = TorchLMHeadGRPO(H=H, V=V, dtype=dtype, beta=beta, loss_type=loss_type, use_ref_model=False) + liger_lm = LigerLMHeadGRPO(H=H, V=V, dtype=dtype, beta=beta, loss_type=loss_type, use_ref_model=False) + torch_lm.lin.weight.data = liger_lm.lin.weight.data = _weight.clone() + + loss1, aux1 = torch_lm(input1, selected_token_ids, attention_mask, advantages, vllm_is_ratio=vllm_is_ratio) + loss2, aux2 = liger_lm(input2, selected_token_ids, attention_mask, advantages, vllm_is_ratio=vllm_is_ratio) + + assert not torch.isnan(loss1) + assert not torch.isnan(loss2) + assert_verbose_allclose(loss1, loss2, atol=atol, rtol=rtol) + for m1, m2 in zip(aux1, aux2): + assert_verbose_allclose(m1, m2, atol=atol, rtol=rtol) + + loss1.backward() + loss2.backward() + assert_verbose_allclose(input1.grad, input2.grad, atol=atol, rtol=rtol) + assert_verbose_allclose(torch_lm.lin.weight.grad, liger_lm.lin.weight.grad, atol=atol, rtol=rtol) + + # Verify 1D (B,) gives same result as (B, 1) + uniform_val = 0.42 + input3 = _input.detach().clone().requires_grad_(True) + input4 = _input.detach().clone().requires_grad_(True) + liger3 = LigerLMHeadGRPO(H=H, V=V, dtype=dtype, beta=beta, loss_type=loss_type, use_ref_model=False) + liger4 = LigerLMHeadGRPO(H=H, V=V, dtype=dtype, beta=beta, loss_type=loss_type, use_ref_model=False) + liger3.lin.weight.data = liger4.lin.weight.data = _weight.clone() + + loss3, _ = liger3( + input3, + selected_token_ids, + attention_mask, + advantages, + vllm_is_ratio=torch.full((B,), uniform_val, device=device), + ) + loss4, _ = liger4( + input4, + selected_token_ids, + attention_mask, + advantages, + vllm_is_ratio=torch.full((B, 1), uniform_val, device=device), + ) + assert_verbose_allclose(loss3, loss4, atol=1e-5, rtol=1e-5) + loss3.backward() + loss4.backward() + assert_verbose_allclose(input3.grad, input4.grad, atol=1e-5, rtol=1e-5) + + @pytest.mark.parametrize( "B, T, H, V", [ diff --git a/test/transformers/test_grpo_loss.py b/test/transformers/test_grpo_loss.py index cb249c819..514223b7d 100644 --- a/test/transformers/test_grpo_loss.py +++ b/test/transformers/test_grpo_loss.py @@ -270,6 +270,267 @@ def test_grpo_loss(B, T, V, temperature, num_iteration, beta, eps_low, eps_high, assert_verbose_allclose(logits2.grad, logits3.grad, atol=atol, rtol=rtol) +def torch_grpo_loss_with_vllm_is( + logits, + old_logp, + ref_logp, + completion_ids, + advantages, + completion_mask, + temperature, + beta, + eps_low, + eps_high, + vllm_is_ratio, + loss_type="grpo", + sapo_temperature_pos=1.0, + sapo_temperature_neg=1.05, +): + """Reference implementation with vLLM IS ratio correction for all loss types.""" + assert logits.is_contiguous() and completion_ids.is_contiguous() + logits = logits[:, :-1] + per_token_logps = _get_log_probs(logits / temperature, completion_ids) + ref_per_token_logps = ref_logp + if old_logp is None: + old_logp = per_token_logps.detach() + coef_1 = torch.exp(per_token_logps - old_logp) + + if loss_type == "cispo": + coef_2 = torch.clamp(coef_1, max=eps_high).detach() + per_token_loss = -coef_2 * advantages.unsqueeze(1) * per_token_logps + is_clipped = ((coef_1 > eps_high) & (advantages.unsqueeze(1) > 0)).float() + elif loss_type == "sapo": + temp = torch.where(advantages.unsqueeze(1) > 0, sapo_temperature_pos, sapo_temperature_neg) + sigmoid_input = temp * (coef_1 - 1.0) + sapo_coef = torch.sigmoid(sigmoid_input) * 4.0 / temp + per_token_loss = -sapo_coef * advantages.unsqueeze(1) + is_clipped = torch.zeros_like(per_token_loss) + else: + coef_2 = torch.clamp(coef_1, 1 - eps_low, 1 + eps_high) + per_token_loss1 = coef_1 * advantages.unsqueeze(1) + per_token_loss2 = coef_2 * advantages.unsqueeze(1) + per_token_loss = -torch.min(per_token_loss1, per_token_loss2) + is_clipped = (per_token_loss1 < per_token_loss2).float() + + # Apply vLLM IS correction BEFORE KL penalty + if vllm_is_ratio is not None: + per_token_loss = per_token_loss * vllm_is_ratio + per_token_loss = per_token_loss * completion_mask if completion_mask is not None else per_token_loss + per_token_kl = None + if beta != 0.0: + per_token_kl = torch.exp(ref_per_token_logps - per_token_logps) - (ref_per_token_logps - per_token_logps) - 1 + if completion_mask is not None: + per_token_kl *= completion_mask + per_token_loss = per_token_loss + beta * per_token_kl + return per_token_loss, per_token_kl, is_clipped + + +@pytest.mark.parametrize( + "temperature, num_iteration, beta, eps_low, eps_high", + [(0.7, num_iteration, beta, 0.2, 0.4) for num_iteration in [1, 5] for beta in [0.0, 0.04]], +) +@pytest.mark.parametrize( + "B, T, V", + [ + (2, 128, 1000), + ], +) +@pytest.mark.parametrize( + "dtype, atol, rtol", + [ + (torch.bfloat16, 5e-2, 5e-1), + ], +) +@pytest.mark.parametrize("loss_type", ["grpo", "cispo", "sapo"]) +def test_grpo_loss_with_vllm_is_ratio( + B, T, V, temperature, num_iteration, beta, eps_low, eps_high, dtype, atol, rtol, loss_type +): + """Test that triton_grpo_loss with vllm_is_ratio matches PyTorch reference for all loss types.""" + _input = torch.randn(B, T + 1, V, device=device, dtype=dtype) + + logits1 = _input.clone().requires_grad_(True) + logits2 = _input.clone().requires_grad_(True) + logits3 = _input.clone().float().requires_grad_(True) + + completion_ids = torch.randint(0, V - 1, (B, T), dtype=torch.int64, device=device) + completion_mask = torch.ones_like(completion_ids, dtype=torch.int32) + completion_mask[:, -20:] = 0 + + ref_logp = torch.randn(B, T, device=device, dtype=torch.float32) if beta != 0.0 else None + old_logp = torch.randn(B, T, device=device, dtype=torch.float32) if num_iteration > 1 else None + advantages = torch.randn(B, device=device, dtype=torch.float32) + + # Create vLLM IS ratio (random values between 0.001 and 1.0 to simulate typical IS correction) + vllm_is_ratio = torch.rand(B, T, device=device, dtype=torch.float32) * 0.999 + 0.001 + + loss1, kl1, _ = torch_grpo_loss_with_vllm_is( + logits1, + old_logp, + ref_logp, + completion_ids, + advantages, + completion_mask, + temperature, + beta, + eps_low, + eps_high, + vllm_is_ratio, + loss_type=loss_type, + ) + loss2, kl2, _ = triton_grpo_loss( + logits2, + old_logp, + ref_logp, + completion_ids, + advantages, + completion_mask, + temperature, + beta, + eps_low, + eps_high, + inplace=True, + vllm_is_ratio=vllm_is_ratio, + loss_type=loss_type, + ) + loss3, kl3, _ = torch_grpo_loss_with_vllm_is( + logits3, + old_logp, + ref_logp, + completion_ids, + advantages, + completion_mask, + temperature, + beta, + eps_low, + eps_high, + vllm_is_ratio, + loss_type=loss_type, + ) + + dy = torch.randn_like(loss3) + loss1.backward(dy) + loss2.backward(dy) + loss3.backward(dy) + + # Compare triton bf16 vs torch fp32 + assert_verbose_allclose(loss2, loss3, atol=atol, rtol=rtol) + if kl2 is not None and kl3 is not None: + assert_verbose_allclose(kl2, kl3, atol=atol, rtol=rtol) + assert_verbose_allclose(logits2.grad, logits3.grad, atol=atol, rtol=rtol) + + # Verify vllm_is_ratio=None gives same result as vllm_is_ratio=ones + logits_none = _input.clone().float().requires_grad_(True) + logits_ones = _input.clone().float().requires_grad_(True) + loss_none, _, _ = triton_grpo_loss( + logits_none, + old_logp, + ref_logp, + completion_ids, + advantages, + completion_mask, + temperature, + beta, + eps_low, + eps_high, + inplace=False, + vllm_is_ratio=None, + loss_type=loss_type, + ) + loss_ones, _, _ = triton_grpo_loss( + logits_ones, + old_logp, + ref_logp, + completion_ids, + advantages, + completion_mask, + temperature, + beta, + eps_low, + eps_high, + inplace=False, + vllm_is_ratio=torch.ones(B, T, device=device, dtype=torch.float32), + loss_type=loss_type, + ) + assert_verbose_allclose(loss_none, loss_ones, atol=1e-5, rtol=1e-5) + + # Verify (B, 1) shape gives same result as (B, T) with uniform value + uniform_val = 0.42 + logits_b1 = _input.clone().float().requires_grad_(True) + logits_bt = _input.clone().float().requires_grad_(True) + loss_b1, _, _ = triton_grpo_loss( + logits_b1, + old_logp, + ref_logp, + completion_ids, + advantages, + completion_mask, + temperature, + beta, + eps_low, + eps_high, + inplace=False, + vllm_is_ratio=torch.full((B, 1), uniform_val, device=device, dtype=torch.float32), + loss_type=loss_type, + ) + loss_bt, _, _ = triton_grpo_loss( + logits_bt, + old_logp, + ref_logp, + completion_ids, + advantages, + completion_mask, + temperature, + beta, + eps_low, + eps_high, + inplace=False, + vllm_is_ratio=torch.full((B, T), uniform_val, device=device, dtype=torch.float32), + loss_type=loss_type, + ) + loss_b1.backward(dy) + loss_bt.backward(dy) + assert_verbose_allclose(loss_b1, loss_bt, atol=1e-5, rtol=1e-5) + assert_verbose_allclose(logits_b1.grad, logits_bt.grad, atol=1e-5, rtol=1e-5) + + # Verify 1D (B,) shape gives same result as (B, 1) + logits_1d = _input.clone().float().requires_grad_(True) + logits_2d = _input.clone().float().requires_grad_(True) + loss_1d, _, _ = triton_grpo_loss( + logits_1d, + old_logp, + ref_logp, + completion_ids, + advantages, + completion_mask, + temperature, + beta, + eps_low, + eps_high, + inplace=False, + vllm_is_ratio=torch.full((B,), uniform_val, device=device, dtype=torch.float32), + loss_type=loss_type, + ) + loss_2d, _, _ = triton_grpo_loss( + logits_2d, + old_logp, + ref_logp, + completion_ids, + advantages, + completion_mask, + temperature, + beta, + eps_low, + eps_high, + inplace=False, + vllm_is_ratio=torch.full((B, 1), uniform_val, device=device, dtype=torch.float32), + loss_type=loss_type, + ) + loss_1d.backward(dy) + loss_2d.backward(dy) + assert_verbose_allclose(loss_1d, loss_2d, atol=1e-5, rtol=1e-5) + assert_verbose_allclose(logits_1d.grad, logits_2d.grad, atol=1e-5, rtol=1e-5) + + @pytest.mark.parametrize( "temperature, num_iteration, beta, eps_high", [(0.7, num_iteration, beta, 5.0) for num_iteration in [1, 5] for beta in [0.0, 0.04]],