Skip to content

Commit b19b749

Browse files
authored
[algo] feat: add cispo (verl-project#4508)
### What does this PR do? This PR adds the CISPO(https://arxiv.org/pdf/2506.13585) algorithm and a corresponding recipe to verl. ### Checklist Before Starting - [ ✅ ] Search for similar PRs. Paste at least one query link here: ... - [✅] Format the PR title as `[{modules}] {type}: {description}` (This will be checked by the CI) - `{modules}` include `fsdp`, `megatron`, `sglang`, `vllm`, `rollout`, `trainer`, `ci`, `training_utils`, `recipe`, `hardware`, `deployment`, `ray`, `worker`, `single_controller`, `misc`, `perf`, `model`, `algo`, `env`, `tool`, `ckpt`, `doc`, `data` - If this PR involves multiple modules, separate them with `,` like `[megatron, fsdp, doc]` - `{type}` is in `feat`, `fix`, `refactor`, `chore`, `test` - If this PR breaks any API (CLI arguments, config, function signature, etc.), add `[BREAKING]` to the beginning of the title. - Example: `[BREAKING][fsdp, megatron] feat: dynamic batching` ### Test > For changes that can not be tested by CI (e.g., algorithm implementation, new model support), validate by experiment(s) and show results like training curve plots, evaluation results, etc. ### API and Usage Example > Demonstrate how the API changes if any, and provide usage example(s) if possible. ```python # Add code snippet or script demonstrating how to use this ``` ### Design & Code Changes > Demonstrate the high-level design if this PR is complex, and list the specific changes. ### Checklist Before Submitting > [!IMPORTANT] > Please check all the following items before requesting a review, otherwise the reviewer might deprioritize this PR for review. - [✅] Read the [Contribute Guide](https://github.com/volcengine/verl/blob/main/CONTRIBUTING.md). - [✅] Apply [pre-commit checks](https://github.com/volcengine/verl/blob/main/CONTRIBUTING.md#code-linting-and-formatting): `pre-commit install && pre-commit run --all-files --show-diff-on-failure --color=always` - [ ] Add / Update [the documentation](https://github.com/volcengine/verl/tree/main/docs). - [ ] Add unit or end-to-end test(s) to [the CI workflow](https://github.com/volcengine/verl/tree/main/.github/workflows) to cover all the code. If not feasible, explain why: ... - [ ] Once your PR is ready for CI, send a message in [the `ci-request` channel](https://verl-project.slack.com/archives/C091TCESWB1) in [the `verl` Slack workspace](https://join.slack.com/t/verl-project/shared_invite/zt-3855yhg8g-CTkqXu~hKojPCmo7k_yXTQ). (If not accessible, please try [the Feishu group (飞书群)](https://applink.larkoffice.com/client/chat/chatter/add_by_link?link_token=772jd4f1-cd91-441e-a820-498c6614126a).)
1 parent c1a7c9e commit b19b749

File tree

2 files changed

+112
-0
lines changed

2 files changed

+112
-0
lines changed
Lines changed: 51 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,51 @@
1+
set -x
2+
3+
4+
gsm8k_train_path=$HOME/data/gsm8k/train.parquet
5+
gsm8k_test_path=$HOME/data/gsm8k/test.parquet
6+
7+
train_files="['$gsm8k_train_path']"
8+
test_files="['$gsm8k_test_path']"
9+
10+
python3 -m verl.trainer.main_ppo \
11+
algorithm.adv_estimator=grpo \
12+
actor_rollout_ref.actor.policy_loss.loss_mode=cispo \
13+
actor_rollout_ref.actor.clip_ratio_low=10 \
14+
actor_rollout_ref.actor.clip_ratio_high=0.2 \
15+
data.train_files="$train_files" \
16+
data.val_files="$test_files" \
17+
data.train_batch_size=256 \
18+
data.max_prompt_length=1024 \
19+
data.max_response_length=1024 \
20+
data.filter_overlong_prompts=True \
21+
data.truncation='error' \
22+
actor_rollout_ref.model.path=Qwen/Qwen2.5-0.5B-Instruct \
23+
actor_rollout_ref.model.torch_dtype=bfloat16 \
24+
actor_rollout_ref.actor.optim.lr=1e-6 \
25+
actor_rollout_ref.model.use_remove_padding=True \
26+
actor_rollout_ref.actor.ppo_mini_batch_size=128 \
27+
actor_rollout_ref.actor.ppo_micro_batch_size_per_gpu=16 \
28+
actor_rollout_ref.actor.use_kl_loss=True \
29+
actor_rollout_ref.actor.kl_loss_coef=0.001 \
30+
actor_rollout_ref.actor.kl_loss_type=low_var_kl \
31+
actor_rollout_ref.actor.entropy_coeff=0 \
32+
actor_rollout_ref.model.enable_gradient_checkpointing=True \
33+
actor_rollout_ref.actor.fsdp_config.param_offload=False \
34+
actor_rollout_ref.actor.fsdp_config.optimizer_offload=False \
35+
actor_rollout_ref.rollout.log_prob_micro_batch_size_per_gpu=16 \
36+
actor_rollout_ref.rollout.tensor_model_parallel_size=1 \
37+
actor_rollout_ref.rollout.name=vllm \
38+
actor_rollout_ref.rollout.gpu_memory_utilization=0.8 \
39+
actor_rollout_ref.rollout.n=5 \
40+
actor_rollout_ref.ref.log_prob_micro_batch_size_per_gpu=16 \
41+
actor_rollout_ref.ref.fsdp_config.param_offload=True \
42+
algorithm.use_kl_in_reward=False \
43+
trainer.critic_warmup=0 \
44+
trainer.logger='["console","wandb"]' \
45+
trainer.project_name='verl_cispo_example_gsm8k' \
46+
trainer.experiment_name='qwen2_5_0_5b_cispo' \
47+
trainer.n_gpus_per_node=1 \
48+
trainer.nnodes=1 \
49+
trainer.save_freq=5 \
50+
trainer.test_freq=5 \
51+
trainer.total_epochs=3 $@

verl/trainer/ppo/core_algos.py

Lines changed: 61 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1395,6 +1395,67 @@ def compute_policy_loss_geo_mean(
13951395
return pg_loss, pg_metrics
13961396

13971397

1398+
@register_policy_loss("cispo")
1399+
def compute_policy_loss_cispo(
1400+
old_log_prob: torch.Tensor,
1401+
log_prob: torch.Tensor,
1402+
advantages: torch.Tensor,
1403+
response_mask: torch.Tensor,
1404+
loss_agg_mode: str = "token-mean",
1405+
config: Optional[DictConfig | ActorConfig] = None,
1406+
rollout_is_weights: torch.Tensor | None = None,
1407+
) -> tuple[torch.Tensor, dict[str, Any]]:
1408+
"""
1409+
Compute the clipped policy objective and related metrics for CISPO.
1410+
1411+
See https://arxiv.org/pdf/2506.13585 for more details.
1412+
"""
1413+
1414+
assert config is not None
1415+
assert isinstance(config, ActorConfig)
1416+
clip_ratio_low = config.clip_ratio_low if config.clip_ratio_low is not None else config.clip_ratio
1417+
clip_ratio_high = config.clip_ratio_high if config.clip_ratio_high is not None else config.clip_ratio
1418+
1419+
# Compute importance sampling ratio: π_θ / π_θ_old
1420+
negative_approx_kl = log_prob - old_log_prob
1421+
# Clamp for numerical stability
1422+
negative_approx_kl = torch.clamp(negative_approx_kl, min=-20.0, max=20.0)
1423+
ratio = torch.exp(negative_approx_kl)
1424+
ppo_kl = verl_F.masked_mean(-negative_approx_kl, response_mask)
1425+
1426+
# CISPO: Clip the importance sampling weights
1427+
# KEY: Apply stop gradient to the clipped ratio
1428+
# This prevents gradients from flowing through the ratio computation and clipping
1429+
# Gradients only flow through log_prob in the final loss term
1430+
clipped_ratio = torch.clamp(ratio, 1 - clip_ratio_low, 1 + clip_ratio_high)
1431+
clipped_ratio_sg = clipped_ratio.detach()
1432+
1433+
# CISPO objective function (to maximize): J = sg(clip(ratio)) * A * log π_θ
1434+
# Loss function (to minimize): L = -J = -sg(clip(ratio)) * A * log_prob
1435+
pg_losses = -clipped_ratio_sg * advantages * log_prob
1436+
1437+
# Track clipping statistics
1438+
pg_clipfrac = verl_F.masked_mean((ratio != clipped_ratio).float(), response_mask)
1439+
1440+
# Apply rollout importance sampling weights if provided
1441+
if rollout_is_weights is not None:
1442+
pg_losses = pg_losses * rollout_is_weights
1443+
1444+
pg_loss = agg_loss(
1445+
loss_mat=pg_losses, loss_mask=response_mask, loss_agg_mode=loss_agg_mode, **config.global_batch_info
1446+
)
1447+
1448+
# For compatibility, return zero for pg_clipfrac_lower (not used in CISPO)
1449+
pg_clipfrac_lower = torch.tensor(0.0, device=pg_loss.device)
1450+
1451+
pg_metrics = {
1452+
"actor/pg_clipfrac": pg_clipfrac.detach().item(),
1453+
"actor/ppo_kl": ppo_kl.detach().item(),
1454+
"actor/pg_clipfrac_lower": pg_clipfrac_lower.detach().item(),
1455+
}
1456+
return pg_loss, pg_metrics
1457+
1458+
13981459
def compute_entropy_loss(logits, response_mask, loss_agg_mode: str = "token-mean"):
13991460
"""Compute categorical entropy loss (For backward compatibility)
14001461

0 commit comments

Comments
 (0)