Skip to content

Commit 22879e4

Browse files
authored
Merge branch 'modelscope:main' into main
2 parents a6d0b6f + 7964099 commit 22879e4

File tree

12 files changed

+264
-6
lines changed

12 files changed

+264
-6
lines changed
184 KB
Loading
184 KB
Loading
138 KB
Loading
485 KB
Loading

examples/grpo_vlm/vlm.yaml

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -21,7 +21,7 @@ buffer:
2121
taskset:
2222
name: geometry3k
2323
storage_type: file
24-
path: hiyouga/geometry3k
24+
path: ${oc.env:TRINITY_TASKSET_PATH,hiyouga/geometry3k}
2525
subset_name: 'default'
2626
split: 'train'
2727
format:

examples/mix_chord/mix_chord.yaml

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -62,7 +62,7 @@ buffer:
6262
name: SFT_data
6363
storage_type: file
6464
schema_type: sft
65-
path: ${oc.env:TRINITY_SFT_DATASET_PATH,open-r1/Mixture-of-Thoughts}
65+
path: ${oc.env:TRINITY_SFT_DATASET_PATH}
6666
split: 'train'
6767
format:
6868
prompt_type: messages

examples/mix_math/mix_math.yaml

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -61,7 +61,7 @@ buffer:
6161
name: math_sft
6262
storage_type: file
6363
schema_type: sft
64-
path: ${oc.env:TRINITY_SFT_DATASET_PATH,open-r1/Mixture-of-Thoughts}
64+
path: ${oc.env:TRINITY_SFT_DATASET_PATH}
6565
split: 'train'
6666
format:
6767
prompt_type: messages

examples/mix_vlm/README.md

Lines changed: 36 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,36 @@
1+
# MIX algorithm with VLM
2+
3+
This is an example of using the [MIX](../../docs/sphinx_doc/source/tutorial/example_mix_algo.md) algorithm with Qwen2.5-VL-3B-Instruct model.
4+
5+
> [!NOTE]
6+
> This feature is experimental and will be subject to change in future releases.
7+
8+
The specific requirements are:
9+
10+
```yaml
11+
vllm>=0.9.1,<0.10.0
12+
transformers<4.53.0
13+
qwen_vl_utils
14+
```
15+
16+
## Prepare the SFT Dataset
17+
We use the [geometry3k](https://huggingface.co/datasets/hiyouga/geometry3k) dataset for training; we generate the [SFT dataset](https://huggingface.co/datasets/datajuicer/geometry_sft) by prompting Qwen2.5-VL-32B-Instruct model on the validation set. Note that this dataset only showcases the format of SFT data in this example, as shown below:
18+
```json
19+
{
20+
"problem": "<image>Find $x$ so that $m || n$.",
21+
"response": "To determine the value of $ x $ ... Answer:\n\\[\n\\boxed{63}\n\\]",
22+
"images": [<image>]
23+
}
24+
```
25+
26+
The config file is located in [`mix_vlm.yaml`](mix_vlm.yaml). To get better performance, feel free to try out different algorithm hyperparameters!
27+
28+
## Run the Example
29+
30+
Run the following command to start the training:
31+
```bash
32+
trinity run --config examples/mix_vlm/mix_vlm.yaml
33+
```
34+
35+
The reward curve is shown below:
36+
![](../../docs/sphinx_doc/assets/mix_vlm_reward.png)

examples/mix_vlm/mix_vlm.yaml

Lines changed: 94 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,94 @@
1+
project: "Trinity-RFT"
2+
name: "mix_vlm"
3+
checkpoint_root_dir: ${oc.env:TRINITY_CHECKPOINT_ROOT_DIR,./checkpoints}
4+
algorithm:
5+
algorithm_type: mix_chord
6+
repeat_times: 8
7+
optimizer:
8+
lr: 1e-6
9+
kl_loss_fn_args:
10+
kl_coef: 0.0
11+
entropy_loss_fn: mix
12+
sample_strategy_args:
13+
expert_data_ratio: 0.20
14+
policy_loss_fn_args:
15+
mu_warmup_steps: 200
16+
mu_decay_steps: 400
17+
mu_peak: 0.1
18+
mu_valley: 0.1
19+
enable_phi_function: false
20+
clip_range: 0.2
21+
sft_loss_agg_mode: "token-mean"
22+
use_dynamic_bsz: true
23+
ppo_mini_batch_size: 320 # 320 = 256 + 64
24+
ppo_micro_batch_size_per_gpu: 4
25+
ngpus_trainer: 4
26+
train_batch_size_expert: 64
27+
train_batch_size_usual: 256 # 32 batchsize * 8 repeat times
28+
model:
29+
model_path: ${oc.env:TRINITY_MODEL_PATH,Qwen/Qwen2.5-1.5B-Instruct}
30+
max_response_tokens: 10240
31+
max_model_len: 11264
32+
cluster:
33+
node_num: 1
34+
gpu_per_node: 8
35+
buffer:
36+
total_epochs: 4
37+
batch_size: 32
38+
train_batch_size: 320
39+
explorer_input:
40+
taskset:
41+
name: geometry3k
42+
storage_type: file
43+
path: ${oc.env:TRINITY_TASKSET_PATH,hiyouga/geometry3k}
44+
subset_name: 'default'
45+
split: 'train'
46+
format:
47+
prompt_key: 'problem'
48+
response_key: 'answer'
49+
image_key: 'images'
50+
rollout_args:
51+
temperature: 1.0
52+
logprobs: 0
53+
workflow_args:
54+
with_think: true
55+
eval_tasksets: [] # you can add your own eval tasksets here
56+
default_workflow_type: 'simple_mm_workflow'
57+
default_reward_fn_type: 'math_boxed_reward'
58+
trainer_input:
59+
experience_buffer:
60+
name: experience_buffer
61+
storage_type: queue
62+
auxiliary_buffers:
63+
sft_dataset:
64+
total_epochs: 25
65+
name: geometry_sft
66+
storage_type: file
67+
schema_type: sft
68+
path: datajuicer/geometry_sft
69+
split: 'train'
70+
format:
71+
prompt_type: plaintext
72+
prompt_key: 'problem'
73+
response_key: 'response'
74+
image_key: 'images'
75+
explorer:
76+
eval_interval: 10
77+
runner_per_model: 8
78+
rollout_model:
79+
engine_num: 4
80+
tensor_parallel_size: 1
81+
enable_prefix_caching: false
82+
enforce_eager: true
83+
dtype: bfloat16
84+
seed: 42
85+
synchronizer:
86+
sync_method: 'nccl'
87+
sync_interval: 1
88+
sync_timeout: 1200
89+
trainer:
90+
save_interval: 50
91+
grad_clip: 1.0
92+
use_dynamic_bsz: true
93+
max_token_len_per_gpu: 11264
94+
ulysses_sequence_parallel_size: 2
Lines changed: 51 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,51 @@
1+
# Example: PPO on Countdown dataset with experience replay
2+
3+
In this example, we follow the main settings in [`ppo_countdown`](../ppo_countdown/README.md),
4+
and demonstrate the **experience replay** mechanisms in Trinity-RFT.
5+
6+
7+
### Motivations
8+
9+
One motivation for experience replay is that, it is often desirable to improve learning efficiency by reusing the rollout samples for multiple training steps, especially in scenarios where rollout (with agent-environment interaction) is slow or expensive.
10+
Moreover, experience replay offers a straightforward method for filling pipeline bubbles in the trainer (caused by discrepencies between explorer's and trainer's speeds) with useful computation, improving hardware utilization for the disaggregated architecture adopted by Trinity (and many other RL systems).
11+
12+
### Implementation and configuration
13+
14+
The priority queue buffer in Trinity offers seamless support for experience replay.
15+
Whenever a batch of highest-priority samples are retrieved from the buffer,
16+
a **priority function** updates their priority scores and decide which one should be put back into the buffer (after `reuse_cooldown_time` seconds have passed) for replay.
17+
Users of Trinity can implement and register their own customized priority functions,
18+
which can then be called by setting the `priority_fn` field in the yaml config.
19+
20+
We present an example config file in [`countdown.yaml`](./countdown.yaml),
21+
where 1 GPU is allocated to the explorer and 6 GPUs to the trainer,
22+
simulating a scenario where agent-environment interaction is slow and rollout data is scarce.
23+
Important config parameters for experience replay include:
24+
* `buffer.trainer_input.experience_buffer.storage_type`: set to `queue`
25+
* `buffer.trainer_input.experience_buffer.replay_buffer`
26+
* `enable`: set to `true` for enabling priority queue buffer
27+
* `reuse_cooldown_time`: delay time (in seconds) before putting sample back into the buffer; must be set explicitly
28+
* `priority_fn`: name of the priority function
29+
* `priority_fn_args`: additional args for the priority function
30+
* `synchronizer.sync_style`: set to `dynamic_by_explorer`, which allows the trainer to run more training steps as long as the priority queue buffer is non-empty
31+
32+
The priority function used in this example is named `decay_limit_randomization`.
33+
The logic behind it:
34+
* Priority score is calculated as `model_version - decay * use_count`, i.e., fresher and less used samples are prioritized;
35+
* If `sigma` is non-zero, priority score is further perturbed by random Gaussian noise with standard deviation `sigma`;
36+
* A retrieved sample will be put back into the buffer if and only if its use count has not exceeded `use_count_limit`.
37+
38+
39+
### Experimental results
40+
41+
We conduct experiment for this config, and compare it with a baseline config that uses each rollout sample exactly once for training.
42+
The first and second figures below --- using rollout step or wall-clock time as the X-axis --- confirms the benefits brought by experience replay (with default hyperparameters).
43+
This is partly because more training steps can be taken, as shown in the third figure (where X-axis represents rollout step).
44+
45+
46+
47+
<img src="../../docs/sphinx_doc/assets/example_experience_replay/exp_replay_X_explore_step.png" alt="score-vs-explore-step" width="600" />
48+
49+
<img src="../../docs/sphinx_doc/assets/example_experience_replay/exp_replay_X_time.png" alt="score-vs-wall-clock-time" width="600" />
50+
51+
<img src="../../docs/sphinx_doc/assets/example_experience_replay/exp_replay_model_version.png" alt="model-version" width="600" />

0 commit comments

Comments
 (0)