-
Notifications
You must be signed in to change notification settings - Fork 810
[grpo] support GSPO-token & add GSPO script #5188
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Open
hjh0119
wants to merge
8
commits into
modelscope:main
Choose a base branch
from
hjh0119:gspo-script
base: main
Could not load branches
Branch not found: {{ refName }}
Loading
Could not load tags
Nothing to show
Loading
Are you sure you want to change the base?
Some commits from the old base branch may be removed from the timeline,
and old review comments may become outdated.
+134
−9
Open
Changes from 3 commits
Commits
Show all changes
8 commits
Select commit
Hold shift + click to select a range
17959a3
upload script
hjh0119 30b5aa3
Merge remote-tracking branch 'origin' into gspo-script
hjh0119 9b54b00
support gspo-token
hjh0119 6751146
fix script and argument
hjh0119 ff62fc6
docs
hjh0119 06928e7
update log weight
hjh0119 de01a83
Merge remote-tracking branch 'origin' into gspo-script
hjh0119 515357a
update docs
hjh0119 File filter
Filter by extension
Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
There are no files selected for viewing
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -1,18 +1,56 @@ | ||
# Group Sequence Policy Optimization | ||
|
||
[Group Sequence Policy Optimization](https://www.arxiv.org/abs/2507.18071)中指出GRPO在计算重要性采样权重时,是在token级别进行操作的。然而,这种做法由于每个token仅采样一次,无法实现有效的分布校正,反而会在模型训练过程中引入高方差噪声,极易导致模型的梯度估计不稳定,最终造成模型训练的崩塌。因此,论文认为,优化目标的单位应该与奖励的单位保持一致。由于奖励通常是在序列级别(即完整生成的回复)给出的,因此更合理的做法是将 off-policy 校正和优化也提升到序列级别,而非 token 级别。 | ||
[Group Sequence Policy Optimization](https://www.arxiv.org/abs/2507.18071)中指出GRPO在计算重要性采样权重时,是在token级别进行操作的。然而,这种做法由于每个token仅采样一次,无法实现有效的分布校正,反而会在模型训练过程中引入高方差噪声,极易导致模型的梯度估计不稳定,最终造成模型训练的崩塌。因此,论文认为,优化目标的单位应该与奖励的单位保持一致。由于奖励通常是在序列级别(即完整生成的回复)给出的,因此更合理的做法是将 off-policy 校正和优化也提升到序列级别,而非 token 级别。以下是三种计算策略对比: | ||
|
||
GRPO 中,重要性采样比在 token 级别上计算,具体公式为 | ||
1. GRPO | ||
对每个 token 独立计算重要性采样比,具体公式为 | ||
|
||
$$ | ||
w^{\mathrm{GRPO}}_{i,t} = \frac{\pi_\theta (y_{i, t} \mid x, y_{i, <t})}{\pi_{\theta_{\mathrm{old}}} (y_{i, t} \mid x, y_{i, <t})} | ||
$$ | ||
|
||
GSPO 中,重要性采样比在序列级别上计算,具体公式为 | ||
2. GSPO (Sequence-Level) | ||
|
||
在序列级别上计算重要性采样比,具体公式为 | ||
|
||
$$ | ||
w^{\mathrm{GSPO}}_{i} = \left[ \frac{\pi_\theta (y_i \mid x)}{\pi_{\theta_{\mathrm{old}}} (y_i \mid x)} \right]^{\frac{1}{|y_i|}} | ||
= \exp\left( \frac{1}{|y_i|} \sum_{t=1}^{|y_i|} \log \frac{\pi_\theta (y_{i, t} \mid x, y_{i, <t})}{\pi_{\theta_{\mathrm{old}}} (y_{i, t} \mid x, y_{i, <t})} \right) | ||
$$ | ||
|
||
我们可以在 GRPO 训练的基础上,使用参数`--importance_sampling_level sequence` 来使用 GSPO 算法 | ||
3. GSPO-token | ||
GSPO-token 结合了序列级与 token 级的重要性采样思想 | ||
|
||
$$ | ||
w_{i, t}^{\mathrm{GSPO-token}} = \mathrm{sg}\left[w_i^{\mathrm{GSPO}}\right] \cdot \frac{\pi_{\theta}(y_{i, t} \mid x, y_{i, < t})}{\mathrm{sg}\left[\pi_{\theta}(y_{i, t} \mid x, y_{i, < t})\right]} | ||
$$ | ||
|
||
其中,$(\mathrm{sg}[\cdot])$ 表示梯度截断(detach())。 | ||
|
||
> 注意:根据梯度推导(即论文中的公式(11)和(18)),当各 token 的 advantage 相同时,GSPO-token 与 GSPO 等价。当前的 GRPO 实现中,所有 token 的 advantage 实际上都是基于句子级 reward 并在 group 内进行归一化,因此在这种设置下,GSPO-token 和 GSPO 在理论上是等价的。 | ||
|
||
伪代码实现 | ||
```python | ||
log_ratio = per_token_logps - old_per_token_logps | ||
# GRPO | ||
log_importance_weights = log_ratio | ||
|
||
# GSPO (Sequence-Level) | ||
seq_weight = (log_ratio * mask).sum(-1) / mask.sum(-1) | ||
log_importance_weights = seq_weight.unsqueeze(-1) # (B,1) | ||
|
||
# GSPO-token | ||
seq_weight = (log_ratio * mask).sum(-1) / mask.sum(-1) | ||
log_importance_weights = seq_weight.detach().unsqueeze(-1) + (per_token_logps - per_token_logps.detach()) | ||
|
||
importance_weights = torch.exp(log_importance_weights) | ||
``` | ||
|
||
我们可以在 GRPO 训练的基础上,通过参数 `--importance_sampling_level` 选择不同的算法: | ||
|
||
- `importance_sampling_level token` (默认,GRPO 实现) | ||
- `importance_sampling_level sequence` (GSPO) | ||
- `importance_sampling_level sequence_token` (GSPO-token) | ||
|
||
|
||
训练可以参考该[脚本](https://github.com/modelscope/ms-swift/blob/main/examples/train/grpo/internal/gspo.sh) |
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -1,18 +1,56 @@ | ||
# Group Sequence Policy Optimization | ||
|
||
[Group Sequence Policy Optimization](https://www.arxiv.org/abs/2507.18071) points out that in GRPO, importance sampling weights are computed at the token level. However, this approach samples only once per token, making it ineffective for proper distribution correction. Instead, it introduces high-variance noise into the training process, which can destabilize gradient estimation and ultimately cause model collapse. Therefore, the paper argues that the unit of optimization should match the unit of reward. Since rewards are typically assigned at the sequence level (i.e., for the entire generated response), it is more reasonable to perform off-policy correction and optimization at the sequence level, rather than at the token level. | ||
In [Group Sequence Policy Optimization](https://www.arxiv.org/abs/2507.18071), it is pointed out that GRPO computes importance sampling weights at the token level. However, this approach is problematic: since each token is only sampled once, it cannot realize effective distribution correction, and instead introduces high-variance noise during training, which can easily lead to unstable gradient estimates and even training collapse. Therefore, the paper argues that the unit of the objective function should be consistent with that of the reward. Since the reward is typically given at the sequence level (i.e., for the entire generated response), it is more reasonable to perform off-policy correction and optimization at the sequence level rather than the token level. | ||
|
||
In GRPO, the importance sampling ratio is computed at the token level as follows: | ||
Below are the three main strategies for computing importance sampling weights: | ||
|
||
1. GRPO | ||
GRPO computes the importance sampling ratio independently for each token, as follows: | ||
|
||
$$ | ||
w^{\mathrm{GRPO}}_{i,t} = \frac{\pi_\theta (y_{i, t} \mid x, y_{i, <t})}{\pi_{\theta_{\mathrm{old}}} (y_{i, t} \mid x, y_{i, <t})} | ||
$$ | ||
|
||
In GSPO, the importance sampling ratio is calculated at the sequence level as: | ||
2. GSPO (Sequence-Level) | ||
GSPO calculates the importance sampling ratio at the sequence level, given by: | ||
|
||
$$ | ||
w^{\mathrm{GSPO}}_{i} = \left[ \frac{\pi_\theta (y_i \mid x)}{\pi_{\theta_{\mathrm{old}}} (y_i \mid x)} \right]^{\frac{1}{|y_i|}} | ||
= \exp\left( \frac{1}{|y_i|} \sum_{t=1}^{|y_i|} \log \frac{\pi_\theta (y_{i, t} \mid x, y_{i, <t})}{\pi_{\theta_{\mathrm{old}}} (y_{i, t} \mid x, y_{i, <t})} \right) | ||
$$ | ||
|
||
Based on GRPO training, we can use the parameter `--importance_sampling_level sequence` to apply the GSPO algorithm. | ||
3. GSPO-token | ||
GSPO-token combines both sequence-level and token-level importance sampling: | ||
|
||
$$ | ||
w_{i, t}^{\mathrm{GSPO-token}} = \mathrm{sg}\left[w_i^{\mathrm{GSPO}}\right] \cdot \frac{\pi_{\theta}(y_{i, t} \mid x, y_{i, < t})}{\mathrm{sg}\left[\pi_{\theta}(y_{i, t} \mid x, y_{i, < t})\right]} | ||
$$ | ||
|
||
where $\mathrm{sg}[\cdot]$ denotes stop-gradient (detach()). | ||
|
||
> **NOTE:** According to gradient analysis (i.e., Eqs. (11) and (18) in the paper), when the advantage for each token is identical, GSPO-token is equivalent to GSPO. In the current implementation of GRPO, all token advantages are normalized based on the sentence-level reward within each group. Therefore, in this setting, GSPO-token and GSPO are theoretically equivalent. However, GSPO-token provides support for future fine-grained (token-level) advantages. | ||
|
||
Pseudo-code implementation: | ||
```python | ||
log_ratio = per_token_logps - old_per_token_logps | ||
# GRPO | ||
log_importance_weights = log_ratio | ||
|
||
# GSPO (Sequence-Level) | ||
seq_weight = (log_ratio * mask).sum(-1) / mask.sum(-1) | ||
log_importance_weights = seq_weight.unsqueeze(-1) # (B,1) | ||
|
||
# GSPO-token | ||
seq_weight = (log_ratio * mask).sum(-1) / mask.sum(-1) | ||
log_importance_weights = seq_weight.detach().unsqueeze(-1) + (per_token_logps - per_token_logps.detach()) | ||
|
||
importance_weights = torch.exp(log_importance_weights) | ||
``` | ||
|
||
Based on GRPO training, you can select different algorithms via the `--importance_sampling_level` argument: | ||
|
||
- `importance_sampling_level token` (default, GRPO implementation) | ||
- `importance_sampling_level sequence` (GSPO) | ||
- `importance_sampling_level sequence_token` (GSPO-token) | ||
|
||
For training, you can refer to [this script](https://github.com/modelscope/ms-swift/blob/main/examples/train/grpo/internal/gspo.sh). |
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,37 @@ | ||
# 8*80G GPU | ||
# GSPO https://arxiv.org/pdf/2507.18071 | ||
# hyperparameter | ||
# - epsilon = 3e-4 from paper serction 5.1 | ||
# - epsilon_high = 4e-4 from paper serction 5.1 | ||
# - steps_per_generation = 4 from paper serction 5.1 (each batch of rollout data is partitioned into four minibatches for gradient updates) | ||
# - beta = 0: zero kl regularization https://github.com/volcengine/verl/pull/2775#issuecomment-3131807306 | ||
|
||
CUDA_VISIBLE_DEVICES=0,1,2,3,4,5,6,7 \ | ||
NPROC_PER_NODE=8 \ | ||
swift rlhf \ | ||
hjh0119 marked this conversation as resolved.
Show resolved
Hide resolved
|
||
--rlhf_type grpo \ | ||
--model Qwen/Qwen2.5-7B-Instruct \ | ||
--dataset AI-MO/NuminaMath-TIR#10000 \ | ||
--torch_dtype bfloat16 \ | ||
--num_train_epochs 1 \ | ||
--per_device_train_batch_size 2 \ | ||
--gradient_accumulation_steps 8 \ | ||
--num_generations 16 \ | ||
--train_type full \ | ||
--reward_funcs accuracy \ | ||
--use_vllm true \ | ||
--vllm_mode colocate \ | ||
--vllm_gpu_memory_utilization 0.6 \ | ||
--vllm_max_model_len 16384 \ | ||
--max_completion_length 8192 \ | ||
--offload_optimizer true \ | ||
--offload_model true \ | ||
--sleep_level 1 \ | ||
--save_steps 1000 \ | ||
--learning_rate 1e-6 \ | ||
--save_total_limit 2 \ | ||
--logging_steps 5 \ | ||
--warmup_ratio 0.05 \ | ||
--dataloader_num_workers 4 \ | ||
--deepspeed zero3 \ | ||
--log_completions true \ | ||
hjh0119 marked this conversation as resolved.
Show resolved
Hide resolved
|
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Oops, something went wrong.
Add this suggestion to a batch that can be applied as a single commit.
This suggestion is invalid because no changes were made to the code.
Suggestions cannot be applied while the pull request is closed.
Suggestions cannot be applied while viewing a subset of changes.
Only one suggestion per line can be applied in a batch.
Add this suggestion to a batch that can be applied as a single commit.
Applying suggestions on deleted lines is not supported.
You must change the existing code in this line in order to create a valid suggestion.
Outdated suggestions cannot be applied.
This suggestion has been applied or marked resolved.
Suggestions cannot be applied from pending reviews.
Suggestions cannot be applied on multi-line comments.
Suggestions cannot be applied while the pull request is queued to merge.
Suggestion cannot be applied right now. Please check back later.
Uh oh!
There was an error while loading. Please reload this page.