You signed in with another tab or window. Reload to refresh your session.You signed out in another tab or window. Reload to refresh your session.You switched accounts on another tab or window. Reload to refresh your session.Dismiss alert
@@ -76,20 +73,20 @@ trinity run --config examples/grpo_gsm8k/gsm8k.yaml
76
73
77
74
## Optional: RFT with SFT Warmup
78
75
79
-
Before RFT, we may use SFT as a warmup step. We need to set `trainer.sft_warmup_steps > 0` and prepare the SFT data to `buffer.train_dataset.path=$DATASET_PATH/{sft_data}`.
76
+
Before RFT, we may use SFT as a warmup step. We need to set `buffer.trainer_input.sft_warmup_steps > 0` and prepare the SFT data to `buffer.trainer_input.sft_warmup_dataset.path=$DATASET_PATH/{sft_data}`.
80
77
81
78
```yaml
82
79
# Properly set the following configs in gsm8k.yaml
Copy file name to clipboardExpand all lines: docs/sphinx_doc/source/tutorial/trinity_configs.md
+32-43Lines changed: 32 additions & 43 deletions
Display the source diff
Display the rich diff
Original file line number
Diff line number
Diff line change
@@ -5,33 +5,36 @@ The following is the main config file for Trinity-RFT. Take `countdown.yaml` as
5
5
## Global Config
6
6
7
7
```yaml
8
+
project: Trinity-RFT
9
+
name: example
8
10
mode: both
9
-
global_config:
10
-
algorithm_type: ppo
11
-
total_epochs: 1
12
-
batch_size: 96
13
-
eval_interval: 1000
14
-
eval_on_latest_ckp: true
11
+
checkpoint_root_dir: /PATH/TO/CHECKPOINT
15
12
```
16
13
14
+
- `project`: The name of the project.
15
+
- `name`: The name of the experiment.
17
16
- `mode`: The mode of the experiment, chosen from `both`, `train`, `explore` or `bench`. `both` means both trainer and explorer are launched; `train` means only trainer is launched; `explore` means only explorer is launched; `bench` conducts benchmark evaluation. Default is `both`.
18
-
- `global_config.algorithm_type`: The type of the algorithm, Support `ppo`, `grpo`, `opmd` and `dpo`.
19
-
- `global_config.total_epochs`: The total number of epochs. It should be checked manually.
20
-
- `global_config.batch_size`: The batch size used for training. It should be checked manually.
21
-
- `global_config.eval_interval`: The interval steps between two evaluations. Default is `1000`.
22
-
- `global_config.eval_on_latest_ckp`: Whether to evaluate on only the latest checkpoint or all the checkpoints in the path. Only valid in `bench` mode. Default is `true`.
17
+
- `checkpoint_root_dir`: The root directory to save the checkpoints. Sepcifically, the generated checkpoints will be saved in `<checkpoint_root_dir>/<project>/<name>/.
23
18
19
+
## Algorithm
20
+
21
+
```yaml
22
+
algorithm:
23
+
algorithm_type: grpo
24
+
repeat_times: 1
25
+
```
26
+
27
+
- `algorithm.algorithm_type`: The type of the algorithm. Support `ppo`, `grpo`, `opmd` and `dpo`.
28
+
- `algorithm.repeat_times`: The number of times to repeat each task. Used for GRPO-like algorithm. Default is `1`.
24
29
25
30
## Monitor
26
31
27
32
```yaml
28
33
monitor:
29
-
project: "Trinity-RFT-countdown"
30
-
name: "qwen2.5-1.5B-countdown"
34
+
monitor_type: MonitorType.WANDB
31
35
```
32
36
33
-
- `monitor.project`: The project name. It must be set manually.
34
-
- `monitor.name`: The name of the experiment. It must be set manually.
37
+
- `monitor.monitor_type`: The type of the monitor. For now, `MonitorType.WANDB` and `MonitorType.TENSORBOARD` are supported.
35
38
36
39
37
40
## Data Processing
@@ -69,16 +72,11 @@ The `model` configuration specifies the model used for training. It includes the
- `model.model_path`: The path to the model checkpoint. It must be set manually.
78
78
- `model.critic_model_path`: The path to the critic model checkpoint. If not set, the `model.critic_model_path` will be set to `model.model_path`.
79
-
- `model.max_prompt_tokens`: The maximum number of tokens in the prompt. Default is `2048`. It should be set manually.
80
-
- `model.max_response_tokens`: The maximum number of tokens in the response. Default is `2048`. It should be set manually.
81
-
- `model.checkpoint_path`: The path to the checkpoint of the model. It must be set manually.
79
+
82
80
83
81
## Cluster
84
82
@@ -108,7 +106,7 @@ buffer:
108
106
prompt_key: 'question'
109
107
response_key: 'answer'
110
108
rollout_args:
111
-
repeat_times: 1
109
+
n: 1
112
110
temperature: 1.0
113
111
logprobs: 0
114
112
eval_tasksets: []
@@ -129,7 +127,7 @@ buffer:
129
127
- `buffer.explorer_input.taskset.path`: The path to the taskset.
130
128
- `buffer.explorer_input.taskset.split`: The split name of the taskset used for training. Default is `train`.
131
129
- `buffer.explorer_input.taskset.format`: The format of the taskset. It includes `prompt_key`, `response_key`, `workflow_key` and `reward_fn_key`.
132
-
- `buffer.explorer_input.taskset.rollout_args.repeat_times`: The number of times to repeat each task, used for GRPO-like algorithms. Default is `1`.
130
+
- `buffer.explorer_input.taskset.rollout_args.n`: The number of times to repeat each task. This field is automatically set to `algorithm.repeat_times`.
133
131
- `buffer.explorer_input.taskset.rollout_args.temperature`: The temperature used in vLLM. Default is `1.0`.
134
132
- `buffer.explorer_input.taskset.rollout_args.logprobs`: The logprobs used in vLLM. Default is `0`.
135
133
- `buffer.explorer_input.eval_tasksets`: The configuration of the eval tasksets. It is a list of tasksets which will be used for evaluation. And it is empty by default.
@@ -143,22 +141,19 @@ buffer:
143
141
144
142
## Explorer
145
143
146
-
The `explorer` configuration specifies the explorer configuration. It includes the type of the engine, the number of engines, the number of workflow runners, the tensor parallel size, whether to enable prefix caching, whether to enforce eager mode, the data type, the `temperature`, the `top-p`, the `top-k`, the `seed`, the `logprobs`, the number of times to repeat each task, whether to use Ray, the backend, the maximum number of pending requests, and the maximum number of waitingsteps.
144
+
The `explorer` configuration specifies the explorer configuration. It includes the type of the engine, the number of engines, the number of workflow runners, the tensor parallel size, whether to enable prefix caching, whether to enforce eager mode, the data type, the `temperature`, the `top-p`, the `top-k`, the `seed`, the `logprobs`, the number of times to repeat each task, the maximum number of pending requests, and the maximum number of waitingsteps.
147
145
148
146
```yaml
149
147
explorer:
150
-
engine_type: vllm_async
151
-
engine_num: 2
152
148
runner_num: 32
153
-
tensor_parallel_size: 1
154
-
enable_prefix_caching: false
155
-
enforce_eager: true
156
-
dtype: bfloat16
157
-
seed: 42
158
-
use_ray: false
159
-
backend: 'nccl'
160
-
max_pending_requests: 32
161
-
max_waiting_steps: 4
149
+
rollout_model:
150
+
engine_type: vllm_async
151
+
engine_num: 2
152
+
tensor_parallel_size: 1
153
+
enable_prefix_caching: false
154
+
enforce_eager: true
155
+
dtype: bfloat16
156
+
seed: 42
162
157
```
163
158
164
159
- `explorer.engine_type`: The type of the engine, Support `vllm_async` and `vllm_sync`. Default is `vllm_async`.
@@ -169,10 +164,8 @@ explorer:
169
164
- `explorer.enforce_eager`: Whether to enforce eager mode. Default is `True`.
170
165
- `explorer.dtype`: The data type used in vLLM. Default is `bfloat16`.
171
166
- `explorer.seed`: The seed used in vLLM. Default is `42`.
172
-
- `explorer.use_ray`: Whether to use Ray. Default is `False`.
173
-
- `explorer.backend`: The backend used in vLLM. Default is `nccl`.
174
-
- `explorer.max_pending_requests`: The maximum number of pending requests. Default is `32`.
175
-
- `explorer.max_waiting_steps`: The maximum number of waiting steps. Default is `4`.
167
+
- `explorer.rollout_model.max_prompt_tokens`: The maximum number of tokens in the prompt. Default is `2048`. It should be set manually.
168
+
- `explorer.rollout_model.max_response_tokens`: The maximum number of tokens in the response. Default is `2048`. It should be set manually.
176
169
177
170
## Synchronizer
178
171
@@ -195,15 +188,11 @@ Support `nccl` and `checkpoint`, `nccl` represents that model weights in `explor
0 commit comments