Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
16 changes: 16 additions & 0 deletions src/liger_kernel/chunked_loss/fused_linear_ppo.py
Original file line number Diff line number Diff line change
Expand Up @@ -39,6 +39,7 @@ def forward(
compiled=True,
use_ref_model=False,
chunk_size=1,
vllm_is_ratio=None,
):
# TODO: check torch compile matmul
"""Chunked forward pass for PPO loss computation.
Expand Down Expand Up @@ -107,6 +108,7 @@ def fused_fwd_bwd(
ref_per_token_logps_chunk,
old_per_token_logps_chunk,
ref_input_chunk,
vllm_is_ratio_chunk,
):
"""Fused forward and backward for a chunk."""
argnums = (0, 1, 5) if bias is not None else (0, 1)
Expand All @@ -120,6 +122,7 @@ def fused_fwd_bwd(
ref_per_token_logps_chunk=ref_per_token_logps_chunk, # arg 6
old_per_token_logps_chunk=old_per_token_logps_chunk, # arg 7
ref_input_chunk=ref_input_chunk, # arg 8
vllm_is_ratio_chunk=vllm_is_ratio_chunk, # arg 9
)

def accumulate_chunk(
Expand All @@ -130,6 +133,7 @@ def accumulate_chunk(
ref_per_token_logps_chunk=None,
old_per_token_logps_chunk=None,
ref_input_chunk=None,
vllm_is_ratio_chunk=None,
):
(chunk_grad_input, chunk_grad_weight, *chunk_grad_bias), (chunk_loss, chunk_metrics) = fused_fwd_bwd(
input_chunk,
Expand All @@ -139,6 +143,7 @@ def accumulate_chunk(
ref_per_token_logps_chunk,
old_per_token_logps_chunk,
ref_input_chunk,
vllm_is_ratio_chunk,
)
if bias is not None:
grad_bias.add_(chunk_grad_bias[0])
Expand Down Expand Up @@ -189,6 +194,9 @@ def accumulate_chunk(
if use_ref_model and ref_per_token_logps is None
else [None] * chunks
)
_vllm_is_ratio_chunks = (
torch.chunk(vllm_is_ratio, chunks=chunks, dim=0) if vllm_is_ratio is not None else [None] * chunks
)

for (
input_chunk,
Expand All @@ -198,6 +206,7 @@ def accumulate_chunk(
ref_per_token_logps_chunk,
old_per_token_logps_chunk,
ref_input_chunk,
vllm_is_ratio_chunk,
) in zip(
_input_chunks,
_selected_token_ids_chunks,
Expand All @@ -206,6 +215,7 @@ def accumulate_chunk(
_ref_per_token_logps_chunks,
_old_per_token_logps_chunks,
_ref_input_chunks,
_vllm_is_ratio_chunks,
):
# Mark dynamic dimensions
torch._dynamo.mark_dynamic(input_chunk, 1)
Expand All @@ -217,6 +227,8 @@ def accumulate_chunk(
torch._dynamo.mark_dynamic(ref_input_chunk, 1)
if old_per_token_logps_chunk is not None:
torch._dynamo.mark_dynamic(old_per_token_logps_chunk, 1)
if vllm_is_ratio_chunk is not None:
torch._dynamo.mark_dynamic(vllm_is_ratio_chunk, 1)

accumulate_chunk(
input_chunk,
Expand All @@ -226,6 +238,7 @@ def accumulate_chunk(
ref_per_token_logps_chunk,
old_per_token_logps_chunk,
ref_input_chunk,
vllm_is_ratio_chunk,
)

# Combine gradients
Expand Down Expand Up @@ -270,6 +283,7 @@ def _compute_chunk_loss(
ref_per_token_logps_chunk=None,
old_per_token_logps_chunk=None,
ref_input_chunk=None,
vllm_is_ratio_chunk=None,
ref_weight=None,
ref_bias=None,
full_attention_mask=None,
Expand Down Expand Up @@ -311,6 +325,7 @@ def _compute_chunk_loss(
loss_type=loss_type,
max_completion_length=max_completion_length,
importance_sampling_level=importance_sampling_level,
vllm_is_ratio=vllm_is_ratio_chunk,
)

return chunk_loss, chunk_metrics
Expand Down Expand Up @@ -363,4 +378,5 @@ def backward(ctx, grad_output, *grad_metrics):
None, # grad_compiled
None, # grad_use_ref_model
None, # grad_chunk_size
None, # grad_vllm_is_ratio
)
24 changes: 18 additions & 6 deletions src/liger_kernel/chunked_loss/grpo_loss.py
Original file line number Diff line number Diff line change
Expand Up @@ -30,8 +30,9 @@ def ppo_loss_fn(
epsilon_high=0.2,
beta=0.04,
loss_type="dapo", # ["grpo", "bnpo", "dr_grpo", "dapo"]
max_completion_length=None, # Required for dr_grpo
max_completion_length=None, # Optional for dr_grpo (defaults to sequence length)
importance_sampling_level="token", # ["token", "sequence"] - new parameter for GSPO
vllm_is_ratio=None, # vLLM importance sampling ratio (chunk_size, seq_len) or None
**kwargs,
):
"""GRPO Loss Function matching GRPOTrainer implementation."""
Expand Down Expand Up @@ -71,6 +72,11 @@ def ppo_loss_fn(
per_token_loss1 = coef_1 * advantages.unsqueeze(1)
per_token_loss2 = coef_2 * advantages.unsqueeze(1)
per_token_loss = -torch.min(per_token_loss1, per_token_loss2)

# Apply vLLM importance sampling correction BEFORE adding KL penalty
if vllm_is_ratio is not None:
per_token_loss = per_token_loss * vllm_is_ratio

if beta != 0.0:
# Compute KL penalty (approximates KL[per_token_logps, ref_per_token_logps])
kl_div = k3_loss_fn(ref_per_token_logps, per_token_logps)
Expand All @@ -91,9 +97,8 @@ def ppo_loss_fn(
loss = (per_token_loss * attention_mask).sum() / torch.clamp(full_attention_mask.sum(), min=1.0)
elif loss_type == "dr_grpo":
# Dimension-Reduced GRPO (normalize by batch_size * max_completion_length)
if max_completion_length is None:
raise ValueError("max_completion_length must be provided for loss_type 'dr_grpo'")
loss = (per_token_loss * attention_mask).sum() / (full_attention_mask.shape[0] * max_completion_length)
max_len = max_completion_length if max_completion_length is not None else attention_mask.shape[1]
loss = (per_token_loss * attention_mask).sum() / (full_attention_mask.shape[0] * max_len)
elif loss_type == "dapo":
loss_normalizer = LigerFusedLinearPPOBase._compute_dapo_normalizer(full_attention_mask)
loss = (per_token_loss * attention_mask).sum() / loss_normalizer
Expand Down Expand Up @@ -145,6 +150,7 @@ def forward(
compiled=True,
use_ref_model=True,
chunk_size=1,
vllm_is_ratio=None,
):
"""
Fused linear layer with GRPO loss.
Expand All @@ -161,12 +167,14 @@ def forward(
ref_bias (torch.Tensor, optional): Reference model bias tensor. Shape: (vocab_size,)
beta (float): Weight for the KL penalty
loss_type (str): Type of loss calculation ("grpo", "bnpo", "dr_grpo", "dapo"). Defaults to "dapo".
max_completion_length (int, optional): Maximum completion length, required for "dr_grpo". Defaults to None.
max_completion_length (int, optional): Maximum completion length; if None, defaults to sequence length.
importance_sampling_level (str): Level of importance sampling ("token" or "sequence"). Defaults to "token".
temperature (float): Temperature for the logits
compiled (bool): Whether to use torch compile
use_ref_model (bool): Whether to use a reference model
chunk_size (int): Size of chunks for processing.
vllm_is_ratio (torch.Tensor, optional): vLLM importance sampling ratio (batch_size, seq_len) or None.
Used to correct for distribution mismatch when using vLLM for generation.
Returns:
torch.Tensor: Computed loss
"""
Expand Down Expand Up @@ -194,6 +202,7 @@ def forward(
use_ref_model=use_ref_model,
chunk_size=chunk_size,
importance_sampling_level=importance_sampling_level,
vllm_is_ratio=vllm_is_ratio,
)

@staticmethod
Expand Down Expand Up @@ -224,6 +233,7 @@ def backward(ctx, grad_output, *grad_metrics):
None, # grad_compiled
None, # grad_use_ref_model
None, # grad_chunk_size
None, # grad_vllm_is_ratio
)


Expand Down Expand Up @@ -252,7 +262,7 @@ def __init__(
epsilon_low (float): Lower bound for the importance sampling ratio.
epsilon_high (float): Upper bound for the importance sampling ratio.
loss_type (str): Type of loss calculation ("grpo", "bnpo", "dr_grpo", "dapo"). Defaults to "dapo".
max_completion_length (int, optional): Maximum completion length, required for "dr_grpo". Defaults to None.
max_completion_length (int, optional): Maximum completion length; if None, defaults to sequence length.
importance_sampling_level (str): Level of importance sampling ("token" or "sequence"). Defaults to "token".
temperature (float): Temperature for the logits.
"""
Expand Down Expand Up @@ -281,6 +291,7 @@ def forward(
ref_input=None,
ref_weight=None,
ref_bias=None,
vllm_is_ratio=None,
):
return LigerFusedLinearGRPOFunction.apply(
_input,
Expand All @@ -304,4 +315,5 @@ def forward(
self.compiled,
self.use_ref_model,
self.chunk_size,
vllm_is_ratio,
)
Loading
Loading