Skip to content
Merged
Show file tree
Hide file tree
Changes from 3 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
98 changes: 87 additions & 11 deletions docs/sphinx_doc/source/tutorial/example_async_mode.md
Original file line number Diff line number Diff line change
@@ -1,32 +1,108 @@
# A quick example for asynchronous mode
# Asynchronous RFT Mode

This example shows how to run RFT in asynchronous mode with the GRPO algorithm, Qwen-2.5-1.5B-Instruct model and GSM8K dataset.
This example shows how to run RFT in a fully asynchronous mode with the GRPO algorithm, Qwen-2.5-1.5B-Instruct model and GSM8K dataset.

Trinity-RFT supports an asynchronous mode by running the trainer and explorer in separate processes.

For this purpose, we prepare two main config files: `trainer.yaml` and `explorer.yaml`.
The main difference between them is that in `trainer.yaml` we set `mode=train`, while in `explorer.yaml` we set `mode=explore`.
In addition, we need to configure the following parameters in both files.
For this purpose, we prepare two main config files: [`explorer.yaml`](https://github.com/modelscope/Trinity-RFT/blob/main/examples/async_gsm8k/explorer.yaml) and [`trainer.yaml`](https://github.com/modelscope/Trinity-RFT/blob/main/examples/async_gsm8k/trainer.yaml).
The main difference between them is that in `explorer.yaml` we set `mode` as `explore`, while in `trainer.yaml` we set `mode` as `train`.
The model weights of the explorer and trainer are synchronized once every `sync_interval * batch_size` tasks.

```yaml
project: tutorial
name: async_mode_example
checkpoint_root_dir: /PATH/TO/CHECKPOINT
Suppose we have a node of 8 GPUs; we use 4 GPUs for the trainer and 4 GPUs for the explorer.
Some important setups of `explorer.yaml` are listed in the following:

```yaml
project: <project_name>
name: <experiment_name>
mode: explore
checkpoint_root_dir: /PATH/TO/CHECKPOINT/
algorithm:
algorithm_type: grpo
repeat_times: 8
model:
model_path: /PATH/TO/MODEL/
cluster:
node_num: 1
gpu_per_node: 4
buffer:
batch_size: <batch_size>
total_epochs: 1
batch_size: 96
explorer_input:
taskset:
name: gsm8k
storage_type: file
path: /PATH/TO/DATASET/
split: train
format:
prompt_key: 'question'
response_key: 'answer'
rollout_args:
temperature: 1.0
logprobs: 0
default_workflow_type: 'math_workflow'
trainer_input:
experience_buffer:
name: gsm8k_buffer
storage_type: queue
path: 'sqlite:///gsm8k.db'
explorer:
eval_interval: 10
runner_num: 32
rollout_model:
engine_type: vllm_async
engine_num: 4
dtype: bfloat16
seed: 42
synchronizer:
sync_method: 'checkpoint'
sync_interval: 10
trainer:
trainer_config_path: examples/async_gsm8k/verl_config.yaml
```
Some important setups of `trainer.yaml` are listed in the following:

```yaml
project: <project_name>
name: <experiment_name>
mode: train
checkpoint_root_dir: /PATH/TO/CHECKPOINT/
algorithm:
algorithm_type: grpo
repeat_times: 8
model:
model_path: /PATH/TO/MODEL/
cluster:
node_num: 1
gpu_per_node: 4
buffer:
total_epochs: 1
batch_size: 96
explorer_input:
taskset:
name: gsm8k
storage_type: file
path: /PATH/TO/DATASET/
format:
prompt_key: 'question'
response_key: 'answer'
rollout_args:
n: 8
temperature: 1.0
default_workflow_type: 'math_workflow'
trainer_input:
experience_buffer:
name: gsm8k_buffer
storage_type: queue
path: 'sqlite:///gsm8k.db'
synchronizer:
sync_method: 'checkpoint'
sync_interval: <sync_interval>
sync_interval: 10
trainer:
trainer_config_path: examples/async_gsm8k/verl_config.yaml
```


You may run this example with the following command:

```bash
Expand Down
37 changes: 24 additions & 13 deletions docs/sphinx_doc/source/tutorial/example_dpo.md
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
# Example: Run DPO on Human-Like-DPO-Dataset
# DPO Mode

This example describes DPO based on the Qwen-2.5-1.5B-Instruct model and [Human-like-DPO-dataset](https://huggingface.co/datasets/HumanLLMs/Human-Like-DPO-Dataset).

Expand Down Expand Up @@ -40,25 +40,36 @@ Note that the dataset has the keys `prompt`, `chosen` and `rejected`. If not, pa

We use the configurations in [`dpo.yaml`](https://github.com/modelscope/Trinity-RFT/tree/main/examples/dpo_humanlike/dpo.yaml) and [`train_dpo.yaml`](https://github.com/modelscope/Trinity-RFT/tree/main/examples/dpo_humanlike/train_dpo.yaml) for this experiment. Some important setups are listed in the following:

We run the experiment in a train mode, as there is no Explorer. To enable this mode, we config `mode` to `train` and set `sync_method` to `checkpoint`.
We run the experiment in a train mode, as there is no Explorer. To enable this mode, we config `mode` to `train` and pass the data path to the trainer.

```yaml
# In dpo.yaml
project: <project_name>
name: <experiment_name>
mode: train
algorithm:
algorithm_type: dpo
synchronizer:
sync_method: 'checkpoint'
checkpoint_root_dir: /PATH/TO/CHECKPOINT/
model:
model_path: /PATH/TO/MODEL/
cluster:
node_num: 1
gpu_per_node: 8
buffer:
train_dataset:
storage_type: file
path: <$DATASET_PATH/human_like_dpo_dataset>
format:
prompt_type: <prompt_type> # messages/plaintext
prompt_key: <prompt_key>
chosen_key: <chosen_key>
rejected_key: <rejected_key>
total_epochs: 2
batch_size: 64
trainer_input:
experience_buffer:
name: dpo_buffer
storage_type: file
path: /PATH/TO/DATASET/
format:
prompt_type: plaintext # plaintext/messages/chatpair
prompt_key: prompt
chosen_key: chosen
rejected_key: rejected
trainer:
trainer_config_path: 'examples/dpo_humanlike/train_dpo.yaml'
save_interval: 30
actor_use_kl_loss: True
actor_kl_loss_coef: 0.1 # value of beta in DPO
```
Expand Down
2 changes: 1 addition & 1 deletion docs/sphinx_doc/source/tutorial/example_multi_turn.md
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
# Example: Multi-Turn RFT
# Multi-Turn RFT

In Trinity-RFT, we support Agentic RL with multiple rounds of interaction with environments.

Expand Down
5 changes: 2 additions & 3 deletions docs/sphinx_doc/source/tutorial/example_reasoning_advanced.md
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
# Example: off-policy RFT mode
# Off-Policy RFT


Let's continue with the [previous GSM8k example](./example_reasoning_basic.md) and show some advanced features provided by Trinity-RFT, namely, off-policy or asynchronous RFT mode.
Expand All @@ -12,8 +12,7 @@ Let's continue with the [previous GSM8k example](./example_reasoning_basic.md) a

As an experimental feature of Trinity-RFT, we develop an embarrasingly simple off-policy RL algorithm, termed as OPMD (Online Policy Mirror Descent, inspired by [Kimi k1.5](https://arxiv.org/abs/2501.12599)).
The algorithm design and analysis can be found in this [technical report](../../assets/opmd.pdf).


The config files are [`opmd_gsm8k.yaml`](https://github.com/modelscope/Trinity-RFT/blob/main/examples/opmd_gsm8k/opmd_gsm8k.yaml) and [`train_opmd_gsm8k.yaml`](https://github.com/modelscope/Trinity-RFT/blob/main/examples/opmd_gsm8k/train_opmd_gsm8k.yaml).

To try out the OPMD algorithm:
```shell
Expand Down
132 changes: 113 additions & 19 deletions docs/sphinx_doc/source/tutorial/example_reasoning_basic.md
Original file line number Diff line number Diff line change
@@ -1,6 +1,59 @@
# A quick example with GSM8k
# Quick Start

This tutorial shows a quick start guide for running RFT with Trinity-RFT.

## Step 0: Environment Preparation

Minimal environment requirements:

- GPUs: At least 2 GPUs
- CUDA: Version >= 12.4
- Python: Version >= 3.10

```shell
# Pull the source code from GitHub
git clone https://github.com/modelscope/Trinity-RFT
cd Trinity-RFT

# Create a new environment using Conda or venv
# Option 1: Conda
conda create -n trinity python=3.10
conda activate trinity

# Option 2: venv
python3.10 -m venv .venv
source .venv/bin/activate

# Install the package in editable mode
# for bash
pip install -e .[dev]
# for zsh
pip install -e .\[dev\]

# Install flash-attn after all dependencies are installed
# Note: flash-attn will take a long time to compile, please be patient.
pip install flash-attn -v
# Try the following command if you encounter errors during installation
# pip install flash-attn -v --no-build-isolation
```

Installation from docker:

We provided a dockerfile for Trinity-RFT.

```shell
git clone https://github.com/modelscope/Trinity-RFT
cd Trinity-RFT

# build the docker image
# Note: you can edit the dockerfile to customize the environment
# e.g., use pip mirrors or set api key
docker build -f scripts/docker/Dockerfile -t trinity-rft:latest .

# run the docker image
docker run -it --gpus all --shm-size="64g" --rm -v $PWD:/workspace -v <root_path_of_data_and_checkpoints>:/data trinity-rft:latest
```

This example shows how to run RFT with the Qwen-2.5-1.5B-Instruct model and GSM8K dataset.

## Step 1: Model and Data Preparation

Expand Down Expand Up @@ -37,31 +90,72 @@ More details on dataset downloading are referred to [ModelScope](https://modelsc

### Synchronous Mode of Trinity-RFT

We run the experiment in a synchronous mode where the Explorer and Trainer operate in turn. To enable this mode, we config `mode` to `both` (default) and set `sync_interval` properly. A smaller value of `sync_interval` makes the training closer to an on-policy setup.
We run the experiment in a synchronous mode where the Explorer and Trainer operate in turn. To enable this mode, we config `mode` to `both` (default) and set `sync_interval` properly. A smaller value of `sync_interval` makes the training closer to an on-policy setup. For example, we set `sync_interval` to 1 to simulate an on-policy setup.

```yaml
mode: both
synchronizer:
sync_method: 'nccl'
sync_interval: 2
```
### Use GRPO Algorithm

### Use GRPO or PPO Algorithm

We use the configurations in [`gsm8k.yaml`](https://github.com/modelscope/Trinity-RFT/tree/main/examples/grpo_gsm8k/gsm8k.yaml) and [`train_gsm8k.yaml`](https://github.com/modelscope/Trinity-RFT/tree/main/examples/grpo_gsm8k/train_gsm8k.yaml) for this experiment. Some important setups are listed in the following:
We use the configurations in [`gsm8k.yaml`](https://github.com/modelscope/Trinity-RFT/tree/main/examples/grpo_gsm8k/gsm8k.yaml) and [`train_gsm8k.yaml`](https://github.com/modelscope/Trinity-RFT/tree/main/examples/grpo_gsm8k/train_gsm8k.yaml) for this experiment. Some important setups of `gsm8k.yaml` are listed in the following:


```yaml
# In gsm8k.yaml
project: <project_name>
name: <experiment_name>
checkpoint_root_dir: /PATH/TO/CHECKPOINT/
algorithm:
algorithm_type: grpo / ppo
repeat_times: {number of rollouts for each task}

algorithm_type: grpo
repeat_times: 8
model:
model_path: /PATH/TO/MODEL/
cluster:
node_num: 1
gpu_per_node: 2
buffer:
total_epochs: 1
batch_size: 128
explorer_input:
taskset:
name: gsm8k
storage_type: file
path: <$DATASET_PATH/gsm8k>
subset_name: 'main'
split: 'train'
format:
prompt_key: 'question'
response_key: 'answer'
rollout_args:
n: 8
temperature: 1.0
eval_tasksets:
- name: gsm8k-eval
storage_type: file
path: <$DATASET_PATH/gsm8k>
subset_name: 'main'
split: 'test'
format:
prompt_key: 'question'
response_key: 'answer'
default_workflow_type: 'math_workflow'
trainer_input:
experience_buffer:
name: gsm8k_buffer
storage_type: queue
path: 'sqlite:///gsm8k.db'
explorer:
eval_interval: 50
runner_num: 16
rollout_model:
engine_type: vllm_async
engine_num: 1
synchronizer:
sync_method: 'nccl'
sync_interval: 1
trainer:
actor_use_kl_loss: True (fro GRPO) / False (for PPO)
actort_kl_loss_coef: 0.001
trainer_config_path: 'examples/grpo_gsm8k/train_gsm8k.yaml'
save_interval: 100

```


### Run the Experiment

Run the RFT process with the following command:
Expand All @@ -76,7 +170,7 @@ trinity run --config examples/grpo_gsm8k/gsm8k.yaml
Before RFT, we may use SFT as a warmup step. We need to set `buffer.trainer_input.sft_warmup_steps > 0` and prepare the SFT data to `buffer.trainer_input.sft_warmup_dataset.path=$DATASET_PATH/{sft_data}`.

```yaml
# Properly set the following configs in gsm8k.yaml
# Properly add the following configs in gsm8k.yaml
buffer:
trainer_input:
sft_warmup_dataset:
Expand Down
8 changes: 4 additions & 4 deletions examples/async_gsm8k/explorer.yaml
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
project: "Trinity-RFT-gsm8k"
name: "async-qwen2.5-1.5B-gsm8k"
mode: explore
checkpoint_root_dir: '/PATH/TO/CHECKPOINT/'
checkpoint_root_dir: /PATH/TO/CHECKPOINT/
algorithm:
algorithm_type: grpo
repeat_times: 8
Expand All @@ -11,9 +11,9 @@ model:
max_response_tokens: 1024
cluster:
node_num: 1
gpu_per_node: 8
gpu_per_node: 4
buffer:
total_epochs: 20
total_epochs: 1
batch_size: 96
max_retry_times: 3
max_retry_interval: 1
Expand All @@ -40,7 +40,7 @@ explorer:
runner_num: 32
rollout_model:
engine_type: vllm_async
engine_num: 2
engine_num: 4
tensor_parallel_size: 1
enable_prefix_caching: false
enforce_eager: true
Expand Down
Loading