Skip to content

Commit 1362f35

Browse files
committed
update config files
1 parent a04ca48 commit 1362f35

File tree

23 files changed

+201
-215
lines changed

23 files changed

+201
-215
lines changed

docs/sphinx_doc/source/tutorial/example_async_mode.md

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -10,10 +10,10 @@ In addition, we need to configure the following parameters in both files.
1010
The model weights of the explorer and trainer are synchronized once every `sync_iteration_interval * batch_size` tasks.
1111

1212
```yaml
13-
model:
14-
checkpoint_path: /PATH/TO/CHECKPOINT
13+
project: tutorial
14+
name: async_mode_example
15+
checkpoint_root_dir: /PATH/TO/CHECKPOINT
1516

16-
# The same data_base path
1717
buffer:
1818
batch_size: <batch_size>
1919
trainer_input:

docs/sphinx_doc/source/tutorial/example_reasoning_basic.md

Lines changed: 10 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -79,19 +79,20 @@ trinity run --config examples/grpo_gsm8k/gsm8k.yaml
7979

8080
## Optional: RFT with SFT Warmup
8181

82-
Before RFT, we may use SFT as a warmup step. We need to set `trainer.sft_warmup_steps > 0` and prepare the SFT data to `buffer.train_dataset.path=$DATASET_PATH/{sft_data}`.
82+
Before RFT, we may use SFT as a warmup step. We need to set `buffer.trainer_input.sft_warmup_steps > 0` and prepare the SFT data to `buffer.trainer_input.sft_warmup_dataset.path=$DATASET_PATH/{sft_data}`.
8383

8484
```yaml
8585
# Properly set the following configs in gsm8k.yaml
8686
buffer:
87-
sft_warmup_dataset:
88-
storage_type: file
89-
path: <$DATASET_PATH/{sft_data}>
90-
format:
91-
prompt_type: <prompt_type> # messages/plaintext/chatpair
92-
prompt_key: <prompt_key>
93-
response_key: <response_key>
94-
sft_warmup_steps: 10
87+
trainer_input:
88+
sft_warmup_dataset:
89+
storage_type: file
90+
path: <$DATASET_PATH/{sft_data}>
91+
format:
92+
prompt_type: <prompt_type> # messages/plaintext/chatpair
93+
prompt_key: <prompt_key>
94+
response_key: <response_key>
95+
sft_warmup_steps: 10
9596
```
9697

9798
The following command runs SFT and RFT in sequence:

docs/sphinx_doc/source/tutorial/trinity_configs.md

Lines changed: 11 additions & 18 deletions
Original file line numberDiff line numberDiff line change
@@ -29,7 +29,6 @@ monitor:
2929
- `monitor.monitor_type`: The type of the monitor. For now, `MonitorType.WANDB` and `MonitorType.TENSORBOARD` are supported.
3030

3131

32-
3332
## Data Processing
3433

3534
<!-- The `data` configuration specifies the data used for training. It includes the total number of epochs, the batch size, the path to the dataset, the default workflow type, the default reward function type, and the format configuration. -->
@@ -65,16 +64,11 @@ The `model` configuration specifies the model used for training. It includes the
6564
model:
6665
model_path: '/PATH/TO/MODEL/CHECKPOINT/'
6766
critic_model_path: ''
68-
max_prompt_tokens: 256
69-
max_response_tokens: 1024
70-
checkpoint_path: 'checkpoints/qwen2.5-1.5B-countdown'
7167
```
7268

7369
- `model.model_path`: The path to the model checkpoint. It must be set manually.
7470
- `model.critic_model_path`: The path to the critic model checkpoint. If not set, the `model.critic_model_path` will be set to `model.model_path`.
75-
- `model.max_prompt_tokens`: The maximum number of tokens in the prompt. Default is `2048`. It should be set manually.
76-
- `model.max_response_tokens`: The maximum number of tokens in the response. Default is `2048`. It should be set manually.
77-
- `model.checkpoint_path`: The path to the checkpoint of the model. It must be set manually.
71+
7872

7973
## Cluster
8074

@@ -143,14 +137,15 @@ The `explorer` configuration specifies the explorer configuration. It includes t
143137

144138
```yaml
145139
explorer:
146-
engine_type: vllm_async
147-
engine_num: 2
148140
runner_num: 32
149-
tensor_parallel_size: 1
150-
enable_prefix_caching: false
151-
enforce_eager: true
152-
dtype: bfloat16
153-
seed: 42
141+
rollout_model:
142+
engine_type: vllm_async
143+
engine_num: 2
144+
tensor_parallel_size: 1
145+
enable_prefix_caching: false
146+
enforce_eager: true
147+
dtype: bfloat16
148+
seed: 42
154149
```
155150

156151
- `explorer.engine_type`: The type of the engine, Support `vllm_async` and `vllm_sync`. Default is `vllm_async`.
@@ -161,6 +156,8 @@ explorer:
161156
- `explorer.enforce_eager`: Whether to enforce eager mode. Default is `True`.
162157
- `explorer.dtype`: The data type used in vLLM. Default is `bfloat16`.
163158
- `explorer.seed`: The seed used in vLLM. Default is `42`.
159+
- `explorer.rollout_model.max_prompt_tokens`: The maximum number of tokens in the prompt. Default is `2048`. It should be set manually.
160+
- `explorer.rollout_model.max_response_tokens`: The maximum number of tokens in the response. Default is `2048`. It should be set manually.
164161

165162
## Synchronizer
166163

@@ -183,15 +180,11 @@ Support `nccl` and `checkpoint`, `nccl` represents that model weights in `explor
183180
trainer:
184181
trainer_type: 'verl'
185182
trainer_config_path: 'examples/ppo_countdown/train_countdown.yaml'
186-
sft_warmup_steps: 0
187-
eval_interval: 1000
188183
save_interval: 100
189184
```
190185

191186
- `trainer.trainer_type`: The backend of the trainer, Only `verl` is supported.
192187
- `trainer.trainer_config_path`: The path to the trainer configuration file. It must be set manually.
193-
- `trainer.sft_warmup_steps`: The number of steps to warm up the model. Default is `0`.
194-
- `trainer.eval_interval`: The interval steps between two evaluations. Default is `1000`.
195188
- `trainer.save_interval`: The interval steps between two checkpoints. Default is `100`.
196189

197190
### veRL Trainer Configuration

examples/async_gsm8k/explorer.yaml

Lines changed: 9 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -2,11 +2,11 @@ project: "Trinity-RFT-gsm8k"
22
name: "async-qwen2.5-1.5B-gsm8k"
33
mode: explore
44
algorithm_type: grpo
5+
checkpoint_root_dir: 'checkpoints/qwen2.5-1.5B-gsm8k'
56
model:
67
model_path: /PATH/TO/MODEL/
78
max_prompt_tokens: 256
89
max_response_tokens: 1024
9-
checkpoint_path: 'checkpoints/qwen2.5-1.5B-gsm8k'
1010
cluster:
1111
node_num: 1
1212
gpu_per_node: 8
@@ -36,20 +36,18 @@ buffer:
3636
path: 'sqlite:///gsm8k.db'
3737
explorer:
3838
eval_interval: 10
39-
engine_type: vllm_async
40-
engine_num: 2
4139
runner_num: 32
42-
tensor_parallel_size: 1
43-
enable_prefix_caching: false
44-
enforce_eager: true
45-
dtype: bfloat16
46-
seed: 42
40+
rollout_model:
41+
engine_type: vllm_async
42+
engine_num: 2
43+
tensor_parallel_size: 1
44+
enable_prefix_caching: false
45+
enforce_eager: true
46+
dtype: bfloat16
47+
seed: 42
4748
synchronizer:
4849
sync_method: 'checkpoint'
4950
sync_iteration_interval: 10
5051
trainer:
5152
trainer_type: 'verl'
5253
trainer_config_path: examples/async_gsm8k/verl_config.yaml
53-
sft_warmup_steps: 0 # Set to integer to enable sft warmup
54-
monitor:
55-
cache_root_dir: ""

examples/async_gsm8k/trainer.yaml

Lines changed: 10 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -2,11 +2,11 @@ project: "Trinity-RFT-gsm8k"
22
name: "async-qwen2.5-1.5B-gsm8k"
33
mode: train
44
algorithm_type: grpo
5+
checkpoint_root_dir: /PATH/TO/CHECKPOINT
56
model:
6-
model_path: /PATH/TO/MODEL/
7+
model_path: /PATH/TO/MODEL
78
max_prompt_tokens: 256
89
max_response_tokens: 1024
9-
checkpoint_path: ""
1010
cluster:
1111
node_num: 1
1212
gpu_per_node: 8
@@ -35,20 +35,18 @@ buffer:
3535
path: 'sqlite:///gsm8k.db'
3636
explorer:
3737
eval_interval: 10
38-
engine_type: vllm_async
39-
engine_num: 2
4038
runner_num: 32
41-
tensor_parallel_size: 1
42-
enable_prefix_caching: false
43-
enforce_eager: true
44-
dtype: bfloat16
45-
seed: 42
39+
rollout_model:
40+
engine_type: vllm_async
41+
engine_num: 2
42+
tensor_parallel_size: 1
43+
enable_prefix_caching: false
44+
enforce_eager: true
45+
dtype: bfloat16
46+
seed: 42
4647
synchronizer:
4748
sync_method: 'checkpoint'
4849
sync_iteration_interval: 10
4950
trainer:
5051
trainer_type: 'verl'
5152
trainer_config_path: examples/async_gsm8k/verl_config.yaml
52-
sft_warmup_steps: 0 # Set to integer to enable sft warmup
53-
monitor:
54-
cache_root_dir: ""

examples/dpo_humanlike/dpo.yaml

Lines changed: 2 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -2,11 +2,11 @@ project: "dpo_example"
22
name: "trinity_dpo"
33
mode: train
44
algorithm_type: dpo
5+
checkpoint_root_dir: /PATH/TO/CHECKPOINT
56
model:
6-
model_path: '/PATH/TO/MODEL/CHECKPOINT/' # NOTE
7+
model_path: '/PATH/TO/MODEL' # NOTE
78
max_prompt_tokens: 1792
89
max_response_tokens: 256
9-
checkpoint_path: 'checkpoints/trinity_dpo'
1010
cluster:
1111
node_num: 1
1212
gpu_per_node: 8
@@ -33,5 +33,3 @@ trainer:
3333
trainer_type: 'verl'
3434
trainer_config_path: 'examples/dpo_humanlike/train_dpo.yaml'
3535
save_interval: 30
36-
monitor:
37-
cache_root_dir: ""

examples/grpo_alfworld/alfworld.yaml

Lines changed: 14 additions & 15 deletions
Original file line numberDiff line numberDiff line change
@@ -1,11 +1,9 @@
11
project: "ALFWORLD"
22
name: "ALFWORLD_RFT"
33
algorithm_type: grpo
4+
checkpoint_root_dir: /PATH/TO/CHECKPOINT/ALFWORLD_RFT/
45
model:
5-
model_path: '/PATH/TO/MODEL/CHECKPOINT/'
6-
max_prompt_tokens: 4096
7-
max_response_tokens: 16384
8-
checkpoint_path: 'checkpoints/ALFWORLD_RFT'
6+
model_path: /PATH/TO/MODEL/
97
cluster:
108
node_num: 1
119
gpu_per_node: 8
@@ -32,16 +30,19 @@ buffer:
3230
storage_type: queue
3331
path: 'sqlite:///alfworld.db'
3432
explorer:
35-
engine_type: vllm_async
36-
engine_num: 2
3733
runner_num: 32
38-
tensor_parallel_size: 2
39-
enable_prefix_caching: false
40-
enforce_eager: true
41-
dtype: bfloat16
42-
seed: 42
43-
gpu_memory_utilization: 0.7
44-
enable_chunked_prefill: true
34+
rollout_model:
35+
engine_type: vllm_async
36+
engine_num: 2
37+
tensor_parallel_size: 2
38+
enable_prefix_caching: false
39+
enforce_eager: true
40+
max_prompt_tokens: 4096
41+
max_response_tokens: 16384
42+
dtype: bfloat16
43+
seed: 42
44+
gpu_memory_utilization: 0.7
45+
enable_chunked_prefill: true
4546
synchronizer:
4647
sync_method: 'nccl'
4748
sync_interval: 8
@@ -50,5 +51,3 @@ trainer:
5051
trainer_type: 'verl'
5152
trainer_config_path: 'examples/grpo_alfworld/train_alfworld.yaml'
5253
save_interval: 10
53-
monitor:
54-
cache_root_dir: ""

examples/grpo_gsm8k/gsm8k.yaml

Lines changed: 11 additions & 16 deletions
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,7 @@
11
project: "Trinity-RFT-gsm8k"
22
name: "qwen2.5-1.5B-gsm8k"
33
algorithm_type: grpo
4+
checkpoint_root_dir: /PATH/TO/CHECKPOINT/
45
data_processor:
56
# basic info
67
source_data_path: 'openai/gsm8k'
@@ -17,9 +18,6 @@ data_processor:
1718

1819
model:
1920
model_path: '/PATH/TO/MODEL/'
20-
max_prompt_tokens: 256
21-
max_response_tokens: 1024
22-
checkpoint_path: ""
2321
cluster:
2422
node_num: 1
2523
gpu_per_node: 8
@@ -61,27 +59,24 @@ buffer:
6159
# name: warmup_data
6260
# storage_type: file
6361
# path: '/PATH/TO/WARMUP_DATA/'
64-
# kwargs:
65-
# prompt_type: plaintext
6662
explorer:
6763
eval_interval: 50
68-
engine_type: vllm_async
69-
engine_num: 2
7064
runner_num: 32
71-
tensor_parallel_size: 1
72-
enable_prefix_caching: false
73-
enforce_eager: true
74-
dtype: bfloat16
75-
seed: 42
65+
rollout_model:
66+
engine_type: vllm_async
67+
engine_num: 2
68+
tensor_parallel_size: 1
69+
enable_prefix_caching: false
70+
enforce_eager: true
71+
dtype: bfloat16
72+
max_prompt_tokens: 256
73+
max_response_tokens: 1024
74+
seed: 42
7675
synchronizer:
7776
sync_method: 'nccl'
7877
sync_interval: 2
7978
sync_timeout: 1200
8079
trainer:
8180
trainer_type: 'verl'
8281
trainer_config_path: 'examples/grpo_gsm8k/train_gsm8k.yaml'
83-
sft_warmup_steps: 0 # Set to integer to enable sft warmup
8482
save_interval: 100
85-
# get_exp_strategy: 'LFU'
86-
monitor:
87-
cache_root_dir: ""

examples/grpo_math/math.yaml

Lines changed: 11 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -1,11 +1,9 @@
11
project: grpo_math
22
name: grpo_math_example
33
algorithm_type: grpo
4+
checkpoint_root_dir: /PATH/TO/CHECKPOINT/
45
model:
56
model_path: /PATH/TO/MODEL/
6-
max_prompt_tokens: 1024
7-
max_response_tokens: 3072
8-
checkpoint_path: /PATH/TO/CHECKPOINT/
97
cluster:
108
node_num: 1
119
gpu_per_node: 8
@@ -34,22 +32,22 @@ buffer:
3432
path: 'sqlite:///math.db'
3533
explorer:
3634
eval_interval: 10
37-
engine_type: vllm_async
38-
engine_num: 2
3935
runner_num: 32
40-
tensor_parallel_size: 1
41-
enable_prefix_caching: false
42-
enforce_eager: true
43-
dtype: bfloat16
44-
seed: 42
36+
rollout_model:
37+
engine_type: vllm_async
38+
engine_num: 2
39+
tensor_parallel_size: 1
40+
enable_prefix_caching: false
41+
enforce_eager: true
42+
dtype: bfloat16
43+
max_prompt_tokens: 1024
44+
max_response_tokens: 3072
45+
seed: 42
4546
synchronizer:
4647
sync_method: 'nccl'
4748
sync_interval: 2
4849
sync_timeout: 1200
4950
trainer:
5051
trainer_type: 'verl'
5152
trainer_config_path: 'examples/grpo_math/train_math.yaml'
52-
sft_warmup_steps: 0 # Set to integer to enable sft warmup
5353
save_interval: 100
54-
monitor:
55-
cache_root_dir: ""

0 commit comments

Comments
 (0)