Skip to content
Merged
Show file tree
Hide file tree
Changes from 91 commits
Commits
Show all changes
96 commits
Select commit Hold shift + click to select a range
b27fb9b
wip
hjh0119 Jul 29, 2025
3a9f23c
wip
hjh0119 Jul 29, 2025
061a2d5
revert prompt_ids
hjh0119 Jul 30, 2025
b1ac613
Merge remote-tracking branch 'origin' into grpo-use-ids
hjh0119 Jul 31, 2025
8d2b170
remove tokenizer in reward
hjh0119 Aug 1, 2025
3da665b
Merge remote-tracking branch 'origin' into grpo-use-ids
hjh0119 Aug 1, 2025
0dea74d
encode ids
hjh0119 Aug 1, 2025
1114cea
wip replace ids
hjh0119 Aug 3, 2025
ebf1b35
fix adv
hjh0119 Aug 4, 2025
67d124a
Merge remote-tracking branch 'origin' into grpo-use-ids
hjh0119 Aug 4, 2025
ffcd9b4
wip
hjh0119 Aug 4, 2025
05ae143
wip
hjh0119 Aug 4, 2025
fdb1ae8
wip
hjh0119 Aug 4, 2025
725e1f6
wip
hjh0119 Aug 5, 2025
dcc9052
wip
hjh0119 Aug 5, 2025
f72712e
wip
hjh0119 Aug 6, 2025
a8aeb73
refactor v1
hjh0119 Aug 8, 2025
8555f31
rename completion id
hjh0119 Aug 8, 2025
4b30574
Merge remote-tracking branch 'origin' into grpo-use-ids
hjh0119 Aug 8, 2025
4c0ada9
fix typo & bugs
hjh0119 Aug 8, 2025
b9e4c04
compute loss for dynamic batch size
hjh0119 Aug 10, 2025
6ad2700
Merge remote-tracking branch 'origin' into grpo-use-ids
hjh0119 Aug 10, 2025
eea4485
fix tiny bugs
hjh0119 Aug 11, 2025
5b927de
dynamic rollout advantages
hjh0119 Aug 11, 2025
b0c52b7
fix score_completions
hjh0119 Aug 11, 2025
e25c2e4
fix split mini batch
hjh0119 Aug 11, 2025
bf035b3
docstring for split mini batches
hjh0119 Aug 11, 2025
60fb903
fix gather device
hjh0119 Aug 11, 2025
5108fce
fix rollout async infer
hjh0119 Aug 12, 2025
7deeec3
Merge remote-tracking branch 'origin' into grpo-use-ids
hjh0119 Aug 13, 2025
69f6e43
Merge branch 'grpo-use-ids' of github.com:hjh0119/swift into grpo-use…
hjh0119 Aug 13, 2025
0812129
thinking tips scheduler
hjh0119 Aug 13, 2025
fc9d8ed
resolve dynamic sampling"
hjh0119 Aug 14, 2025
fe9ae1b
wip for chunk loss
hjh0119 Aug 14, 2025
95860db
version2
hjh0119 Aug 15, 2025
f3dcea1
fix gemini
hjh0119 Aug 15, 2025
2d01fb3
batch metrics
hjh0119 Aug 16, 2025
09b2151
Merge remote-tracking branch 'origin' into grpo-use-ids
hjh0119 Aug 16, 2025
d750c77
fix merge ouput
hjh0119 Aug 17, 2025
2aee82d
remove tests
hjh0119 Aug 17, 2025
514226b
fix server rollout & same prompt bewteen process
hjh0119 Aug 17, 2025
7d67c49
fix
hjh0119 Aug 18, 2025
1a3c838
Merge branch 'grpo-use-ids' of github.com:hjh0119/swift into grpo-use…
hjh0119 Aug 18, 2025
54c3e15
merge main
hjh0119 Aug 18, 2025
13e6447
Merge remote-tracking branch 'origin' into grpo-use-ids
hjh0119 Aug 18, 2025
e286cb0
update
hjh0119 Aug 18, 2025
ca8df3e
revert chmod
hjh0119 Aug 18, 2025
3083a00
update docs
hjh0119 Aug 18, 2025
dec32c5
Merge branch 'grpo-use-ids' of github.com:hjh0119/swift into grpo-use…
hjh0119 Aug 18, 2025
8d3ba6c
revert make docs
hjh0119 Aug 18, 2025
1ef3964
update images
hjh0119 Aug 18, 2025
1c2b8e7
update images
hjh0119 Aug 18, 2025
97e870e
global inputs for reward model
hjh0119 Aug 18, 2025
0e987c2
pass loss scale
hjh0119 Aug 18, 2025
aae79b0
tool call scheduler
hjh0119 Aug 18, 2025
c0a5dc0
move toolcall scheduler to external plugin
hjh0119 Aug 19, 2025
17a8369
update deepeyes
hjh0119 Aug 19, 2025
1fbd02a
Merge branch 'grpo-use-ids' of github.com:hjh0119/swift into grpo-use…
hjh0119 Aug 19, 2025
148db1f
update toolcall scheduler
hjh0119 Aug 19, 2025
4bbad40
Merge branch 'grpo-use-ids' of github.com:hjh0119/swift into grpo-use…
hjh0119 Aug 19, 2025
985851d
Merge remote-tracking branch 'origin' into grpo-use-ids
hjh0119 Aug 19, 2025
d8f3268
update deepeyes script
hjh0119 Aug 19, 2025
5f91719
fix script
hjh0119 Aug 19, 2025
565f762
use safer ast literal_eval
hjh0119 Aug 19, 2025
baffac7
compatible with sp
hjh0119 Aug 19, 2025
092be5d
Merge branch 'grpo-use-ids' of github.com:hjh0119/swift into grpo-use…
hjh0119 Aug 19, 2025
cb62350
fix advantages & sort outputs
hjh0119 Aug 19, 2025
631bd9a
fix sp
hjh0119 Aug 19, 2025
1074caa
get trajectory inputs
hjh0119 Aug 19, 2025
abe68cc
Merge branch 'grpo-use-ids' of github.com:hjh0119/swift into grpo-use…
hjh0119 Aug 19, 2025
30d2771
lint
hjh0119 Aug 19, 2025
7088805
multi turn reward example
hjh0119 Aug 20, 2025
ea8aae5
Merge branch 'grpo-use-ids' of github.com:hjh0119/swift into grpo-use…
hjh0119 Aug 20, 2025
34731c4
update multi turn docs
hjh0119 Aug 20, 2025
6146e0a
restrict rollout async engine
hjh0119 Aug 20, 2025
8980a20
update docs
hjh0119 Aug 20, 2025
dc62577
check dynamic num and simplify logic for normal training
hjh0119 Aug 20, 2025
d5df11b
flag dynamic num samples
hjh0119 Aug 21, 2025
ab37b73
fix docstring typo
hjh0119 Aug 21, 2025
4cef0f9
fix chunked inputs
hjh0119 Aug 21, 2025
d48671f
log last turn metrics
hjh0119 Aug 21, 2025
a8a7abd
Merge remote-tracking branch 'origin' into grpo-use-ids
hjh0119 Aug 21, 2025
e999e77
Merge branch 'grpo-use-ids' of github.com:hjh0119/swift into grpo-use…
hjh0119 Aug 21, 2025
e9968d6
last turn metrics
hjh0119 Aug 22, 2025
fc7b163
more profiling
hjh0119 Aug 22, 2025
fe2681e
fix multi turn script
hjh0119 Aug 22, 2025
93e60e1
Merge remote-tracking branch 'origin' into grpo-use-ids
hjh0119 Aug 25, 2025
10003ee
fix docstring
hjh0119 Aug 25, 2025
1709428
fix engine
hjh0119 Aug 25, 2025
6eed04b
fix log completion
hjh0119 Aug 25, 2025
70bd526
exp link for script
hjh0119 Aug 25, 2025
14c2e75
log num_turns
hjh0119 Aug 25, 2025
d6b9915
Merge branch 'grpo-use-ids' of github.com:hjh0119/swift into grpo-use…
hjh0119 Aug 25, 2025
2fb5d92
fix args
hjh0119 Aug 25, 2025
d71ef82
Merge branch 'grpo-use-ids' of github.com:hjh0119/swift into grpo-use…
hjh0119 Aug 25, 2025
864c1f9
align num of device of script to exp
hjh0119 Aug 25, 2025
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Binary file added docs/resources/grpo_multi_turn.png
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
18 changes: 10 additions & 8 deletions docs/source/Instruction/GRPO/DeveloperGuide/GYM环境训练.md
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
# GYM环境训练

注意:该 feature 需要使用 ms-swift>=3.7 且目前仅支持纯文本模型
**注意** GYM环境训练逻辑已在 ms-swift 3.8 中进行重构,如果您的 ms-swift 版本低于该版本,请参考对应版本的文档。

## Gym接口

Expand Down Expand Up @@ -105,12 +105,17 @@ RolloutResponseChoice(
messages=None)
"""
```
在 `rollout` 命令中使用参数 `use_gym_env` 来指定使用gym作为训练的环境接口
GYM环境训练可以视作一种特殊的多轮训练,区别在于使用GYM环境训练,奖励信息通过环境直接获取。

在 `rollout` 命令中使用参数 `use_gym_env` 来指定使用gym作为训练的环境接口。我们提供了兼容GYM环境的多轮规划器参考实现,见[内置多轮调度器实现](https://github.com/modelscope/ms-swift/blob/main/swift/plugin/multi_turn.py)中的 GymScheduler 类


```bash
CUDA_VISIBLE_DEVICES=0 \
swift rollout \
--model xxx \
--use_gym_env true \
--multi_turn_scheduler gym_scheduler \
--max_turns xxx
```

Expand All @@ -133,14 +138,11 @@ swift rollout \
```json
{"messages": [{"role": "system", "content": "你是个有用无害的助手"}, {"role": "user", "content": "告诉我明天的天气"}],"env_config":{"name":"custom_env","other_config":"xxxx"},"ctx_config":{"name":"custom_ctx","other_config":"xxxx"}}
```
2. gym 环境目前仅兼容纯文本模型和 AsyncEngine
2. 默认仅对最后一轮response进行训练,如果gym涉及到多轮response生成,使用参数`--loss_scale default`对所有轮次的response进行训练,具体参考[文档](./多轮训练.md#损失掩码)

3. 默认仅对最后一轮response进行训练,如果gym涉及到多轮response生成,使用参数`--loss_scale default`对所有轮次的response进行训练,具体参考[文档](./多轮训练.md#损失掩码)

4. 数据流程
3. 数据流程
整个gym数据流程如下:
<img src="../../../../resources/gym_env.png" width="400" />


5. 奖励日志
4. 奖励日志
由于gym的奖励是在step函数内计算完成,所以需要手动通过`info`返回日志,最终的记录会放在completions.jsonl中的`trajectory_infos`字段.
210 changes: 122 additions & 88 deletions docs/source/Instruction/GRPO/DeveloperGuide/多轮训练.md
Original file line number Diff line number Diff line change
@@ -1,79 +1,81 @@
# 多轮训练

注意:该 feature 需要使用 ms-swift>=3.6
**注意** 多轮训练逻辑已在 ms-swift 3.8 中进行重构,如果您的 ms-swift 版本低于该版本,请参考对应版本的文档。

在强化学习训练场景中,模型采样可能需要与环境进行多轮交互(如工具调用、外部API访问等)。这种交互式训练要求模型能够根据环境反馈信息进行连续推理。本文档将详细介绍如何在 GRPO 训练中自定义多轮训练流程。
在强化学习训练场景中,模型采样可能需要与环境进行多轮交互(如工具调用)。这种交互式训练要求模型能够根据环境反馈信息进行连续推理。本文档将详细介绍如何在 GRPO 训练中自定义多轮训练流程。

以下是多轮训练示例图,模型可能涉及多轮 rollout,包括环境交互、工具调用等步骤:

根据环境反馈插入方式不同,多轮可以分为:

- 新一轮推理:环境反馈结果作为 query,模型进行新一轮对话轮次进行响应
- 当轮续写:环境反馈结果插入模型当前回复中,模型在此基础上继续续写后续内容


我们可以自定义并通过参数 `multi_turn_scheduler` 设置多轮采样的规划器来实现多轮采样逻辑
```
--multi_turn_scheduler xxx
--max_turns xxx
```
两种方式的实现例子可以参考[最佳实践](#最佳实践)
![多轮示例图](../../../../resources/grpo_multi_turn.png)

## 多轮规划器 MultiTurnScheduler
多轮规划器是多轮训练的核心组件,其工作流程如下图所示:

`MultiTurnScheduler` 是一个抽象基类,提供了默认的多轮对话管理逻辑,其工作流程如下图所示:

<img src="https://raw.githubusercontent.com/modelscope/ms-swift/main/docs/resources/multiturn_pipeline.png " width="300" />

多轮规划器主要承担两大核心功能:
- **终止条件判断**:通过 `check_finished` 方法判断当前轮次推理是否应该结束
- **推理请求构造**:通过 `step` 方法构建下一轮推理的请求对象

多轮规划器主要承担两大功能:
- 终止条件判断:通过 check_finished 方法判断当前轮次推理是否应该结束
- 推理请求构造:通过 step 方法构建下一轮推理的请求对象
抽象基类 `MultiTurnScheduler` 的核心方法如下:

抽象基类 MultiTurnScheduler 代码如下
```python
class MultiTurnScheduler(ABC):

def __init__(self, max_turns: Optional[int] = None, *args, **kwargs):
self.max_turns = max_turns

@abstractmethod
def step(self, infer_request: 'RolloutInferRequest', result: 'RolloutResponseChoice',
current_turn: int) -> Union['RolloutInferRequest', Tuple['RolloutInferRequest', Dict]]:
pass

def check_finished(self, infer_request: 'RolloutInferRequest', result: 'RolloutResponseChoice',
def step(self, infer_request: 'RolloutInferRequest', response_choice: 'ChatCompletionResponseChoice',
current_turn: int) -> Dict:
"""
处理对话轮次之间的转换。

Args:
infer_request: 当前推理请求
response_choice: 当前轮次的响应
current_turn: 当前轮次数

Returns:
Dict[str, Any]: 包含推理结果的字典,结构如下:
- infer_request (必需): 下一轮的推理请求对象
- response_token_ids (可选): 每个 rollout 轮次的响应 token IDs
- response_loss_mask (可选): 每个 rollout 轮次响应的损失掩码
- rollout_infos (可选): 额外信息数据
"""
raise NotImplementedError

def check_finished(self, infer_request: 'RolloutInferRequest', response_choice: 'ChatCompletionResponseChoice',
current_turn: int) -> bool:
if result.finish_reason == 'length':
"""
检查多轮 rollout 是否应该结束的默认终止逻辑。

默认终止条件:
1. 当响应达到长度限制时 (finish_reason == 'length')
2. 当对话达到最大轮数时 (如果设置了 max_turns)

Args:
infer_request: 推理请求对象
response_choice: 包含生成结果的响应选择,包括 finish_reason
current_turn: 当前对话轮数

Returns:
bool: True 表示终止对话,False 表示继续
"""
if response_choice.finish_reason == 'length':
return True
if self.max_turns and current_turn >= self.max_turns:
return True
return False
```

> 如果你想要奖励函数获取多轮交互中的信息,请在 step 方法中返回额外的 dict 对象, 在奖励函数中的 kwargs中,获取 `multi_turn_infos`
`step` 和 `check_finished` 方法接收的参数说明:
- **infer_request**: 当前的推理请求
- **response_choice**: 当前轮次的推理结果
- **current_turn**: 当前推理轮次(从 1 开始)

```python
class Scheduler():
def step(self, infer_request: 'RolloutInferRequest', result: 'RolloutResponseChoice',
current_turn: int) -> Union['RolloutInferRequest', Tuple['RolloutInferRequest', Dict]]:
...
return infer_request, extra_dict
<details><summary>入参示例(点击展开)</summary>

class RewardFunction():
def __call__(self, completions, **kwargs):
infos = kwargs.get('multi_turn_infos', {})
...
```

step 和 check_finished 方法接收参数:
- infer_request: 上轮的推理请求,包括
- `messages` 键包含了模型的交互历史(注意:已包括当前模型推理结果)
- 多模态信息,如 `images`
- `data_dict` 包含了数据集中的其他列
- result: 上轮的推理结果,
- current_turn: 当前推理轮次 (从1开始)

入参示例
```python
infer_request
"""
Expand All @@ -93,9 +95,9 @@ RolloutInferRequest(
}
)
"""
result
response_choice
"""
RolloutResponseChoice(
ChatCompletionResponseChoice(
index=0,
message=ChatMessage(
role='assistant',
Expand All @@ -104,78 +106,110 @@ RolloutResponseChoice(
logprobs=None,
messages=None)
"""
# result.messages will be copied at the end of multi-turn inference.
# response_choice.messages will be copied at the end of multi-turn inference.
```
</details>

默认的 `check_finished` 逻辑会在两种情况下停止推理
<br>
<br>

默认的 `check_finished` 逻辑会在以下两种情况下停止推理:
- 模型回复被截断,即超出了 `max_completion_length`
- 模型推理轮数超出了限制的最大轮数

完整的默认多轮 rollout 逻辑请参考该类的 `run` 方法,我们也可以通过重载`run` 方法来实现自定义多轮逻辑。

推荐使用 AsyncEngine 来实现高效的批量数据异步多轮采样(只支持 external server mode),AsyncEngine 在多轮推理时能够减小推理过程中的计算气泡(如图)

<img src="https://raw.githubusercontent.com/modelscope/ms-swift/main/docs/resources/asyncengine.png" width="400" />
## 设置多轮训练参数

在 swift rollout 命令中,设置 multi_turn_scheduler 参数指定规划器

在 `rollout` 命令中使用参数 `use_async_engine` 来指定engine的种类
```bash
CUDA_VISIBLE_DEVICES=0 \
swift rollout \
--model xxx \
--model Qwen/Qwen3-1.7B \
--use_async_engine true \
--multi_turn_scheduler xxx \
--max_turns xxx
--multi_turn_scheduler thinking_tips_scheduler \
--vllm_max_model_len 32768 \
--vllm_gpu_memory_utilization 0.8 \
--max_turns 3
```

通过参数`external_plugins`, 我们可以将本地的多轮规划器注册进 ms-swift 中,具体实现参考[代码](https://github.com/modelscope/ms-swift/blob/main/examples/train/grpo/plugin/plugin.py)

多轮训练脚本参考
> 通过参数 `external_plugins`,我们可以将本地的多轮规划器注册到 ms-swift 中,具体实现请参考[代码](https://github.com/modelscope/ms-swift/blob/main/examples/train/grpo/plugin/plugin.py)。

- [server mode](https://github.com/modelscope/ms-swift/blob/main/examples/train/grpo/external/vllm_multi_turn.sh)
- [colocate mode](https://github.com/modelscope/ms-swift/blob/main/examples/train/grpo/internal/vllm_multi_turn.sh)
多轮训练脚本请参考[脚本](https://github.com/modelscope/ms-swift/blob/main/examples/train/grpo/external/vllm_multi_turn.sh)。


## 最佳实践
[插件代码示例](https://github.com/modelscope/ms-swift/blob/main/examples/train/grpo/plugin/plugin.py)中提供了两种多轮规划器的例子,实现在数学问题中提示模型再次思考并给出答案,分别对应两种多轮推理:
对于多轮 rollout,我们使用 AsyncEngine 来实现高效的批量数据异步多轮采样。AsyncEngine 在多轮推理时能够减少推理过程中的计算气泡:

- 第一种方式(新一轮推理):新插入一轮对话,提示模型的答案错误,需要重新思考(math_tip_trick_multi_turn)
- 第二种方式(续写):回溯到模型的思考阶段,并加入思考错误的提示 (math_tip_trick)
<img src="https://raw.githubusercontent.com/modelscope/ms-swift/main/docs/resources/asyncengine.png" width="400" />

在 `rollout` 命令中使用参数 `use_async_engine` 来指定 engine 的种类(默认使用 async engine):

## 注意事项

### 奖励函数
注意在奖励函数中,接受的 `completions` 参数为最后一轮模型回复,如果奖励函数需要根据模型多轮回复计算奖励,需要获取 `messages` 键来获取完整的多轮对话记录
## 高级设置

```python
class Reward(ORM):
### 自定义多轮交互逻辑
在以上默认逻辑中,我们用一条轨迹来计算多轮 rollout 的损失,这里需要假设多轮交互的过程中,模型的历史信息没有收到改变。

def __call__(completions, **kwargs):
print(kwargs.keys())
# dict_keys(['problem', 'solution', 'messages', 'is_truncated'])
messages = kwargs.get('messages')
...
```
而在一些多轮场景中,我们可以需要在多轮 rollout 过程中动态地修改模型的历史信息(比如压缩历史信息),此时,我们需要将每轮的 rollout 单独作为一条轨迹进行训练。

比较常见的一种场景是对于思考类模型,在实际推理过程中,模型通常只会保留最后一轮的思考内容,而忽略历史模型回复中的思考内容。

对于这类场景,我们需要重写多轮规划器中的交互逻辑,即重载 `run` 方法,从而单独返回每一轮的 Rollout 的结果。

框架内置的 `ThinkingModelTipsScheduler` 类展示了如何通过重写 `run()` 方法来实现完全自定义的多轮推理逻辑。请参考[内置多轮调度器实现](https://github.com/modelscope/ms-swift/blob/main/swift/plugin/multi_turn.py)

**注意**: 这种情况下,相同轨迹的数据会拆分为多条数据,在奖励相关的处理中,需要对相同轨迹的数据分配同样的reward。

可以在kwargs中获取 trajectory_inputs 获取完整轨迹的数据,具体实现参考[MultiTurnThinkingTips类](https://github.com/modelscope/ms-swift/blob/main/examples/train/grpo/plugin/plugin.py)

### 返回 response token ids
在默认的多轮交互流程中,规划器先把模型生成的文本字符串返回给 trainer,trainer 再将其重新 encode 为 token id,用于后续训练。为了避免这一步重复编码的开销,你可以让规划器直接返回 response_token_ids,省去 trainer 侧的再次 encode。

具体做法如下:

- 在 response_choice 对象中读取 token_ids 属性,即可获得本次 rollout 生成的 token 序列。
- 在 step/run 方法的返回值里加入 response_token_ids,trainer 便能直接使用这些 token id 参与训练,无需重新编码。

具体实现可以参考[ThinkingModelTipsScheduler](https://github.com/modelscope/ms-swift/blob/main/swift/plugin/multi_turn.py)类

### 损失掩码

在工具调用或环境交互返回结果时,若需将返回内容作为模型响应的一部分,建议对这些插入内容进行掩码处理,以确保模型在训练过程中不会对这些外部生成的内容计算损失。

这里需要通过设置参数 loss_scale ,实现自定义掩码逻辑,具体参考[定制化loss_scale文档](../../../Customization/插件化.md#定制化loss_scale)。
我们可以通过两种方式设置损失掩码

**第一种:设置 loss_scale**

ms-swift 提供 loss_scale 参数来对模型回复部分的内容进行损失缩放设置。比如设置`--loss_scale last_round`,可以将非最后一轮的模型回复的损失置零。我们也可以实现自定义 loss_scale,具体请参考[定制化 loss_scale 文档](../../../Customization/插件化.md#定制化loss_scale)。

> 注:在GRPO中,loss_scale 只提供掩码功能,不提供缩放功能。

默认 loss_scale 值:
**第二种:设置loss_mask**

grpo训练(即设置`multi_turn_scheduler`),loss_scale 默认为`default`,即对 messages 中的 每一轮 response 进行训练
> 如果数据集中本身包含 assistant response 也会被计算入内,如果想要排除数据集中的response , 需要自定义 loss_scale
在`step`或者`run`方法中设置 response_loss_mask, 可以在规划器中自定义损失掩码。前提需要返回response token ids,返回的 response_loss_mask 需要与 response token ids等长。当返回 response_loss_mask 时,loss_scale 参数失效。

如果只想只计算最后一轮 response(rollout结果)损失,请修改为`last_round`
response_loss_mask 返回可以参考[ToolCallScheduler类](https://github.com/modelscope/ms-swift/blob/main/examples/train/grpo/plugin/plugin.py)

### 奖励函数相关

注意 loss_scale 可以用于
在奖励函数中获取多轮 Rollout 中的信息

在`step`或者`run`方法中,返回 `rollout_infos` 对象,在奖励函数的 kwargs 中获取 `rollout_infos`:

```python
class Scheduler():
def step(self, infer_request: 'RolloutInferRequest', response_choice: 'ChatCompletionResponseChoice',
current_turn: int) -> Dict:
...
return {'infer_request': infer_request, 'rollout_infos': extra_dict}

class RewardFunction():
def __call__(self, completions, **kwargs):
infos = kwargs.get('rollout_infos', {})
...
```

1. 标注需要训练的 tokens (0为不训练)
2. 放缩 tokens 的训练权重
### 在 Scheduler 中获取额外的数据集信息

而 GRPO 中暂不支持 loss_scale 的权重设置
在训练侧设置参数`--vllm_server_pass_dataset`,可将数据集中的其他列传入多轮规划器。在`infer_request.data_dict`中获取
8 changes: 7 additions & 1 deletion docs/source/Instruction/GRPO/GetStarted/GRPO.md
Original file line number Diff line number Diff line change
Expand Up @@ -234,7 +234,13 @@ swift rollout \
- {reward_func_name}:特定奖励
- entropy:entropy token 均值,在设置`log_entropy`时记录

设置 `report_to wandb/swanlab` 将训练动态推送到对应的平台
设置 `report_to wandb/swanlab` 将训练动态Table推送到对应的平台

如果需要在Table中额外记录其他列,请在 `GRPOTrainer._generate_and_score_completions` 方法中,设置 metrics_to_gather 字典。

默认自动检测
- `image`:视觉数据集图像输入。(暂时只支持wandb)
- `solution`:数据集中的 solution 列。

## FAQ
**1. 训练过程中 loss 等于0 / 接近0 / 小于0**
Expand Down
1 change: 1 addition & 0 deletions docs/source/Instruction/命令行参数.md
Original file line number Diff line number Diff line change
Expand Up @@ -511,6 +511,7 @@ reward模型参数将在PPO、GRPO中使用。
- vllm_server_host:vLLM server host地址,默认为None,使用外部vLLM server时使用。
- vllm_server_port vLLM server 服务端口,默认为8000。
- vllm_server_timeout 连接vLLM server的超时时间,默认为 240s。
- vllm_server_pass_dataset: 透传额外的数据集信息到vLLM server,用于多轮训练。
- async_generate: 异步rollout以提高训练速度,注意开启时采样会使用上一轮更新的模型进行采样,不支持多轮场景。默认`false`.
- vllm_mode colocate 参数(更多参数支持参考[vLLM参数](#vLLM参数)。)
- vllm_gpu_memory_utilization: vllm透传参数,默认为0.9。
Expand Down
1 change: 1 addition & 0 deletions docs/source_en/Instruction/Command-line-parameters.md
Original file line number Diff line number Diff line change
Expand Up @@ -522,6 +522,7 @@ The meanings of the following parameters can be referenced [here](https://huggin
- vllm_server_host: The host address of the vLLM server. Default is None. This is used when connecting to an external vLLM server.
- vllm_server_port: The service port of the vLLM server. Default is 8000.
- vllm_server_timeout: The connection timeout for the vLLM server. Default is 240 seconds.
- vllm_server_pass_dataset: pass additional dataset information through to the vLLM server for multi-turn training.
- async_generate: Use async rollout to improve train speed. Note that rollout will use the model updated in the previous round when enabled. Multi-turn scenarios are not supported. Default is `false`.
- vllm_mode colocate parameter (For more parameter support, refer to the [vLLM Arguments](#vLLM-Arguments).)
- vllm_gpu_memory_utilization: vLLM passthrough parameter, default is 0.9.
Expand Down
Loading
Loading