Skip to content

Commit c389bce

Browse files
authored
Reorganize config (#46)
1 parent 3830dd3 commit c389bce

Some content is hidden

Large Commits have some content hidden by default. Use the searchbox below for content that may be hidden.

43 files changed

+737
-728
lines changed

docs/sphinx_doc/source/tutorial/example_async_mode.md

Lines changed: 6 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -7,17 +7,15 @@ Trinity-RFT supports an asynchronous mode by running the trainer and explorer in
77
For this purpose, we prepare two main config files: `trainer.yaml` and `explorer.yaml`.
88
The main difference between them is that in `trainer.yaml` we set `mode=train`, while in `explorer.yaml` we set `mode=explore`.
99
In addition, we need to configure the following parameters in both files.
10-
The model weights of the explorer and trainer are synchronized once every `sync_iteration_interval * batch_size` tasks.
10+
The model weights of the explorer and trainer are synchronized once every `sync_interval * batch_size` tasks.
1111

1212
```yaml
13-
global_config:
14-
batch_size: <batch_size>
15-
# The same checkpoint path
16-
model:
17-
checkpoint_path: /PATH/TO/CHECKPOINT
13+
project: tutorial
14+
name: async_mode_example
15+
checkpoint_root_dir: /PATH/TO/CHECKPOINT
1816

19-
# The same data_base path
2017
buffer:
18+
batch_size: <batch_size>
2119
trainer_input:
2220
experience_buffer:
2321
name: gsm8k_buffer
@@ -26,7 +24,7 @@ buffer:
2624

2725
synchronizer:
2826
sync_method: 'checkpoint'
29-
sync_iteration_interval: <sync_iteration_interval>
27+
sync_interval: <sync_interval>
3028
```
3129
3230
You may run this example with the following command:

docs/sphinx_doc/source/tutorial/example_dpo.md

Lines changed: 5 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -45,6 +45,8 @@ We run the experiment in a train mode, as there is no Explorer. To enable this m
4545
```yaml
4646
# In dpo.yaml
4747
mode: train
48+
algorithm:
49+
algorithm_type: dpo
4850
synchronizer:
4951
sync_method: 'checkpoint'
5052
buffer:
@@ -56,14 +58,9 @@ buffer:
5658
prompt_key: <prompt_key>
5759
chosen_key: <chosen_key>
5860
rejected_key: <rejected_key>
59-
global_config:
60-
algorithm_type: dpo
61-
62-
# In train_dpo.yaml
63-
actor_rollout_ref:
64-
actor:
65-
use_kl_loss: True
66-
kl_loss_coef: 0.1 # value of beta in DPO
61+
trainer:
62+
actor_use_kl_loss: True
63+
actor_kl_loss_coef: 0.1 # value of beta in DPO
6764
```
6865
6966
### Run the Experiment

docs/sphinx_doc/source/tutorial/example_reasoning_basic.md

Lines changed: 15 additions & 18 deletions
Original file line numberDiff line numberDiff line change
@@ -53,16 +53,13 @@ We use the configurations in [`gsm8k.yaml`](https://github.com/modelscope/Trinit
5353

5454
```yaml
5555
# In gsm8k.yaml
56-
explorer:
56+
algorithm:
57+
algorithm_type: grpo / ppo
5758
repeat_times: {number of rollouts for each task}
5859
59-
# In train_gsm8k.yaml
60-
actor_rollout_ref:
61-
actor:
62-
use_kl_loss: True (fro GRPO) / False (for PPO)
63-
kl_loss_coef: 0.001
64-
algorithm:
65-
adv_estimator: grpo (fro GRPO) / gae (for PPO)
60+
trainer:
61+
actor_use_kl_loss: True (fro GRPO) / False (for PPO)
62+
actort_kl_loss_coef: 0.001
6663
```
6764

6865
### Run the Experiment
@@ -76,20 +73,20 @@ trinity run --config examples/grpo_gsm8k/gsm8k.yaml
7673

7774
## Optional: RFT with SFT Warmup
7875

79-
Before RFT, we may use SFT as a warmup step. We need to set `trainer.sft_warmup_steps > 0` and prepare the SFT data to `buffer.train_dataset.path=$DATASET_PATH/{sft_data}`.
76+
Before RFT, we may use SFT as a warmup step. We need to set `buffer.trainer_input.sft_warmup_steps > 0` and prepare the SFT data to `buffer.trainer_input.sft_warmup_dataset.path=$DATASET_PATH/{sft_data}`.
8077

8178
```yaml
8279
# Properly set the following configs in gsm8k.yaml
8380
buffer:
84-
sft_warmup_dataset:
85-
storage_type: file
86-
path: <$DATASET_PATH/{sft_data}>
87-
format:
88-
prompt_type: <prompt_type> # messages/plaintext/chatpair
89-
prompt_key: <prompt_key>
90-
response_key: <response_key>
91-
trainer:
92-
sft_warmup_steps: 10
81+
trainer_input:
82+
sft_warmup_dataset:
83+
storage_type: file
84+
path: <$DATASET_PATH/{sft_data}>
85+
format:
86+
prompt_type: <prompt_type> # messages/plaintext/chatpair
87+
prompt_key: <prompt_key>
88+
response_key: <response_key>
89+
sft_warmup_steps: 10
9390
```
9491

9592
The following command runs SFT and RFT in sequence:

docs/sphinx_doc/source/tutorial/trinity_configs.md

Lines changed: 32 additions & 43 deletions
Original file line numberDiff line numberDiff line change
@@ -5,33 +5,36 @@ The following is the main config file for Trinity-RFT. Take `countdown.yaml` as
55
## Global Config
66

77
```yaml
8+
project: Trinity-RFT
9+
name: example
810
mode: both
9-
global_config:
10-
algorithm_type: ppo
11-
total_epochs: 1
12-
batch_size: 96
13-
eval_interval: 1000
14-
eval_on_latest_ckp: true
11+
checkpoint_root_dir: /PATH/TO/CHECKPOINT
1512
```
1613
14+
- `project`: The name of the project.
15+
- `name`: The name of the experiment.
1716
- `mode`: The mode of the experiment, chosen from `both`, `train`, `explore` or `bench`. `both` means both trainer and explorer are launched; `train` means only trainer is launched; `explore` means only explorer is launched; `bench` conducts benchmark evaluation. Default is `both`.
18-
- `global_config.algorithm_type`: The type of the algorithm, Support `ppo`, `grpo`, `opmd` and `dpo`.
19-
- `global_config.total_epochs`: The total number of epochs. It should be checked manually.
20-
- `global_config.batch_size`: The batch size used for training. It should be checked manually.
21-
- `global_config.eval_interval`: The interval steps between two evaluations. Default is `1000`.
22-
- `global_config.eval_on_latest_ckp`: Whether to evaluate on only the latest checkpoint or all the checkpoints in the path. Only valid in `bench` mode. Default is `true`.
17+
- `checkpoint_root_dir`: The root directory to save the checkpoints. Sepcifically, the generated checkpoints will be saved in `<checkpoint_root_dir>/<project>/<name>/.
2318

19+
## Algorithm
20+
21+
```yaml
22+
algorithm:
23+
algorithm_type: grpo
24+
repeat_times: 1
25+
```
26+
27+
- `algorithm.algorithm_type`: The type of the algorithm. Support `ppo`, `grpo`, `opmd` and `dpo`.
28+
- `algorithm.repeat_times`: The number of times to repeat each task. Used for GRPO-like algorithm. Default is `1`.
2429

2530
## Monitor
2631

2732
```yaml
2833
monitor:
29-
project: "Trinity-RFT-countdown"
30-
name: "qwen2.5-1.5B-countdown"
34+
monitor_type: MonitorType.WANDB
3135
```
3236

33-
- `monitor.project`: The project name. It must be set manually.
34-
- `monitor.name`: The name of the experiment. It must be set manually.
37+
- `monitor.monitor_type`: The type of the monitor. For now, `MonitorType.WANDB` and `MonitorType.TENSORBOARD` are supported.
3538

3639

3740
## Data Processing
@@ -69,16 +72,11 @@ The `model` configuration specifies the model used for training. It includes the
6972
model:
7073
model_path: '/PATH/TO/MODEL/CHECKPOINT/'
7174
critic_model_path: ''
72-
max_prompt_tokens: 256
73-
max_response_tokens: 1024
74-
checkpoint_path: 'checkpoints/qwen2.5-1.5B-countdown'
7575
```
7676

7777
- `model.model_path`: The path to the model checkpoint. It must be set manually.
7878
- `model.critic_model_path`: The path to the critic model checkpoint. If not set, the `model.critic_model_path` will be set to `model.model_path`.
79-
- `model.max_prompt_tokens`: The maximum number of tokens in the prompt. Default is `2048`. It should be set manually.
80-
- `model.max_response_tokens`: The maximum number of tokens in the response. Default is `2048`. It should be set manually.
81-
- `model.checkpoint_path`: The path to the checkpoint of the model. It must be set manually.
79+
8280

8381
## Cluster
8482

@@ -108,7 +106,7 @@ buffer:
108106
prompt_key: 'question'
109107
response_key: 'answer'
110108
rollout_args:
111-
repeat_times: 1
109+
n: 1
112110
temperature: 1.0
113111
logprobs: 0
114112
eval_tasksets: []
@@ -129,7 +127,7 @@ buffer:
129127
- `buffer.explorer_input.taskset.path`: The path to the taskset.
130128
- `buffer.explorer_input.taskset.split`: The split name of the taskset used for training. Default is `train`.
131129
- `buffer.explorer_input.taskset.format`: The format of the taskset. It includes `prompt_key`, `response_key`, `workflow_key` and `reward_fn_key`.
132-
- `buffer.explorer_input.taskset.rollout_args.repeat_times`: The number of times to repeat each task, used for GRPO-like algorithms. Default is `1`.
130+
- `buffer.explorer_input.taskset.rollout_args.n`: The number of times to repeat each task. This field is automatically set to `algorithm.repeat_times`.
133131
- `buffer.explorer_input.taskset.rollout_args.temperature`: The temperature used in vLLM. Default is `1.0`.
134132
- `buffer.explorer_input.taskset.rollout_args.logprobs`: The logprobs used in vLLM. Default is `0`.
135133
- `buffer.explorer_input.eval_tasksets`: The configuration of the eval tasksets. It is a list of tasksets which will be used for evaluation. And it is empty by default.
@@ -143,22 +141,19 @@ buffer:
143141

144142
## Explorer
145143

146-
The `explorer` configuration specifies the explorer configuration. It includes the type of the engine, the number of engines, the number of workflow runners, the tensor parallel size, whether to enable prefix caching, whether to enforce eager mode, the data type, the `temperature`, the `top-p`, the `top-k`, the `seed`, the `logprobs`, the number of times to repeat each task, whether to use Ray, the backend, the maximum number of pending requests, and the maximum number of waitingsteps.
144+
The `explorer` configuration specifies the explorer configuration. It includes the type of the engine, the number of engines, the number of workflow runners, the tensor parallel size, whether to enable prefix caching, whether to enforce eager mode, the data type, the `temperature`, the `top-p`, the `top-k`, the `seed`, the `logprobs`, the number of times to repeat each task, the maximum number of pending requests, and the maximum number of waitingsteps.
147145

148146
```yaml
149147
explorer:
150-
engine_type: vllm_async
151-
engine_num: 2
152148
runner_num: 32
153-
tensor_parallel_size: 1
154-
enable_prefix_caching: false
155-
enforce_eager: true
156-
dtype: bfloat16
157-
seed: 42
158-
use_ray: false
159-
backend: 'nccl'
160-
max_pending_requests: 32
161-
max_waiting_steps: 4
149+
rollout_model:
150+
engine_type: vllm_async
151+
engine_num: 2
152+
tensor_parallel_size: 1
153+
enable_prefix_caching: false
154+
enforce_eager: true
155+
dtype: bfloat16
156+
seed: 42
162157
```
163158

164159
- `explorer.engine_type`: The type of the engine, Support `vllm_async` and `vllm_sync`. Default is `vllm_async`.
@@ -169,10 +164,8 @@ explorer:
169164
- `explorer.enforce_eager`: Whether to enforce eager mode. Default is `True`.
170165
- `explorer.dtype`: The data type used in vLLM. Default is `bfloat16`.
171166
- `explorer.seed`: The seed used in vLLM. Default is `42`.
172-
- `explorer.use_ray`: Whether to use Ray. Default is `False`.
173-
- `explorer.backend`: The backend used in vLLM. Default is `nccl`.
174-
- `explorer.max_pending_requests`: The maximum number of pending requests. Default is `32`.
175-
- `explorer.max_waiting_steps`: The maximum number of waiting steps. Default is `4`.
167+
- `explorer.rollout_model.max_prompt_tokens`: The maximum number of tokens in the prompt. Default is `2048`. It should be set manually.
168+
- `explorer.rollout_model.max_response_tokens`: The maximum number of tokens in the response. Default is `2048`. It should be set manually.
176169

177170
## Synchronizer
178171

@@ -195,15 +188,11 @@ Support `nccl` and `checkpoint`, `nccl` represents that model weights in `explor
195188
trainer:
196189
trainer_type: 'verl'
197190
trainer_config_path: 'examples/ppo_countdown/train_countdown.yaml'
198-
sft_warmup_steps: 0
199-
eval_interval: 1000
200191
save_interval: 100
201192
```
202193

203194
- `trainer.trainer_type`: The backend of the trainer, Only `verl` is supported.
204195
- `trainer.trainer_config_path`: The path to the trainer configuration file. It must be set manually.
205-
- `trainer.sft_warmup_steps`: The number of steps to warm up the model. Default is `0`.
206-
- `trainer.eval_interval`: The interval steps between two evaluations. Default is `1000`.
207196
- `trainer.save_interval`: The interval steps between two checkpoints. Default is `100`.
208197

209198
### veRL Trainer Configuration

docs/sphinx_doc/source/tutorial/trinity_programming_guide.md

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -116,7 +116,7 @@ class ExampleWorkflow(Workflow):
116116
"content": f"Question:\n{self.question}",
117117
}
118118
],
119-
n=self.task.rollout_args.repeat_times,
119+
n=self.task.rollout_args.n,
120120
temperature=self.task.rollout_args.temperature,
121121
)
122122
reward: float = self.calculate_reward(response.response_text, self.answer)

examples/async_gsm8k/explorer.yaml

Lines changed: 17 additions & 23 deletions
Original file line numberDiff line numberDiff line change
@@ -1,18 +1,20 @@
1+
project: "Trinity-RFT-gsm8k"
2+
name: "async-qwen2.5-1.5B-gsm8k"
13
mode: explore
2-
global_config:
3-
total_epochs: 20
4-
batch_size: 96
5-
eval_interval: 10
4+
checkpoint_root_dir: '/PATH/TO/CHECKPOINT/'
5+
algorithm:
66
algorithm_type: grpo
7+
repeat_times: 8
78
model:
89
model_path: /PATH/TO/MODEL/
910
max_prompt_tokens: 256
1011
max_response_tokens: 1024
11-
checkpoint_path: 'checkpoints/qwen2.5-1.5B-gsm8k'
1212
cluster:
1313
node_num: 1
1414
gpu_per_node: 8
1515
buffer:
16+
total_epochs: 20
17+
batch_size: 96
1618
max_retry_times: 3
1719
max_retry_interval: 1
1820
explorer_input:
@@ -25,7 +27,6 @@ buffer:
2527
prompt_key: 'question'
2628
response_key: 'answer'
2729
rollout_args:
28-
repeat_times: 8
2930
temperature: 1.0
3031
logprobs: 0
3132
default_workflow_type: 'math_workflow'
@@ -35,26 +36,19 @@ buffer:
3536
storage_type: queue
3637
path: 'sqlite:///gsm8k.db'
3738
explorer:
38-
engine_type: vllm_async
39-
engine_num: 2
39+
eval_interval: 10
4040
runner_num: 32
41-
tensor_parallel_size: 1
42-
enable_prefix_caching: false
43-
enforce_eager: true
44-
dtype: bfloat16
45-
seed: 42
46-
use_ray: false
47-
backend: 'nccl'
48-
max_pending_requests: 32
49-
max_waiting_steps: 4
41+
rollout_model:
42+
engine_type: vllm_async
43+
engine_num: 2
44+
tensor_parallel_size: 1
45+
enable_prefix_caching: false
46+
enforce_eager: true
47+
dtype: bfloat16
48+
seed: 42
5049
synchronizer:
5150
sync_method: 'checkpoint'
52-
sync_iteration_interval: 10
51+
sync_interval: 10
5352
trainer:
5453
trainer_type: 'verl'
5554
trainer_config_path: examples/async_gsm8k/verl_config.yaml
56-
sft_warmup_steps: 0 # Set to integer to enable sft warmup
57-
monitor:
58-
cache_root_dir: ""
59-
project: "Trinity-RFT-gsm8k"
60-
name: "async-qwen2.5-1.5B-gsm8k"

0 commit comments

Comments
 (0)