Skip to content

Commit 397ab30

Browse files
[GRPO] add support for dapo loss (#939)
## Summary <!--- This is a required section; please describe the main purpose of this proposed code change. ---> Add support for DAPO loss and made it the default loss <!--- ## Details This is an optional section; is there anything specific that reviewers should be aware of? ---> ## Testing Done <!--- This is a required section; please describe how this change was tested. ---> <!-- Replace BLANK with your device type. For example, A100-80G-PCIe Complete the following tasks before sending your PR, and replace `[ ]` with `[x]` to indicate you have done them. --> - Hardware Type: <BLANK> - [x] run `make test` to ensure correctness - [x] run `make checkstyle` to ensure code style - [ ] run `make test-convergence` to ensure convergence --------- Co-authored-by: Vaibhav Jindal <[email protected]>
1 parent 8a93398 commit 397ab30

File tree

5 files changed

+312
-59
lines changed

5 files changed

+312
-59
lines changed

src/liger_kernel/chunked_loss/fused_linear_ppo.py

Lines changed: 21 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -32,7 +32,7 @@ def forward(
3232
epsilon_low=0.2,
3333
epsilon_high=0.2,
3434
beta=0.04,
35-
loss_type="bnpo",
35+
loss_type="dapo",
3636
max_completion_length=None,
3737
importance_sampling_level="token",
3838
temperature=1.0,
@@ -60,7 +60,7 @@ def forward(
6060
epsilon_low: Lower bound for clipping the importance sampling ratio
6161
epsilon_high: Upper bound for clipping the importance sampling ratio
6262
beta: Weight for the KL penalty
63-
loss_type: Type of loss calculation ("grpo", "bnpo", "dr_grpo")
63+
loss_type: Type of loss calculation ("grpo", "bnpo", "dr_grpo", "dapo")
6464
max_completion_length: Maximum completion length required for "dr_grpo"
6565
temperature: Temperature for the logits
6666
compiled: Whether to use torch compile
@@ -244,6 +244,21 @@ def accumulate_chunk(
244244

245245
return loss_acc, tuple(final_metrics)
246246

247+
@staticmethod
248+
def _compute_dapo_normalizer(attention_mask):
249+
"""Global active tokens averaged per process."""
250+
normalizer = attention_mask.to(torch.float32).sum()
251+
world_size = 1
252+
if torch.distributed.is_available() and torch.distributed.is_initialized():
253+
import torch.distributed as dist
254+
255+
normalizer = normalizer.clone()
256+
dist.all_reduce(normalizer, op=dist.ReduceOp.SUM)
257+
world_size = dist.get_world_size()
258+
259+
normalizer = normalizer / world_size
260+
return torch.clamp(normalizer, min=1.0)
261+
247262
@staticmethod
248263
def _compute_chunk_loss(
249264
input_chunk,
@@ -261,7 +276,7 @@ def _compute_chunk_loss(
261276
epsilon_low=0.2,
262277
epsilon_high=0.2,
263278
beta=0.04,
264-
loss_type="bnpo",
279+
loss_type="dapo",
265280
max_completion_length=None,
266281
importance_sampling_level="token",
267282
temperature=1.0,
@@ -341,10 +356,11 @@ def backward(ctx, grad_output, *grad_metrics):
341356
None, # grad_epsilon_low
342357
None, # grad_epsilon_high
343358
None, # grad_beta
359+
None, # grad_loss_type
360+
None, # grad_max_completion_length
361+
None, # grad_importance_sampling_level
344362
None, # grad_temperature
345363
None, # grad_compiled
346364
None, # grad_use_ref_model
347365
None, # grad_chunk_size
348-
None, # grad_loss_type
349-
None, # grad_max_completion_length
350366
)

src/liger_kernel/chunked_loss/grpo_loss.py

Lines changed: 8 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -29,7 +29,7 @@ def ppo_loss_fn(
2929
epsilon_low=0.2,
3030
epsilon_high=0.2,
3131
beta=0.04,
32-
loss_type="bnpo", # ["grpo", "bnpo", "dr_grpo"]
32+
loss_type="dapo", # ["grpo", "bnpo", "dr_grpo", "dapo"]
3333
max_completion_length=None, # Required for dr_grpo
3434
importance_sampling_level="token", # ["token", "sequence"] - new parameter for GSPO
3535
**kwargs,
@@ -94,6 +94,9 @@ def ppo_loss_fn(
9494
if max_completion_length is None:
9595
raise ValueError("max_completion_length must be provided for loss_type 'dr_grpo'")
9696
loss = (per_token_loss * attention_mask).sum() / (full_attention_mask.shape[0] * max_completion_length)
97+
elif loss_type == "dapo":
98+
loss_normalizer = LigerFusedLinearPPOBase._compute_dapo_normalizer(full_attention_mask)
99+
loss = (per_token_loss * attention_mask).sum() / loss_normalizer
97100
else:
98101
raise ValueError(f"Unknown loss type: {loss_type}")
99102

@@ -135,7 +138,7 @@ def forward(
135138
beta=0.04,
136139
epsilon_low=0.2,
137140
epsilon_high=0.2,
138-
loss_type="bnpo",
141+
loss_type="dapo",
139142
max_completion_length=None,
140143
importance_sampling_level="token",
141144
temperature=1.0,
@@ -157,7 +160,7 @@ def forward(
157160
ref_weight (torch.Tensor, optional): Reference model weight tensor. Shape: (vocab_size, hidden_size)
158161
ref_bias (torch.Tensor, optional): Reference model bias tensor. Shape: (vocab_size,)
159162
beta (float): Weight for the KL penalty
160-
loss_type (str): Type of loss calculation ("grpo", "bnpo", "dr_grpo"). Defaults to "bnpo".
163+
loss_type (str): Type of loss calculation ("grpo", "bnpo", "dr_grpo", "dapo"). Defaults to "dapo".
161164
max_completion_length (int, optional): Maximum completion length, required for "dr_grpo". Defaults to None.
162165
importance_sampling_level (str): Level of importance sampling ("token" or "sequence"). Defaults to "token".
163166
temperature (float): Temperature for the logits
@@ -235,7 +238,7 @@ def __init__(
235238
chunk_size: int = 1,
236239
epsilon_low: float = 0.2,
237240
epsilon_high: float = 0.2,
238-
loss_type: str = "bnpo",
241+
loss_type: str = "dapo",
239242
max_completion_length: Optional[int] = None,
240243
importance_sampling_level: str = "token",
241244
temperature: float = 1.0,
@@ -248,7 +251,7 @@ def __init__(
248251
chunk_size (int): Size of chunks for processing.
249252
epsilon_low (float): Lower bound for the importance sampling ratio.
250253
epsilon_high (float): Upper bound for the importance sampling ratio.
251-
loss_type (str): Type of loss calculation ("grpo", "bnpo", "dr_grpo"). Defaults to "bnpo".
254+
loss_type (str): Type of loss calculation ("grpo", "bnpo", "dr_grpo", "dapo"). Defaults to "dapo".
252255
max_completion_length (int, optional): Maximum completion length, required for "dr_grpo". Defaults to None.
253256
importance_sampling_level (str): Level of importance sampling ("token" or "sequence"). Defaults to "token".
254257
temperature (float): Temperature for the logits.

src/liger_kernel/ops/grpo_loss.py

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -128,7 +128,9 @@ def _grpo_loss_fwd_kernel(
128128
per_token_loss1 = coef_1 * advantage
129129
per_token_loss2 = coef_2 * advantage
130130
per_token_loss = -tl.minimum(per_token_loss1, per_token_loss2)
131-
is_clipped = per_token_loss1 < per_token_loss2
131+
is_low_clipped = (coef_1 < 1 - EPS_LOW) & (advantage < 0)
132+
is_high_clipped = (coef_1 > 1 + EPS_HIGH) & (advantage > 0)
133+
is_clipped = is_low_clipped | is_high_clipped
132134

133135
if BETA != 0.0:
134136
REF_LOGP += off_b * L + off_l

src/liger_kernel/transformers/grpo_loss.py

Lines changed: 56 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,3 +1,6 @@
1+
import torch
2+
3+
from liger_kernel.chunked_loss.fused_linear_ppo import LigerFusedLinearPPOBase
14
from liger_kernel.ops.grpo_loss import GrpoLossFunction
25

36

@@ -13,12 +16,20 @@ def triton_grpo_loss(
1316
eps_low=0.2,
1417
eps_high=0.4,
1518
inplace=True,
19+
loss_type="dapo",
20+
max_completion_length=None,
21+
importance_sampling_level="token",
22+
reduce=False,
1623
):
1724
assert logits is not None and completion_ids is not None and advantages is not None, (
1825
"must provide logits、completion_ids and advantages"
1926
)
27+
if importance_sampling_level != "token":
28+
raise ValueError(
29+
f"Triton GRPO loss only supports token-level importance sampling. Got {importance_sampling_level}."
30+
)
2031

21-
return GrpoLossFunction.apply(
32+
per_token_loss, per_token_kl, is_clipped = GrpoLossFunction.apply(
2233
logits,
2334
old_logp,
2435
ref_logp,
@@ -31,6 +42,50 @@ def triton_grpo_loss(
3142
eps_high,
3243
inplace,
3344
)
45+
if not reduce:
46+
return per_token_loss, per_token_kl, is_clipped
47+
48+
loss = _reduce_grpo_loss(
49+
per_token_loss,
50+
completion_mask,
51+
loss_type=loss_type,
52+
max_completion_length=max_completion_length,
53+
)
54+
55+
metrics = []
56+
if beta != 0.0 and per_token_kl is not None:
57+
metrics.append(_masked_mean(per_token_kl, completion_mask))
58+
metrics.append(_masked_mean(is_clipped.float(), completion_mask))
59+
return loss, metrics
60+
61+
62+
def _reduce_grpo_loss(per_token_loss, completion_mask, loss_type, max_completion_length):
63+
mask = completion_mask
64+
if mask is None:
65+
mask = torch.ones_like(per_token_loss, dtype=per_token_loss.dtype, device=per_token_loss.device)
66+
mask = mask.to(per_token_loss.dtype)
67+
68+
if loss_type == "grpo":
69+
per_seq = (per_token_loss * mask).sum(-1) / mask.sum(-1).clamp(min=1.0)
70+
return per_seq.mean()
71+
if loss_type == "bnpo":
72+
return (per_token_loss * mask).sum() / mask.sum().clamp(min=1.0)
73+
if loss_type == "dr_grpo":
74+
if max_completion_length is None:
75+
raise ValueError("max_completion_length must be provided when using loss_type='dr_grpo'")
76+
batch = per_token_loss.shape[0]
77+
return (per_token_loss * mask).sum() / (batch * max_completion_length)
78+
if loss_type == "dapo":
79+
normalizer = LigerFusedLinearPPOBase._compute_dapo_normalizer(mask)
80+
return (per_token_loss * mask).sum() / normalizer
81+
raise ValueError(f"Unsupported loss_type '{loss_type}' for Triton GRPO loss.")
82+
83+
84+
def _masked_mean(values, mask):
85+
if mask is None:
86+
mask = torch.ones_like(values, dtype=values.dtype, device=values.device)
87+
mask = mask.to(values.dtype)
88+
return (values * mask).sum() / mask.sum().clamp(min=1.0)
3489

3590

3691
# This is a demo how to use grpo_loss in GRPOTrainer. The Trl version must be 0.16

0 commit comments

Comments
 (0)