Skip to content

Commit b45ac40

Browse files
authored
Add example_async_mode (#28)
1 parent 50aba82 commit b45ac40

File tree

13 files changed

+367
-17
lines changed

13 files changed

+367
-17
lines changed

README.md

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -260,7 +260,8 @@ More example config files can be found in `examples`.
260260

261261
For more detailed examples about how to use Trinity-RFT, please refer to the following tutorials:
262262
+ [A quick example with GSM8k](./docs/sphinx_doc/source/tutorial/example_reasoning_basic.md);
263-
+ [Off-policy / asynchronous modes of RFT](./docs/sphinx_doc/source/tutorial/example_reasoning_advanced.md);
263+
+ [Off-policy mode of RFT](./docs/sphinx_doc/source/tutorial/example_reasoning_advanced.md);
264+
+ [Asynchronous mode of RFT](./docs/sphinx_doc/source/tutorial/example_async_mode.md);
264265
+ [Multi-turn tasks](./docs/sphinx_doc/source/tutorial/example_multi_turn.md);
265266
+ [Data processing pipelines](./docs/sphinx_doc/source/tutorial/example_data_functionalities.md);
266267
+ [Offline learning by DPO](./docs/sphinx_doc/source/tutorial/example_dpo.md).
51.8 KB
Loading

docs/sphinx_doc/source/main.md

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -240,7 +240,8 @@ More example config files can be found in `examples`.
240240

241241
For more detailed examples about how to use Trinity-RFT, please refer to the following documents:
242242
+ [A quick example with GSM8k](tutorial/example_reasoning_basic.md);
243-
+ [Off-policy / asynchronous modes of RFT](tutorial/example_reasoning_advanced.md);
243+
+ [Off-policy mode of RFT](tutorial/example_reasoning_advanced.md);
244+
+ [Asynchronous mode of RFT](tutorial/example_async_mode.md);
244245
+ [Multi-turn tasks](tutorial/example_multi_turn.md);
245246
+ [Data processing pipelines](tutorial/example_data_functionalities.md);
246247
+ [Offline learning by DPO](tutorial/example_dpo.md).
Lines changed: 41 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,41 @@
1+
# A quick example for asynchronous mode
2+
3+
This example shows how to run RFT in asynchronous mode with the GRPO algorithm, Qwen-2.5-1.5B-Instruct model and GSM8K dataset.
4+
5+
Trinity-RFT supports an asynchronous mode by running the trainer and explorer in separate processes.
6+
7+
For this purpose, we prepare two main config files: `trainer.yaml` and `explorer.yaml`.
8+
The main difference between them is that in `trainer.yaml` we set `mode=train`, while in `explorer.yaml` we set `mode=explore`.
9+
In addition, we need to configure the following parameters in both files.
10+
The model weights of the explorer and trainer are synchronized once every `sync_iteration_interval * batch_size` tasks.
11+
12+
```yaml
13+
data:
14+
batch_size: <batch_size>
15+
# The same checkpoint path
16+
model:
17+
checkpoint_path: /PATH/TO/CHECKPOINT
18+
19+
# The same data_base path
20+
buffer:
21+
train_dataset:
22+
name: gsm8k_buffer
23+
storage_type: queue
24+
path: 'sqlite:///gsm8k.db'
25+
26+
synchronizer:
27+
sync_method: 'checkpoint'
28+
sync_iteration_interval: <sync_iteration_interval>
29+
```
30+
31+
You may run this example with the following command:
32+
33+
```bash
34+
bash examples/async_gsm8k/run.sh
35+
```
36+
37+
The following plot shows the learning curve of GRPO in the asynchronous mode.
38+
> This result should be regarded merely as a baseline, since GRPO is supposed to be an on-policy algorithm.
39+
> We are continuously investigating other RL algorithms (e.g., [OPMD](./example_reasoning_advanced.md)) in the asynchronous mode.
40+
41+
![async](../../assets/async-curve.png)

docs/sphinx_doc/source/tutorial/example_reasoning_advanced.md

Lines changed: 1 addition & 15 deletions
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,4 @@
1-
# Example: off-policy / asynchronous RFT mode
1+
# Example: off-policy RFT mode
22

33

44
Let's continue with the [previous GSM8k example](./example_reasoning_basic.md) and show some advanced features provided by Trinity-RFT, namely, off-policy or asynchronous RFT mode.
@@ -35,17 +35,3 @@ A similar performance boost is shown at step 21, which leads to a converged scor
3535

3636

3737
![opmd](../../assets/opmd-curve.png)
38-
39-
40-
41-
42-
43-
## Asynchronous mode
44-
45-
46-
Trinity-RFT supports the asynchronous and decoupled mode of RFT, where explorer and trainer act independently and asynchronously.
47-
To run this mode, the explorer and trainer need to be launched separately, with the `mode` parameter in the config file set to `explore` and `train` respectively.
48-
49-
50-
51-
*We are still testing this mode more thoroughly. A concrete example is coming soon!*

examples/async_gsm8k/README.md

Lines changed: 13 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,13 @@
1+
# Asynchronous mode on GSM8K dataset
2+
3+
This example shows the usage of GRPO on the GSM8K dataset in an asynchronous mode.
4+
5+
For more detailed information, please refer to the [documentation](../../docs/sphinx_doc/source/tutorial/example_async_mode.md).
6+
7+
The config files are located in [`trainer.yaml`](trainer.yaml), [`explorer.yaml`](explorer.yaml), and [`verl_config.yaml`](verl_config.yaml).
8+
9+
You can run this example by the following command:
10+
11+
```bash
12+
bash examples/async_gsm8k/run.sh
13+
```

examples/async_gsm8k/explorer.yaml

Lines changed: 58 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,58 @@
1+
mode: explore
2+
data:
3+
# basic info
4+
dataset_path: /PATH/TO/DATASET/
5+
subset_name: ''
6+
train_split: 'train'
7+
eval_split: 'test'
8+
format_config:
9+
prompt_key: 'question'
10+
response_key: 'answer'
11+
# downstream loading related
12+
total_epochs: 20
13+
batch_size: 96
14+
default_workflow_type: 'math_workflow'
15+
model:
16+
model_path: /PATH/TO/MODEL/
17+
max_prompt_tokens: 256
18+
max_response_tokens: 1024
19+
checkpoint_path: 'checkpoints/qwen2.5-1.5B-gsm8k'
20+
cluster:
21+
node_num: 1
22+
gpu_per_node: 8
23+
buffer:
24+
max_retry_times: 3
25+
max_retry_interval: 1
26+
train_dataset:
27+
name: gsm8k_buffer
28+
storage_type: queue
29+
path: 'sqlite:///gsm8k.db'
30+
explorer:
31+
engine_type: vllm_async
32+
engine_num: 2
33+
runner_num: 32
34+
tensor_parallel_size: 1
35+
enable_prefix_caching: false
36+
enforce_eager: true
37+
dtype: bfloat16
38+
temperature: 1.0
39+
seed: 42
40+
logprobs: 0
41+
repeat_times: 8
42+
use_ray: false
43+
backend: 'nccl'
44+
max_pending_requests: 32
45+
max_waiting_steps: 4
46+
synchronizer:
47+
sync_method: 'checkpoint'
48+
sync_iteration_interval: 10
49+
trainer:
50+
trainer_type: 'verl'
51+
algorithm_type: ppo
52+
trainer_config_path: examples/async_gsm8k/verl_config.yaml
53+
sft_warmup_iteration: 0 # Set to integer to enable sft warmup
54+
eval_interval: 10
55+
monitor:
56+
cache_root_dir: ""
57+
project: "Trinity-RFT-gsm8k"
58+
name: "async-qwen2.5-1.5B-gsm8k"

examples/async_gsm8k/run.sh

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,4 @@
1+
#!/bin/bash
2+
trinity run --config examples/async_gsm8k/explorer.yaml 2>&1 | tee explorer.log &
3+
sleep 30
4+
trinity run --config examples/async_gsm8k/trainer.yaml 2>&1 | tee trainer.log &

examples/async_gsm8k/trainer.yaml

Lines changed: 58 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,58 @@
1+
mode: train
2+
data:
3+
# basic info
4+
dataset_path: /PATH/TO/DATASET/
5+
subset_name: ''
6+
train_split: 'train'
7+
eval_split: 'test'
8+
format_config:
9+
prompt_key: 'question'
10+
response_key: 'answer'
11+
# downstream loading related
12+
total_epochs: 20
13+
batch_size: 96
14+
default_workflow_type: 'math_workflow'
15+
model:
16+
model_path: /PATH/TO/MODEL/
17+
max_prompt_tokens: 256
18+
max_response_tokens: 1024
19+
checkpoint_path: ""
20+
cluster:
21+
node_num: 1
22+
gpu_per_node: 8
23+
buffer:
24+
max_retry_times: 3
25+
max_retry_interval: 1
26+
train_dataset:
27+
name: gsm8k_buffer
28+
storage_type: queue
29+
path: 'sqlite:///gsm8k.db'
30+
explorer:
31+
engine_type: vllm_async
32+
engine_num: 2
33+
runner_num: 32
34+
tensor_parallel_size: 1
35+
enable_prefix_caching: false
36+
enforce_eager: true
37+
dtype: bfloat16
38+
temperature: 1.0
39+
seed: 42
40+
logprobs: 0
41+
repeat_times: 8
42+
use_ray: false
43+
backend: 'nccl'
44+
max_pending_requests: 32
45+
max_waiting_steps: 4
46+
synchronizer:
47+
sync_method: 'checkpoint'
48+
sync_iteration_interval: 10
49+
trainer:
50+
trainer_type: 'verl'
51+
algorithm_type: ppo
52+
trainer_config_path: examples/async_gsm8k/verl_config.yaml
53+
sft_warmup_iteration: 0 # Set to integer to enable sft warmup
54+
eval_interval: 10
55+
monitor:
56+
cache_root_dir: ""
57+
project: "Trinity-RFT-gsm8k"
58+
name: "async-qwen2.5-1.5B-gsm8k"

0 commit comments

Comments
 (0)