Skip to content
Merged
Show file tree
Hide file tree
Changes from 5 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
10 changes: 4 additions & 6 deletions docs/sphinx_doc/source/tutorial/example_async_mode.md
Original file line number Diff line number Diff line change
Expand Up @@ -10,14 +10,12 @@ In addition, we need to configure the following parameters in both files.
The model weights of the explorer and trainer are synchronized once every `sync_iteration_interval * batch_size` tasks.

```yaml
global_config:
batch_size: <batch_size>
# The same checkpoint path
model:
checkpoint_path: /PATH/TO/CHECKPOINT
project: tutorial
name: async_mode_example
checkpoint_root_dir: /PATH/TO/CHECKPOINT

# The same data_base path
buffer:
batch_size: <batch_size>
trainer_input:
experience_buffer:
name: gsm8k_buffer
Expand Down
3 changes: 1 addition & 2 deletions docs/sphinx_doc/source/tutorial/example_dpo.md
Original file line number Diff line number Diff line change
Expand Up @@ -45,6 +45,7 @@ We run the experiment in a train mode, as there is no Explorer. To enable this m
```yaml
# In dpo.yaml
mode: train
algorithm_type: dpo
synchronizer:
sync_method: 'checkpoint'
buffer:
Expand All @@ -56,8 +57,6 @@ buffer:
prompt_key: <prompt_key>
chosen_key: <chosen_key>
rejected_key: <rejected_key>
global_config:
algorithm_type: dpo

# In train_dpo.yaml
actor_rollout_ref:
Expand Down
27 changes: 15 additions & 12 deletions docs/sphinx_doc/source/tutorial/example_reasoning_basic.md
Original file line number Diff line number Diff line change
Expand Up @@ -53,8 +53,11 @@ We use the configurations in [`gsm8k.yaml`](https://github.com/modelscope/Trinit

```yaml
# In gsm8k.yaml
explorer:
repeat_times: {number of rollouts for each task}
buffer:
explorer_input:
taskset:
rollout_args:
n: {number of rollouts for each task}

# In train_gsm8k.yaml
actor_rollout_ref:
Expand All @@ -76,20 +79,20 @@ trinity run --config examples/grpo_gsm8k/gsm8k.yaml

## Optional: RFT with SFT Warmup

Before RFT, we may use SFT as a warmup step. We need to set `trainer.sft_warmup_steps > 0` and prepare the SFT data to `buffer.train_dataset.path=$DATASET_PATH/{sft_data}`.
Before RFT, we may use SFT as a warmup step. We need to set `buffer.trainer_input.sft_warmup_steps > 0` and prepare the SFT data to `buffer.trainer_input.sft_warmup_dataset.path=$DATASET_PATH/{sft_data}`.

```yaml
# Properly set the following configs in gsm8k.yaml
buffer:
sft_warmup_dataset:
storage_type: file
path: <$DATASET_PATH/{sft_data}>
format:
prompt_type: <prompt_type> # messages/plaintext/chatpair
prompt_key: <prompt_key>
response_key: <response_key>
trainer:
sft_warmup_steps: 10
trainer_input:
sft_warmup_dataset:
storage_type: file
path: <$DATASET_PATH/{sft_data}>
format:
prompt_type: <prompt_type> # messages/plaintext/chatpair
prompt_key: <prompt_key>
response_key: <response_key>
sft_warmup_steps: 10
```

The following command runs SFT and RFT in sequence:
Expand Down
67 changes: 24 additions & 43 deletions docs/sphinx_doc/source/tutorial/trinity_configs.md
Original file line number Diff line number Diff line change
Expand Up @@ -6,32 +6,27 @@ The following is the main config file for Trinity-RFT. Take `countdown.yaml` as

```yaml
mode: both
global_config:
algorithm_type: ppo
total_epochs: 1
batch_size: 96
eval_interval: 1000
eval_on_latest_ckp: true
project: Trinity-RFT
name: example
algorithm_type: ppo
checkpoint_root_dir: /PATH/TO/CHECKPOINT_DIR
```

- `mode`: The mode of the experiment, chosen from `both`, `train`, `explore` or `bench`. `both` means both trainer and explorer are launched; `train` means only trainer is launched; `explore` means only explorer is launched; `bench` conducts benchmark evaluation. Default is `both`.
- `global_config.algorithm_type`: The type of the algorithm, Support `ppo`, `grpo`, `opmd` and `dpo`.
- `global_config.total_epochs`: The total number of epochs. It should be checked manually.
- `global_config.batch_size`: The batch size used for training. It should be checked manually.
- `global_config.eval_interval`: The interval steps between two evaluations. Default is `1000`.
- `global_config.eval_on_latest_ckp`: Whether to evaluate on only the latest checkpoint or all the checkpoints in the path. Only valid in `bench` mode. Default is `true`.
- `project`: The name of the project.
- `name`: The name of the experiment.
- `algorithm_type`: The type of the algorithm, Support `ppo`, `grpo`, `opmd` and `dpo`.
- `checkpoint_root_dir`: The root directory of the checkpoint.


## Monitor

```yaml
monitor:
project: "Trinity-RFT-countdown"
name: "qwen2.5-1.5B-countdown"
monitor_type: MonitorType.WANDB
```

- `monitor.project`: The project name. It must be set manually.
- `monitor.name`: The name of the experiment. It must be set manually.
- `monitor.monitor_type`: The type of the monitor. For now, `MonitorType.WANDB` and `MonitorType.TENSORBOARD` are supported.


## Data Processing
Expand Down Expand Up @@ -69,16 +64,11 @@ The `model` configuration specifies the model used for training. It includes the
model:
model_path: '/PATH/TO/MODEL/CHECKPOINT/'
critic_model_path: ''
max_prompt_tokens: 256
max_response_tokens: 1024
checkpoint_path: 'checkpoints/qwen2.5-1.5B-countdown'
```

- `model.model_path`: The path to the model checkpoint. It must be set manually.
- `model.critic_model_path`: The path to the critic model checkpoint. If not set, the `model.critic_model_path` will be set to `model.model_path`.
- `model.max_prompt_tokens`: The maximum number of tokens in the prompt. Default is `2048`. It should be set manually.
- `model.max_response_tokens`: The maximum number of tokens in the response. Default is `2048`. It should be set manually.
- `model.checkpoint_path`: The path to the checkpoint of the model. It must be set manually.


## Cluster

Expand Down Expand Up @@ -108,7 +98,7 @@ buffer:
prompt_key: 'question'
response_key: 'answer'
rollout_args:
repeat_times: 1
n: 1
temperature: 1.0
logprobs: 0
eval_tasksets: []
Expand All @@ -129,7 +119,7 @@ buffer:
- `buffer.explorer_input.taskset.path`: The path to the taskset.
- `buffer.explorer_input.taskset.split`: The split name of the taskset used for training. Default is `train`.
- `buffer.explorer_input.taskset.format`: The format of the taskset. It includes `prompt_key`, `response_key`, `workflow_key` and `reward_fn_key`.
- `buffer.explorer_input.taskset.rollout_args.repeat_times`: The number of times to repeat each task, used for GRPO-like algorithms. Default is `1`.
- `buffer.explorer_input.taskset.rollout_args.n`: The number of times to repeat each task, used for GRPO-like algorithms. Default is `1`.
- `buffer.explorer_input.taskset.rollout_args.temperature`: The temperature used in vLLM. Default is `1.0`.
- `buffer.explorer_input.taskset.rollout_args.logprobs`: The logprobs used in vLLM. Default is `0`.
- `buffer.explorer_input.eval_tasksets`: The configuration of the eval tasksets. It is a list of tasksets which will be used for evaluation. And it is empty by default.
Expand All @@ -143,22 +133,19 @@ buffer:

## Explorer

The `explorer` configuration specifies the explorer configuration. It includes the type of the engine, the number of engines, the number of workflow runners, the tensor parallel size, whether to enable prefix caching, whether to enforce eager mode, the data type, the `temperature`, the `top-p`, the `top-k`, the `seed`, the `logprobs`, the number of times to repeat each task, whether to use Ray, the backend, the maximum number of pending requests, and the maximum number of waitingsteps.
The `explorer` configuration specifies the explorer configuration. It includes the type of the engine, the number of engines, the number of workflow runners, the tensor parallel size, whether to enable prefix caching, whether to enforce eager mode, the data type, the `temperature`, the `top-p`, the `top-k`, the `seed`, the `logprobs`, the number of times to repeat each task, the maximum number of pending requests, and the maximum number of waitingsteps.

```yaml
explorer:
engine_type: vllm_async
engine_num: 2
runner_num: 32
tensor_parallel_size: 1
enable_prefix_caching: false
enforce_eager: true
dtype: bfloat16
seed: 42
use_ray: false
backend: 'nccl'
max_pending_requests: 32
max_waiting_steps: 4
rollout_model:
engine_type: vllm_async
engine_num: 2
tensor_parallel_size: 1
enable_prefix_caching: false
enforce_eager: true
dtype: bfloat16
seed: 42
```

- `explorer.engine_type`: The type of the engine, Support `vllm_async` and `vllm_sync`. Default is `vllm_async`.
Expand All @@ -169,10 +156,8 @@ explorer:
- `explorer.enforce_eager`: Whether to enforce eager mode. Default is `True`.
- `explorer.dtype`: The data type used in vLLM. Default is `bfloat16`.
- `explorer.seed`: The seed used in vLLM. Default is `42`.
- `explorer.use_ray`: Whether to use Ray. Default is `False`.
- `explorer.backend`: The backend used in vLLM. Default is `nccl`.
- `explorer.max_pending_requests`: The maximum number of pending requests. Default is `32`.
- `explorer.max_waiting_steps`: The maximum number of waiting steps. Default is `4`.
- `explorer.rollout_model.max_prompt_tokens`: The maximum number of tokens in the prompt. Default is `2048`. It should be set manually.
- `explorer.rollout_model.max_response_tokens`: The maximum number of tokens in the response. Default is `2048`. It should be set manually.

## Synchronizer

Expand All @@ -195,15 +180,11 @@ Support `nccl` and `checkpoint`, `nccl` represents that model weights in `explor
trainer:
trainer_type: 'verl'
trainer_config_path: 'examples/ppo_countdown/train_countdown.yaml'
sft_warmup_steps: 0
eval_interval: 1000
save_interval: 100
```

- `trainer.trainer_type`: The backend of the trainer, Only `verl` is supported.
- `trainer.trainer_config_path`: The path to the trainer configuration file. It must be set manually.
- `trainer.sft_warmup_steps`: The number of steps to warm up the model. Default is `0`.
- `trainer.eval_interval`: The interval steps between two evaluations. Default is `1000`.
- `trainer.save_interval`: The interval steps between two checkpoints. Default is `100`.

### veRL Trainer Configuration
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -116,7 +116,7 @@ class ExampleWorkflow(Workflow):
"content": f"Question:\n{self.question}",
}
],
n=self.task.rollout_args.repeat_times,
n=self.task.rollout_args.n,
temperature=self.task.rollout_args.temperature,
)
reward: float = self.calculate_reward(response.response_text, self.answer)
Expand Down
39 changes: 16 additions & 23 deletions examples/async_gsm8k/explorer.yaml
Original file line number Diff line number Diff line change
@@ -1,18 +1,18 @@
project: "Trinity-RFT-gsm8k"
name: "async-qwen2.5-1.5B-gsm8k"
mode: explore
global_config:
total_epochs: 20
batch_size: 96
eval_interval: 10
algorithm_type: grpo
algorithm_type: grpo
checkpoint_root_dir: 'checkpoints/qwen2.5-1.5B-gsm8k'
model:
model_path: /PATH/TO/MODEL/
max_prompt_tokens: 256
max_response_tokens: 1024
checkpoint_path: 'checkpoints/qwen2.5-1.5B-gsm8k'
cluster:
node_num: 1
gpu_per_node: 8
buffer:
total_epochs: 20
batch_size: 96
max_retry_times: 3
max_retry_interval: 1
explorer_input:
Expand All @@ -25,7 +25,7 @@ buffer:
prompt_key: 'question'
response_key: 'answer'
rollout_args:
repeat_times: 8
n: 8
temperature: 1.0
logprobs: 0
default_workflow_type: 'math_workflow'
Expand All @@ -35,26 +35,19 @@ buffer:
storage_type: queue
path: 'sqlite:///gsm8k.db'
explorer:
engine_type: vllm_async
engine_num: 2
eval_interval: 10
runner_num: 32
tensor_parallel_size: 1
enable_prefix_caching: false
enforce_eager: true
dtype: bfloat16
seed: 42
use_ray: false
backend: 'nccl'
max_pending_requests: 32
max_waiting_steps: 4
rollout_model:
engine_type: vllm_async
engine_num: 2
tensor_parallel_size: 1
enable_prefix_caching: false
enforce_eager: true
dtype: bfloat16
seed: 42
synchronizer:
sync_method: 'checkpoint'
sync_iteration_interval: 10
trainer:
trainer_type: 'verl'
trainer_config_path: examples/async_gsm8k/verl_config.yaml
sft_warmup_steps: 0 # Set to integer to enable sft warmup
monitor:
cache_root_dir: ""
project: "Trinity-RFT-gsm8k"
name: "async-qwen2.5-1.5B-gsm8k"
41 changes: 17 additions & 24 deletions examples/async_gsm8k/trainer.yaml
Original file line number Diff line number Diff line change
@@ -1,18 +1,18 @@
project: "Trinity-RFT-gsm8k"
name: "async-qwen2.5-1.5B-gsm8k"
mode: train
global_config:
total_epochs: 20
batch_size: 96
eval_interval: 10
algorithm_type: grpo
algorithm_type: grpo
checkpoint_root_dir: /PATH/TO/CHECKPOINT
model:
model_path: /PATH/TO/MODEL/
model_path: /PATH/TO/MODEL
max_prompt_tokens: 256
max_response_tokens: 1024
checkpoint_path: ""
cluster:
node_num: 1
gpu_per_node: 8
buffer:
total_epochs: 20
batch_size: 96
max_retry_times: 3
max_retry_interval: 1
explorer_input:
Expand All @@ -24,7 +24,7 @@ buffer:
prompt_key: 'question'
response_key: 'answer'
rollout_args:
repeat_times: 8
n: 8
temperature: 1.0
logprobs: 0
default_workflow_type: 'math_workflow'
Expand All @@ -34,26 +34,19 @@ buffer:
storage_type: queue
path: 'sqlite:///gsm8k.db'
explorer:
engine_type: vllm_async
engine_num: 2
eval_interval: 10
runner_num: 32
tensor_parallel_size: 1
enable_prefix_caching: false
enforce_eager: true
dtype: bfloat16
seed: 42
use_ray: false
backend: 'nccl'
max_pending_requests: 32
max_waiting_steps: 4
rollout_model:
engine_type: vllm_async
engine_num: 2
tensor_parallel_size: 1
enable_prefix_caching: false
enforce_eager: true
dtype: bfloat16
seed: 42
synchronizer:
sync_method: 'checkpoint'
sync_iteration_interval: 10
trainer:
trainer_type: 'verl'
trainer_config_path: examples/async_gsm8k/verl_config.yaml
sft_warmup_steps: 0 # Set to integer to enable sft warmup
monitor:
cache_root_dir: ""
project: "Trinity-RFT-gsm8k"
name: "async-qwen2.5-1.5B-gsm8k"
Loading