Skip to content

Commit a5241be

Browse files
committed
Merge branch 'main' of github.com:modelscope/Trinity-RFT into dev/add_task_scheduler
2 parents 41bd02f + 9f1719e commit a5241be

File tree

66 files changed

+641
-766
lines changed

Some content is hidden

Large Commits have some content hidden by default. Use the searchbox below for content that may be hidden.

66 files changed

+641
-766
lines changed

benchmark/config/countdown-template.yaml

Lines changed: 2 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -35,11 +35,9 @@ buffer:
3535
rollout_args:
3636
temperature: 1.0
3737
logprobs: 0
38+
default_workflow_type: math_workflow
39+
default_reward_fn_type: countdown_reward
3840
eval_tasksets: []
39-
default_workflow_type: math_workflow
40-
default_reward_fn_type: countdown_reward
41-
system_prompt: null
42-
reply_prefix: null
4341
trainer_input:
4442
experience_buffer:
4543
name: experience_buffer

benchmark/config/gsm8k-template.yaml

Lines changed: 3 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -40,11 +40,9 @@ buffer:
4040
rollout_args:
4141
temperature: 1.0
4242
logprobs: 0
43+
default_workflow_type: math_workflow
44+
default_reward_fn_type: math_reward
4345
eval_tasksets: []
44-
default_workflow_type: math_workflow
45-
default_reward_fn_type: math_reward
46-
system_prompt: null
47-
reply_prefix: null
4846
trainer_input:
4947
experience_buffer:
5048
name: experience_buffer
@@ -79,7 +77,7 @@ trainer:
7977
enable_preview: true
8078
grad_clip: 1.0
8179
use_dynamic_bsz: true
82-
ppo_max_token_len_per_gpu: 10240
80+
max_token_len_per_gpu: 10240
8381
ulysses_sequence_parallel_size: 1
8482
monitor:
8583
monitor_type: wandb

docs/sphinx_doc/source/tutorial/example_async_mode.md

Lines changed: 6 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -39,14 +39,14 @@ buffer:
3939
response_key: 'answer'
4040
rollout_args:
4141
temperature: 1.0
42-
default_workflow_type: 'math_workflow'
42+
default_workflow_type: 'math_workflow'
4343
trainer_input:
4444
experience_buffer:
4545
name: gsm8k_buffer
4646
storage_type: queue
4747
path: 'sqlite:///gsm8k.db'
4848
explorer:
49-
runner_num: 32
49+
runner_per_model: 8
5050
rollout_model:
5151
engine_num: 4
5252
synchronizer:
@@ -86,7 +86,7 @@ buffer:
8686
response_key: 'answer'
8787
rollout_args:
8888
temperature: 1.0
89-
default_workflow_type: 'math_workflow'
89+
default_workflow_type: 'math_workflow'
9090
trainer_input:
9191
experience_buffer:
9292
name: gsm8k_buffer
@@ -98,7 +98,7 @@ synchronizer:
9898
trainer:
9999
grad_clip: 1.0
100100
use_dynamic_bsz: true
101-
ppo_max_token_len_per_gpu: 16384
101+
max_token_len_per_gpu: 16384
102102
ulysses_sequence_parallel_size: 1
103103
```
104104

@@ -133,7 +133,7 @@ cluster: # important
133133
gpu_per_node: 8
134134
explorer:
135135
name: 'explorer_new' # important
136-
runner_num: 64
136+
runner_per_model: 8
137137
rollout_model:
138138
engine_num: 8
139139
buffer:
@@ -150,7 +150,7 @@ buffer:
150150
response_key: 'answer'
151151
rollout_args:
152152
temperature: 1.0
153-
default_workflow_type: 'math_workflow'
153+
default_workflow_type: 'math_workflow'
154154
trainer_input:
155155
experience_buffer:
156156
name: gsm8k_buffer

docs/sphinx_doc/source/tutorial/example_reasoning_basic.md

Lines changed: 4 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -77,6 +77,7 @@ buffer:
7777
response_key: 'answer'
7878
rollout_args:
7979
temperature: 1.0
80+
default_workflow_type: 'math_workflow'
8081
eval_tasksets:
8182
- name: gsm8k-eval
8283
storage_type: file
@@ -86,15 +87,15 @@ buffer:
8687
format:
8788
prompt_key: 'question'
8889
response_key: 'answer'
89-
default_workflow_type: 'math_workflow'
90+
default_workflow_type: 'math_workflow'
9091
trainer_input:
9192
experience_buffer:
9293
name: gsm8k_buffer
9394
storage_type: queue
9495
path: 'sqlite:///gsm8k.db'
9596
explorer:
9697
eval_interval: 50
97-
runner_num: 16
98+
runner_per_model: 16
9899
rollout_model:
99100
engine_num: 1
100101
synchronizer:
@@ -117,7 +118,7 @@ trinity run --config examples/grpo_gsm8k/gsm8k.yaml
117118

118119
## Optional: RFT with SFT Warmup
119120

120-
Before RFT, we may use SFT as a warmup step. Trinity-RFT supports adding SFT warmup stage before RFT by setting `stages` in the config file. The `sft_warmup_dataset` specifies the dataset used for SFT warmup, and `sft_warmup_steps` specifies the number of training steps for SFT warmup.
121+
Before RFT, we may use SFT as a warmup step. Trinity-RFT supports adding SFT warmup stage before RFT by setting `stages` in the config file. The `experience_buffer` specifies the dataset used for SFT warmup, and `total_steps` specifies the number of training steps for SFT warmup.
121122

122123
```yaml
123124
# Properly add the following configs in gsm8k.yaml

docs/sphinx_doc/source/tutorial/example_step_wise.md

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -121,15 +121,15 @@ buffer:
121121
workflow_args:
122122
max_env_steps: 30
123123
enable_progress_bar: false
124-
default_workflow_type: 'step_wise_alfworld_workflow'
124+
default_workflow_type: 'step_wise_alfworld_workflow'
125125
trainer_input:
126126
experience_buffer:
127127
name: alfworld_buffer
128128
storage_type: queue
129129
use_priority_queue: true
130130
explorer:
131131
max_repeat_times_per_runner: 1
132-
runner_num: 32
132+
runner_per_model: 32
133133
max_timeout: 3600
134134
rollout_model:
135135
enable_history: true
@@ -152,7 +152,7 @@ trainer:
152152
save_interval: 50
153153
grad_clip: 1.0
154154
use_dynamic_bsz: true
155-
ppo_max_token_len_per_gpu: 16384
155+
max_token_len_per_gpu: 16384
156156
ulysses_sequence_parallel_size: 1
157157
```
158158

docs/sphinx_doc/source/tutorial/trinity_configs.md

Lines changed: 10 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -200,6 +200,7 @@ buffer:
200200
batch_size: 32
201201
train_batch_size: 256
202202
total_epochs: 100
203+
total_steps: null
203204
204205
explorer_input:
205206
taskset:
@@ -214,9 +215,6 @@ buffer:
214215
...
215216
buffer_2:
216217
...
217-
218-
default_workflow_type: 'math_workflow'
219-
default_reward_fn_type: 'countdown_reward'
220218
```
221219

222220
- `batch_size`: Number of tasks used per training step. *Please do not multiply this value by the `algorithm.repeat_times` manually*.
@@ -231,6 +229,9 @@ Defines the dataset(s) used by the explorer for training and evaluation.
231229
```yaml
232230
buffer:
233231
explorer_input:
232+
default_workflow_type: 'math_workflow'
233+
default_eval_workflow_type: 'math_workflow'
234+
default_reward_fn_type: 'countdown_reward'
234235
taskset:
235236
name: countdown_train
236237
storage_type: file
@@ -262,7 +263,10 @@ buffer:
262263
```
263264

264265
- `buffer.explorer_input.taskset`: Task dataset used for training exploration policies.
265-
- `buffer.explorer_input.eval_taskset`: List of task datasets used for evaluation.
266+
- `buffer.explorer_input.eval_tasksets`: List of task datasets used for evaluation.
267+
- `buffer.explorer_input.default_workflow_type`: Default workflow type for all task datasets under `explorer_input` if not specified at the dataset level.
268+
- `buffer.explorer_input.default_eval_workflow_type`: Default evaluation workflow type for all eval task datasets under `explorer_input` if not specified at the dataset level.
269+
- `buffer.explorer_input.default_reward_fn_type`: Default reward function type for all task datasets under `explorer_input` if not specified at the dataset level.
266270

267271
The configuration for each task dataset is defined as follows:
268272

@@ -413,7 +417,7 @@ trainer:
413417
save_strategy: "unrestricted"
414418
grad_clip: 1.0
415419
use_dynamic_bsz: true
416-
ppo_max_token_len_per_gpu: 16384
420+
max_token_len_per_gpu: 16384
417421
ulysses_sequence_parallel_size: 1
418422
trainer_config: null
419423
```
@@ -429,7 +433,7 @@ trainer:
429433
- `unrestricted`: No restrictions on saving operations; multiple nodes, processes, or threads are allowed to save the model simultaneously.
430434
- `grad_clip`: Gradient clipping for updates.
431435
- `use_dynamic_bsz`: Whether to use dynamic batch size.
432-
- `ppo_max_token_len_per_gpu`: The maximum number of tokens to be processed in forward and backward when updating the policy. Effective when `use_dynamic_bsz=true`.
436+
- `max_token_len_per_gpu`: The maximum number of tokens to be processed in forward and backward when updating the policy. Effective when `use_dynamic_bsz=true`.
433437
- `ulysses_sequence_parallel_size`: Sequence parallel size.
434438
- `trainer_config`: The trainer configuration provided inline.
435439
---

docs/sphinx_doc/source_zh/tutorial/example_async_mode.md

Lines changed: 6 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -39,14 +39,14 @@ buffer:
3939
response_key: 'answer'
4040
rollout_args:
4141
temperature: 1.0
42-
default_workflow_type: 'math_workflow'
42+
default_workflow_type: 'math_workflow'
4343
trainer_input:
4444
experience_buffer:
4545
name: gsm8k_buffer
4646
storage_type: queue
4747
path: 'sqlite:///gsm8k.db'
4848
explorer:
49-
runner_num: 32
49+
runner_per_model: 16
5050
rollout_model:
5151
engine_num: 4
5252
synchronizer:
@@ -86,7 +86,7 @@ buffer:
8686
response_key: 'answer'
8787
rollout_args:
8888
temperature: 1.0
89-
default_workflow_type: 'math_workflow'
89+
default_workflow_type: 'math_workflow'
9090
trainer_input:
9191
experience_buffer:
9292
name: gsm8k_buffer
@@ -98,7 +98,7 @@ synchronizer:
9898
trainer:
9999
grad_clip: 1.0
100100
use_dynamic_bsz: true
101-
ppo_max_token_len_per_gpu: 16384
101+
max_token_len_per_gpu: 16384
102102
ulysses_sequence_parallel_size: 1
103103
```
104104

@@ -133,7 +133,7 @@ cluster: # important
133133
gpu_per_node: 8
134134
explorer:
135135
name: 'explorer_new' # important
136-
runner_num: 64
136+
runner_per_model: 8
137137
rollout_model:
138138
engine_num: 8
139139
buffer:
@@ -150,7 +150,7 @@ buffer:
150150
response_key: 'answer'
151151
rollout_args:
152152
temperature: 1.0
153-
default_workflow_type: 'math_workflow'
153+
default_workflow_type: 'math_workflow'
154154
trainer_input:
155155
experience_buffer:
156156
name: gsm8k_buffer

docs/sphinx_doc/source_zh/tutorial/example_reasoning_basic.md

Lines changed: 4 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -77,6 +77,7 @@ buffer:
7777
response_key: 'answer'
7878
rollout_args:
7979
temperature: 1.0
80+
default_workflow_type: 'math_workflow'
8081
eval_tasksets:
8182
- name: gsm8k-eval
8283
storage_type: file
@@ -86,15 +87,15 @@ buffer:
8687
format:
8788
prompt_key: 'question'
8889
response_key: 'answer'
89-
default_workflow_type: 'math_workflow'
90+
default_workflow_type: 'math_workflow'
9091
trainer_input:
9192
experience_buffer:
9293
name: gsm8k_buffer
9394
storage_type: queue
9495
path: 'sqlite:///gsm8k.db'
9596
explorer:
9697
eval_interval: 50
97-
runner_num: 16
98+
runner_per_model: 16
9899
rollout_model:
99100
engine_num: 1
100101
synchronizer:
@@ -117,7 +118,7 @@ trinity run --config examples/grpo_gsm8k/gsm8k.yaml
117118

118119
## 进阶选项:带 SFT warmup 的 RFT
119120

120-
在进行 RFT 之前,我们可以先使用 SFT 作为预热步骤。Trinity-RFT 支持通过在配置文件中设置 `stages` 来添加 SFT 预热阶段。`sft_warmup_dataset` 指定用于 SFT warmup 的数据集,`sft_warmup_steps` 指定 SFT warmup 的训练步数。
121+
在进行 RFT 之前,我们可以先使用 SFT 作为预热步骤。Trinity-RFT 支持通过在配置文件中设置 `stages` 来添加 SFT 预热阶段。`experience_buffer` 指定用于 SFT warmup 的数据集,`total_steps` 指定 SFT warmup 的训练步数。
121122

122123
```yaml
123124
# 在 gsm8k.yaml 中正确添加以下配置

docs/sphinx_doc/source_zh/tutorial/example_step_wise.md

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -119,15 +119,15 @@ buffer:
119119
workflow_args:
120120
max_env_steps: 30
121121
enable_progress_bar: false
122-
default_workflow_type: 'step_wise_alfworld_workflow'
122+
default_workflow_type: 'step_wise_alfworld_workflow'
123123
trainer_input:
124124
experience_buffer:
125125
name: alfworld_buffer
126126
storage_type: queue
127127
use_priority_queue: true
128128
explorer:
129129
max_repeat_times_per_runner: 1
130-
runner_num: 32
130+
runner_per_model: 16
131131
max_timeout: 3600
132132
rollout_model:
133133
enable_history: true
@@ -150,7 +150,7 @@ trainer:
150150
save_interval: 50
151151
grad_clip: 1.0
152152
use_dynamic_bsz: true
153-
ppo_max_token_len_per_gpu: 16384
153+
max_token_len_per_gpu: 16384
154154
ulysses_sequence_parallel_size: 1
155155
```
156156

docs/sphinx_doc/source_zh/tutorial/trinity_configs.md

Lines changed: 9 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -214,9 +214,6 @@ buffer:
214214
...
215215
buffer_2:
216216
...
217-
218-
default_workflow_type: 'math_workflow'
219-
default_reward_fn_type: 'countdown_reward'
220217
```
221218

222219
- `batch_size`: 每个训练步骤使用的任务数。*请勿手动将此值乘以 `algorithm.repeat_times`*。
@@ -231,6 +228,9 @@ buffer:
231228
```yaml
232229
buffer:
233230
explorer_input:
231+
default_workflow_type: 'math_workflow'
232+
default_eval_workflow_type: 'math_workflow'
233+
default_reward_fn_type: 'countdown_reward'
234234
taskset:
235235
name: countdown_train
236236
storage_type: file
@@ -256,13 +256,14 @@ buffer:
256256
response_key: 'answer'
257257
rollout_args:
258258
temperature: 0.1
259-
default_workflow_type: 'math_workflow'
260-
default_reward_fn_type: 'countdown_reward'
261259
...
262260
```
263261

264262
- `buffer.explorer_input.taskset`: 用于训练探索策略的任务数据集。
265-
- `buffer.explorer_input.eval_taskset`: 用于评估的任务数据集列表。
263+
- `buffer.explorer_input.eval_tasksets`: 用于评测的任务数据集列表。
264+
- `buffer.explorer_input.default_workflow_type`: 若未在数据集级别指定,则为所有任务数据集设置默认的工作流类型。
265+
- `buffer.explorer_input.default_eval_workflow_type`: 若未在数据集级别指定,则为所有评测任务数据集设置默认的工作流类型。
266+
- `buffer.explorer_input.default_reward_fn_type`: 若未在数据集级别指定,则为所有任务数据集设置默认的奖励类型。
266267

267268
每个任务数据集的配置定义如下:
268269

@@ -413,7 +414,7 @@ trainer:
413414
save_strategy: "unrestricted"
414415
grad_clip: 1.0
415416
use_dynamic_bsz: true
416-
ppo_max_token_len_per_gpu: 16384
417+
max_token_len_per_gpu: 16384
417418
ulysses_sequence_parallel_size: 1
418419
trainer_config: null
419420
```
@@ -429,7 +430,7 @@ trainer:
429430
- `unrestricted`:不限制保存操作,允许多个节点、进程或线程同时保存模型。
430431
- `grad_clip`: 梯度裁剪阈值。
431432
- `use_dynamic_bsz`: 是否使用动态批量大小。
432-
- `ppo_max_token_len_per_gpu`: 训练过程中,每个 GPU 最大 token 长度; 当 `use_dynamic_bsz=true` 时生效。
433+
- `max_token_len_per_gpu`: 训练过程中,每个 GPU 最大 token 长度; 当 `use_dynamic_bsz=true` 时生效。
433434
- `ulysses_sequence_parallel_size`: 序列并行的并行度,即用于分割单个序列的 GPU 数量。
434435
- `trainer_config`: 内联提供的 trainer 配置。
435436

0 commit comments

Comments
 (0)