Skip to content

Commit 8384b9c

Browse files
committed
Merge branch 'algorithm_dev' into refactor/data
# Conflicts: # docs/sphinx_doc/source/tutorial/example_data_functionalities.md # trinity/buffer/buffer.py # trinity/buffer/reader/file_reader.py # trinity/cli/launcher.py
2 parents 2baed2e + 99a772a commit 8384b9c

File tree

95 files changed

+4645
-4243
lines changed

Some content is hidden

Large Commits have some content hidden by default. Use the searchbox below for content that may be hidden.

95 files changed

+4645
-4243
lines changed

README.md

Lines changed: 7 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -148,8 +148,11 @@ pip install -e .\[dev\]
148148

149149
# Install flash-attn after all dependencies are installed
150150
# Note: flash-attn will take a long time to compile, please be patient.
151-
pip install flash-attn -v
152-
# Try the following command if you encounter errors during installation
151+
# for bash
152+
pip install -e .[flash_attn]
153+
# for zsh
154+
pip install -e .\[flash_attn\]
155+
# Try the following command if you encounter errors during flash-attn installation
153156
# pip install flash-attn -v --no-build-isolation
154157
```
155158

@@ -263,7 +266,7 @@ Then, for command-line users, run the RFT process with the following command:
263266
trinity run --config <config_path>
264267
```
265268

266-
> For example, below is the command for fine-tuning Qwen-2.5-1.5B-Instruct on GSM8k dataset using GRPO algorithm:
269+
> For example, below is the command for fine-tuning Qwen2.5-1.5B-Instruct on GSM8k dataset using GRPO algorithm:
267270
> ```shell
268271
> trinity run --config examples/grpo_gsm8k/gsm8k.yaml
269272
> ```
@@ -276,7 +279,7 @@ For more detailed examples about how to use Trinity-RFT, please refer to the fol
276279
+ [Off-policy mode of RFT](./docs/sphinx_doc/source/tutorial/example_reasoning_advanced.md)
277280
+ [Asynchronous mode of RFT](./docs/sphinx_doc/source/tutorial/example_async_mode.md)
278281
+ [Multi-turn tasks](./docs/sphinx_doc/source/tutorial/example_multi_turn.md)
279-
+ [Offline learning by DPO](./docs/sphinx_doc/source/tutorial/example_dpo.md)
282+
+ [Offline learning by DPO or SFT](./docs/sphinx_doc/source/tutorial/example_dpo.md)
280283
+ [Advanced data processing / human-in-the-loop](./docs/sphinx_doc/source/tutorial/example_data_functionalities.md)
281284
282285

docs/sphinx_doc/source/conf.py

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -22,12 +22,13 @@
2222
"sphinx.ext.napoleon",
2323
"sphinx.ext.autosectionlabel",
2424
"myst_parser",
25+
"sphinx.ext.mathjax",
2526
]
2627
source_suffix = {
2728
".rst": "restructuredtext",
2829
".md": "markdown",
2930
}
30-
myst_enable_extensions = ["colon_fence"]
31+
myst_enable_extensions = ["colon_fence", "amsmath", "dollarmath"]
3132

3233
# Prefix document path to section labels, otherwise autogenerated labels would
3334
# look like 'heading' rather than 'path/to/file:heading'

docs/sphinx_doc/source/index.rst

Lines changed: 11 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -14,16 +14,24 @@ Welcome to Trinity-RFT's documentation!
1414
:maxdepth: 1
1515
:glob:
1616
:hidden:
17-
:caption: Tutorial
17+
:caption: Examples
1818

1919
tutorial/example_reasoning_basic.md
2020
tutorial/example_reasoning_advanced.md
2121
tutorial/example_async_mode.md
2222
tutorial/example_multi_turn.md
2323
tutorial/example_dpo.md
2424
tutorial/example_data_functionalities.md
25-
tutorial/trinity_configs.md
25+
26+
.. toctree::
27+
:maxdepth: 2
28+
:glob:
29+
:hidden:
30+
:caption: Guidelines
31+
2632
tutorial/trinity_programming_guide.md
33+
tutorial/trinity_configs.md
34+
tutorial/example_mix_algo.md
2735

2836
.. toctree::
2937
:maxdepth: 1
@@ -33,6 +41,7 @@ Welcome to Trinity-RFT's documentation!
3341
build_api/trinity.buffer
3442
build_api/trinity.explorer
3543
build_api/trinity.trainer
44+
build_api/trinity.algorithm
3645
build_api/trinity.manager
3746
build_api/trinity.common
3847
build_api/trinity.utils

docs/sphinx_doc/source/main.md

Lines changed: 10 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -84,15 +84,18 @@ e.g., utilizing NCCL (when feasible) for model weight synchronization, sequence
8484

8585
## Getting started
8686

87-
88-
*Note: this project is currently under active development; comments and suggestions are welcome!*
89-
87+
```{note}
88+
Note: This project is currently under active development; comments and suggestions are welcome!
89+
```
9090

9191

9292

9393
### Step 1: preparations
9494

95-
95+
Trinity-RFT requires
96+
Python version >= 3.10,
97+
CUDA version >= 12.4,
98+
and at least 2 GPUs.
9699

97100

98101
Installation from source (recommended):
@@ -146,11 +149,6 @@ docker build -f scripts/docker/Dockerfile -t trinity-rft:latest .
146149
docker run -it --gpus all --shm-size="64g" --rm -v $PWD:/workspace -v <root_path_of_data_and_checkpoints>:/data trinity-rft:latest
147150
```
148151

149-
Trinity-RFT requires
150-
Python version >= 3.10,
151-
CUDA version >= 12.4,
152-
and at least 2 GPUs.
153-
154152

155153
### Step 2: prepare dataset and model
156154

@@ -243,15 +241,15 @@ trinity run --config <config_path>
243241

244242

245243

246-
For example, below is the command for fine-tuning Qwen-2.5-1.5B-Instruct on GSM8k dataset using GRPO algorithm:
244+
For example, below is the command for fine-tuning Qwen2.5-1.5B-Instruct on GSM8k dataset using GRPO algorithm:
247245

248246
```shell
249247
trinity run --config examples/grpo_gsm8k/gsm8k.yaml
250248
```
251249

252250

253251

254-
More example config files can be found in `examples`.
252+
More example config files can be found in [`examples`](https://github.com/modelscope/Trinity-RFT/tree/main/examples/).
255253

256254

257255

@@ -260,7 +258,7 @@ For more detailed examples about how to use Trinity-RFT, please refer to the fol
260258
+ [Off-policy mode of RFT](tutorial/example_reasoning_advanced.md)
261259
+ [Asynchronous mode of RFT](tutorial/example_async_mode.md)
262260
+ [Multi-turn tasks](tutorial/example_multi_turn.md)
263-
+ [Offline learning by DPO](tutorial/example_dpo.md)
261+
+ [Offline learning by DPO or SFT](tutorial/example_dpo.md)
264262
+ [Advanced data processing / human-in-the-loop](tutorial/example_data_functionalities.md)
265263

266264

docs/sphinx_doc/source/tutorial/example_async_mode.md

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,6 @@
11
# Asynchronous RFT
22

3-
This example shows how to run RFT in a fully asynchronous mode with the GRPO algorithm, Qwen-2.5-1.5B-Instruct model and GSM8K dataset.
3+
This example shows how to run RFT in a fully asynchronous mode with the GRPO algorithm, Qwen2.5-1.5B-Instruct model and GSM8K dataset.
44

55
Trinity-RFT supports an asynchronous mode by running the trainer and explorer in separate processes.
66

docs/sphinx_doc/source/tutorial/example_data_functionalities.md

Lines changed: 7 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -38,12 +38,12 @@ data_processor:
3838
# I/O buffers
3939
input_buffers:
4040
- name: 'raw_input'
41-
path: 'openai/gsm8k'
41+
path: /PATH/TO/GSM8K/
4242
storage_type: 'file'
4343
raw: true
4444
output_buffer:
4545
name: 'raw_output'
46-
path: './outputs/task_pipeline_output/prioritized_gsm8k.jsonl'
46+
path: /PATH/TO/OUTPUT/JSONL/FILE
4747
storage_type: 'file'
4848
# format mapping
4949
format:
@@ -72,12 +72,12 @@ data_processor:
7272
# I/O buffers
7373
input_buffers:
7474
- name: 'raw_input'
75-
path: 'openai/gsm8k'
75+
path: /PATH/TO/GSM8K/
7676
storage_type: 'file'
7777
raw: true
7878
output_buffer:
7979
name: 'raw_output'
80-
path: './outputs/task_pipeline_output/prioritized_gsm8k.jsonl'
80+
path: /PATH/TO/OUTPUT/JSONL/FILE
8181
storage_type: 'file'
8282
# format mapping
8383
format:
@@ -122,12 +122,12 @@ data_processor:
122122
# I/O buffers
123123
input_buffers:
124124
- name: 'raw_input'
125-
path: 'openai/gsm8k'
125+
path: /PATH/TO/GSM8K/
126126
storage_type: 'file'
127127
raw: true
128128
output_buffer:
129129
name: 'raw_output'
130-
path: './outputs/task_pipeline_output/prioritized_gsm8k.jsonl'
130+
path: /PATH/TO/OUTPUT/JSONL/FILE
131131
storage_type: 'file'
132132
# format mapping
133133
format:
@@ -217,7 +217,7 @@ data_processor:
217217
218218
Here you can set the basic information for the example dataset, database information that is used to store the result dataset, and some other items about downstream dataset loading for exploring and training, which is similar to the example above.
219219
220-
For this example, we assume that you are somehow familiar with the basic usage of Data-Juicer, so we need to prepare a Data-Juicer data processing recipe in `tests/test_configs/human_annotator_test_dj_cfg.yaml` that includes an OP of `human_preference_annotation_mapper`. For example:
220+
For this example, we assume that you are somehow familiar with the basic usage of Data-Juicer, so we need to prepare a Data-Juicer data processing recipe in [`tests/test_configs/human_annotator_test_dj_cfg.yaml`](https://github.com/modelscope/Trinity-RFT/blob/main/tests/test_configs/human_annotator_test_dj_cfg.yaml) that includes an OP of `human_preference_annotation_mapper`. For example:
221221

222222
```yaml
223223
project_name: 'demo-human-annotator'

docs/sphinx_doc/source/tutorial/example_dpo.md

Lines changed: 54 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -1,12 +1,12 @@
1-
# Offline DPO
1+
# Offline DPO and SFT
22

3-
This example describes DPO based on the Qwen-2.5-1.5B-Instruct model and [Human-like-DPO-dataset](https://huggingface.co/datasets/HumanLLMs/Human-Like-DPO-Dataset).
3+
This example describes DPO and SFT based on the Qwen2.5-1.5B-Instruct model.
44

55
## Step 1: Model and Data Preparation
66

77
### Model Preparation
88

9-
Download the Qwen-2.5-1.5B-Instruct model to the local directory `$MODEL_PATH/Qwen2.5-1.5B-Instruct`:
9+
Download the Qwen2.5-1.5B-Instruct model to the local directory `$MODEL_PATH/Qwen2.5-1.5B-Instruct`:
1010

1111
```shell
1212
# Using Modelscope
@@ -20,7 +20,7 @@ More details of model downloading are referred to [ModelScope](https://modelscop
2020

2121
### Data Preparation
2222

23-
Download the Human-Like-DPO-Dataset dataset to the local directory `$DATASET_PATH/human_like_dpo_dataset`:
23+
For DPO, we download the [Human-like-DPO-dataset](https://huggingface.co/datasets/HumanLLMs/Human-Like-DPO-Dataset) to the local directory `$DATASET_PATH/human_like_dpo_dataset`:
2424

2525
```shell
2626
# Using Modelscope
@@ -34,9 +34,11 @@ More details of dataset downloading are referred to [ModelScope](https://modelsc
3434

3535
Note that the dataset has the keys `prompt`, `chosen` and `rejected`. If not, pass the proper keys to the config.
3636

37-
## Step 2: Setup Configuration and Run Experiment
37+
For SFT, we download the dataset to the local directory `/PATH/TO/SFT_DATASET/`, which usually contains message-based data.
3838

39-
### Configuration
39+
## Step 2: Setup Configuration
40+
41+
### Configuration for DPO
4042

4143
We use the configurations in [`dpo.yaml`](https://github.com/modelscope/Trinity-RFT/tree/main/examples/dpo_humanlike/dpo.yaml) and [`train_dpo.yaml`](https://github.com/modelscope/Trinity-RFT/tree/main/examples/dpo_humanlike/train_dpo.yaml) for this experiment. Some important setups are listed in the following:
4244

@@ -48,9 +50,12 @@ name: <experiment_name>
4850
mode: train
4951
algorithm:
5052
algorithm_type: dpo
53+
kl_loss_fn: k1
54+
kl_loss_fn_args:
55+
kl_coef: 0.1 # value of beta in DPO
5156
checkpoint_root_dir: /PATH/TO/CHECKPOINT/
5257
model:
53-
model_path: /PATH/TO/MODEL/
58+
model_path: $MODEL_PATH/Qwen2.5-1.5B-Instruct
5459
cluster:
5560
node_num: 1
5661
gpu_per_node: 8
@@ -59,9 +64,9 @@ buffer:
5964
batch_size: 64
6065
trainer_input:
6166
experience_buffer:
62-
name: dpo_buffer
67+
name: human_like_dpo
6368
storage_type: file
64-
path: /PATH/TO/DATASET/
69+
path: $DATASET_PATH/human_like_dpo_dataset
6570
format:
6671
prompt_type: plaintext # plaintext/messages/chatpair
6772
prompt_key: prompt
@@ -70,14 +75,50 @@ buffer:
7075
trainer:
7176
trainer_config_path: 'examples/dpo_humanlike/train_dpo.yaml'
7277
save_interval: 30
73-
actor_use_kl_loss: True
74-
actor_kl_loss_coef: 0.1 # value of beta in DPO
7578
```
7679
77-
### Run the Experiment
80+
### Configuration for SFT
81+
82+
We set the `algorithm_type` as `sft` to run SFT process. Then we modify the config file `sft.yaml` with the following changes:
83+
84+
```yaml
85+
project: <project_name>
86+
name: <experiment_name>
87+
mode: train
88+
algorithm:
89+
algorithm_type: sft
90+
checkpoint_root_dir: /PATH/TO/CHECKPOINT/
91+
model:
92+
model_path: /PATH/TO/MODEL/
93+
cluster:
94+
node_num: 1
95+
gpu_per_node: 2
96+
buffer:
97+
total_epochs: 5
98+
batch_size: 64
99+
trainer_input:
100+
experience_buffer:
101+
name: <sft_dataset_name>
102+
storage_type: file
103+
path: /PATH/TO/SFT_DATASET/
104+
split: train
105+
format:
106+
prompt_type: messages
107+
messages_key: messages
108+
trainer:
109+
trainer_config_path: /PATH/TO/TRAIN_CONFIG_YAML/
110+
save_interval: 50
111+
```
112+
113+
## Step 3: Run the Experiment
78114

79-
Run RFT process with the following command:
115+
Run DPO process with the following command:
80116

81117
```shell
82118
trinity run --config examples/dpo_humanlike/dpo.yaml
83119
```
120+
or, for SFT:
121+
122+
```shell
123+
trinity run --config /PATH/TO/sft.yaml
124+
```

0 commit comments

Comments
 (0)