Skip to content

Commit 1a7c3a9

Browse files
authored
[grpo] support GSPO-token & add GSPO script (#5188)
* upload script * support gspo-token * fix script and argument * docs * update log weight * update docs
1 parent ce426e1 commit 1a7c3a9

File tree

5 files changed

+134
-9
lines changed

5 files changed

+134
-9
lines changed

docs/source/Instruction/GRPO/AdvancedResearch/GSPO.md

Lines changed: 43 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -2,22 +2,59 @@
22

33
**版本依赖**:ms-swift>=3.7
44

5-
[Group Sequence Policy Optimization](https://www.arxiv.org/abs/2507.18071)中指出GRPO在计算重要性采样权重时,是在token级别进行操作的。然而,这种做法由于每个token仅采样一次,无法实现有效的分布校正,反而会在模型训练过程中引入高方差噪声,极易导致模型的梯度估计不稳定,最终造成模型训练的崩塌。因此,论文认为,优化目标的单位应该与奖励的单位保持一致。由于奖励通常是在序列级别(即完整生成的回复)给出的,因此更合理的做法是将 off-policy 校正和优化也提升到序列级别,而非 token 级别。
5+
[Group Sequence Policy Optimization](https://www.arxiv.org/abs/2507.18071)中指出GRPO在计算重要性采样权重时,是在token级别进行操作的。然而,这种做法由于每个token仅采样一次,无法实现有效的分布校正,反而会在模型训练过程中引入高方差噪声,极易导致模型的梯度估计不稳定,最终造成模型训练的崩塌。因此,论文认为,优化目标的单位应该与奖励的单位保持一致。由于奖励通常是在序列级别(即完整生成的回复)给出的,因此更合理的做法是将 off-policy 校正和优化也提升到序列级别,而非 token 级别。以下是三种计算策略对比:
66

7-
GRPO 中,重要性采样比在 token 级别上计算,具体公式为
7+
1. GRPO
8+
对每个 token 独立计算重要性采样比,具体公式为
89

910
$$
1011
w^{\mathrm{GRPO}}_{i,t} = \frac{\pi_\theta (y_{i, t} \mid x, y_{i, <t})}{\pi_{\theta_{\mathrm{old}}} (y_{i, t} \mid x, y_{i, <t})}
1112
$$
1213

13-
GSPO 中,重要性采样比在序列级别上计算,具体公式为
14+
2. GSPO (Sequence-Level)
15+
16+
在序列级别上计算重要性采样比,具体公式为
1417

1518
$$
1619
w^{\mathrm{GSPO}}_{i} = \left[ \frac{\pi_\theta (y_i \mid x)}{\pi_{\theta_{\mathrm{old}}} (y_i \mid x)} \right]^{\frac{1}{|y_i|}}
1720
= \exp\left( \frac{1}{|y_i|} \sum_{t=1}^{|y_i|} \log \frac{\pi_\theta (y_{i, t} \mid x, y_{i, <t})}{\pi_{\theta_{\mathrm{old}}} (y_{i, t} \mid x, y_{i, <t})} \right)
1821
$$
1922

20-
我们可以在 GRPO 训练的基础上,使用参数`--importance_sampling_level sequence` 来使用 GSPO 算法
23+
3. GSPO-token
24+
GSPO-token 结合了序列级与 token 级的重要性采样思想
25+
26+
$$
27+
w_{i, t}^{\mathrm{GSPO-token}} = \mathrm{sg}\left[w_i^{\mathrm{GSPO}}\right] \cdot \frac{\pi_{\theta}(y_{i, t} \mid x, y_{i, < t})}{\mathrm{sg}\left[\pi_{\theta}(y_{i, t} \mid x, y_{i, < t})\right]}
28+
$$
29+
30+
其中,$(\mathrm{sg}[\cdot])$ 表示梯度截断(detach())。
31+
32+
> 注意:根据梯度推导(即论文中的公式(11)和(18)),当各 token 的 advantage 相同时,GSPO-token 与 GSPO 等价。当前的 GRPO 实现中,所有 token 的 advantage 实际上都是基于句子级 reward 并在 group 内进行归一化,因此在这种设置下,GSPO-token 和 GSPO 在理论上是等价的。不过,GSPO-token 为未来更细粒度(token 级别)的 advantage 提供了支持。
33+
34+
伪代码实现
35+
```python
36+
log_ratio = per_token_logps - old_per_token_logps
37+
# GRPO
38+
log_importance_weights = log_ratio
39+
40+
# GSPO (Sequence-Level)
41+
seq_weight = (log_ratio * mask).sum(-1) / mask.sum(-1)
42+
log_importance_weights = seq_weight.unsqueeze(-1) # (B,1)
43+
44+
# GSPO-token
45+
seq_weight = (log_ratio * mask).sum(-1) / mask.sum(-1)
46+
log_importance_weights = seq_weight.detach().unsqueeze(-1) + (per_token_logps - per_token_logps.detach())
47+
48+
importance_weights = torch.exp(log_importance_weights)
49+
```
50+
51+
我们可以在 GRPO 训练的基础上,通过参数 `--importance_sampling_level` 选择不同的算法:
52+
53+
- `importance_sampling_level token` (默认,GRPO 实现)
54+
- `importance_sampling_level sequence` (GSPO)
55+
- `importance_sampling_level sequence_token` (GSPO-token)
56+
57+
其中 sequence_token 要求 ms-swfit > 3.7 (源码安装)
2158

2259
论文其他超参
2360
```bash
@@ -26,3 +63,5 @@ $$
2663
--steps_per_generation 4 # from paper section 5.1 (each batch of rollout data is partitioned into four minibatches for gradient updates)
2764
--beta 0 # zero kl regularization https://github.com/volcengine/verl/pull/2775#issuecomment-3131807306
2865
```
66+
67+
训练可以参考该[脚本](https://github.com/modelscope/ms-swift/blob/main/examples/train/grpo/internal/gspo.sh)

docs/source_en/Instruction/GRPO/AdvancedResearch/GSPO.md

Lines changed: 43 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -2,22 +2,59 @@
22

33
**Version Requirement**: ms-swift>=3.7
44

5-
[Group Sequence Policy Optimization](https://www.arxiv.org/abs/2507.18071) points out that in GRPO, importance sampling weights are computed at the token level. However, this approach samples only once per token, making it ineffective for proper distribution correction. Instead, it introduces high-variance noise into the training process, which can destabilize gradient estimation and ultimately cause model collapse. Therefore, the paper argues that the unit of optimization should match the unit of reward. Since rewards are typically assigned at the sequence level (i.e., for the entire generated response), it is more reasonable to perform off-policy correction and optimization at the sequence level, rather than at the token level.
5+
In [Group Sequence Policy Optimization](https://www.arxiv.org/abs/2507.18071), it is pointed out that GRPO computes importance sampling weights at the token level. However, this approach is problematic: since each token is only sampled once, it cannot realize effective distribution correction, and instead introduces high-variance noise during training, which can easily lead to unstable gradient estimates and even training collapse. Therefore, the paper argues that the unit of the objective function should be consistent with that of the reward. Since the reward is typically given at the sequence level (i.e., for the entire generated response), it is more reasonable to perform off-policy correction and optimization at the sequence level rather than the token level.
66

7-
In GRPO, the importance sampling ratio is computed at the token level as follows:
7+
Below are the three main strategies for computing importance sampling weights:
8+
9+
1. GRPO
10+
GRPO computes the importance sampling ratio independently for each token, as follows:
811

912
$$
1013
w^{\mathrm{GRPO}}_{i,t} = \frac{\pi_\theta (y_{i, t} \mid x, y_{i, <t})}{\pi_{\theta_{\mathrm{old}}} (y_{i, t} \mid x, y_{i, <t})}
1114
$$
1215

13-
In GSPO, the importance sampling ratio is calculated at the sequence level as:
16+
2. GSPO (Sequence-Level)
17+
GSPO calculates the importance sampling ratio at the sequence level, given by:
1418

1519
$$
1620
w^{\mathrm{GSPO}}_{i} = \left[ \frac{\pi_\theta (y_i \mid x)}{\pi_{\theta_{\mathrm{old}}} (y_i \mid x)} \right]^{\frac{1}{|y_i|}}
1721
= \exp\left( \frac{1}{|y_i|} \sum_{t=1}^{|y_i|} \log \frac{\pi_\theta (y_{i, t} \mid x, y_{i, <t})}{\pi_{\theta_{\mathrm{old}}} (y_{i, t} \mid x, y_{i, <t})} \right)
1822
$$
1923

20-
Based on GRPO training, we can use the parameter `--importance_sampling_level sequence` to apply the GSPO algorithm.
24+
3. GSPO-token
25+
GSPO-token combines both sequence-level and token-level importance sampling:
26+
27+
$$
28+
w_{i, t}^{\mathrm{GSPO-token}} = \mathrm{sg}\left[w_i^{\mathrm{GSPO}}\right] \cdot \frac{\pi_{\theta}(y_{i, t} \mid x, y_{i, < t})}{\mathrm{sg}\left[\pi_{\theta}(y_{i, t} \mid x, y_{i, < t})\right]}
29+
$$
30+
31+
where $\mathrm{sg}[\cdot]$ denotes stop-gradient (detach()).
32+
33+
> **NOTE:** According to gradient analysis (i.e., Eqs. (11) and (18) in the paper), when the advantage for each token is identical, GSPO-token is equivalent to GSPO. In the current implementation of GRPO, all token advantages are normalized based on the sentence-level reward within each group. Therefore, in this setting, GSPO-token and GSPO are theoretically equivalent. However, GSPO-token provides support for future fine-grained (token-level) advantages.
34+
35+
Pseudo-code implementation:
36+
```python
37+
log_ratio = per_token_logps - old_per_token_logps
38+
# GRPO
39+
log_importance_weights = log_ratio
40+
41+
# GSPO (Sequence-Level)
42+
seq_weight = (log_ratio * mask).sum(-1) / mask.sum(-1)
43+
log_importance_weights = seq_weight.unsqueeze(-1) # (B,1)
44+
45+
# GSPO-token
46+
seq_weight = (log_ratio * mask).sum(-1) / mask.sum(-1)
47+
log_importance_weights = seq_weight.detach().unsqueeze(-1) + (per_token_logps - per_token_logps.detach())
48+
49+
importance_weights = torch.exp(log_importance_weights)
50+
```
51+
52+
Based on GRPO training, you can select different algorithms via the `--importance_sampling_level` argument:
53+
54+
- `importance_sampling_level token` (default, GRPO implementation)
55+
- `importance_sampling_level sequence` (GSPO)
56+
- `importance_sampling_level sequence_token` (GSPO-token)
57+
2158

2259
Other hyperparameters in the paper
2360
```bash
@@ -26,3 +63,5 @@ Other hyperparameters in the paper
2663
--steps_per_generation 4 # from paper section 5.1 (each batch of rollout data is partitioned into four minibatches for gradient updates)
2764
--beta 0 # zero kl regularization https://github.com/volcengine/verl/pull/2775#issuecomment-3131807306
2865
```
66+
67+
For training, you can refer to [this script](https://github.com/modelscope/ms-swift/blob/main/examples/train/grpo/internal/gspo.sh).

examples/train/grpo/internal/gspo.sh

Lines changed: 42 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,42 @@
1+
# 8*80G GPU
2+
# GSPO https://arxiv.org/pdf/2507.18071
3+
# hyperparameter
4+
# - epsilon = 3e-4 from paper serction 5.1
5+
# - epsilon_high = 4e-4 from paper serction 5.1
6+
# - steps_per_generation = 4 from paper serction 5.1 (each batch of rollout data is partitioned into four minibatches for gradient updates)
7+
# - beta = 0: zero kl regularization https://github.com/volcengine/verl/pull/2775#issuecomment-3131807306
8+
9+
CUDA_VISIBLE_DEVICES=0,1,2,3,4,5,6,7 \
10+
NPROC_PER_NODE=8 \
11+
swift rlhf \
12+
--rlhf_type grpo \
13+
--model Qwen/Qwen2.5-7B-Instruct \
14+
--dataset AI-MO/NuminaMath-TIR#10000 \
15+
--torch_dtype bfloat16 \
16+
--beta 0.0 \
17+
--epsilon 3e-4 \
18+
--epsilon_high 4e-4 \
19+
--steps_per_generation 4 \
20+
--importance_sampling_level sequence \
21+
--num_train_epochs 1 \
22+
--per_device_train_batch_size 2 \
23+
--gradient_accumulation_steps 8 \
24+
--num_generations 16 \
25+
--train_type full \
26+
--reward_funcs accuracy \
27+
--use_vllm true \
28+
--vllm_mode colocate \
29+
--vllm_gpu_memory_utilization 0.6 \
30+
--vllm_max_model_len 16384 \
31+
--max_completion_length 8192 \
32+
--offload_optimizer true \
33+
--offload_model true \
34+
--sleep_level 1 \
35+
--save_steps 1000 \
36+
--learning_rate 1e-6 \
37+
--save_total_limit 2 \
38+
--logging_steps 5 \
39+
--warmup_ratio 0.05 \
40+
--dataloader_num_workers 4 \
41+
--deepspeed zero3 \
42+
--log_completions true

swift/trainers/arguments.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -310,7 +310,7 @@ class GRPOArgumentsMixin(VllmArguments):
310310
top_entropy_quantile: float = 1.0
311311

312312
# GSPO https://www.arxiv.org/abs/2507.18071
313-
importance_sampling_level: Literal['token', 'sequence'] = 'token'
313+
importance_sampling_level: Literal['token', 'sequence', 'sequence_token'] = 'token'
314314

315315
wandb_log_unique_prompts: Optional[bool] = None
316316
generation_batch_size: Optional[int] = None

swift/trainers/rlhf_trainer/grpo_trainer.py

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1300,6 +1300,11 @@ def _compute_loss(self, model, inputs):
13001300
elif self.importance_sampling_level == 'sequence':
13011301
log_importance_weights = (log_ratio * completion_mask).sum(-1) / completion_mask.sum(-1).clamp(min=1.0)
13021302
log_importance_weights = log_importance_weights.unsqueeze(-1)
1303+
elif self.importance_sampling_level == 'sequence_token':
1304+
# GSPO-token: sg[si(θ)] * πθ(yi,t)/sg[πθ(yi,t)]
1305+
seq_level_log_weight = (log_ratio * completion_mask).sum(-1) / completion_mask.sum(-1).clamp(min=1.0)
1306+
seq_level_log_weight = seq_level_log_weight.detach().unsqueeze(-1) # Stop gradient
1307+
log_importance_weights = per_token_logps - per_token_logps.detach() + seq_level_log_weight
13031308
else:
13041309
raise ValueError(
13051310
f"Unknown importance sampling level: {self.importance_sampling_level}. Possible values are 'token' "

0 commit comments

Comments
 (0)