Skip to content

Commit d397909

Browse files
authored
Update grpo and dpo examples (#48)
1 parent c389bce commit d397909

File tree

10 files changed

+239
-75
lines changed

10 files changed

+239
-75
lines changed

docs/sphinx_doc/source/tutorial/example_async_mode.md

Lines changed: 83 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -1,32 +1,104 @@
1-
# A quick example for asynchronous mode
1+
# Asynchronous RFT
22

3-
This example shows how to run RFT in asynchronous mode with the GRPO algorithm, Qwen-2.5-1.5B-Instruct model and GSM8K dataset.
3+
This example shows how to run RFT in a fully asynchronous mode with the GRPO algorithm, Qwen-2.5-1.5B-Instruct model and GSM8K dataset.
44

55
Trinity-RFT supports an asynchronous mode by running the trainer and explorer in separate processes.
66

7-
For this purpose, we prepare two main config files: `trainer.yaml` and `explorer.yaml`.
8-
The main difference between them is that in `trainer.yaml` we set `mode=train`, while in `explorer.yaml` we set `mode=explore`.
9-
In addition, we need to configure the following parameters in both files.
7+
For this purpose, we prepare two main config files: [`explorer.yaml`](https://github.com/modelscope/Trinity-RFT/blob/main/examples/async_gsm8k/explorer.yaml) and [`trainer.yaml`](https://github.com/modelscope/Trinity-RFT/blob/main/examples/async_gsm8k/trainer.yaml).
8+
The main difference between them is that in `explorer.yaml` we set `mode` as `explore`, while in `trainer.yaml` we set `mode` as `train`.
109
The model weights of the explorer and trainer are synchronized once every `sync_interval * batch_size` tasks.
1110

12-
```yaml
13-
project: tutorial
14-
name: async_mode_example
15-
checkpoint_root_dir: /PATH/TO/CHECKPOINT
11+
Suppose we have a node of 8 GPUs; we use 4 GPUs for the trainer and 4 GPUs for the explorer.
12+
Some important setups of `explorer.yaml` are listed in the following:
1613

14+
```yaml
15+
project: <project_name>
16+
name: <experiment_name>
17+
mode: explore
18+
checkpoint_root_dir: /PATH/TO/CHECKPOINT/
19+
algorithm:
20+
algorithm_type: grpo
21+
repeat_times: 8
22+
model:
23+
model_path: /PATH/TO/MODEL/
24+
cluster:
25+
node_num: 1
26+
gpu_per_node: 4
1727
buffer:
18-
batch_size: <batch_size>
28+
total_epochs: 1
29+
batch_size: 96
30+
explorer_input:
31+
taskset:
32+
name: gsm8k
33+
storage_type: file
34+
path: /PATH/TO/DATASET/
35+
split: train
36+
format:
37+
prompt_key: 'question'
38+
response_key: 'answer'
39+
rollout_args:
40+
temperature: 1.0
41+
default_workflow_type: 'math_workflow'
1942
trainer_input:
2043
experience_buffer:
2144
name: gsm8k_buffer
2245
storage_type: queue
2346
path: 'sqlite:///gsm8k.db'
47+
explorer:
48+
eval_interval: 10
49+
runner_num: 32
50+
rollout_model:
51+
engine_type: vllm_async
52+
engine_num: 4
53+
synchronizer:
54+
sync_method: 'checkpoint'
55+
sync_interval: 10
56+
trainer:
57+
trainer_config_path: examples/async_gsm8k/verl_config.yaml
58+
```
59+
60+
Some important setups of `trainer.yaml` are listed in the following:
2461

62+
```yaml
63+
project: <project_name>
64+
name: <experiment_name>
65+
mode: train
66+
checkpoint_root_dir: /PATH/TO/CHECKPOINT/
67+
algorithm:
68+
algorithm_type: grpo
69+
repeat_times: 8
70+
model:
71+
model_path: /PATH/TO/MODEL/
72+
cluster:
73+
node_num: 1
74+
gpu_per_node: 4
75+
buffer:
76+
total_epochs: 1
77+
batch_size: 96
78+
explorer_input:
79+
taskset:
80+
name: gsm8k
81+
storage_type: file
82+
path: /PATH/TO/DATASET/
83+
format:
84+
prompt_key: 'question'
85+
response_key: 'answer'
86+
rollout_args:
87+
temperature: 1.0
88+
default_workflow_type: 'math_workflow'
89+
trainer_input:
90+
experience_buffer:
91+
name: gsm8k_buffer
92+
storage_type: queue
93+
path: 'sqlite:///gsm8k.db'
2594
synchronizer:
2695
sync_method: 'checkpoint'
27-
sync_interval: <sync_interval>
96+
sync_interval: 10
97+
trainer:
98+
trainer_config_path: examples/async_gsm8k/verl_config.yaml
2899
```
29100

101+
30102
You may run this example with the following command:
31103

32104
```bash

docs/sphinx_doc/source/tutorial/example_dpo.md

Lines changed: 24 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,4 @@
1-
# Example: Run DPO on Human-Like-DPO-Dataset
1+
# Offline DPO
22

33
This example describes DPO based on the Qwen-2.5-1.5B-Instruct model and [Human-like-DPO-dataset](https://huggingface.co/datasets/HumanLLMs/Human-Like-DPO-Dataset).
44

@@ -40,25 +40,36 @@ Note that the dataset has the keys `prompt`, `chosen` and `rejected`. If not, pa
4040

4141
We use the configurations in [`dpo.yaml`](https://github.com/modelscope/Trinity-RFT/tree/main/examples/dpo_humanlike/dpo.yaml) and [`train_dpo.yaml`](https://github.com/modelscope/Trinity-RFT/tree/main/examples/dpo_humanlike/train_dpo.yaml) for this experiment. Some important setups are listed in the following:
4242

43-
We run the experiment in a train mode, as there is no Explorer. To enable this mode, we config `mode` to `train` and set `sync_method` to `checkpoint`.
43+
We run the experiment in a train mode, as there is no Explorer. To enable this mode, we config `mode` to `train` and pass the data path to the trainer.
4444

4545
```yaml
46-
# In dpo.yaml
46+
project: <project_name>
47+
name: <experiment_name>
4748
mode: train
4849
algorithm:
4950
algorithm_type: dpo
50-
synchronizer:
51-
sync_method: 'checkpoint'
51+
checkpoint_root_dir: /PATH/TO/CHECKPOINT/
52+
model:
53+
model_path: /PATH/TO/MODEL/
54+
cluster:
55+
node_num: 1
56+
gpu_per_node: 8
5257
buffer:
53-
train_dataset:
54-
storage_type: file
55-
path: <$DATASET_PATH/human_like_dpo_dataset>
56-
format:
57-
prompt_type: <prompt_type> # messages/plaintext
58-
prompt_key: <prompt_key>
59-
chosen_key: <chosen_key>
60-
rejected_key: <rejected_key>
58+
total_epochs: 2
59+
batch_size: 64
60+
trainer_input:
61+
experience_buffer:
62+
name: dpo_buffer
63+
storage_type: file
64+
path: /PATH/TO/DATASET/
65+
format:
66+
prompt_type: plaintext # plaintext/messages/chatpair
67+
prompt_key: prompt
68+
chosen_key: chosen
69+
rejected_key: rejected
6170
trainer:
71+
trainer_config_path: 'examples/dpo_humanlike/train_dpo.yaml'
72+
save_interval: 30
6273
actor_use_kl_loss: True
6374
actor_kl_loss_coef: 0.1 # value of beta in DPO
6475
```

docs/sphinx_doc/source/tutorial/example_multi_turn.md

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,4 @@
1-
# Example: Multi-Turn RFT
1+
# Multi-Turn RFT
22

33
In Trinity-RFT, we support Agentic RL with multiple rounds of interaction with environments.
44

docs/sphinx_doc/source/tutorial/example_reasoning_advanced.md

Lines changed: 2 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,4 @@
1-
# Example: off-policy RFT mode
1+
# Off-Policy RFT
22

33

44
Let's continue with the [previous GSM8k example](./example_reasoning_basic.md) and show some advanced features provided by Trinity-RFT, namely, off-policy or asynchronous RFT mode.
@@ -12,8 +12,7 @@ Let's continue with the [previous GSM8k example](./example_reasoning_basic.md) a
1212

1313
As an experimental feature of Trinity-RFT, we develop an embarrasingly simple off-policy RL algorithm, termed as OPMD (Online Policy Mirror Descent, inspired by [Kimi k1.5](https://arxiv.org/abs/2501.12599)).
1414
The algorithm design and analysis can be found in this [technical report](../../assets/opmd.pdf).
15-
16-
15+
The config files are [`opmd_gsm8k.yaml`](https://github.com/modelscope/Trinity-RFT/blob/main/examples/opmd_gsm8k/opmd_gsm8k.yaml) and [`train_opmd_gsm8k.yaml`](https://github.com/modelscope/Trinity-RFT/blob/main/examples/opmd_gsm8k/train_opmd_gsm8k.yaml).
1716

1817
To try out the OPMD algorithm:
1918
```shell

docs/sphinx_doc/source/tutorial/example_reasoning_basic.md

Lines changed: 112 additions & 19 deletions
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,59 @@
1-
# A quick example with GSM8k
1+
# Quick Start
2+
3+
This tutorial shows a quick start guide for running RFT with Trinity-RFT.
4+
5+
## Step 0: Environment Preparation
6+
7+
Minimal environment requirements:
8+
9+
- GPUs: At least 2 GPUs
10+
- CUDA: Version >= 12.4
11+
- Python: Version >= 3.10
12+
13+
```shell
14+
# Pull the source code from GitHub
15+
git clone https://github.com/modelscope/Trinity-RFT
16+
cd Trinity-RFT
17+
18+
# Create a new environment using Conda or venv
19+
# Option 1: Conda
20+
conda create -n trinity python=3.10
21+
conda activate trinity
22+
23+
# Option 2: venv
24+
python3.10 -m venv .venv
25+
source .venv/bin/activate
26+
27+
# Install the package in editable mode
28+
# for bash
29+
pip install -e .[dev]
30+
# for zsh
31+
pip install -e .\[dev\]
32+
33+
# Install flash-attn after all dependencies are installed
34+
# Note: flash-attn will take a long time to compile, please be patient.
35+
pip install flash-attn -v
36+
# Try the following command if you encounter errors during installation
37+
# pip install flash-attn -v --no-build-isolation
38+
```
39+
40+
Installation from docker:
41+
42+
We provided a dockerfile for Trinity-RFT.
43+
44+
```shell
45+
git clone https://github.com/modelscope/Trinity-RFT
46+
cd Trinity-RFT
47+
48+
# build the docker image
49+
# Note: you can edit the dockerfile to customize the environment
50+
# e.g., use pip mirrors or set api key
51+
docker build -f scripts/docker/Dockerfile -t trinity-rft:latest .
52+
53+
# run the docker image
54+
docker run -it --gpus all --shm-size="64g" --rm -v $PWD:/workspace -v <root_path_of_data_and_checkpoints>:/data trinity-rft:latest
55+
```
256

3-
This example shows how to run RFT with the Qwen-2.5-1.5B-Instruct model and GSM8K dataset.
457

558
## Step 1: Model and Data Preparation
659

@@ -37,31 +90,71 @@ More details on dataset downloading are referred to [ModelScope](https://modelsc
3790

3891
### Synchronous Mode of Trinity-RFT
3992

40-
We run the experiment in a synchronous mode where the Explorer and Trainer operate in turn. To enable this mode, we config `mode` to `both` (default) and set `sync_interval` properly. A smaller value of `sync_interval` makes the training closer to an on-policy setup.
93+
We run the experiment in a synchronous mode where the Explorer and Trainer operate in turn. To enable this mode, we config `mode` to `both` (default) and set `sync_interval` properly. A smaller value of `sync_interval` makes the training closer to an on-policy setup. For example, we set `sync_interval` to 1 to simulate an on-policy setup.
4194

42-
```yaml
43-
mode: both
44-
synchronizer:
45-
sync_method: 'nccl'
46-
sync_interval: 2
47-
```
95+
### Use GRPO Algorithm
4896

49-
### Use GRPO or PPO Algorithm
50-
51-
We use the configurations in [`gsm8k.yaml`](https://github.com/modelscope/Trinity-RFT/tree/main/examples/grpo_gsm8k/gsm8k.yaml) and [`train_gsm8k.yaml`](https://github.com/modelscope/Trinity-RFT/tree/main/examples/grpo_gsm8k/train_gsm8k.yaml) for this experiment. Some important setups are listed in the following:
97+
We use the configurations in [`gsm8k.yaml`](https://github.com/modelscope/Trinity-RFT/tree/main/examples/grpo_gsm8k/gsm8k.yaml) and [`train_gsm8k.yaml`](https://github.com/modelscope/Trinity-RFT/tree/main/examples/grpo_gsm8k/train_gsm8k.yaml) for this experiment. Some important setups of `gsm8k.yaml` are listed in the following:
5298

5399

54100
```yaml
55-
# In gsm8k.yaml
101+
project: <project_name>
102+
name: <experiment_name>
103+
checkpoint_root_dir: /PATH/TO/CHECKPOINT/
56104
algorithm:
57-
algorithm_type: grpo / ppo
58-
repeat_times: {number of rollouts for each task}
59-
105+
algorithm_type: grpo
106+
repeat_times: 8
107+
model:
108+
model_path: /PATH/TO/MODEL/
109+
cluster:
110+
node_num: 1
111+
gpu_per_node: 2
112+
buffer:
113+
total_epochs: 1
114+
batch_size: 128
115+
explorer_input:
116+
taskset:
117+
name: gsm8k
118+
storage_type: file
119+
path: <$DATASET_PATH/gsm8k>
120+
subset_name: 'main'
121+
split: 'train'
122+
format:
123+
prompt_key: 'question'
124+
response_key: 'answer'
125+
rollout_args:
126+
temperature: 1.0
127+
eval_tasksets:
128+
- name: gsm8k-eval
129+
storage_type: file
130+
path: <$DATASET_PATH/gsm8k>
131+
subset_name: 'main'
132+
split: 'test'
133+
format:
134+
prompt_key: 'question'
135+
response_key: 'answer'
136+
default_workflow_type: 'math_workflow'
137+
trainer_input:
138+
experience_buffer:
139+
name: gsm8k_buffer
140+
storage_type: queue
141+
path: 'sqlite:///gsm8k.db'
142+
explorer:
143+
eval_interval: 50
144+
runner_num: 16
145+
rollout_model:
146+
engine_type: vllm_async
147+
engine_num: 1
148+
synchronizer:
149+
sync_method: 'nccl'
150+
sync_interval: 1
60151
trainer:
61-
actor_use_kl_loss: True (fro GRPO) / False (for PPO)
62-
actort_kl_loss_coef: 0.001
152+
trainer_config_path: 'examples/grpo_gsm8k/train_gsm8k.yaml'
153+
save_interval: 100
154+
63155
```
64156

157+
65158
### Run the Experiment
66159

67160
Run the RFT process with the following command:
@@ -76,7 +169,7 @@ trinity run --config examples/grpo_gsm8k/gsm8k.yaml
76169
Before RFT, we may use SFT as a warmup step. We need to set `buffer.trainer_input.sft_warmup_steps > 0` and prepare the SFT data to `buffer.trainer_input.sft_warmup_dataset.path=$DATASET_PATH/{sft_data}`.
77170

78171
```yaml
79-
# Properly set the following configs in gsm8k.yaml
172+
# Properly add the following configs in gsm8k.yaml
80173
buffer:
81174
trainer_input:
82175
sft_warmup_dataset:

examples/async_gsm8k/explorer.yaml

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,7 @@
11
project: "Trinity-RFT-gsm8k"
22
name: "async-qwen2.5-1.5B-gsm8k"
33
mode: explore
4-
checkpoint_root_dir: '/PATH/TO/CHECKPOINT/'
4+
checkpoint_root_dir: /PATH/TO/CHECKPOINT/
55
algorithm:
66
algorithm_type: grpo
77
repeat_times: 8
@@ -11,9 +11,9 @@ model:
1111
max_response_tokens: 1024
1212
cluster:
1313
node_num: 1
14-
gpu_per_node: 8
14+
gpu_per_node: 4
1515
buffer:
16-
total_epochs: 20
16+
total_epochs: 1
1717
batch_size: 96
1818
max_retry_times: 3
1919
max_retry_interval: 1
@@ -40,7 +40,7 @@ explorer:
4040
runner_num: 32
4141
rollout_model:
4242
engine_type: vllm_async
43-
engine_num: 2
43+
engine_num: 4
4444
tensor_parallel_size: 1
4545
enable_prefix_caching: false
4646
enforce_eager: true

0 commit comments

Comments
 (0)