Skip to content
Merged
Show file tree
Hide file tree
Changes from 7 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Binary file added docs/sphinx_doc/assets/async-curve.png
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
39 changes: 39 additions & 0 deletions docs/sphinx_doc/source/tutorial/example_async_mode.md
Original file line number Diff line number Diff line change
@@ -0,0 +1,39 @@
# A quick example for asynchronous mode

This example shows how to run RFT in asynchronous mode with the Qwen-2.5-1.5B-Instruct model and GSM8K dataset.

Trinity-RFT supports an asynchronous mode by running the trainer and explorer in separate processes.

For this purpose, we prepare two main config files: `trainer.yaml` and `explorer.yaml`.
The main difference between them is that in `trainer.yaml` we set `mode=train`, while in `explorer.yaml` we set `mode=explore`.
In addition, we need to configure the following parameters in both files.
The synchronization frequency is defined as every `sync_iteration_interval * batch_size` tasks.

```yaml
data:
batch_size: <batch_size>
# The same checkpoint path
model:
checkpoint_path: /PATH/TO/CHECKPOINT

# The same data_base path
buffer:
train_dataset:
name: gsm8k_buffer
storage_type: queue
path: 'sqlite:///gsm8k.db'

synchronizer:
sync_method: 'offline'
sync_iteration_interval: <sync_iteration_interval>
```

You may run this examples by running the following command:

```bash
bash examples/async_gsm8k/run.sh
```

In the following, we show the results of asynchronous mode in the following.

![async](../../assets/async-curve.png)
16 changes: 1 addition & 15 deletions docs/sphinx_doc/source/tutorial/example_reasoning_advanced.md
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
# Example: off-policy / asynchronous RFT mode
# Example: off-policy RFT mode


Let's continue with the [previous GSM8k example](./example_reasoning_basic.md) and show some advanced features provided by Trinity-RFT, namely, off-policy or asynchronous RFT mode.
Expand Down Expand Up @@ -35,17 +35,3 @@ A similar performance boost is shown at step 21, which leads to a converged scor


![opmd](../../assets/opmd-curve.png)





## Asynchronous mode


Trinity-RFT supports the asynchronous and decoupled mode of RFT, where explorer and trainer act independently and asynchronously.
To run this mode, the explorer and trainer need to be launched separately, with the `mode` parameter in the config file set to `explore` and `train` respectively.



*We are still testing this mode more thoroughly. A concrete example is coming soon!*
13 changes: 13 additions & 0 deletions examples/async_gsm8k/README.md
Original file line number Diff line number Diff line change
@@ -0,0 +1,13 @@
# Asynchronous mode on GSM8K dataset

This example shows the usage of GRPO on the GSM8K dataset in an asynchronous mode.

For more detailed information, please refer to the [documentation](../../docs/sphinx_doc/source/tutorial/example_async_mode.md).

The config files are located in [`trainer.yaml`](trainer.yaml), [`explorer.yaml`](explorer.yaml), and [`verl_config.yaml`](verl_config.yaml).

You can run this example by the following command:

```bash
bash examples/async_gsm8k/run.sh
```
60 changes: 60 additions & 0 deletions examples/async_gsm8k/explorer.yaml
Original file line number Diff line number Diff line change
@@ -0,0 +1,60 @@
mode: explore
data:
# basic info
dataset_path: /PATH/TO/DATASET/
subset_name: ''
train_split: 'train'
eval_split: 'test'
format_config:
prompt_key: 'question'
response_key: 'answer'
# downstream loading related
total_epochs: 20
batch_size: 96
default_workflow_type: 'math_workflow'
model:
model_path: /PATH/TO/MODEL/
max_prompt_tokens: 256
max_response_tokens: 1024
checkpoint_path: 'checkpoints/qwen2.5-1.5B-gsm8k'
cluster:
node_num: 1
gpu_per_node: 8
buffer:
max_retry_times: 3
max_retry_interval: 1
train_dataset:
name: gsm8k_buffer
storage_type: queue
path: 'sqlite:///gsm8k.db'
explorer:
engine_type: vllm_async
engine_num: 2
runner_num: 32
tensor_parallel_size: 1
enable_prefix_caching: false
enforce_eager: true
dtype: bfloat16
temperature: 1.0
top_p: 1.0
top_k: -1
seed: 42
logprobs: 0
repeat_times: 8
use_ray: false
backend: 'nccl'
max_pending_requests: 32
max_waiting_steps: 4
synchronizer:
sync_method: 'offline'
sync_iteration_interval: 10
trainer:
trainer_type: 'verl'
algorithm_type: ppo
trainer_config_path: examples/async_gsm8k/verl_config.yaml
sft_warmup_iteration: 0 # Set to integer to enable sft warmup
eval_interval: 10
monitor:
cache_root_dir: ""
project: "Trinity-RFT-gsm8k"
name: "async-qwen2.5-1.5B-gsm8k"
4 changes: 4 additions & 0 deletions examples/async_gsm8k/run.sh
Original file line number Diff line number Diff line change
@@ -0,0 +1,4 @@
#!/bin/bash
trinity run --config examples/async_gsm8k/explorer.yaml 2>&1 | tee explorer.log &
sleep 30
trinity run --config examples/async_gsm8k/trainer.yaml 2>&1 | tee trainer.log &
60 changes: 60 additions & 0 deletions examples/async_gsm8k/trainer.yaml
Original file line number Diff line number Diff line change
@@ -0,0 +1,60 @@
mode: train
data:
# basic info
dataset_path: /PATH/TO/DATASET/
subset_name: ''
train_split: 'train'
eval_split: 'test'
format_config:
prompt_key: 'question'
response_key: 'answer'
# downstream loading related
total_epochs: 20
batch_size: 96
default_workflow_type: 'math_workflow'
model:
model_path: /PATH/TO/MODEL/
max_prompt_tokens: 256
max_response_tokens: 1024
checkpoint_path: ""
cluster:
node_num: 1
gpu_per_node: 8
buffer:
max_retry_times: 3
max_retry_interval: 1
train_dataset:
name: gsm8k_buffer
storage_type: queue
path: 'sqlite:///gsm8k.db'
explorer:
engine_type: vllm_async
engine_num: 2
runner_num: 32
tensor_parallel_size: 1
enable_prefix_caching: false
enforce_eager: true
dtype: bfloat16
temperature: 1.0
top_p: 1.0
top_k: -1
seed: 42
logprobs: 0
repeat_times: 8
use_ray: false
backend: 'nccl'
max_pending_requests: 32
max_waiting_steps: 4
synchronizer:
sync_method: 'offline'
sync_iteration_interval: 10
trainer:
trainer_type: 'verl'
algorithm_type: ppo
trainer_config_path: examples/async_gsm8k/verl_config.yaml
sft_warmup_iteration: 0 # Set to integer to enable sft warmup
eval_interval: 10
monitor:
cache_root_dir: ""
project: "Trinity-RFT-gsm8k"
name: "async-qwen2.5-1.5B-gsm8k"
186 changes: 186 additions & 0 deletions examples/async_gsm8k/verl_config.yaml
Original file line number Diff line number Diff line change
@@ -0,0 +1,186 @@
data:
tokenizer: null
train_files: placeholder
val_files: placeholder
prompt_key: prompt
max_prompt_length: 256
max_response_length: 1024
train_batch_size: 256
val_batch_size: null
return_raw_input_ids: False # This should be set to true when the tokenizer between policy and rm differs
return_raw_chat: False
shuffle: True
filter_overlong_prompts: False # for large-scale dataset, filtering overlong prompts could be timeconsuming. You should disable this and set `truncation='left'
truncation: error
image_key: images

actor_rollout_ref:
hybrid_engine: True
model:
path: /PATH/TO/MODEL/
external_lib: null
override_config: { }
enable_gradient_checkpointing: True
use_remove_padding: True # False
actor:
strategy: fsdp # This is for backward-compatibility
ppo_mini_batch_size: 128
# ppo_micro_batch_size: 8 # will be deprecated, use ppo_micro_batch_size_per_gpu
ppo_micro_batch_size_per_gpu: 4
use_dynamic_bsz: True # False
ppo_max_token_len_per_gpu: 16384 # n * ${data.max_prompt_length} + ${data.max_response_length}
grad_clip: 1.0
clip_ratio: 0.2
entropy_coeff: 0.001
use_kl_loss: True # True for GRPO
kl_loss_coef: 0.001 # for grpo
kl_loss_type: low_var_kl # for grpo
ppo_epochs: 1
shuffle: False
ulysses_sequence_parallel_size: 1 # sp size
optim:
lr: 1e-5
lr_warmup_steps_ratio: 0. # the total steps will be injected during runtime
# min_lr_ratio: null # only useful for warmup with cosine
warmup_style: constant # select from constant/cosine
total_training_steps: -1 # must be override by program
fsdp_config:
wrap_policy:
# transformer_layer_cls_to_wrap: None
min_num_params: 0
param_offload: False
optimizer_offload: False
fsdp_size: -1
# --- below: opmd ---
alg_type: ppo # ppo / opmd / pairwise_opmd
tau: 0.000 # strength of regularization w.r.t. old / ref policy
opmd_baseline: mean # mean / logavgexp, applicable to opmd
use_uid: False # True / False, applicable to pairwise_opmd
ref:
fsdp_config:
param_offload: False
wrap_policy:
# transformer_layer_cls_to_wrap: None
min_num_params: 0
# log_prob_micro_batch_size: 4 # will be deprecated, use log_prob_micro_batch_size_per_gpu
log_prob_micro_batch_size_per_gpu: 16
log_prob_use_dynamic_bsz: ${actor_rollout_ref.actor.use_dynamic_bsz}
log_prob_max_token_len_per_gpu: ${actor_rollout_ref.actor.ppo_max_token_len_per_gpu}
ulysses_sequence_parallel_size: ${actor_rollout_ref.actor.ulysses_sequence_parallel_size} # sp size
rollout:
name: vllm
temperature: 1.0
top_k: -1 # 0 for hf rollout, -1 for vllm rollout
top_p: 1
use_fire_sampling: False # https://arxiv.org/abs/2410.21236
prompt_length: ${data.max_prompt_length} # not use for opensource
response_length: ${data.max_response_length}
# for vllm rollout
dtype: bfloat16 # should align with FSDP
gpu_memory_utilization: 0.4
ignore_eos: False
enforce_eager: True
free_cache_engine: True
load_format: dummy_dtensor
tensor_model_parallel_size: 2
max_num_batched_tokens: 8192
max_model_len: null
max_num_seqs: 1024
# log_prob_micro_batch_size: 8 # will be deprecated, use log_prob_micro_batch_size_per_gpu
log_prob_micro_batch_size_per_gpu: 4
log_prob_use_dynamic_bsz: ${actor_rollout_ref.actor.use_dynamic_bsz}
log_prob_max_token_len_per_gpu: ${actor_rollout_ref.actor.ppo_max_token_len_per_gpu}
disable_log_stats: True
enable_chunked_prefill: True # could get higher throughput
# for hf rollout
do_sample: True
# number of responses (i.e. num sample times)
n: 8 # > 1 for grpo

critic:
strategy: fsdp
optim:
lr: 1e-5
lr_warmup_steps_ratio: 0. # the total steps will be injected during runtime
# min_lr_ratio: null # only useful for warmup with cosine
warmup_style: constant # select from constant/cosine
total_training_steps: -1 # must be override by program
model:
path: /PATH/TO/MODEL/
tokenizer_path: ${actor_rollout_ref.model.path}
override_config: { }
external_lib: ${actor_rollout_ref.model.external_lib}
enable_gradient_checkpointing: True
use_remove_padding: False
fsdp_config:
param_offload: False
optimizer_offload: False
wrap_policy:
# transformer_layer_cls_to_wrap: None
min_num_params: 0
fsdp_size: -1
ppo_mini_batch_size: ${actor_rollout_ref.actor.ppo_mini_batch_size}
# ppo_micro_batch_size: 8 # will be deprecated, use ppo_micro_batch_size_per_gpu
ppo_micro_batch_size_per_gpu: 64
forward_micro_batch_size_per_gpu: ${critic.ppo_micro_batch_size_per_gpu}
use_dynamic_bsz: ${actor_rollout_ref.actor.use_dynamic_bsz}
ppo_max_token_len_per_gpu: 32768 # (${actor_rollout_ref.actor.ppo_max_token_len_per_gpu}) * 2
forward_max_token_len_per_gpu: ${critic.ppo_max_token_len_per_gpu}
ulysses_sequence_parallel_size: 1 # sp size
ppo_epochs: ${actor_rollout_ref.actor.ppo_epochs}
shuffle: ${actor_rollout_ref.actor.shuffle}
grad_clip: 1.0
cliprange_value: 0.5

reward_model:
enable: False
strategy: fsdp
model:
input_tokenizer: ${actor_rollout_ref.model.path} # set this to null if the chat template is identical
path: ~/models/FsfairX-LLaMA3-RM-v0.1
external_lib: ${actor_rollout_ref.model.external_lib}
use_remove_padding: False
fsdp_config:
min_num_params: 0
param_offload: False
fsdp_size: -1
# micro_batch_size: null # will be deprecated, use micro_batch_size_per_gpu
# micro_batch_size_per_gpu: 2 # set a number
# max_length: null
ulysses_sequence_parallel_size: 1 # sp size
use_dynamic_bsz: ${critic.use_dynamic_bsz}
forward_max_token_len_per_gpu: ${critic.forward_max_token_len_per_gpu}

custom_reward_function:
path: null
name: compute_score

algorithm:
gamma: 1.0
lam: 1.0
adv_estimator: grpo
kl_penalty: kl # how to estimate kl divergence
kl_ctrl:
type: fixed
kl_coef: 0.001

trainer:
balance_batch: True
total_epochs: 10
# total_training_steps: null
project_name: rft_example_gsm8k
experiment_name: cys-qwen2_1.5b_rollout8_grpo_kl0.001_lr1e-5
logger: [ 'console','wandb' ]
val_generations_to_log_to_wandb: 0
nnodes: 1
n_gpus_per_node: 2
save_freq: 100
# auto: find the last ckpt to resume. If can't find, start from scratch
resume_mode: auto # or auto or resume_path if
test_freq: 5
critic_warmup: 0
default_hdfs_dir: null
remove_previous_ckpt_in_save: False
del_local_ckpt_after_load: False
default_local_dir: checkpoints/${trainer.project_name}/${trainer.experiment_name}
val_before_train: False
1 change: 1 addition & 0 deletions trinity/buffer/reader/queue_reader.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,6 +21,7 @@ def __init__(self, meta: DatasetConfig, config: BufferConfig):
self.config = config
self.queue = QueueActor.options(
name=f"queue-{meta.name}",
namespace="Trinity-RFT",
get_if_exists=True,
).remote(meta, config)

Expand Down
Loading