Skip to content

Commit 83cdcf8

Browse files
yukiu00Tcc0403
andauthored
Add CISPO loss type support for LigerFusedLinearGRPOLoss (#1054)
## Summary Resolve: #1057 * Add **CISPO** (`loss_type="cispo"`) support to **`LigerFusedLinearGRPOLoss`** (chunked loss path) * Enable TRL's **`GRPOTrainer`** to work with `use_liger_kernel=True` and `loss_type="cispo"` ### Background / Motivation CISPO (Clipped Importance Sampling Policy Optimization) is a loss variant proposed in the **MiniMax-M1** technical report. It clips the importance sampling ratio with **only an upper bound** and **detaches it from gradient computation**. TRL added `loss_type="cispo"` to `GRPOTrainer`, but Liger Kernel did not support it, causing errors when using `use_liger_kernel=True` with `loss_type="cispo"`. ### Changes **`src/liger_kernel/chunked_loss/grpo_loss.py`** * Add CISPO loss matching TRL's implementation * Clip importance sampling ratio with **upper bound only** and **detach**: ```python clamped_ratios = torch.clamp(coef_1, max=epsilon_high).detach() ``` * Use **DAPO-style normalization** for CISPO reduction (consistent with TRL) * Add CISPO-specific clip metric for logging compatibility: * Count tokens where `(coef_1 > epsilon_high) & (advantages > 0)` **`src/liger_kernel/transformers/grpo_loss.py`** * Add CISPO reduction logic (uses same normalizer as DAPO) * Raise explicit error for Triton GRPO loss path (CISPO not supported there) **`ops/grpo_loss` (Triton fused path)** * CISPO is **not implemented** in `ops/grpo_loss` in this PR * `loss_type="cispo"` is **only supported via chunked loss path** (Triton fused support is a follow-up) **`test/chunked_loss/test_grpo_loss.py`** * Add CISPO to torch reference implementation (`TorchLMHeadGRPO`) * Add `"cispo"` to parameterized test cases to verify parity with reference ### References * MiniMax-M1 (CISPO introduction): https://arxiv.org/abs/2506.13585 * DAPO (normalization / reduction reference): https://arxiv.org/abs/2503.14476 * TRL CISPO implementation: https://github.com/huggingface/trl/blob/035c3ff151b953ca72cdfe0ee966bc1469a26fde/trl/trainer/grpo_trainer.py#L2030 ## Testing Done - Hardware Type: RTX3090 24GB (NVIDIA Ampere) - [x] run `make test` to ensure correctness - [x] run `make checkstyle` to ensure code style - [x] run `make test-convergence` to ensure convergence --------- Signed-off-by: Tcc0403 <76503978+Tcc0403@users.noreply.github.com> Co-authored-by: Tcc0403 <76503978+Tcc0403@users.noreply.github.com>
1 parent 81f932a commit 83cdcf8

File tree

4 files changed

+73
-31
lines changed

4 files changed

+73
-31
lines changed

src/liger_kernel/chunked_loss/fused_linear_ppo.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -60,7 +60,7 @@ def forward(
6060
epsilon_low: Lower bound for clipping the importance sampling ratio
6161
epsilon_high: Upper bound for clipping the importance sampling ratio
6262
beta: Weight for the KL penalty
63-
loss_type: Type of loss calculation ("grpo", "bnpo", "dr_grpo", "dapo")
63+
loss_type: Type of loss calculation ("grpo", "bnpo", "dr_grpo", "dapo", "cispo")
6464
max_completion_length: Maximum completion length required for "dr_grpo"
6565
temperature: Temperature for the logits
6666
compiled: Whether to use torch compile

src/liger_kernel/chunked_loss/grpo_loss.py

Lines changed: 35 additions & 15 deletions
Original file line numberDiff line numberDiff line change
@@ -11,8 +11,21 @@ def k3_loss_fn(log_p, log_q):
1111
return torch.exp(log_p - log_q) - (log_p - log_q) - 1.0
1212

1313

14-
def clip_coef_fn(coef, epsilon_low, epsilon_high):
15-
return torch.clamp(coef, 1 - epsilon_low, 1 + epsilon_high)
14+
def clip_coef_fn(coef, epsilon_low, epsilon_high, loss_type):
15+
if loss_type == "cispo":
16+
# CISPO: clip and detach the importance weights
17+
upper_bound = epsilon_high
18+
lower_bound = None
19+
clipped_coef = torch.clamp(coef, lower_bound, upper_bound).detach()
20+
is_lower_clipped = False
21+
is_upper_clipped = coef > upper_bound
22+
else:
23+
upper_bound = 1 + epsilon_high
24+
lower_bound = 1 - epsilon_low
25+
clipped_coef = torch.clamp(coef, lower_bound, upper_bound)
26+
is_lower_clipped = coef < lower_bound
27+
is_upper_clipped = coef > upper_bound
28+
return clipped_coef, is_lower_clipped, is_upper_clipped
1629

1730

1831
class LigerFusedLinearGRPOFunction(LigerFusedLinearPPOBase):
@@ -29,7 +42,7 @@ def ppo_loss_fn(
2942
epsilon_low=0.2,
3043
epsilon_high=0.2,
3144
beta=0.04,
32-
loss_type="dapo", # ["grpo", "bnpo", "dr_grpo", "dapo"]
45+
loss_type="dapo", # ["grpo", "bnpo", "dr_grpo", "dapo", "cispo"]
3346
max_completion_length=None, # Required for dr_grpo
3447
importance_sampling_level="token", # ["token", "sequence"] - new parameter for GSPO
3548
**kwargs,
@@ -67,10 +80,15 @@ def ppo_loss_fn(
6780
# From here, log_importance_weights (and all subsequent tensors, coef_1, coef_2, etc.) shape depends on
6881
# importance_sampling_level: "token" level: (B, T); "sequence" level: (B, 1)
6982
coef_1 = torch.exp(log_importance_weights)
70-
coef_2 = clip_coef_fn(coef_1, epsilon_low, epsilon_high)
71-
per_token_loss1 = coef_1 * advantages.unsqueeze(1)
72-
per_token_loss2 = coef_2 * advantages.unsqueeze(1)
73-
per_token_loss = -torch.min(per_token_loss1, per_token_loss2)
83+
coef_2, is_lower_clipped, is_upper_clipped = clip_coef_fn(coef_1, epsilon_low, epsilon_high, loss_type)
84+
if loss_type == "cispo":
85+
# CISPO: clip and detach the importance weights, multiply by log probs
86+
# Reference: https://github.com/huggingface/trl/blob/035c3ff151b953ca72cdfe0ee966bc1469a26fde/trl/trainer/grpo_trainer.py#L2030
87+
per_token_loss = -coef_2 * advantages.unsqueeze(1) * per_token_logps
88+
else:
89+
per_token_loss1 = coef_1 * advantages.unsqueeze(1)
90+
per_token_loss2 = coef_2 * advantages.unsqueeze(1)
91+
per_token_loss = -torch.min(per_token_loss1, per_token_loss2)
7492
if beta != 0.0:
7593
# Compute KL penalty (approximates KL[per_token_logps, ref_per_token_logps])
7694
kl_div = k3_loss_fn(ref_per_token_logps, per_token_logps)
@@ -94,7 +112,7 @@ def ppo_loss_fn(
94112
if max_completion_length is None:
95113
raise ValueError("max_completion_length must be provided for loss_type 'dr_grpo'")
96114
loss = (per_token_loss * attention_mask).sum() / (full_attention_mask.shape[0] * max_completion_length)
97-
elif loss_type == "dapo":
115+
elif loss_type == "dapo" or loss_type == "cispo":
98116
loss_normalizer = LigerFusedLinearPPOBase._compute_dapo_normalizer(full_attention_mask)
99117
loss = (per_token_loss * attention_mask).sum() / loss_normalizer
100118
else:
@@ -107,15 +125,15 @@ def ppo_loss_fn(
107125

108126
# Adjust clipping metric calculation based on importance sampling level
109127
if importance_sampling_level == "token":
110-
is_clipped = ((coef_1 < 1 - epsilon_low) & (advantages.unsqueeze(1) < 0)) | (
111-
(coef_1 > 1 + epsilon_high) & (advantages.unsqueeze(1) > 0)
128+
is_clipped = (is_lower_clipped & (advantages.unsqueeze(1) < 0)) | (
129+
is_upper_clipped & (advantages.unsqueeze(1) > 0)
112130
)
113131
else: # sequence level
114132
# For sequence level, coef_1 is shape (B, 1), advantages is shape (B,)
115-
is_clipped = ((coef_1.squeeze(-1) < 1 - epsilon_low) & (advantages < 0)) | (
116-
(coef_1.squeeze(-1) > 1 + epsilon_high) & (advantages > 0)
133+
is_clipped = (is_lower_clipped & (advantages.unsqueeze(1) < 0)) | (
134+
is_upper_clipped & (advantages.unsqueeze(1) > 0)
117135
)
118-
is_clipped = is_clipped.unsqueeze(1).expand_as(attention_mask)
136+
is_clipped = is_clipped.expand_as(attention_mask)
119137

120138
metrics.append((is_clipped * attention_mask).sum() / torch.clamp(full_attention_mask.sum(), min=1.0))
121139
return loss, metrics
@@ -160,7 +178,7 @@ def forward(
160178
ref_weight (torch.Tensor, optional): Reference model weight tensor. Shape: (vocab_size, hidden_size)
161179
ref_bias (torch.Tensor, optional): Reference model bias tensor. Shape: (vocab_size,)
162180
beta (float): Weight for the KL penalty
163-
loss_type (str): Type of loss calculation ("grpo", "bnpo", "dr_grpo", "dapo"). Defaults to "dapo".
181+
loss_type (str): Type of loss calculation ("grpo", "bnpo", "dr_grpo", "dapo", "cispo"). Defaults to "dapo".
164182
max_completion_length (int, optional): Maximum completion length, required for "dr_grpo". Defaults to None.
165183
importance_sampling_level (str): Level of importance sampling ("token" or "sequence"). Defaults to "token".
166184
temperature (float): Temperature for the logits
@@ -251,7 +269,9 @@ def __init__(
251269
chunk_size (int): Size of chunks for processing.
252270
epsilon_low (float): Lower bound for the importance sampling ratio.
253271
epsilon_high (float): Upper bound for the importance sampling ratio.
254-
loss_type (str): Type of loss calculation ("grpo", "bnpo", "dr_grpo", "dapo"). Defaults to "dapo".
272+
loss_type (str): Type of loss calculation ("grpo", "bnpo", "dr_grpo", "dapo", "cispo").
273+
Defaults to "dapo". For "cispo", epsilon_high is typically larger (e.g. 5.0) and
274+
epsilon_low is unused.
255275
max_completion_length (int, optional): Maximum completion length, required for "dr_grpo". Defaults to None.
256276
importance_sampling_level (str): Level of importance sampling ("token" or "sequence"). Defaults to "token".
257277
temperature (float): Temperature for the logits.

src/liger_kernel/transformers/grpo_loss.py

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -22,12 +22,14 @@ def triton_grpo_loss(
2222
reduce=False,
2323
):
2424
assert logits is not None and completion_ids is not None and advantages is not None, (
25-
"must provide logitscompletion_ids and advantages"
25+
"must provide logits, completion_ids and advantages"
2626
)
2727
if importance_sampling_level != "token":
2828
raise ValueError(
2929
f"Triton GRPO loss only supports token-level importance sampling. Got {importance_sampling_level}."
3030
)
31+
if loss_type == "cispo":
32+
raise ValueError("Triton GRPO loss does not support loss_type='cispo'. Use the chunked GRPO loss path.")
3133

3234
per_token_loss, per_token_kl, is_clipped = GrpoLossFunction.apply(
3335
logits,

test/chunked_loss/test_grpo_loss.py

Lines changed: 34 additions & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -58,6 +58,7 @@ def compute_per_token_components(
5858
epsilon_high,
5959
beta,
6060
importance_sampling_level,
61+
loss_type: str = "grpo",
6162
):
6263
attention_mask = attention_mask.to(per_token_logps.dtype)
6364
old_per_token_logps = (
@@ -77,28 +78,43 @@ def compute_per_token_components(
7778
)
7879

7980
coef_1 = torch.exp(log_importance_weights)
80-
coef_2 = torch.clamp(coef_1, 1 - epsilon_low, 1 + epsilon_high)
8181
expanded_advantages = advantages.unsqueeze(1)
82-
per_token_loss1 = coef_1 * expanded_advantages
83-
per_token_loss2 = coef_2 * expanded_advantages
84-
per_token_loss = -torch.min(per_token_loss1, per_token_loss2)
82+
# Compute clipped coefficients and clipping flags
83+
if loss_type == "cispo":
84+
# CISPO: clip and detach the importance weights
85+
upper_bound = epsilon_high
86+
lower_bound = None
87+
coef_2 = torch.clamp(coef_1, lower_bound, upper_bound).detach()
88+
is_lower_clipped = False
89+
is_upper_clipped = coef_1 > upper_bound
90+
else:
91+
upper_bound = 1 + epsilon_high
92+
lower_bound = 1 - epsilon_low
93+
coef_2 = torch.clamp(coef_1, lower_bound, upper_bound)
94+
is_lower_clipped = coef_1 < lower_bound
95+
is_upper_clipped = coef_1 > upper_bound
96+
97+
if loss_type == "cispo":
98+
# CISPO: clip and detach the importance weights, multiply by log probs
99+
# Reference: https://github.com/huggingface/trl/blob/035c3ff151b953ca72cdfe0ee966bc1469a26fde/trl/trainer/grpo_trainer.py#L2030
100+
per_token_loss = -coef_2 * expanded_advantages * per_token_logps
101+
else:
102+
per_token_loss1 = coef_1 * expanded_advantages
103+
per_token_loss2 = coef_2 * expanded_advantages
104+
per_token_loss = -torch.min(per_token_loss1, per_token_loss2)
85105
kl_div = None
86106
if beta != 0.0:
87107
ref_per_token_logps = ref_per_token_logps.float()
88108
kl_div = torch.exp(ref_per_token_logps - per_token_logps) - (ref_per_token_logps - per_token_logps) - 1.0
89109
per_token_loss = per_token_loss + beta * kl_div
90110

111+
# Adjust clipping metric calculation based on importance sampling level
91112
if importance_sampling_level == "token":
92-
is_clipped = ((coef_1 < 1 - epsilon_low) & (expanded_advantages < 0)) | (
93-
(coef_1 > 1 + epsilon_high) & (expanded_advantages > 0)
94-
)
113+
is_clipped = (is_lower_clipped & (expanded_advantages < 0)) | (is_upper_clipped & (expanded_advantages > 0))
95114
else: # sequence level
96115
# For sequence level, coef_1 is shape (B, 1), advantages is shape (B,)
97-
seq_advantages = advantages
98-
is_clipped = ((coef_1.squeeze(-1) < 1 - epsilon_low) & (seq_advantages < 0)) | (
99-
(coef_1.squeeze(-1) > 1 + epsilon_high) & (seq_advantages > 0)
100-
)
101-
is_clipped = is_clipped.unsqueeze(1).expand_as(attention_mask)
116+
is_clipped = (is_lower_clipped & (expanded_advantages < 0)) | (is_upper_clipped & (expanded_advantages > 0))
117+
is_clipped = is_clipped.expand_as(attention_mask)
102118
return per_token_loss, kl_div, is_clipped
103119

104120
def forward(
@@ -148,6 +164,7 @@ def forward(
148164
self.epsilon_high,
149165
self.beta,
150166
self.importance_sampling_level,
167+
self.loss_type,
151168
)
152169

153170
# Apply masking and calculate loss based on loss_type
@@ -160,6 +177,9 @@ def forward(
160177
elif self.loss_type == "dapo":
161178
normalizer = attention_mask.sum().clamp(min=1.0)
162179
loss = (per_token_loss * attention_mask).sum() / normalizer
180+
elif self.loss_type == "cispo":
181+
normalizer = attention_mask.sum().clamp(min=1.0)
182+
loss = (per_token_loss * attention_mask).sum() / normalizer
163183
else:
164184
raise ValueError(f"Unknown loss type: {self.loss_type}")
165185

@@ -259,7 +279,7 @@ def forward(
259279
(False, False, True),
260280
],
261281
)
262-
@pytest.mark.parametrize("loss_type", ["bnpo", "grpo", "dr_grpo", "dapo"])
282+
@pytest.mark.parametrize("loss_type", ["bnpo", "grpo", "dr_grpo", "dapo", "cispo"])
263283
@pytest.mark.parametrize("importance_sampling_level", ["token", "sequence"])
264284
def test_correctness(
265285
B,
@@ -565,7 +585,7 @@ def test_reduce_grpo_loss_matches_reference(loss_type):
565585
expected = (per_token_loss * mask_f).sum() / mask_f.sum().clamp(min=1.0)
566586
elif loss_type == "dr_grpo":
567587
expected = (per_token_loss * mask_f).sum() / (per_token_loss.size(0) * max_completion_length)
568-
else: # dapo
588+
else: # dapo/cispo
569589
expected = (per_token_loss * mask_f).sum() / mask_f.sum().clamp(min=1.0)
570590

571591
assert_verbose_allclose(reduced, expected)

0 commit comments

Comments
 (0)