Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
32 changes: 32 additions & 0 deletions src/liger_kernel/chunked_loss/fused_linear_ppo.py
Original file line number Diff line number Diff line change
Expand Up @@ -41,6 +41,7 @@ def forward(
chunk_size=1,
sapo_temperature_pos=1.0,
sapo_temperature_neg=1.05,
vllm_is_ratio=None,
):
# TODO: check torch compile matmul
"""Chunked forward pass for PPO loss computation.
Expand Down Expand Up @@ -71,6 +72,8 @@ def forward(
chunk_size: Size of chunks for processing in other loss modules
sapo_temperature_pos: Temperature for positive advantages in SAPO
sapo_temperature_neg: Temperature for negative advantages in SAPO
vllm_is_ratio: vLLM importance sampling ratio tensor (batch_size, seq_len) or (batch_size, 1) or None.
Used to correct for distribution mismatch when using vLLM for generation.
"""
if use_ref_model:
assert ref_per_token_logps is not None or ref_input is not None, (
Expand All @@ -80,6 +83,20 @@ def forward(
raise Warning("Both ref_per_token_logps and ref_input are provided. Using ref_per_token_logps.")
if loss_type == "dr_grpo":
assert max_completion_length is not None, "max_completion_length must be provided for loss_type 'dr_grpo'"
if vllm_is_ratio is not None:
B, T = attention_mask.shape
assert vllm_is_ratio.dim() in (1, 2), (
f"vllm_is_ratio must be 1D (B,) or 2D (B, T) / (B, 1), got {vllm_is_ratio.dim()}D"
)
if vllm_is_ratio.dim() == 2:
assert vllm_is_ratio.shape[0] == B and vllm_is_ratio.shape[1] in (1, T), (
f"vllm_is_ratio shape must be ({B}, 1) or ({B}, {T}), got {tuple(vllm_is_ratio.shape)}"
)
else:
assert vllm_is_ratio.shape[0] == B, (
f"vllm_is_ratio shape must be ({B},), got {tuple(vllm_is_ratio.shape)}"
)
vllm_is_ratio = vllm_is_ratio.unsqueeze(-1) # (B,) -> (B, 1) for broadcasting
# Initialize accumulators
loss_acc = torch.zeros((), device=_input.device, dtype=torch.float32)
grad_weight = torch.zeros_like(weight) # [V, H]
Expand Down Expand Up @@ -114,6 +131,7 @@ def fused_fwd_bwd(
ref_per_token_logps_chunk,
old_per_token_logps_chunk,
ref_input_chunk,
vllm_is_ratio_chunk,
):
"""Fused forward and backward for a chunk."""
argnums = (0, 1, 5) if bias is not None else (0, 1)
Expand All @@ -127,6 +145,7 @@ def fused_fwd_bwd(
ref_per_token_logps_chunk=ref_per_token_logps_chunk, # arg 6
old_per_token_logps_chunk=old_per_token_logps_chunk, # arg 7
ref_input_chunk=ref_input_chunk, # arg 8
vllm_is_ratio_chunk=vllm_is_ratio_chunk, # arg 9
)

def accumulate_chunk(
Expand All @@ -137,6 +156,7 @@ def accumulate_chunk(
ref_per_token_logps_chunk=None,
old_per_token_logps_chunk=None,
ref_input_chunk=None,
vllm_is_ratio_chunk=None,
):
(chunk_grad_input, chunk_grad_weight, *chunk_grad_bias), (chunk_loss, chunk_metrics) = fused_fwd_bwd(
input_chunk,
Expand All @@ -146,6 +166,7 @@ def accumulate_chunk(
ref_per_token_logps_chunk,
old_per_token_logps_chunk,
ref_input_chunk,
vllm_is_ratio_chunk,
)
if bias is not None:
grad_bias.add_(chunk_grad_bias[0])
Expand Down Expand Up @@ -196,6 +217,9 @@ def accumulate_chunk(
if use_ref_model and ref_per_token_logps is None
else [None] * chunks
)
_vllm_is_ratio_chunks = (
torch.chunk(vllm_is_ratio, chunks=chunks, dim=0) if vllm_is_ratio is not None else [None] * chunks
)

for (
input_chunk,
Expand All @@ -205,6 +229,7 @@ def accumulate_chunk(
ref_per_token_logps_chunk,
old_per_token_logps_chunk,
ref_input_chunk,
vllm_is_ratio_chunk,
) in zip(
_input_chunks,
_selected_token_ids_chunks,
Expand All @@ -213,6 +238,7 @@ def accumulate_chunk(
_ref_per_token_logps_chunks,
_old_per_token_logps_chunks,
_ref_input_chunks,
_vllm_is_ratio_chunks,
):
# Mark dynamic dimensions
torch._dynamo.mark_dynamic(input_chunk, 1)
Expand All @@ -224,6 +250,8 @@ def accumulate_chunk(
torch._dynamo.mark_dynamic(ref_input_chunk, 1)
if old_per_token_logps_chunk is not None:
torch._dynamo.mark_dynamic(old_per_token_logps_chunk, 1)
if vllm_is_ratio_chunk is not None:
torch._dynamo.mark_dynamic(vllm_is_ratio_chunk, 1)

accumulate_chunk(
input_chunk,
Expand All @@ -233,6 +261,7 @@ def accumulate_chunk(
ref_per_token_logps_chunk,
old_per_token_logps_chunk,
ref_input_chunk,
vllm_is_ratio_chunk,
)

# Combine gradients
Expand Down Expand Up @@ -277,6 +306,7 @@ def _compute_chunk_loss(
ref_per_token_logps_chunk=None,
old_per_token_logps_chunk=None,
ref_input_chunk=None,
vllm_is_ratio_chunk=None,
ref_weight=None,
ref_bias=None,
full_attention_mask=None,
Expand Down Expand Up @@ -322,6 +352,7 @@ def _compute_chunk_loss(
importance_sampling_level=importance_sampling_level,
sapo_temperature_pos=sapo_temperature_pos,
sapo_temperature_neg=sapo_temperature_neg,
vllm_is_ratio=vllm_is_ratio_chunk,
)

return chunk_loss, chunk_metrics
Expand Down Expand Up @@ -376,4 +407,5 @@ def backward(ctx, grad_output, *grad_metrics):
None, # grad_chunk_size
None, # grad_sapo_temperature_pos
None, # grad_sapo_temperature_neg
None, # grad_vllm_is_ratio
)
12 changes: 12 additions & 0 deletions src/liger_kernel/chunked_loss/grpo_loss.py
Original file line number Diff line number Diff line change
Expand Up @@ -75,6 +75,7 @@ def ppo_loss_fn(
importance_sampling_level="token", # ["token", "sequence"] - new parameter for GSPO
sapo_temperature_pos=1.0, # Temperature for positive advantages in SAPO
sapo_temperature_neg=1.05, # Temperature for negative advantages in SAPO
vllm_is_ratio=None, # vLLM importance sampling ratio (chunk_size, seq_len) or (chunk_size, 1) or None
**kwargs,
):
"""GRPO Loss Function matching GRPOTrainer implementation."""
Expand Down Expand Up @@ -138,6 +139,10 @@ def ppo_loss_fn(
per_token_loss2 = coef_2 * advantages.unsqueeze(1)
per_token_loss = -torch.min(per_token_loss1, per_token_loss2)

# Apply vLLM importance sampling correction BEFORE adding KL penalty
if vllm_is_ratio is not None:
per_token_loss = per_token_loss * vllm_is_ratio

if beta != 0.0:
# Compute KL penalty (approximates KL[per_token_logps, ref_per_token_logps])
kl_div = k3_loss_fn(ref_per_token_logps, per_token_logps)
Expand Down Expand Up @@ -214,6 +219,7 @@ def forward(
compiled=True,
use_ref_model=True,
chunk_size=1,
vllm_is_ratio=None,
):
"""
Fused linear layer with GRPO loss.
Expand All @@ -239,6 +245,8 @@ def forward(
compiled (bool): Whether to use torch compile
use_ref_model (bool): Whether to use a reference model
chunk_size (int): Size of chunks for processing.
vllm_is_ratio (torch.Tensor, optional): vLLM importance sampling ratio (batch_size, seq_len) or (batch_size, 1) or None.
Used to correct for distribution mismatch when using vLLM for generation.
Returns:
torch.Tensor: Computed loss
"""
Expand Down Expand Up @@ -268,6 +276,7 @@ def forward(
importance_sampling_level=importance_sampling_level,
sapo_temperature_pos=sapo_temperature_pos,
sapo_temperature_neg=sapo_temperature_neg,
vllm_is_ratio=vllm_is_ratio,
)

@staticmethod
Expand Down Expand Up @@ -300,6 +309,7 @@ def backward(ctx, grad_output, *grad_metrics):
None, # grad_compiled
None, # grad_use_ref_model
None, # grad_chunk_size
None, # grad_vllm_is_ratio
)


Expand Down Expand Up @@ -370,6 +380,7 @@ def forward(
ref_input=None,
ref_weight=None,
ref_bias=None,
vllm_is_ratio=None,
):
return LigerFusedLinearGRPOFunction.apply(
_input,
Expand All @@ -395,4 +406,5 @@ def forward(
self.compiled,
self.use_ref_model,
self.chunk_size,
vllm_is_ratio,
)
53 changes: 51 additions & 2 deletions src/liger_kernel/ops/grpo_loss.py
Original file line number Diff line number Diff line change
Expand Up @@ -90,6 +90,8 @@ def _grpo_loss_fwd_kernel(
INPUT_IDS,
COMPLETION_MASK,
ADVANTAGES,
VLLM_IS_RATIO,
VLLM_IS_RATIO_STRIDE,
LOSS,
LSE,
KL,
Expand Down Expand Up @@ -169,6 +171,14 @@ def _grpo_loss_fwd_kernel(
per_token_loss = -sapo_coef * advantage
is_clipped = 0.0 # SAPO has no clipping concept

# Apply vLLM importance sampling correction BEFORE adding KL penalty
if VLLM_IS_RATIO is not None:
# Use modulo to support both (B, L) per-token and (B, 1) per-sequence shapes
vllm_is_ratio = tl.load(VLLM_IS_RATIO + off_b * VLLM_IS_RATIO_STRIDE + off_l % VLLM_IS_RATIO_STRIDE).to(
tl.float32
)
per_token_loss = per_token_loss * vllm_is_ratio

if BETA != 0.0:
REF_LOGP += off_b * L + off_l
KL += off_b * L + off_l
Expand Down Expand Up @@ -198,6 +208,8 @@ def _grpo_loss_bwd_kernel(
ADVANTAGES,
COMPLETION_MASK,
LSE,
VLLM_IS_RATIO,
VLLM_IS_RATIO_STRIDE,
TEMPERATURE,
BETA: tl.constexpr,
EPS_LOW,
Expand Down Expand Up @@ -271,6 +283,14 @@ def _grpo_loss_bwd_kernel(
d_sapo_d_coef1 = 4.0 * sigmoid_val * (1.0 - sigmoid_val)
dlogp = -advantage * d_sapo_d_coef1 * coef_1

# Apply vLLM IS ratio to PPO gradient (before KL gradient)
if VLLM_IS_RATIO is not None:
# Use modulo to support both (B, L) per-token and (B, 1) per-sequence shapes
vllm_is_ratio = tl.load(VLLM_IS_RATIO + off_b * VLLM_IS_RATIO_STRIDE + off_l % VLLM_IS_RATIO_STRIDE).to(
tl.float32
)
dlogp = dlogp * vllm_is_ratio

if BETA != 0.0:
REF_LOGP += off_b * L + off_l
ref_logp = tl.load(REF_LOGP).to(tl.float32)
Expand Down Expand Up @@ -304,6 +324,7 @@ def forward(
loss_type="grpo",
sapo_temperature_pos=1.0,
sapo_temperature_neg=1.05,
vllm_is_ratio=None,
):
assert logits.is_contiguous() and completion_ids.is_contiguous()
assert old_logp is None or old_logp.is_contiguous()
Expand All @@ -329,6 +350,25 @@ def forward(
if completion_mask is not None:
assert completion_mask.is_contiguous()

# Handle vLLM IS ratio
vllm_is_ratio_ptr = None
vllm_is_ratio_stride = L # default to per-token (unused when ptr is None)
if vllm_is_ratio is not None:
assert vllm_is_ratio.dim() in (1, 2), (
f"vllm_is_ratio must be 1D (B,) or 2D (B, L) / (B, 1), got {vllm_is_ratio.dim()}D"
)
if vllm_is_ratio.dim() == 2:
assert vllm_is_ratio.shape[0] == B and vllm_is_ratio.shape[1] in (1, L), (
f"vllm_is_ratio shape must be ({B}, 1) or ({B}, {L}), got {tuple(vllm_is_ratio.shape)}"
)
else:
assert vllm_is_ratio.shape[0] == B, (
f"vllm_is_ratio shape must be ({B},), got {tuple(vllm_is_ratio.shape)}"
)
vllm_is_ratio = vllm_is_ratio.contiguous()
vllm_is_ratio_ptr = vllm_is_ratio
vllm_is_ratio_stride = vllm_is_ratio.shape[1] if vllm_is_ratio.dim() > 1 else 1

loss = torch.zeros(B, L, device=logits.device, dtype=torch.float32)
lse = torch.zeros_like(loss)
is_clipped = torch.zeros_like(loss)
Expand All @@ -341,6 +381,8 @@ def forward(
completion_ids,
completion_mask,
advantages,
vllm_is_ratio_ptr,
vllm_is_ratio_stride,
loss,
lse,
kl,
Expand All @@ -357,6 +399,8 @@ def forward(
**kwargs,
)
ctx.save_for_backward(logits, old_logp, ref_logp, completion_ids, advantages, completion_mask, lse)
ctx.vllm_is_ratio = vllm_is_ratio_ptr
ctx.vllm_is_ratio_stride = vllm_is_ratio_stride
ctx.infos = (
temperature,
beta,
Expand All @@ -376,6 +420,8 @@ def backward(ctx, *args):
temperature, beta, eps_low, eps_high, inplace, loss_type_int, sapo_temperature_pos, sapo_temperature_neg = (
ctx.infos
)
vllm_is_ratio = ctx.vllm_is_ratio
vllm_is_ratio_stride = ctx.vllm_is_ratio_stride
B, L_ADD_1, N = logits.shape
L = L_ADD_1 - 1
dlogits = logits.data if inplace else torch.empty_like(logits)
Expand All @@ -390,6 +436,8 @@ def backward(ctx, *args):
advantages,
completion_mask,
lse,
vllm_is_ratio,
vllm_is_ratio_stride,
temperature,
beta,
eps_low,
Expand All @@ -404,5 +452,6 @@ def backward(ctx, *args):
)
dlogits[:, -1, :] = 0
# Return None for: old_logp, ref_logp, completion_ids, advantages, completion_mask,
# temperature, beta, eps_low, eps_high, inplace, loss_type, sapo_temperature_pos, sapo_temperature_neg
return dlogits, None, None, None, None, None, None, None, None, None, None, None, None, None
# temperature, beta, eps_low, eps_high, inplace, loss_type, sapo_temperature_pos, sapo_temperature_neg,
# vllm_is_ratio
return dlogits, None, None, None, None, None, None, None, None, None, None, None, None, None, None
2 changes: 2 additions & 0 deletions src/liger_kernel/transformers/grpo_loss.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,6 +22,7 @@ def triton_grpo_loss(
reduce=False,
sapo_temperature_pos=1.0,
sapo_temperature_neg=1.05,
vllm_is_ratio=None,
):
assert logits is not None and completion_ids is not None and advantages is not None, (
"must provide logits, completion_ids and advantages"
Expand All @@ -46,6 +47,7 @@ def triton_grpo_loss(
loss_type,
sapo_temperature_pos,
sapo_temperature_neg,
vllm_is_ratio,
)
if not reduce:
return per_token_loss, per_token_kl, is_clipped
Expand Down
Loading
Loading