Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
329 changes: 222 additions & 107 deletions src/liger_kernel/chunked_loss/fused_linear_ppo.py

Large diffs are not rendered by default.

31 changes: 5 additions & 26 deletions src/liger_kernel/chunked_loss/grpo_loss.py
Original file line number Diff line number Diff line change
Expand Up @@ -59,14 +59,12 @@ def clip_coef_fn(coef, epsilon_low, epsilon_high, loss_type):
class LigerFusedLinearGRPOFunction(LigerFusedLinearPPOBase):
@staticmethod
def ppo_loss_fn(
log_probs,
selected_token_ids,
per_token_logps,
attention_mask,
advantages,
full_attention_mask,
ref_per_token_logps=None, # shape: [chunk_size, seq_len]
old_per_token_logps=None,
ref_log_probs=None, # used when ref_per_token_logps is None (shape: [chunk_size, seq_len, vocab_size])
epsilon_low=0.2,
epsilon_high=0.2,
beta=0.04,
Expand All @@ -88,19 +86,9 @@ def ppo_loss_fn(
f"Use importance_sampling_level='token' instead."
)

per_token_logps = log_probs.gather(dim=-1, index=selected_token_ids.unsqueeze(-1)).squeeze(
-1
) # (batch_size, seq_len)

# Get reference model probabilities
if ref_per_token_logps is None:
if ref_log_probs is not None:
with torch.no_grad():
ref_per_token_logps = ref_log_probs.gather(dim=-1, index=selected_token_ids.unsqueeze(-1)).squeeze(
-1
)
else:
ref_per_token_logps = per_token_logps.detach()
ref_per_token_logps = per_token_logps.detach()

# Compute policy gradient loss with importance sampling ratio
old_per_token_logps = old_per_token_logps if old_per_token_logps is not None else per_token_logps.detach()
Expand Down Expand Up @@ -186,18 +174,9 @@ def ppo_loss_fn(
loss_normalizer = LigerFusedLinearPPOBase._compute_dapo_normalizer(full_attention_mask)
loss = (per_token_loss * attention_mask).sum() / loss_normalizer
elif loss_type == "luspo":
# LUSPO: loss = (per_token_loss * mask.sum(1, keepdim=True)).mean()
# Reformulated as: sum_i(sum_j(per_token_loss_ij) * seq_len_i) / numel
# to avoid (B,T) * (B,1) broadcast which amplifies torch.compile differences.
seq_lens = attention_mask.sum(-1) # (chunk_B,)
per_seq_sum = per_token_loss.sum(-1) # (chunk_B,)
weighted = per_seq_sum * seq_lens # (chunk_B,)
if importance_sampling_level == "sequence" and beta == 0.0:
# per_token_loss stays (B, 1), so .mean() divides by B
loss = weighted.sum() / full_attention_mask.shape[0]
else:
# per_token_loss is (B, T), .mean() divides by B*T
loss = weighted.sum() / (full_attention_mask.shape[0] * full_attention_mask.shape[1])
# Match TRL exactly: loss = (per_token_loss * mask.sum(1, keepdim=True)).mean()
weighted = per_token_loss * attention_mask.sum(1, keepdim=True)
loss = weighted.sum() / (full_attention_mask.shape[0] * weighted.shape[1])
else:
raise ValueError(f"Unknown loss type: {loss_type}")

Expand Down
1 change: 1 addition & 0 deletions src/liger_kernel/ops/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -52,6 +52,7 @@
from liger_kernel.ops.group_norm import group_norm_backward # noqa: F401
from liger_kernel.ops.group_norm import group_norm_forward # noqa: F401
from liger_kernel.ops.grpo_loss import GrpoLossFunction # noqa: F401
from liger_kernel.ops.grpo_loss import fused_linear_grpo_loss # noqa: F401
from liger_kernel.ops.jsd import LigerJSDFunction # noqa: F401
from liger_kernel.ops.jsd import jsd_backward # noqa: F401
from liger_kernel.ops.jsd import jsd_forward # noqa: F401
Expand Down
Loading
Loading