Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 2 additions & 2 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -266,7 +266,7 @@ Then, for command-line users, run the RFT process with the following command:
trinity run --config <config_path>
```

> For example, below is the command for fine-tuning Qwen-2.5-1.5B-Instruct on GSM8k dataset using GRPO algorithm:
> For example, below is the command for fine-tuning Qwen2.5-1.5B-Instruct on GSM8k dataset using GRPO algorithm:
> ```shell
> trinity run --config examples/grpo_gsm8k/gsm8k.yaml
> ```
Expand All @@ -279,7 +279,7 @@ For more detailed examples about how to use Trinity-RFT, please refer to the fol
+ [Off-policy mode of RFT](./docs/sphinx_doc/source/tutorial/example_reasoning_advanced.md)
+ [Asynchronous mode of RFT](./docs/sphinx_doc/source/tutorial/example_async_mode.md)
+ [Multi-turn tasks](./docs/sphinx_doc/source/tutorial/example_multi_turn.md)
+ [Offline learning by DPO](./docs/sphinx_doc/source/tutorial/example_dpo.md)
+ [Offline learning by DPO or SFT](./docs/sphinx_doc/source/tutorial/example_dpo.md)
+ [Advanced data processing / human-in-the-loop](./docs/sphinx_doc/source/tutorial/example_data_functionalities.md)


Expand Down
22 changes: 10 additions & 12 deletions docs/sphinx_doc/source/main.md
Original file line number Diff line number Diff line change
Expand Up @@ -84,15 +84,18 @@ e.g., utilizing NCCL (when feasible) for model weight synchronization, sequence

## Getting started


*Note: this project is currently under active development; comments and suggestions are welcome!*

```{note}
Note: This project is currently under active development; comments and suggestions are welcome!
```



### Step 1: preparations


Trinity-RFT requires
Python version >= 3.10,
CUDA version >= 12.4,
and at least 2 GPUs.


Installation from source (recommended):
Expand Down Expand Up @@ -146,11 +149,6 @@ docker build -f scripts/docker/Dockerfile -t trinity-rft:latest .
docker run -it --gpus all --shm-size="64g" --rm -v $PWD:/workspace -v <root_path_of_data_and_checkpoints>:/data trinity-rft:latest
```

Trinity-RFT requires
Python version >= 3.10,
CUDA version >= 12.4,
and at least 2 GPUs.


### Step 2: prepare dataset and model

Expand Down Expand Up @@ -243,15 +241,15 @@ trinity run --config <config_path>



For example, below is the command for fine-tuning Qwen-2.5-1.5B-Instruct on GSM8k dataset using GRPO algorithm:
For example, below is the command for fine-tuning Qwen2.5-1.5B-Instruct on GSM8k dataset using GRPO algorithm:

```shell
trinity run --config examples/grpo_gsm8k/gsm8k.yaml
```



More example config files can be found in `examples`.
More example config files can be found in [`examples`](https://github.com/modelscope/Trinity-RFT/tree/main/examples/).



Expand All @@ -260,7 +258,7 @@ For more detailed examples about how to use Trinity-RFT, please refer to the fol
+ [Off-policy mode of RFT](tutorial/example_reasoning_advanced.md)
+ [Asynchronous mode of RFT](tutorial/example_async_mode.md)
+ [Multi-turn tasks](tutorial/example_multi_turn.md)
+ [Offline learning by DPO](tutorial/example_dpo.md)
+ [Offline learning by DPO or SFT](tutorial/example_dpo.md)
+ [Advanced data processing / human-in-the-loop](tutorial/example_data_functionalities.md)


Expand Down
2 changes: 1 addition & 1 deletion docs/sphinx_doc/source/tutorial/example_async_mode.md
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
# Asynchronous RFT

This example shows how to run RFT in a fully asynchronous mode with the GRPO algorithm, Qwen-2.5-1.5B-Instruct model and GSM8K dataset.
This example shows how to run RFT in a fully asynchronous mode with the GRPO algorithm, Qwen2.5-1.5B-Instruct model and GSM8K dataset.

Trinity-RFT supports an asynchronous mode by running the trainer and explorer in separate processes.

Expand Down
12 changes: 6 additions & 6 deletions docs/sphinx_doc/source/tutorial/example_data_functionalities.md
Original file line number Diff line number Diff line change
Expand Up @@ -26,14 +26,14 @@ python scripts/start_servers.py

### Configure the Data Module

Trinity-RFT uses a unified config file to manage all config items. For the data module, you need to focus on the `data` section in the config file.
Trinity-RFT uses a unified config file to manage all config items. For the data module, you need to focus on the `data_processor` section in the config file.

In this example, assume that you need to rank all math questions and corresponding answers by their difficulties. So you can set these config items like the following example:

```yaml
data_processor:
# basic info
source_data_path: '/path/to/gsm8k'
source_data_path: /PATH/TO/GSM8K/
load_kwargs:
split: 'train' # only need the train split
format: # set the field mappings
Expand All @@ -58,7 +58,7 @@ If you are not familiar with Data-Juicer, the data module provides a natural-lan
```yaml
data_processor:
# basic info
source_data_path: '/path/to/gsm8k'
source_data_path: /PATH/TO/GSM8K/
load_kwargs:
split: 'train' # only need the train split
format: # set the field mappings
Expand Down Expand Up @@ -100,7 +100,7 @@ After preparing the Data-Juicer data processing recipe, you can set the `dj_conf
```yaml
data_processor:
# basic info
source_data_path: '/path/to/gsm8k'
source_data_path: /PATH/TO/GSM8K/
load_kwargs:
split: 'train' # only need the train split
format: # set the field mappings
Expand Down Expand Up @@ -165,7 +165,7 @@ python scripts/start_servers.py

### Configure the Data Module

Trinity-RFT uses a unified config file to manage all config items. For the data module, you need to focus on the `data` section in the config file.
Trinity-RFT uses a unified config file to manage all config items. For the data module, you need to focus on the `data_processor` section in the config file.

In this example, assume that you need to rank all math questions and corresponding answers by their difficulties. So you can set these config items like the following example:

Expand All @@ -187,7 +187,7 @@ data_processor:

Here you can set the basic information for the example dataset, database information that is used to store the result dataset, and some other items about downstream dataset loading for exploring and training, which is similar to the example above.

For this example, we assume that you are somehow familiar with the basic usage of Data-Juicer, so we need to prepare a Data-Juicer data processing recipe in `tests/test_configs/human_annotator_test_dj_cfg.yaml` that includes an OP of `human_preference_annotation_mapper`. For example:
For this example, we assume that you are somehow familiar with the basic usage of Data-Juicer, so we need to prepare a Data-Juicer data processing recipe in [`tests/test_configs/human_annotator_test_dj_cfg.yaml`](https://github.com/modelscope/Trinity-RFT/blob/main/tests/test_configs/human_annotator_test_dj_cfg.yaml) that includes an OP of `human_preference_annotation_mapper`. For example:

```yaml
project_name: 'demo-human-annotator'
Expand Down
4 changes: 2 additions & 2 deletions docs/sphinx_doc/source/tutorial/example_dpo.md
Original file line number Diff line number Diff line change
@@ -1,12 +1,12 @@
# Offline DPO and SFT

This example describes DPO and SFT based on the Qwen-2.5-1.5B-Instruct model.
This example describes DPO and SFT based on the Qwen2.5-1.5B-Instruct model.

## Step 1: Model and Data Preparation

### Model Preparation

Download the Qwen-2.5-1.5B-Instruct model to the local directory `$MODEL_PATH/Qwen2.5-1.5B-Instruct`:
Download the Qwen2.5-1.5B-Instruct model to the local directory `$MODEL_PATH/Qwen2.5-1.5B-Instruct`:

```shell
# Using Modelscope
Expand Down
14 changes: 13 additions & 1 deletion docs/sphinx_doc/source/tutorial/example_mix_algo.md
Original file line number Diff line number Diff line change
Expand Up @@ -25,9 +25,15 @@ The first term corresponds to the standard GRPO objective, which aims to maximiz
We prompt a powerful LLM to generate responses with the CoT process for some pre-defined questions. The collected dta are viewed as some experiences from an expert. We store them in a `jsonl` file `expert_data.jsonl` with the following format:

```json
{"question": "What is the average of 4, 6, and 8?","response": "I add the numbers together and divide by the count: 4 + 6 + 8 = 18, divided by 3 gives 6. The answer is 6."}
{
"messages": [
{ "role": "system", "content": <system_prompt> },
{ "role": "user", "content": "What is the sum of 4 and 12?" },
{ "role": "assistant", "content": "<think>thinking process...</think>\n<answer>16</answer>" } ]
},
...
```
The path to expert data is passed to `buffer.trainer_input.sft_warmup_dataset` for later use.


## Step 1: Define the Algorithm
Expand Down Expand Up @@ -296,3 +302,9 @@ algorithm:
read_batch_size_expert: 64
read_batch_size_usual: 192
```

With the above configurations, the experiment can be run with the following command:

```bash
trinity run --config examples/mix_math/mix_math.yaml
```
10 changes: 5 additions & 5 deletions docs/sphinx_doc/source/tutorial/example_multi_turn.md
Original file line number Diff line number Diff line change
Expand Up @@ -15,8 +15,8 @@ To run the ALFworld and WebShop env, you need to setup the corresponding environ
- WebShop is a simulated online shopping environment where AI agents learn to shop based on user requirements. The platform allows agents to browse products, compare options, and make purchase decisions, mimicking real-world e-commerce interactions.

You may refer to their original environment to complete the setup.
- For ALFworld, refer to: https://github.com/alfworld/alfworld
- For WebShop, refer to: https://github.com/princeton-nlp/WebShop
- For ALFWorld, refer to the [ALFWorld](https://github.com/alfworld/alfworld) repository.
- For WebShop, refer to the [WebShop](https://github.com/princeton-nlp/WebShop) repository.

### Data Preparation
Our dataset follows the format in Huggingface datasets library, so we should correspondingly convert our env dataset.
Expand All @@ -36,7 +36,7 @@ The task is described as an environment instead of a single prompt.

## Step 2: Config preparation and run the experiment

You can refer to `example_reasoning_basic` to setup the config and others. The default config files are [`alfworld.yaml`](https://github.com/modelscope/Trinity-RFT/tree/main/examples/grpo_alfworld/alfworld.yaml) and [`webshop.yaml`](https://github.com/modelscope/Trinity-RFT/tree/main/examples/grpo_webshop/webshop.yaml), respectively.
You can refer to [Quick Start](./example_reasoning_basic.md) to setup the config and others. The default config files are [`alfworld.yaml`](https://github.com/modelscope/Trinity-RFT/tree/main/examples/grpo_alfworld/alfworld.yaml) and [`webshop.yaml`](https://github.com/modelscope/Trinity-RFT/tree/main/examples/grpo_webshop/webshop.yaml), respectively.
You may revise the configurations properly and run the experiment!

```bash
Expand Down Expand Up @@ -104,7 +104,7 @@ class AlfworldWorkflow(MultiTurnWorkflow):
...
```

and include them in the init files in `trinity/common/workflows/__init__.py`
and include it in the init file `trinity/common/workflows/__init__.py`

```diff
# -*- coding: utf-8 -*-
Expand All @@ -120,7 +120,7 @@ and include them in the init files in `trinity/common/workflows/__init__.py`
]
```

Then you are all set! It should be pretty simple😄, and both environments converge.
Then you are all set! It should be pretty simple😄, and the training processes in both environments converge.

![](../../assets/alfworld_reward_curve.png)
![](../../assets/webshop_reward_curve.png)
8 changes: 7 additions & 1 deletion docs/sphinx_doc/source/tutorial/example_reasoning_basic.md
Original file line number Diff line number Diff line change
Expand Up @@ -37,6 +37,12 @@ pip install flash-attn -v
# pip install flash-attn -v --no-build-isolation
```

Installation using pip:

```shell
pip install trinity-rft
```

Installation from docker:

We provided a dockerfile for Trinity-RFT.
Expand All @@ -60,7 +66,7 @@ docker run -it --gpus all --shm-size="64g" --rm -v $PWD:/workspace -v <root_path

**Model Preparation.**

Download the Qwen-2.5-1.5B-Instruct model to the local directory `$MODEL_PATH/Qwen2.5-1.5B-Instruct`:
Download the Qwen2.5-1.5B-Instruct model to the local directory `$MODEL_PATH/Qwen2.5-1.5B-Instruct`:

```bash
# Using Modelscope
Expand Down
37 changes: 19 additions & 18 deletions docs/sphinx_doc/source/tutorial/trinity_configs.md
Original file line number Diff line number Diff line change
Expand Up @@ -8,7 +8,7 @@ The configuration for **Trinity-RFT** is defined in a `YAML` file and organized

```yaml
project: Trinity-RFT
name: tutorial
name: example
mode: both
checkpoint_root_dir: /PATH/TO/CHECKPOINT

Expand Down Expand Up @@ -78,7 +78,7 @@ Specifies the algorithm type and its related hyperparameters.
```yaml
algorithm:
algorithm_type: grpo
repeat_times: 1
repeat_times: 8

# The following parameters are optional
# If not specified, they will automatically be set based on the `algorithm_type`
Expand All @@ -89,12 +89,11 @@ algorithm:
entropy_loss_fn: "default"
```

- `algorithm_type`: Type of reinforcement learning algorithm. Supported types: `ppo`, `grpo`, `opmd`, `dpo`.
- `repeat_times`: Number of times each task is repeated. Default is `1`. In `dpo`, this is automatically set to `2`.

- `algorithm_type`: Type of reinforcement learning algorithm. Supported types: `ppo`, `grpo`, `opmd`, `dpo`, `sft`, `mix`.
- `repeat_times`: Number of times each task is repeated. Default is `1`. In `dpo`, this is automatically set to `2`. Some algorithms such as GRPO and OPMD require `repeat_times` > 1.
- `sample_strategy`: The sampling strategy used for loading experiences from experience buffer.
- `advantage_fn`: The advantage function used for computing advantages.
- `kl_penalty_fn`: The KL penalty function used for computing KL penalty.
- `kl_penalty_fn`: The KL penalty function used for computing KL penalty applied in reward.
- `kl_loss_fn`: The KL loss function used for computing KL loss.
- `entropy_loss_fn`: The entropy loss function used for computing entropy loss.

Expand All @@ -111,8 +110,8 @@ monitor:
```

- `monitor_type`: Type of monitoring system. Options:
- `wandb`: Logs to Weights & Biases. Requires logging in and setting `WANDB_API_KEY`. Project and run names match the `project` and `name` fields in global configs.
- `tensorboard`: Logs to TensorBoard. Files are saved under `<checkpoint_root_dir>/<project>/<name>/monitor/tensorboard`.
- `wandb`: Logs to [Weights & Biases](https://docs.wandb.ai/quickstart/). Requires logging in and setting `WANDB_API_KEY`. Project and run names match the `project` and `name` fields in global configs.
- `tensorboard`: Logs to [TensorBoard](https://www.tensorflow.org/tensorboard). Files are saved under `<checkpoint_root_dir>/<project>/<name>/monitor/tensorboard`.

---

Expand All @@ -122,13 +121,13 @@ Defines the model paths and token limits.

```yaml
model:
model_path: '/PATH/TO/MODEL/CHECKPOINT/'
model_path: /PATH/TO/MODEL/
critic_model_path: ''
max_prompt_tokens: 4096
max_response_tokens: 16384
```

- `model_path`: Path to the model checkpoint being trained.
- `model_path`: Path to the model being trained.
- `critic_model_path`: Optional path to a separate critic model. If empty, defaults to `model_path`.
- `max_prompt_tokens`: Maximum number of tokens allowed in input prompts.
- `max_response_tokens`: Maximum number of tokens allowed in generated responses.
Expand Down Expand Up @@ -175,8 +174,8 @@ buffer:
default_reward_fn_type: 'countdown_reward'
```

- `batch_size`: Number of samples used per training step. *Please do not multiply this value by the `algorithm.repeat_times` manually*.
- `total_epochs`: Total number of training epochs. Not applicable for streaming datasets (e.g., queue-based buffers).
- `batch_size`: Number of tasks used per training step. *Please do not multiply this value by the `algorithm.repeat_times` manually*.
- `total_epochs`: Total number of training epochs.

### Explorer Input

Expand Down Expand Up @@ -227,6 +226,8 @@ The configuration for each task dataset is defined as follows:
- For `file` storage type, the path is the path to the directory that contains the task dataset files.
- For `queue` storage type, the path is optional. You can back up the data in the queue by specifying a sqlite database path here.
- For `sql` storage type, the path is the path to the sqlite database file.
- `subset_name`: The subset name of the task dataset. Default is `None`.
- `split`: The split of the task dataset. Default is `train`.
- `format`: Defines keys for prompts and responses in the dataset.
- `prompt_key`: Specifies which column in the dataset contains the prompt data.
- `response_key`: Specifies which column in the dataset contains the response data.
Expand Down Expand Up @@ -302,9 +303,9 @@ synchronizer:
```

- `sync_method`: Method of synchronization. Options:
- `nccl`: Uses NCCL for fast synchronization.
- `checkpoint`: Loads latest model from disk.
- `sync_interval`: Interval (in steps) between synchronizations.
- `nccl`: Uses NCCL for fast synchronization. Supported for `both` mode.
- `checkpoint`: Loads latest model from disk. Supported for `train`, `explore`, or `bench` mode.
- `sync_interval`: Interval (in steps) of model weight synchronization between trainer and explorer.
- `sync_timeout`: Timeout duration for synchronization.

---
Expand All @@ -324,7 +325,7 @@ trainer:
- `trainer_type`: Trainer backend implementation. Currently only supports `verl`.
- `save_interval`: Frequency (in steps) at which to save model checkpoints.
- `trainer_config_path`: The path to the trainer configuration file.
- `train_config`: The configuration of the trainer. Only one needs to be set for `trainer.trainer_config` and `trainer.trainer_config_path`
- `trainer_config`: The trainer configuration provided inline. Only one of `trainer_config_path` and `trainer_config` should be specified.

---

Expand All @@ -334,7 +335,7 @@ Configures preprocessing and data cleaning pipelines.

```yaml
data_processor:
source_data_path: '/PATH/TO/DATASET'
source_data_path: /PATH/TO/DATASET
load_kwargs:
split: 'train'
format:
Expand All @@ -345,7 +346,7 @@ data_processor:
db_url: 'postgresql://{username}@localhost:5432/{db_name}'
```

- `source_data_path`: Path to the raw dataset.
- `source_data_path`: Path to the task dataset.
- `load_kwargs`: Arguments passed to HuggingFace’s `load_dataset()`.
- `dj_config_path`: Path to Data-Juicer configuration for cleaning.
- `clean_strategy`: Strategy for iterative data cleaning.
Expand Down