Skip to content

Commit cc14537

Browse files
yukiu00Tcc0403
andauthored
Add CISPO and SAPO loss type support for Triton GRPO loss kernel (#1074)
## Summary Add **CISPO** and **SAPO** loss type support to the Triton `ops/grpo_loss.py` kernel. This is a follow-up to: - #1054 - Add CISPO loss type support for LigerFusedLinearGRPOLoss (chunked loss path) - #1073 - Add SAPO loss type support for LigerFusedLinearGRPOLoss (chunked loss path) > **Note**: This PR depends on #1073 (SAPO PR) being merged first, as it builds on top of that branch. ### Background PR #1054 and #1073 added CISPO and SAPO support to the `chunked_loss` path, but the `ops` (Triton kernel) path was marked as a follow-up. This PR implements that follow-up. ### Changes **`src/liger_kernel/ops/grpo_loss.py`** - Add loss type constants (`_LOSS_TYPE_GRPO`, `_LOSS_TYPE_CISPO`, `_LOSS_TYPE_SAPO`) with `tl.constexpr` for compile-time branching - Implement CISPO in forward/backward kernels: - Upper-bound only clipping (no lower bound) - Detached coefficient (gradient only flows through logp) - Loss formula: `-coef_2 * advantage * logp` - Implement SAPO in forward/backward kernels: - Sigmoid-based soft gating instead of hard clipping - Different temperatures for positive/negative advantages - Loss formula: `-sigmoid(τ*(ρ-1)) * 4/τ * advantage` - Update `GrpoLossFunction` with `loss_type`, `sapo_temperature_pos`, `sapo_temperature_neg` parameters **`src/liger_kernel/transformers/grpo_loss.py`** - Remove error blocking for CISPO and SAPO loss types - Add `sapo_temperature_pos` and `sapo_temperature_neg` parameters - Update `_reduce_grpo_loss` to handle CISPO (DAPO normalization) and SAPO (GRPO normalization) **`test/transformers/test_grpo_loss.py`** - Add reference PyTorch implementations (`torch_cispo_loss`, `torch_sapo_loss`) - Add `test_cispo_loss` and `test_sapo_loss` test functions ## Testing Done - Hardware Type: NVIDIA GPU - [x] run `make test` to ensure correctness - [x] run `make checkstyle` to ensure code style - [ ] run `make test-convergence` to ensure convergence --------- Signed-off-by: Tcc0403 <76503978+Tcc0403@users.noreply.github.com> Co-authored-by: Tcc0403 <76503978+Tcc0403@users.noreply.github.com>
1 parent ad6f0a7 commit cc14537

File tree

4 files changed

+401
-57
lines changed

4 files changed

+401
-57
lines changed

src/liger_kernel/chunked_loss/grpo_loss.py

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -339,6 +339,11 @@ def __init__(
339339
temperature (float): Temperature for the logits.
340340
"""
341341
super().__init__()
342+
# Validate SAPO temperatures to prevent division by zero or numerical instability
343+
if sapo_temperature_pos <= 0:
344+
raise ValueError(f"sapo_temperature_pos must be positive, got {sapo_temperature_pos}")
345+
if sapo_temperature_neg <= 0:
346+
raise ValueError(f"sapo_temperature_neg must be positive, got {sapo_temperature_neg}")
342347
self.beta = beta
343348
self.compiled = compiled
344349
self.use_ref_model = use_ref_model

src/liger_kernel/ops/grpo_loss.py

Lines changed: 113 additions & 17 deletions
Original file line numberDiff line numberDiff line change
@@ -2,6 +2,21 @@
22
import triton
33
import triton.language as tl
44

5+
# Loss type constants for Triton constexpr branching
6+
# GRPO/DAPO/BNPO/DR_GRPO all use the same per-token loss computation (standard PPO clipping)
7+
_LOSS_TYPE_GRPO: tl.constexpr = tl.constexpr(0)
8+
_LOSS_TYPE_CISPO: tl.constexpr = tl.constexpr(1)
9+
_LOSS_TYPE_SAPO: tl.constexpr = tl.constexpr(2)
10+
11+
_str_to_loss_type = {
12+
"grpo": _LOSS_TYPE_GRPO.value,
13+
"dapo": _LOSS_TYPE_GRPO.value,
14+
"bnpo": _LOSS_TYPE_GRPO.value,
15+
"dr_grpo": _LOSS_TYPE_GRPO.value,
16+
"cispo": _LOSS_TYPE_CISPO.value,
17+
"sapo": _LOSS_TYPE_SAPO.value,
18+
}
19+
520

621
@triton.jit
722
def _selective_log_softmax_kernel(
@@ -83,6 +98,9 @@ def _grpo_loss_fwd_kernel(
8398
BETA: tl.constexpr,
8499
EPS_LOW,
85100
EPS_HIGH,
101+
LOSS_TYPE: tl.constexpr,
102+
SAPO_TEMP_POS,
103+
SAPO_TEMP_NEG,
86104
L: tl.constexpr,
87105
N: tl.constexpr,
88106
BLOCK_N: tl.constexpr = 4096,
@@ -123,14 +141,33 @@ def _grpo_loss_fwd_kernel(
123141
OLD_LOGP += off_b * L + off_l
124142
old_logp = tl.load(OLD_LOGP).to(tl.float32)
125143
coef_1 = tl.exp(logp - old_logp)
126-
coef_2 = tl.clamp(coef_1, 1 - EPS_LOW, 1 + EPS_HIGH)
127144
advantage = tl.load(ADVANTAGES).to(tl.float32)
128-
per_token_loss1 = coef_1 * advantage
129-
per_token_loss2 = coef_2 * advantage
130-
per_token_loss = -tl.minimum(per_token_loss1, per_token_loss2)
131-
is_low_clipped = (coef_1 < 1 - EPS_LOW) & (advantage < 0)
132-
is_high_clipped = (coef_1 > 1 + EPS_HIGH) & (advantage > 0)
133-
is_clipped = is_low_clipped | is_high_clipped
145+
146+
# Branch based on loss type
147+
if LOSS_TYPE == 0: # GRPO/DAPO/BNPO/DR_GRPO: standard PPO clipping
148+
coef_2 = tl.clamp(coef_1, 1 - EPS_LOW, 1 + EPS_HIGH)
149+
per_token_loss1 = coef_1 * advantage
150+
per_token_loss2 = coef_2 * advantage
151+
per_token_loss = -tl.minimum(per_token_loss1, per_token_loss2)
152+
is_low_clipped = (coef_1 < 1 - EPS_LOW) & (advantage < 0)
153+
is_high_clipped = (coef_1 > 1 + EPS_HIGH) & (advantage > 0)
154+
is_clipped = is_low_clipped | is_high_clipped
155+
156+
elif LOSS_TYPE == 1: # CISPO: upper-bound only clipping, detached, multiply by logp
157+
# Reference: MiniMax-M1 technical report
158+
# https://github.com/huggingface/trl/blob/035c3ff151b953ca72cdfe0ee966bc1469a26fde/trl/trainer/grpo_trainer.py#L2030
159+
coef_2 = tl.minimum(coef_1, EPS_HIGH) # upper-bound only (EPS_HIGH is the raw bound for CISPO)
160+
per_token_loss = -coef_2 * advantage * logp # includes logp term
161+
is_clipped = (coef_1 > EPS_HIGH) & (advantage > 0)
162+
163+
elif LOSS_TYPE == 2: # SAPO: soft adaptive policy optimization with sigmoid gating
164+
# Reference: https://huggingface.co/papers/2511.20347
165+
# Formula: sigmoid(τ * (ρ - 1)) * 4 / τ
166+
temperature = tl.where(advantage > 0, SAPO_TEMP_POS, SAPO_TEMP_NEG)
167+
sigmoid_input = temperature * (coef_1 - 1.0)
168+
sapo_coef = tl.sigmoid(sigmoid_input) * 4.0 / temperature
169+
per_token_loss = -sapo_coef * advantage
170+
is_clipped = 0.0 # SAPO has no clipping concept
134171

135172
if BETA != 0.0:
136173
REF_LOGP += off_b * L + off_l
@@ -165,6 +202,9 @@ def _grpo_loss_bwd_kernel(
165202
BETA: tl.constexpr,
166203
EPS_LOW,
167204
EPS_HIGH,
205+
LOSS_TYPE: tl.constexpr,
206+
SAPO_TEMP_POS,
207+
SAPO_TEMP_NEG,
168208
loss_stride0,
169209
loss_stride1,
170210
L: tl.constexpr,
@@ -202,13 +242,35 @@ def _grpo_loss_bwd_kernel(
202242
OLD_LOGP += off_b * L + off_l
203243
old_logp = tl.load(OLD_LOGP).to(tl.float32)
204244
coef_1 = tl.exp(logp - old_logp)
205-
coef_2 = tl.clamp(coef_1, 1 - EPS_LOW, 1 + EPS_HIGH)
206245
advantage = tl.load(ADVANTAGES).to(tl.float32)
207-
per_token_loss1 = coef_1 * advantage
208-
per_token_loss2 = coef_2 * advantage
209-
mask = per_token_loss2 >= per_token_loss1
210246

211-
dlogp = -per_token_loss1 * mask
247+
# Branch based on loss type for gradient computation
248+
if LOSS_TYPE == 0: # GRPO/DAPO/BNPO/DR_GRPO: standard PPO clipping
249+
coef_2 = tl.clamp(coef_1, 1 - EPS_LOW, 1 + EPS_HIGH)
250+
per_token_loss1 = coef_1 * advantage
251+
per_token_loss2 = coef_2 * advantage
252+
mask = per_token_loss2 >= per_token_loss1
253+
dlogp = -per_token_loss1 * mask
254+
255+
elif LOSS_TYPE == 1: # CISPO: coef_2 is DETACHED, so gradient only flows through logp
256+
# loss = -coef_2 * advantage * logp, where coef_2 = clamp(coef_1, max=eps_high).detach()
257+
# d(loss)/d(logp) = -coef_2 * advantage (coef_2 treated as constant due to detach)
258+
coef_2 = tl.minimum(coef_1, EPS_HIGH)
259+
dlogp = -coef_2 * advantage
260+
261+
elif LOSS_TYPE == 2: # SAPO: gradient through sigmoid gating
262+
# loss = -sapo_coef * advantage, where sapo_coef = sigmoid(τ*(ρ-1)) * 4/τ
263+
# d(loss)/d(logp) = -advantage * d(sapo_coef)/d(coef_1) * d(coef_1)/d(logp)
264+
# d(coef_1)/d(logp) = coef_1 (since coef_1 = exp(logp - old_logp))
265+
# d(sapo_coef)/d(coef_1) = d/d(coef_1)[sigmoid(τ*(coef_1-1)) * 4/τ]
266+
# = τ * sigmoid' * 4/τ = 4 * sigmoid * (1 - sigmoid)
267+
# (the τ factors cancel out in the derivative)
268+
temperature = tl.where(advantage > 0, SAPO_TEMP_POS, SAPO_TEMP_NEG)
269+
sigmoid_input = temperature * (coef_1 - 1.0)
270+
sigmoid_val = tl.sigmoid(sigmoid_input)
271+
d_sapo_d_coef1 = 4.0 * sigmoid_val * (1.0 - sigmoid_val)
272+
dlogp = -advantage * d_sapo_d_coef1 * coef_1
273+
212274
if BETA != 0.0:
213275
REF_LOGP += off_b * L + off_l
214276
ref_logp = tl.load(REF_LOGP).to(tl.float32)
@@ -239,11 +301,28 @@ def forward(
239301
eps_low,
240302
eps_high,
241303
inplace,
304+
loss_type="grpo",
305+
sapo_temperature_pos=1.0,
306+
sapo_temperature_neg=1.05,
242307
):
243308
assert logits.is_contiguous() and completion_ids.is_contiguous()
244309
assert old_logp is None or old_logp.is_contiguous()
245310
assert (ref_logp is not None and ref_logp.is_contiguous()) if beta != 0.0 else True
246311

312+
# Validate loss_type
313+
if loss_type not in _str_to_loss_type:
314+
raise ValueError(f"Unknown loss_type '{loss_type}'. Supported types: {list(_str_to_loss_type.keys())}")
315+
316+
# Validate SAPO temperatures to prevent division by zero or numerical instability
317+
if loss_type == "sapo":
318+
if sapo_temperature_pos <= 0:
319+
raise ValueError(f"sapo_temperature_pos must be positive, got {sapo_temperature_pos}")
320+
if sapo_temperature_neg <= 0:
321+
raise ValueError(f"sapo_temperature_neg must be positive, got {sapo_temperature_neg}")
322+
323+
# Convert loss_type string to integer for Triton constexpr
324+
loss_type_int = _str_to_loss_type[loss_type]
325+
247326
B, L_ADD_1, N = logits.shape
248327
L = L_ADD_1 - 1
249328

@@ -270,21 +349,33 @@ def forward(
270349
beta,
271350
eps_low,
272351
eps_high,
352+
loss_type_int,
353+
sapo_temperature_pos,
354+
sapo_temperature_neg,
273355
L,
274356
N,
275357
**kwargs,
276358
)
277359
ctx.save_for_backward(logits, old_logp, ref_logp, completion_ids, advantages, completion_mask, lse)
278-
ctx.infos = (temperature, beta, eps_low, eps_high, inplace)
279-
# return loss
360+
ctx.infos = (
361+
temperature,
362+
beta,
363+
eps_low,
364+
eps_high,
365+
inplace,
366+
loss_type_int,
367+
sapo_temperature_pos,
368+
sapo_temperature_neg,
369+
)
280370
return loss, kl, is_clipped
281371

282372
@staticmethod
283373
def backward(ctx, *args):
284374
dloss = args[0]
285-
# print(dloss.shape)
286375
logits, old_logp, ref_logp, completion_ids, advantages, completion_mask, lse = ctx.saved_tensors
287-
temperature, beta, eps_low, eps_high, inplace = ctx.infos
376+
temperature, beta, eps_low, eps_high, inplace, loss_type_int, sapo_temperature_pos, sapo_temperature_neg = (
377+
ctx.infos
378+
)
288379
B, L_ADD_1, N = logits.shape
289380
L = L_ADD_1 - 1
290381
dlogits = logits.data if inplace else torch.empty_like(logits)
@@ -303,10 +394,15 @@ def backward(ctx, *args):
303394
beta,
304395
eps_low,
305396
eps_high,
397+
loss_type_int,
398+
sapo_temperature_pos,
399+
sapo_temperature_neg,
306400
*dloss.stride(),
307401
L,
308402
N,
309403
**kwargs,
310404
)
311405
dlogits[:, -1, :] = 0
312-
return dlogits, None, None, None, None, None, None, None, None, None, None
406+
# Return None for: old_logp, ref_logp, completion_ids, advantages, completion_mask,
407+
# temperature, beta, eps_low, eps_high, inplace, loss_type, sapo_temperature_pos, sapo_temperature_neg
408+
return dlogits, None, None, None, None, None, None, None, None, None, None, None, None, None

src/liger_kernel/transformers/grpo_loss.py

Lines changed: 9 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -20,6 +20,8 @@ def triton_grpo_loss(
2020
max_completion_length=None,
2121
importance_sampling_level="token",
2222
reduce=False,
23+
sapo_temperature_pos=1.0,
24+
sapo_temperature_neg=1.05,
2325
):
2426
assert logits is not None and completion_ids is not None and advantages is not None, (
2527
"must provide logits, completion_ids and advantages"
@@ -28,10 +30,6 @@ def triton_grpo_loss(
2830
raise ValueError(
2931
f"Triton GRPO loss only supports token-level importance sampling. Got {importance_sampling_level}."
3032
)
31-
if loss_type == "cispo":
32-
raise ValueError("Triton GRPO loss does not support loss_type='cispo'. Use the chunked GRPO loss path.")
33-
if loss_type == "sapo":
34-
raise ValueError("Triton GRPO loss does not support loss_type='sapo'. Use the chunked GRPO loss path.")
3533

3634
per_token_loss, per_token_kl, is_clipped = GrpoLossFunction.apply(
3735
logits,
@@ -45,6 +43,9 @@ def triton_grpo_loss(
4543
eps_low,
4644
eps_high,
4745
inplace,
46+
loss_type,
47+
sapo_temperature_pos,
48+
sapo_temperature_neg,
4849
)
4950
if not reduce:
5051
return per_token_loss, per_token_kl, is_clipped
@@ -69,7 +70,8 @@ def _reduce_grpo_loss(per_token_loss, completion_mask, loss_type, max_completion
6970
mask = torch.ones_like(per_token_loss, dtype=per_token_loss.dtype, device=per_token_loss.device)
7071
mask = mask.to(per_token_loss.dtype)
7172

72-
if loss_type == "grpo":
73+
if loss_type == "grpo" or loss_type == "sapo":
74+
# SAPO uses the same normalization as GRPO (per-sequence average)
7375
per_seq = (per_token_loss * mask).sum(-1) / mask.sum(-1).clamp(min=1.0)
7476
return per_seq.mean()
7577
if loss_type == "bnpo":
@@ -79,7 +81,8 @@ def _reduce_grpo_loss(per_token_loss, completion_mask, loss_type, max_completion
7981
raise ValueError("max_completion_length must be provided when using loss_type='dr_grpo'")
8082
batch = per_token_loss.shape[0]
8183
return (per_token_loss * mask).sum() / (batch * max_completion_length)
82-
if loss_type == "dapo":
84+
if loss_type == "dapo" or loss_type == "cispo":
85+
# CISPO uses the same normalization as DAPO
8386
normalizer = LigerFusedLinearPPOBase._compute_dapo_normalizer(mask)
8487
return (per_token_loss * mask).sum() / normalizer
8588
raise ValueError(f"Unsupported loss_type '{loss_type}' for Triton GRPO loss.")

0 commit comments

Comments
 (0)