Skip to content

[grpo] support GSPO-token & add GSPO script #5188

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Open
wants to merge 8 commits into
base: main
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from 3 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
46 changes: 42 additions & 4 deletions docs/source/Instruction/GRPO/AdvancedResearch/GSPO.md
Original file line number Diff line number Diff line change
@@ -1,18 +1,56 @@
# Group Sequence Policy Optimization

[Group Sequence Policy Optimization](https://www.arxiv.org/abs/2507.18071)中指出GRPO在计算重要性采样权重时,是在token级别进行操作的。然而,这种做法由于每个token仅采样一次,无法实现有效的分布校正,反而会在模型训练过程中引入高方差噪声,极易导致模型的梯度估计不稳定,最终造成模型训练的崩塌。因此,论文认为,优化目标的单位应该与奖励的单位保持一致。由于奖励通常是在序列级别(即完整生成的回复)给出的,因此更合理的做法是将 off-policy 校正和优化也提升到序列级别,而非 token 级别。
[Group Sequence Policy Optimization](https://www.arxiv.org/abs/2507.18071)中指出GRPO在计算重要性采样权重时,是在token级别进行操作的。然而,这种做法由于每个token仅采样一次,无法实现有效的分布校正,反而会在模型训练过程中引入高方差噪声,极易导致模型的梯度估计不稳定,最终造成模型训练的崩塌。因此,论文认为,优化目标的单位应该与奖励的单位保持一致。由于奖励通常是在序列级别(即完整生成的回复)给出的,因此更合理的做法是将 off-policy 校正和优化也提升到序列级别,而非 token 级别。以下是三种计算策略对比:

GRPO 中,重要性采样比在 token 级别上计算,具体公式为
1. GRPO
对每个 token 独立计算重要性采样比,具体公式为

$$
w^{\mathrm{GRPO}}_{i,t} = \frac{\pi_\theta (y_{i, t} \mid x, y_{i, <t})}{\pi_{\theta_{\mathrm{old}}} (y_{i, t} \mid x, y_{i, <t})}
$$

GSPO 中,重要性采样比在序列级别上计算,具体公式为
2. GSPO (Sequence-Level)

在序列级别上计算重要性采样比,具体公式为

$$
w^{\mathrm{GSPO}}_{i} = \left[ \frac{\pi_\theta (y_i \mid x)}{\pi_{\theta_{\mathrm{old}}} (y_i \mid x)} \right]^{\frac{1}{|y_i|}}
= \exp\left( \frac{1}{|y_i|} \sum_{t=1}^{|y_i|} \log \frac{\pi_\theta (y_{i, t} \mid x, y_{i, <t})}{\pi_{\theta_{\mathrm{old}}} (y_{i, t} \mid x, y_{i, <t})} \right)
$$

我们可以在 GRPO 训练的基础上,使用参数`--importance_sampling_level sequence` 来使用 GSPO 算法
3. GSPO-token
GSPO-token 结合了序列级与 token 级的重要性采样思想

$$
w_{i, t}^{\mathrm{GSPO-token}} = \mathrm{sg}\left[w_i^{\mathrm{GSPO}}\right] \cdot \frac{\pi_{\theta}(y_{i, t} \mid x, y_{i, < t})}{\mathrm{sg}\left[\pi_{\theta}(y_{i, t} \mid x, y_{i, < t})\right]}
$$

其中,$(\mathrm{sg}[\cdot])$ 表示梯度截断(detach())。

> 注意:根据梯度推导(即论文中的公式(11)和(18)),当各 token 的 advantage 相同时,GSPO-token 与 GSPO 等价。当前的 GRPO 实现中,所有 token 的 advantage 实际上都是基于句子级 reward 并在 group 内进行归一化,因此在这种设置下,GSPO-token 和 GSPO 在理论上是等价的。

伪代码实现
```python
log_ratio = per_token_logps - old_per_token_logps
# GRPO
log_importance_weights = log_ratio

# GSPO (Sequence-Level)
seq_weight = (log_ratio * mask).sum(-1) / mask.sum(-1)
log_importance_weights = seq_weight.unsqueeze(-1) # (B,1)

# GSPO-token
seq_weight = (log_ratio * mask).sum(-1) / mask.sum(-1)
log_importance_weights = seq_weight.detach().unsqueeze(-1) + (per_token_logps - per_token_logps.detach())

importance_weights = torch.exp(log_importance_weights)
```

我们可以在 GRPO 训练的基础上,通过参数 `--importance_sampling_level` 选择不同的算法:

- `importance_sampling_level token` (默认,GRPO 实现)
- `importance_sampling_level sequence` (GSPO)
- `importance_sampling_level sequence_token` (GSPO-token)


训练可以参考该[脚本](https://github.com/modelscope/ms-swift/blob/main/examples/train/grpo/internal/gspo.sh)
46 changes: 42 additions & 4 deletions docs/source_en/Instruction/GRPO/AdvancedResearch/GSPO.md
Original file line number Diff line number Diff line change
@@ -1,18 +1,56 @@
# Group Sequence Policy Optimization

[Group Sequence Policy Optimization](https://www.arxiv.org/abs/2507.18071) points out that in GRPO, importance sampling weights are computed at the token level. However, this approach samples only once per token, making it ineffective for proper distribution correction. Instead, it introduces high-variance noise into the training process, which can destabilize gradient estimation and ultimately cause model collapse. Therefore, the paper argues that the unit of optimization should match the unit of reward. Since rewards are typically assigned at the sequence level (i.e., for the entire generated response), it is more reasonable to perform off-policy correction and optimization at the sequence level, rather than at the token level.
In [Group Sequence Policy Optimization](https://www.arxiv.org/abs/2507.18071), it is pointed out that GRPO computes importance sampling weights at the token level. However, this approach is problematic: since each token is only sampled once, it cannot realize effective distribution correction, and instead introduces high-variance noise during training, which can easily lead to unstable gradient estimates and even training collapse. Therefore, the paper argues that the unit of the objective function should be consistent with that of the reward. Since the reward is typically given at the sequence level (i.e., for the entire generated response), it is more reasonable to perform off-policy correction and optimization at the sequence level rather than the token level.

In GRPO, the importance sampling ratio is computed at the token level as follows:
Below are the three main strategies for computing importance sampling weights:

1. GRPO
GRPO computes the importance sampling ratio independently for each token, as follows:

$$
w^{\mathrm{GRPO}}_{i,t} = \frac{\pi_\theta (y_{i, t} \mid x, y_{i, <t})}{\pi_{\theta_{\mathrm{old}}} (y_{i, t} \mid x, y_{i, <t})}
$$

In GSPO, the importance sampling ratio is calculated at the sequence level as:
2. GSPO (Sequence-Level)
GSPO calculates the importance sampling ratio at the sequence level, given by:

$$
w^{\mathrm{GSPO}}_{i} = \left[ \frac{\pi_\theta (y_i \mid x)}{\pi_{\theta_{\mathrm{old}}} (y_i \mid x)} \right]^{\frac{1}{|y_i|}}
= \exp\left( \frac{1}{|y_i|} \sum_{t=1}^{|y_i|} \log \frac{\pi_\theta (y_{i, t} \mid x, y_{i, <t})}{\pi_{\theta_{\mathrm{old}}} (y_{i, t} \mid x, y_{i, <t})} \right)
$$

Based on GRPO training, we can use the parameter `--importance_sampling_level sequence` to apply the GSPO algorithm.
3. GSPO-token
GSPO-token combines both sequence-level and token-level importance sampling:

$$
w_{i, t}^{\mathrm{GSPO-token}} = \mathrm{sg}\left[w_i^{\mathrm{GSPO}}\right] \cdot \frac{\pi_{\theta}(y_{i, t} \mid x, y_{i, < t})}{\mathrm{sg}\left[\pi_{\theta}(y_{i, t} \mid x, y_{i, < t})\right]}
$$

where $\mathrm{sg}[\cdot]$ denotes stop-gradient (detach()).

> **NOTE:** According to gradient analysis (i.e., Eqs. (11) and (18) in the paper), when the advantage for each token is identical, GSPO-token is equivalent to GSPO. In the current implementation of GRPO, all token advantages are normalized based on the sentence-level reward within each group. Therefore, in this setting, GSPO-token and GSPO are theoretically equivalent. However, GSPO-token provides support for future fine-grained (token-level) advantages.

Pseudo-code implementation:
```python
log_ratio = per_token_logps - old_per_token_logps
# GRPO
log_importance_weights = log_ratio

# GSPO (Sequence-Level)
seq_weight = (log_ratio * mask).sum(-1) / mask.sum(-1)
log_importance_weights = seq_weight.unsqueeze(-1) # (B,1)

# GSPO-token
seq_weight = (log_ratio * mask).sum(-1) / mask.sum(-1)
log_importance_weights = seq_weight.detach().unsqueeze(-1) + (per_token_logps - per_token_logps.detach())

importance_weights = torch.exp(log_importance_weights)
```

Based on GRPO training, you can select different algorithms via the `--importance_sampling_level` argument:

- `importance_sampling_level token` (default, GRPO implementation)
- `importance_sampling_level sequence` (GSPO)
- `importance_sampling_level sequence_token` (GSPO-token)

For training, you can refer to [this script](https://github.com/modelscope/ms-swift/blob/main/examples/train/grpo/internal/gspo.sh).
37 changes: 37 additions & 0 deletions examples/train/grpo/internal/gspo.sh
Original file line number Diff line number Diff line change
@@ -0,0 +1,37 @@
# 8*80G GPU
# GSPO https://arxiv.org/pdf/2507.18071
# hyperparameter
# - epsilon = 3e-4 from paper serction 5.1
# - epsilon_high = 4e-4 from paper serction 5.1
# - steps_per_generation = 4 from paper serction 5.1 (each batch of rollout data is partitioned into four minibatches for gradient updates)
# - beta = 0: zero kl regularization https://github.com/volcengine/verl/pull/2775#issuecomment-3131807306

CUDA_VISIBLE_DEVICES=0,1,2,3,4,5,6,7 \
NPROC_PER_NODE=8 \
swift rlhf \
--rlhf_type grpo \
--model Qwen/Qwen2.5-7B-Instruct \
--dataset AI-MO/NuminaMath-TIR#10000 \
--torch_dtype bfloat16 \
--num_train_epochs 1 \
--per_device_train_batch_size 2 \
--gradient_accumulation_steps 8 \
--num_generations 16 \
--train_type full \
--reward_funcs accuracy \
--use_vllm true \
--vllm_mode colocate \
--vllm_gpu_memory_utilization 0.6 \
--vllm_max_model_len 16384 \
--max_completion_length 8192 \
--offload_optimizer true \
--offload_model true \
--sleep_level 1 \
--save_steps 1000 \
--learning_rate 1e-6 \
--save_total_limit 2 \
--logging_steps 5 \
--warmup_ratio 0.05 \
--dataloader_num_workers 4 \
--deepspeed zero3 \
--log_completions true \
5 changes: 5 additions & 0 deletions swift/trainers/rlhf_trainer/grpo_trainer.py
Original file line number Diff line number Diff line change
Expand Up @@ -1292,6 +1292,11 @@ def _compute_loss(self, model, inputs):
elif self.importance_sampling_level == 'sequence':
log_importance_weights = (log_ratio * completion_mask).sum(-1) / completion_mask.sum(-1).clamp(min=1.0)
log_importance_weights = log_importance_weights.unsqueeze(-1)
elif self.importance_sampling_level == 'sequence_token':
# GSPO-token: sg[si(θ)] * πθ(yi,t)/sg[πθ(yi,t)]
seq_level_log_weight = (log_ratio * completion_mask).sum(-1) / completion_mask.sum(-1).clamp(min=1.0)
seq_level_log_weight = seq_level_log_weight.unsqueeze(-1).detach() # Stop gradient
log_importance_weights = seq_level_log_weight + per_token_logps - per_token_logps.detach()
else:
raise ValueError(
f"Unknown importance sampling level: {self.importance_sampling_level}. Possible values are 'token' "
Expand Down
Loading