Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
27 commits
Select commit Hold shift + click to select a range
b52809f
* prepare the initial config files for exp pipeline
HYLcool Jun 24, 2025
1430035
+ add basic reward shaping func
HYLcool Jun 24, 2025
061407c
Merge branch 'main' into feat/exp_pipeline
HYLcool Jun 25, 2025
d8e9331
Merge branch 'main' into feat/exp_pipeline
HYLcool Jun 25, 2025
78da769
- remove common.schema
HYLcool Jun 26, 2025
04f64aa
* allow async exp pipeline
HYLcool Jun 26, 2025
fe0407f
Merge branch 'main' into feat/exp_pipeline
HYLcool Jun 26, 2025
56dd112
+ add more logs
HYLcool Jun 26, 2025
510b2af
+ add buffer check and sync for experience pipeline
HYLcool Jun 26, 2025
f1f6ba0
* set several default values for format config
HYLcool Jun 26, 2025
f78b6e7
* convert experience to dict before converting to dataset
HYLcool Jun 26, 2025
979ab5a
* fix conversion bugs in dataset
HYLcool Jun 26, 2025
e359179
* fix bugs
HYLcool Jun 26, 2025
d9d4773
* update configs of exp_pipeline
HYLcool Jun 27, 2025
d9501cf
+ init ray in the same namespace for data processor
HYLcool Jun 27, 2025
d16f0a8
* update example docs for experience pipeline
HYLcool Jun 27, 2025
d5e46f3
* after pre-commit
HYLcool Jun 27, 2025
a1cdc7f
Merge branch 'main' into feat/exp_pipeline
HYLcool Jun 27, 2025
a1b3b01
Merge branch 'main' into feat/exp_pipeline
HYLcool Jun 30, 2025
55800f6
* fix dataset buffer logics and tests
HYLcool Jun 30, 2025
c10dd93
* update ray init method
HYLcool Jun 30, 2025
17c91aa
* ignore dj configs when checking example validation
HYLcool Jun 30, 2025
062722f
* move data processor related funcs to data/utils.py
HYLcool Jun 30, 2025
974f3ab
* after pre-commit
HYLcool Jun 30, 2025
60abb01
+ add missing docs
HYLcool Jun 30, 2025
266ba19
+ fix typo and add infos about how to set api keys.
HYLcool Jun 30, 2025
471a93d
* after pre-commit
HYLcool Jun 30, 2025
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
100 changes: 96 additions & 4 deletions docs/sphinx_doc/source/tutorial/example_data_functionalities.md
Original file line number Diff line number Diff line change
Expand Up @@ -8,7 +8,8 @@ In this example, you will learn how to apply the data processor of Trinity-RFT t
2. how to configure the data processor
3. what the data processor can do

Before getting started, you need to prepare the main environment of Trinity-RFT according to the [installation section of the README file](../main.md).
Before getting started, you need to prepare the main environment of Trinity-RFT according to the [installation section of the README file](../main.md),
and store the base url and api key in the environment variables `OPENAI_BASE_URL` and `OPENAI_API_KEY` for some agentic or API-model usages if necessary.

### Data Preparation

Expand Down Expand Up @@ -103,8 +104,6 @@ If you are familiar with Data-Juicer, you will realize that Data-Juicer provides
# This is a Data-Juicer data processing recipe
project_name: 'gsm-8k-difficulty'

export_path: '/path/to/the/result/processed-dataset.jsonl'

process:
- llm_difficulty_score_filter:
api_or_hf_model: "qwen2.5-72b-instruct" # use "qwen2.5-72b-instruct" to calculate the difficulty scores.
Expand Down Expand Up @@ -143,7 +142,7 @@ And you can set the `clean_strategy` to 'iterative' to get a better dataset.



All config items in the `data` section can be found [here](trinity_configs.md). A prepared config file for this example of GSM-8K can be found in [the config file of gsm8k](https://github.com/modelscope/Trinity-RFT/tree/main/examples/grpo_gsm8k/gsm8k.yaml).
All config items in the `data` section can be found [here](trinity_configs.md). A prepared config file for this example of GSM-8K can be found in [the config file of gsm8k](https://github.com/modelscope/Trinity-RFT/tree/main/examples/grpo_gsm8k_task_pipeline/gsm8k.yaml).



Expand All @@ -167,6 +166,99 @@ trinity run --config <Trinity-RFT_config_path>

If you follow the steps above, Trinity-RFT will send a request to the data processor server, the data active iterator will be activated, compute difficulty scores for each sample in the raw dataset, and rank the dataset according to difficulty scores. After that, the data processor server stores the result dataset into the output buffer, when exploring begins, it will load the prepared dataset and continue the downstream steps.

## Example: Data Processor for Experience Pipeline

In this example, you will learn how to apply the data processor of Trinity-RFT to reshape rewards of experiences after exploring. This example takes GSM-8K dataset as the example dataset to figure out how to reshape rewards of experiences from the explorer before sent to the trainer from a view of the quality of generated responses.

Before getting started, you need to prepare the main environment of Trinity-RFT and start server for the data processor according to the first subsection in the previous example.

### Configure the Data Processor

In this example, assume that you need to add an extra reward item to the experiences outputted by the explorer, which access the quality scores of the experiences. So you can set the `experience_pipeline` config like the following example:

```yaml
data_processor:
data_processor_url: 'http://127.0.0.1:5005/data_processor'
# experience pipeline related
experience_pipeline:
# I/O buffers
input_buffers:
- name: gsm8k_exp_output
output_buffer:
name: reshaped_gsm8k_exp_input
# format mapping
format:
reward_key: 'reward' # the key name of the reward in the experience
# data active iterator related
dj_config_path: 'examples/grpo_gsm8k_experience_pipeline/dj_scoring_exp.yaml'
clean_strategy: 'iterative'
# reward shaping
reward_shaping:
- stats_key: 'llm_quality_score'
op_type: ADD
weight: 1.0

# the buffer config
buffer:
...
explorer_output:
name: gsm8k_exp_output
storage_type: queue
path: 'sqlite:///gsm8k_exp_output.db'
trainer_input:
experience_buffer:
name: reshaped_gsm8k_exp_input
storage_type: queue
path: 'sqlite:///reshaped_gsm8k_exp_input.db'
```

Here you can set the input/output buffers for the experience pipeline, and some other items about reward shaping:

+ `data_processor_url`: the URL of the data processor service, which is started in the previous step.
+ `experience_pipeline`: the configs for the experience pipeline. Experience pipeline is used to process the experiences outputted by the explorer, such as reward shaping, data filtering and augmentation. It consists of several inner configs:
+ `input_buffers`: the input buffers for the experience pipeline. It usually loads from the explorer output buffer, so we need to specify the `explorer_output` in the `buffer` config, and here we only need to specify the name that is aligned with the `explorer_output`. It allows multiple input buffers, but for now, we only need to specify one.
+ `output_buffer`: the output buffer for the experience pipeline. It usually writes results to the input buffer of trainer, so we only need to the specify the buffer name that is aligned with the `trainer_input` in the `buffer` config.
+ `format`: some dataset format config items, which are used to map original data field names to unified ones. Here we only need to specify the field name to store the original reward information.
+ `reward_shaping`: the method to reshape the reward. Usually we use some stats computed by operators in Data-Juicer as new reward items. It's a list that allows multiple methods to reshape rewards. Each item in the list has the following config items:
+ `stats_key`: which stats to use as the new reward item.
+ `op_type`: the operator to apply the new reward item to the original reward. For now, ["ADD", "SUB", "MUL", "DIV"] are supported.
+ `weight`: the weight of the new reward item.

In addition, there are several config items related to the data active iterator in `experience_pipeline` part, which is used to compute stats used to reshape rewards. This part is similar to the `task_pipeline` part in the previous example. The Data-Juicer config used here is:
```yaml
# This is a Data-Juicer data processing recipe
project_name: 'gsm-8k-experience-quality'

np: 32

process:
- llm_quality_score_filter:
api_or_hf_model: "qwen2.5-32b-instruct" # use "qwen2.5-32b-instruct" to calculate the quality scores.
min_score: 0.0
input_keys: ["prompt_text", "prompt_text"] # set input_keys and field_names to the existing key names in gsm-8k. Here calculating the difficulty scores according to both questions and answers.
field_names: ["prompt", "response"]
```

All config items in the `data` section can be found [here](trinity_configs.md). A prepared config file for this example of GSM-8K can be found in [the config file of gsm8k](https://github.com/modelscope/Trinity-RFT/tree/main/examples/grpo_gsm8k_experience_pipeline/gsm8k.yaml).

### Exploring & Training
After preparing the config files of Trinity-RFT, you can start your ray cluster and run the RFT process including the data active iterator part with the following commands:

```shell
# start the ray cluster
# on master node
ray start --head
# on worker nodes
ray start --address=<master_address>

# run RFT
trinity run --config <Trinity-RFT_config_path>
```

If you follow the steps above, Trinity-RFT will send a request to the data processor server and prepare the experience pipeline.
It will watch the explorer output buffer. Once there is a new batch of experience, the data processor will compute stats for the experience and reshape the rewards. Then it writes the reshaped experience to the trainer input buffer for training.


## Example: Human in the Loop
Sometimes, you might need to involve human feedbacks for some raw data. In this example, you will learn how to annotate raw data to get a better dataset before training. This example takes an example Q&A dataset and tries to select the chosen and rejected ones for DPO method.

Expand Down
7 changes: 7 additions & 0 deletions examples/grpo_gsm8k_experience_pipeline/README.md
Original file line number Diff line number Diff line change
@@ -0,0 +1,7 @@
# GRPO on GSM8K dataset with Experience Pipeline

This example shows the usage of GRPO on the GSM8K dataset, with a experience pipeline to reshape the rewards of experiences while training.

For more detailed information, please refer to the [documentation](../../docs/sphinx_doc/source/tutorial/example_data_functionalities.md).

The config files are located in [`gsm8k.yaml`](gsm8k.yaml) and [`train_gsm8k.yaml`](train_gsm8k.yaml).
11 changes: 11 additions & 0 deletions examples/grpo_gsm8k_experience_pipeline/dj_scoring_exp.yaml
Original file line number Diff line number Diff line change
@@ -0,0 +1,11 @@
# This is a Data-Juicer data processing recipe
project_name: 'gsm-8k-experience-quality'

np: 32

process:
- llm_quality_score_filter:
api_or_hf_model: "qwen2.5-32b-instruct" # use "qwen2.5-32b-instruct" to calculate the quality scores.
min_score: 0.0
input_keys: ["prompt_text", "prompt_text"] # set input_keys and field_names to the existing key names in gsm-8k. Here calculating the difficulty scores according to both questions and answers.
field_names: ["prompt", "response"]
89 changes: 89 additions & 0 deletions examples/grpo_gsm8k_experience_pipeline/gsm8k.yaml
Original file line number Diff line number Diff line change
@@ -0,0 +1,89 @@
project: "Trinity-RFT-gsm8k-experience-pipeline"
name: "qwen2.5-1.5B-gsm8k-experience-pipeline"
checkpoint_root_dir: /PATH/TO/CHECKPOINT/
algorithm:
algorithm_type: grpo
repeat_times: 8
data_processor:
data_processor_url: 'http://127.0.0.1:5005/data_processor'
# experience pipeline related
experience_pipeline:
# I/O buffers
input_buffers:
- name: gsm8k_exp_output
output_buffer:
name: reshaped_gsm8k_exp_input
# format mapping
format:
reward_key: 'reward' # the key name of the reward in the experience
# data active iterator related
dj_config_path: 'examples/grpo_gsm8k_experience_pipeline/dj_scoring_exp.yaml'
clean_strategy: 'iterative'
# reward shaping
reward_shaping:
- stats_key: 'llm_quality_score'
op_type: ADD
weight: 1.0

model:
model_path: /PATH/TO/MODEL/
max_prompt_tokens: 256
max_response_tokens: 1024
cluster:
node_num: 1
gpu_per_node: 8
buffer:
total_epochs: 1
batch_size: 96
max_retry_times: 3
max_retry_interval: 1
explorer_input:
taskset:
name: gsm8k
storage_type: file
path: 'openai/gsm8k'
subset_name: 'main'
split: 'train'
format:
prompt_key: 'question'
response_key: 'answer'
rollout_args:
temperature: 1.0
eval_tasksets:
- name: gsm8k-eval
storage_type: file
path: 'openai/gsm8k'
subset_name: 'main'
split: 'test'
format:
prompt_key: 'question'
response_key: 'answer'
default_workflow_type: 'math_workflow'
explorer_output:
name: gsm8k_exp_output
storage_type: queue
path: 'sqlite:///gsm8k_exp_output.db'
trainer_input:
experience_buffer:
name: reshaped_gsm8k_exp_input
storage_type: queue
path: 'sqlite:///reshaped_gsm8k_exp_input.db'
explorer:
eval_interval: 50
runner_num: 32
rollout_model:
engine_type: vllm_async
engine_num: 2
tensor_parallel_size: 1
enable_prefix_caching: false
enforce_eager: true
dtype: bfloat16
seed: 42
synchronizer:
sync_method: 'nccl'
sync_interval: 1
sync_timeout: 1200
trainer:
trainer_type: 'verl'
trainer_config_path: 'examples/grpo_gsm8k_experience_pipeline/train_gsm8k.yaml'
save_interval: 100
50 changes: 50 additions & 0 deletions examples/grpo_gsm8k_experience_pipeline/train_gsm8k.yaml
Original file line number Diff line number Diff line change
@@ -0,0 +1,50 @@
actor_rollout_ref:
hybrid_engine: True
model:
external_lib: null
override_config: { }
enable_gradient_checkpointing: True
use_remove_padding: True # False
actor:
strategy: fsdp # This is for backward-compatibility
ppo_mini_batch_size: 128
ppo_micro_batch_size_per_gpu: 4
use_dynamic_bsz: True # False
ppo_max_token_len_per_gpu: 16384 # n * ${data.max_prompt_length} + ${data.max_response_length}
grad_clip: 1.0
ppo_epochs: 1
shuffle: False
ulysses_sequence_parallel_size: 1 # sp size
optim:
lr: 1e-5
lr_warmup_steps_ratio: 0. # the total steps will be injected during runtime
# min_lr_ratio: null # only useful for warmup with cosine
warmup_style: constant # select from constant/cosine
total_training_steps: -1 # must be override by program
fsdp_config:
wrap_policy:
# transformer_layer_cls_to_wrap: None
min_num_params: 0
param_offload: False
optimizer_offload: False
fsdp_size: -1
ref:
fsdp_config:
param_offload: False
wrap_policy:
# transformer_layer_cls_to_wrap: None
min_num_params: 0
log_prob_micro_batch_size_per_gpu: 16
log_prob_use_dynamic_bsz: ${actor_rollout_ref.actor.use_dynamic_bsz}
log_prob_max_token_len_per_gpu: ${actor_rollout_ref.actor.ppo_max_token_len_per_gpu}
ulysses_sequence_parallel_size: ${actor_rollout_ref.actor.ulysses_sequence_parallel_size} # sp size

trainer:
balance_batch: True
# total_training_steps: null
# auto: find the last ckpt to resume. If can't find, start from scratch
resume_mode: auto # or auto or resume_path if
default_hdfs_dir: null
remove_previous_ckpt_in_save: False
del_local_ckpt_after_load: False
val_before_train: False
4 changes: 3 additions & 1 deletion tests/common/config_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -41,7 +41,9 @@ def test_all_examples_are_valid(self):
for example_name in os.listdir(example_dir):
for filename in os.listdir(os.path.join(example_dir, example_name)):
if filename.endswith(".yaml") and not (
filename.startswith("train_") or filename.startswith("verl_")
filename.startswith("train_")
or filename.startswith("verl_")
or filename.startswith("dj_")
):
print(f"Checking config: {filename}")
config_path = os.path.join(example_dir, example_name, filename)
Expand Down
2 changes: 1 addition & 1 deletion tests/common/experience_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,8 +5,8 @@

import torch

from trinity.buffer.schema.sql_schema import ExperienceModel
from trinity.common.experience import Experience, Experiences
from trinity.common.schema import ExperienceModel

db_url = os.path.join(os.path.dirname(__file__), "tmp", "test.db")
dataset_path = os.path.join(os.path.dirname(__file__), "data")
Expand Down
2 changes: 1 addition & 1 deletion tests/data/core/formatter_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -47,7 +47,7 @@ def test_init(self):
self.assertEqual(formatter.config.solution_key, "solution")
self.assertEqual(formatter.config.chat_template, "User: {}\nAssistant: ")
# test for default configs
self.assertEqual(formatter.config.reward_key, "")
self.assertEqual(formatter.config.reward_key, "reward")
self.assertEqual(formatter.config.chosen_key, "chosen")
self.assertEqual(formatter.config.rejected_key, "rejected")
self.assertEqual(formatter.config.label_key, "")
Expand Down
Loading