diff --git a/docs/sphinx_doc/source/tutorial/example_async_mode.md b/docs/sphinx_doc/source/tutorial/example_async_mode.md index a1b09d873f..a565145d83 100644 --- a/docs/sphinx_doc/source/tutorial/example_async_mode.md +++ b/docs/sphinx_doc/source/tutorial/example_async_mode.md @@ -1,32 +1,104 @@ -# A quick example for asynchronous mode +# Asynchronous RFT -This example shows how to run RFT in asynchronous mode with the GRPO algorithm, Qwen-2.5-1.5B-Instruct model and GSM8K dataset. +This example shows how to run RFT in a fully asynchronous mode with the GRPO algorithm, Qwen-2.5-1.5B-Instruct model and GSM8K dataset. Trinity-RFT supports an asynchronous mode by running the trainer and explorer in separate processes. -For this purpose, we prepare two main config files: `trainer.yaml` and `explorer.yaml`. -The main difference between them is that in `trainer.yaml` we set `mode=train`, while in `explorer.yaml` we set `mode=explore`. -In addition, we need to configure the following parameters in both files. +For this purpose, we prepare two main config files: [`explorer.yaml`](https://github.com/modelscope/Trinity-RFT/blob/main/examples/async_gsm8k/explorer.yaml) and [`trainer.yaml`](https://github.com/modelscope/Trinity-RFT/blob/main/examples/async_gsm8k/trainer.yaml). +The main difference between them is that in `explorer.yaml` we set `mode` as `explore`, while in `trainer.yaml` we set `mode` as `train`. The model weights of the explorer and trainer are synchronized once every `sync_interval * batch_size` tasks. -```yaml -project: tutorial -name: async_mode_example -checkpoint_root_dir: /PATH/TO/CHECKPOINT +Suppose we have a node of 8 GPUs; we use 4 GPUs for the trainer and 4 GPUs for the explorer. +Some important setups of `explorer.yaml` are listed in the following: +```yaml +project: +name: +mode: explore +checkpoint_root_dir: /PATH/TO/CHECKPOINT/ +algorithm: + algorithm_type: grpo + repeat_times: 8 +model: + model_path: /PATH/TO/MODEL/ +cluster: + node_num: 1 + gpu_per_node: 4 buffer: - batch_size: + total_epochs: 1 + batch_size: 96 + explorer_input: + taskset: + name: gsm8k + storage_type: file + path: /PATH/TO/DATASET/ + split: train + format: + prompt_key: 'question' + response_key: 'answer' + rollout_args: + temperature: 1.0 + default_workflow_type: 'math_workflow' trainer_input: experience_buffer: name: gsm8k_buffer storage_type: queue path: 'sqlite:///gsm8k.db' +explorer: + eval_interval: 10 + runner_num: 32 + rollout_model: + engine_type: vllm_async + engine_num: 4 +synchronizer: + sync_method: 'checkpoint' + sync_interval: 10 +trainer: + trainer_config_path: examples/async_gsm8k/verl_config.yaml +``` + +Some important setups of `trainer.yaml` are listed in the following: +```yaml +project: +name: +mode: train +checkpoint_root_dir: /PATH/TO/CHECKPOINT/ +algorithm: + algorithm_type: grpo + repeat_times: 8 +model: + model_path: /PATH/TO/MODEL/ +cluster: + node_num: 1 + gpu_per_node: 4 +buffer: + total_epochs: 1 + batch_size: 96 + explorer_input: + taskset: + name: gsm8k + storage_type: file + path: /PATH/TO/DATASET/ + format: + prompt_key: 'question' + response_key: 'answer' + rollout_args: + temperature: 1.0 + default_workflow_type: 'math_workflow' + trainer_input: + experience_buffer: + name: gsm8k_buffer + storage_type: queue + path: 'sqlite:///gsm8k.db' synchronizer: sync_method: 'checkpoint' - sync_interval: + sync_interval: 10 +trainer: + trainer_config_path: examples/async_gsm8k/verl_config.yaml ``` + You may run this example with the following command: ```bash diff --git a/docs/sphinx_doc/source/tutorial/example_dpo.md b/docs/sphinx_doc/source/tutorial/example_dpo.md index 953af399e5..a6f70f5e62 100644 --- a/docs/sphinx_doc/source/tutorial/example_dpo.md +++ b/docs/sphinx_doc/source/tutorial/example_dpo.md @@ -1,4 +1,4 @@ -# Example: Run DPO on Human-Like-DPO-Dataset +# Offline DPO This example describes DPO based on the Qwen-2.5-1.5B-Instruct model and [Human-like-DPO-dataset](https://huggingface.co/datasets/HumanLLMs/Human-Like-DPO-Dataset). @@ -40,25 +40,36 @@ Note that the dataset has the keys `prompt`, `chosen` and `rejected`. If not, pa We use the configurations in [`dpo.yaml`](https://github.com/modelscope/Trinity-RFT/tree/main/examples/dpo_humanlike/dpo.yaml) and [`train_dpo.yaml`](https://github.com/modelscope/Trinity-RFT/tree/main/examples/dpo_humanlike/train_dpo.yaml) for this experiment. Some important setups are listed in the following: -We run the experiment in a train mode, as there is no Explorer. To enable this mode, we config `mode` to `train` and set `sync_method` to `checkpoint`. +We run the experiment in a train mode, as there is no Explorer. To enable this mode, we config `mode` to `train` and pass the data path to the trainer. ```yaml -# In dpo.yaml +project: +name: mode: train algorithm: algorithm_type: dpo -synchronizer: - sync_method: 'checkpoint' +checkpoint_root_dir: /PATH/TO/CHECKPOINT/ +model: + model_path: /PATH/TO/MODEL/ +cluster: + node_num: 1 + gpu_per_node: 8 buffer: - train_dataset: - storage_type: file - path: <$DATASET_PATH/human_like_dpo_dataset> - format: - prompt_type: # messages/plaintext - prompt_key: - chosen_key: - rejected_key: + total_epochs: 2 + batch_size: 64 + trainer_input: + experience_buffer: + name: dpo_buffer + storage_type: file + path: /PATH/TO/DATASET/ + format: + prompt_type: plaintext # plaintext/messages/chatpair + prompt_key: prompt + chosen_key: chosen + rejected_key: rejected trainer: + trainer_config_path: 'examples/dpo_humanlike/train_dpo.yaml' + save_interval: 30 actor_use_kl_loss: True actor_kl_loss_coef: 0.1 # value of beta in DPO ``` diff --git a/docs/sphinx_doc/source/tutorial/example_multi_turn.md b/docs/sphinx_doc/source/tutorial/example_multi_turn.md index d70528b6ed..46cc4ab32e 100644 --- a/docs/sphinx_doc/source/tutorial/example_multi_turn.md +++ b/docs/sphinx_doc/source/tutorial/example_multi_turn.md @@ -1,4 +1,4 @@ -# Example: Multi-Turn RFT +# Multi-Turn RFT In Trinity-RFT, we support Agentic RL with multiple rounds of interaction with environments. diff --git a/docs/sphinx_doc/source/tutorial/example_reasoning_advanced.md b/docs/sphinx_doc/source/tutorial/example_reasoning_advanced.md index 68ce684b81..a80032bc12 100644 --- a/docs/sphinx_doc/source/tutorial/example_reasoning_advanced.md +++ b/docs/sphinx_doc/source/tutorial/example_reasoning_advanced.md @@ -1,4 +1,4 @@ -# Example: off-policy RFT mode +# Off-Policy RFT Let's continue with the [previous GSM8k example](./example_reasoning_basic.md) and show some advanced features provided by Trinity-RFT, namely, off-policy or asynchronous RFT mode. @@ -12,8 +12,7 @@ Let's continue with the [previous GSM8k example](./example_reasoning_basic.md) a As an experimental feature of Trinity-RFT, we develop an embarrasingly simple off-policy RL algorithm, termed as OPMD (Online Policy Mirror Descent, inspired by [Kimi k1.5](https://arxiv.org/abs/2501.12599)). The algorithm design and analysis can be found in this [technical report](../../assets/opmd.pdf). - - +The config files are [`opmd_gsm8k.yaml`](https://github.com/modelscope/Trinity-RFT/blob/main/examples/opmd_gsm8k/opmd_gsm8k.yaml) and [`train_opmd_gsm8k.yaml`](https://github.com/modelscope/Trinity-RFT/blob/main/examples/opmd_gsm8k/train_opmd_gsm8k.yaml). To try out the OPMD algorithm: ```shell diff --git a/docs/sphinx_doc/source/tutorial/example_reasoning_basic.md b/docs/sphinx_doc/source/tutorial/example_reasoning_basic.md index 1cda68fc50..dc45994e98 100644 --- a/docs/sphinx_doc/source/tutorial/example_reasoning_basic.md +++ b/docs/sphinx_doc/source/tutorial/example_reasoning_basic.md @@ -1,6 +1,59 @@ -# A quick example with GSM8k +# Quick Start + +This tutorial shows a quick start guide for running RFT with Trinity-RFT. + +## Step 0: Environment Preparation + +Minimal environment requirements: + +- GPUs: At least 2 GPUs +- CUDA: Version >= 12.4 +- Python: Version >= 3.10 + +```shell +# Pull the source code from GitHub +git clone https://github.com/modelscope/Trinity-RFT +cd Trinity-RFT + +# Create a new environment using Conda or venv +# Option 1: Conda +conda create -n trinity python=3.10 +conda activate trinity + +# Option 2: venv +python3.10 -m venv .venv +source .venv/bin/activate + +# Install the package in editable mode +# for bash +pip install -e .[dev] +# for zsh +pip install -e .\[dev\] + +# Install flash-attn after all dependencies are installed +# Note: flash-attn will take a long time to compile, please be patient. +pip install flash-attn -v +# Try the following command if you encounter errors during installation +# pip install flash-attn -v --no-build-isolation +``` + +Installation from docker: + +We provided a dockerfile for Trinity-RFT. + +```shell +git clone https://github.com/modelscope/Trinity-RFT +cd Trinity-RFT + +# build the docker image +# Note: you can edit the dockerfile to customize the environment +# e.g., use pip mirrors or set api key +docker build -f scripts/docker/Dockerfile -t trinity-rft:latest . + +# run the docker image +docker run -it --gpus all --shm-size="64g" --rm -v $PWD:/workspace -v :/data trinity-rft:latest +``` -This example shows how to run RFT with the Qwen-2.5-1.5B-Instruct model and GSM8K dataset. ## Step 1: Model and Data Preparation @@ -37,31 +90,71 @@ More details on dataset downloading are referred to [ModelScope](https://modelsc ### Synchronous Mode of Trinity-RFT -We run the experiment in a synchronous mode where the Explorer and Trainer operate in turn. To enable this mode, we config `mode` to `both` (default) and set `sync_interval` properly. A smaller value of `sync_interval` makes the training closer to an on-policy setup. +We run the experiment in a synchronous mode where the Explorer and Trainer operate in turn. To enable this mode, we config `mode` to `both` (default) and set `sync_interval` properly. A smaller value of `sync_interval` makes the training closer to an on-policy setup. For example, we set `sync_interval` to 1 to simulate an on-policy setup. -```yaml -mode: both -synchronizer: - sync_method: 'nccl' - sync_interval: 2 -``` +### Use GRPO Algorithm -### Use GRPO or PPO Algorithm - -We use the configurations in [`gsm8k.yaml`](https://github.com/modelscope/Trinity-RFT/tree/main/examples/grpo_gsm8k/gsm8k.yaml) and [`train_gsm8k.yaml`](https://github.com/modelscope/Trinity-RFT/tree/main/examples/grpo_gsm8k/train_gsm8k.yaml) for this experiment. Some important setups are listed in the following: +We use the configurations in [`gsm8k.yaml`](https://github.com/modelscope/Trinity-RFT/tree/main/examples/grpo_gsm8k/gsm8k.yaml) and [`train_gsm8k.yaml`](https://github.com/modelscope/Trinity-RFT/tree/main/examples/grpo_gsm8k/train_gsm8k.yaml) for this experiment. Some important setups of `gsm8k.yaml` are listed in the following: ```yaml -# In gsm8k.yaml +project: +name: +checkpoint_root_dir: /PATH/TO/CHECKPOINT/ algorithm: - algorithm_type: grpo / ppo - repeat_times: {number of rollouts for each task} - + algorithm_type: grpo + repeat_times: 8 +model: + model_path: /PATH/TO/MODEL/ +cluster: + node_num: 1 + gpu_per_node: 2 +buffer: + total_epochs: 1 + batch_size: 128 + explorer_input: + taskset: + name: gsm8k + storage_type: file + path: <$DATASET_PATH/gsm8k> + subset_name: 'main' + split: 'train' + format: + prompt_key: 'question' + response_key: 'answer' + rollout_args: + temperature: 1.0 + eval_tasksets: + - name: gsm8k-eval + storage_type: file + path: <$DATASET_PATH/gsm8k> + subset_name: 'main' + split: 'test' + format: + prompt_key: 'question' + response_key: 'answer' + default_workflow_type: 'math_workflow' + trainer_input: + experience_buffer: + name: gsm8k_buffer + storage_type: queue + path: 'sqlite:///gsm8k.db' +explorer: + eval_interval: 50 + runner_num: 16 + rollout_model: + engine_type: vllm_async + engine_num: 1 +synchronizer: + sync_method: 'nccl' + sync_interval: 1 trainer: - actor_use_kl_loss: True (fro GRPO) / False (for PPO) - actort_kl_loss_coef: 0.001 + trainer_config_path: 'examples/grpo_gsm8k/train_gsm8k.yaml' + save_interval: 100 + ``` + ### Run the Experiment Run the RFT process with the following command: @@ -76,7 +169,7 @@ trinity run --config examples/grpo_gsm8k/gsm8k.yaml Before RFT, we may use SFT as a warmup step. We need to set `buffer.trainer_input.sft_warmup_steps > 0` and prepare the SFT data to `buffer.trainer_input.sft_warmup_dataset.path=$DATASET_PATH/{sft_data}`. ```yaml -# Properly set the following configs in gsm8k.yaml +# Properly add the following configs in gsm8k.yaml buffer: trainer_input: sft_warmup_dataset: diff --git a/examples/async_gsm8k/explorer.yaml b/examples/async_gsm8k/explorer.yaml index a05b2ebfcf..f0c5891536 100644 --- a/examples/async_gsm8k/explorer.yaml +++ b/examples/async_gsm8k/explorer.yaml @@ -1,7 +1,7 @@ project: "Trinity-RFT-gsm8k" name: "async-qwen2.5-1.5B-gsm8k" mode: explore -checkpoint_root_dir: '/PATH/TO/CHECKPOINT/' +checkpoint_root_dir: /PATH/TO/CHECKPOINT/ algorithm: algorithm_type: grpo repeat_times: 8 @@ -11,9 +11,9 @@ model: max_response_tokens: 1024 cluster: node_num: 1 - gpu_per_node: 8 + gpu_per_node: 4 buffer: - total_epochs: 20 + total_epochs: 1 batch_size: 96 max_retry_times: 3 max_retry_interval: 1 @@ -40,7 +40,7 @@ explorer: runner_num: 32 rollout_model: engine_type: vllm_async - engine_num: 2 + engine_num: 4 tensor_parallel_size: 1 enable_prefix_caching: false enforce_eager: true diff --git a/examples/async_gsm8k/trainer.yaml b/examples/async_gsm8k/trainer.yaml index d259cb7ca0..00135be059 100644 --- a/examples/async_gsm8k/trainer.yaml +++ b/examples/async_gsm8k/trainer.yaml @@ -1,19 +1,19 @@ project: "Trinity-RFT-gsm8k" name: "async-qwen2.5-1.5B-gsm8k" mode: train -checkpoint_root_dir: /PATH/TO/CHECKPOINT +checkpoint_root_dir: /PATH/TO/CHECKPOINT/ algorithm: algorithm_type: grpo repeat_times: 8 model: - model_path: /PATH/TO/MODEL + model_path: /PATH/TO/MODEL/ max_prompt_tokens: 256 max_response_tokens: 1024 cluster: node_num: 1 - gpu_per_node: 8 + gpu_per_node: 4 buffer: - total_epochs: 20 + total_epochs: 1 batch_size: 96 max_retry_times: 3 max_retry_interval: 1 @@ -35,17 +35,6 @@ buffer: name: gsm8k_buffer storage_type: queue path: 'sqlite:///gsm8k.db' -explorer: - eval_interval: 10 - runner_num: 32 - rollout_model: - engine_type: vllm_async - engine_num: 2 - tensor_parallel_size: 1 - enable_prefix_caching: false - enforce_eager: true - dtype: bfloat16 - seed: 42 synchronizer: sync_method: 'checkpoint' sync_interval: 10 diff --git a/examples/dpo_humanlike/dpo.yaml b/examples/dpo_humanlike/dpo.yaml index 4648d6e493..8cd3dbe0c8 100644 --- a/examples/dpo_humanlike/dpo.yaml +++ b/examples/dpo_humanlike/dpo.yaml @@ -3,16 +3,16 @@ name: "trinity_dpo" mode: train algorithm: algorithm_type: dpo -checkpoint_root_dir: /PATH/TO/CHECKPOINT +checkpoint_root_dir: /PATH/TO/CHECKPOINT/ model: - model_path: '/PATH/TO/MODEL' - max_prompt_tokens: 1792 - max_response_tokens: 256 + model_path: /PATH/TO/MODEL + max_prompt_tokens: 512 + max_response_tokens: 1024 cluster: node_num: 1 gpu_per_node: 8 buffer: - total_epochs: 20 + total_epochs: 2 batch_size: 32 max_retry_times: 3 max_retry_interval: 1 @@ -20,7 +20,7 @@ buffer: experience_buffer: name: dpo_buffer storage_type: file - path: '/PATH/TO/DATASET/' + path: /PATH/TO/DATASET/ format: prompt_type: plaintext # plaintext/messages/chatpair prompt_key: prompt diff --git a/examples/grpo_gsm8k/gsm8k.yaml b/examples/grpo_gsm8k/gsm8k.yaml index de2c1d2d9e..2a87ef288b 100644 --- a/examples/grpo_gsm8k/gsm8k.yaml +++ b/examples/grpo_gsm8k/gsm8k.yaml @@ -19,7 +19,7 @@ data_processor: db_url: '' model: - model_path: '/PATH/TO/MODEL/' + model_path: /PATH/TO/MODEL/ max_prompt_tokens: 256 max_response_tokens: 1024 cluster: @@ -77,7 +77,7 @@ explorer: seed: 42 synchronizer: sync_method: 'nccl' - sync_interval: 2 + sync_interval: 1 sync_timeout: 1200 trainer: trainer_type: 'verl' diff --git a/examples/grpo_math/math.yaml b/examples/grpo_math/math.yaml index daa1fd5fb7..5d3b16c2cc 100644 --- a/examples/grpo_math/math.yaml +++ b/examples/grpo_math/math.yaml @@ -47,7 +47,7 @@ explorer: seed: 42 synchronizer: sync_method: 'nccl' - sync_interval: 2 + sync_interval: 1 sync_timeout: 1200 trainer: trainer_type: 'verl'