-
Notifications
You must be signed in to change notification settings - Fork 47
Add example_async_mode #28
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Merged
Merged
Changes from 9 commits
Commits
Show all changes
11 commits
Select commit
Hold shift + click to select a range
fc390af
add example_async_mode.md
hiyuchang 241d8bd
add same space for queueactor
hiyuchang cf7cbad
Merge branch 'main' into dev/async
hiyuchang b5f6e9c
add curve for async
hiyuchang d1fccef
add curge figure
hiyuchang 922aa03
fix sqlite database path format
hiyuchang 3fc6b2e
add running command for async
hiyuchang 469875c
add namespace
hiyuchang b959f12
fix async config
hiyuchang 56ebd60
change md link
hiyuchang e2c1dab
polish sentences for async tutorial
hiyuchang File filter
Filter by extension
Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
There are no files selected for viewing
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
| Original file line number | Diff line number | Diff line change |
|---|---|---|
| @@ -0,0 +1,39 @@ | ||
| # A quick example for asynchronous mode | ||
|
|
||
| This example shows how to run RFT in asynchronous mode with the Qwen-2.5-1.5B-Instruct model and GSM8K dataset. | ||
|
|
||
| Trinity-RFT supports an asynchronous mode by running the trainer and explorer in separate processes. | ||
|
|
||
| For this purpose, we prepare two main config files: `trainer.yaml` and `explorer.yaml`. | ||
| The main difference between them is that in `trainer.yaml` we set `mode=train`, while in `explorer.yaml` we set `mode=explore`. | ||
| In addition, we need to configure the following parameters in both files. | ||
| The synchronization frequency is defined as every `sync_iteration_interval * batch_size` tasks. | ||
hiyuchang marked this conversation as resolved.
Outdated
Show resolved
Hide resolved
|
||
|
|
||
| ```yaml | ||
| data: | ||
| batch_size: <batch_size> | ||
| # The same checkpoint path | ||
| model: | ||
| checkpoint_path: /PATH/TO/CHECKPOINT | ||
|
|
||
| # The same data_base path | ||
| buffer: | ||
| train_dataset: | ||
| name: gsm8k_buffer | ||
| storage_type: queue | ||
| path: 'sqlite:///gsm8k.db' | ||
|
|
||
| synchronizer: | ||
| sync_method: 'checkpoint' | ||
| sync_iteration_interval: <sync_iteration_interval> | ||
| ``` | ||
|
|
||
| You may run this examples by running the following command: | ||
hiyuchang marked this conversation as resolved.
Outdated
Show resolved
Hide resolved
|
||
|
|
||
| ```bash | ||
| bash examples/async_gsm8k/run.sh | ||
| ``` | ||
|
|
||
| In the following, we show the results of asynchronous mode in the following. | ||
hiyuchang marked this conversation as resolved.
Outdated
Show resolved
Hide resolved
|
||
|
|
||
|  | ||
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
| Original file line number | Diff line number | Diff line change |
|---|---|---|
| @@ -0,0 +1,13 @@ | ||
| # Asynchronous mode on GSM8K dataset | ||
|
|
||
| This example shows the usage of GRPO on the GSM8K dataset in an asynchronous mode. | ||
|
|
||
| For more detailed information, please refer to the [documentation](../../docs/sphinx_doc/source/tutorial/example_async_mode.md). | ||
|
|
||
| The config files are located in [`trainer.yaml`](trainer.yaml), [`explorer.yaml`](explorer.yaml), and [`verl_config.yaml`](verl_config.yaml). | ||
|
|
||
| You can run this example by the following command: | ||
|
|
||
| ```bash | ||
| bash examples/async_gsm8k/run.sh | ||
| ``` |
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
| Original file line number | Diff line number | Diff line change |
|---|---|---|
| @@ -0,0 +1,58 @@ | ||
| mode: explore | ||
| data: | ||
| # basic info | ||
| dataset_path: /PATH/TO/DATASET/ | ||
| subset_name: '' | ||
| train_split: 'train' | ||
| eval_split: 'test' | ||
| format_config: | ||
| prompt_key: 'question' | ||
| response_key: 'answer' | ||
| # downstream loading related | ||
| total_epochs: 20 | ||
| batch_size: 96 | ||
| default_workflow_type: 'math_workflow' | ||
| model: | ||
| model_path: /PATH/TO/MODEL/ | ||
| max_prompt_tokens: 256 | ||
| max_response_tokens: 1024 | ||
| checkpoint_path: 'checkpoints/qwen2.5-1.5B-gsm8k' | ||
| cluster: | ||
| node_num: 1 | ||
| gpu_per_node: 8 | ||
| buffer: | ||
| max_retry_times: 3 | ||
| max_retry_interval: 1 | ||
| train_dataset: | ||
| name: gsm8k_buffer | ||
| storage_type: queue | ||
| path: 'sqlite:///gsm8k.db' | ||
| explorer: | ||
| engine_type: vllm_async | ||
| engine_num: 2 | ||
| runner_num: 32 | ||
| tensor_parallel_size: 1 | ||
| enable_prefix_caching: false | ||
| enforce_eager: true | ||
| dtype: bfloat16 | ||
| temperature: 1.0 | ||
| seed: 42 | ||
| logprobs: 0 | ||
| repeat_times: 8 | ||
| use_ray: false | ||
| backend: 'nccl' | ||
| max_pending_requests: 32 | ||
| max_waiting_steps: 4 | ||
| synchronizer: | ||
| sync_method: 'checkpoint' | ||
| sync_iteration_interval: 10 | ||
| trainer: | ||
| trainer_type: 'verl' | ||
| algorithm_type: ppo | ||
| trainer_config_path: examples/async_gsm8k/verl_config.yaml | ||
| sft_warmup_iteration: 0 # Set to integer to enable sft warmup | ||
| eval_interval: 10 | ||
| monitor: | ||
| cache_root_dir: "" | ||
| project: "Trinity-RFT-gsm8k" | ||
| name: "async-qwen2.5-1.5B-gsm8k" |
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
| Original file line number | Diff line number | Diff line change |
|---|---|---|
| @@ -0,0 +1,4 @@ | ||
| #!/bin/bash | ||
| trinity run --config examples/async_gsm8k/explorer.yaml 2>&1 | tee explorer.log & | ||
| sleep 30 | ||
| trinity run --config examples/async_gsm8k/trainer.yaml 2>&1 | tee trainer.log & |
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
| Original file line number | Diff line number | Diff line change |
|---|---|---|
| @@ -0,0 +1,58 @@ | ||
| mode: train | ||
| data: | ||
| # basic info | ||
| dataset_path: /PATH/TO/DATASET/ | ||
| subset_name: '' | ||
| train_split: 'train' | ||
| eval_split: 'test' | ||
| format_config: | ||
| prompt_key: 'question' | ||
| response_key: 'answer' | ||
| # downstream loading related | ||
| total_epochs: 20 | ||
| batch_size: 96 | ||
| default_workflow_type: 'math_workflow' | ||
| model: | ||
| model_path: /PATH/TO/MODEL/ | ||
| max_prompt_tokens: 256 | ||
| max_response_tokens: 1024 | ||
| checkpoint_path: "" | ||
| cluster: | ||
| node_num: 1 | ||
| gpu_per_node: 8 | ||
| buffer: | ||
| max_retry_times: 3 | ||
| max_retry_interval: 1 | ||
| train_dataset: | ||
| name: gsm8k_buffer | ||
| storage_type: queue | ||
| path: 'sqlite:///gsm8k.db' | ||
| explorer: | ||
| engine_type: vllm_async | ||
| engine_num: 2 | ||
| runner_num: 32 | ||
| tensor_parallel_size: 1 | ||
| enable_prefix_caching: false | ||
| enforce_eager: true | ||
| dtype: bfloat16 | ||
| temperature: 1.0 | ||
| seed: 42 | ||
| logprobs: 0 | ||
| repeat_times: 8 | ||
| use_ray: false | ||
| backend: 'nccl' | ||
| max_pending_requests: 32 | ||
| max_waiting_steps: 4 | ||
| synchronizer: | ||
| sync_method: 'checkpoint' | ||
| sync_iteration_interval: 10 | ||
| trainer: | ||
| trainer_type: 'verl' | ||
| algorithm_type: ppo | ||
| trainer_config_path: examples/async_gsm8k/verl_config.yaml | ||
| sft_warmup_iteration: 0 # Set to integer to enable sft warmup | ||
| eval_interval: 10 | ||
| monitor: | ||
| cache_root_dir: "" | ||
| project: "Trinity-RFT-gsm8k" | ||
| name: "async-qwen2.5-1.5B-gsm8k" |
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
| Original file line number | Diff line number | Diff line change |
|---|---|---|
| @@ -0,0 +1,184 @@ | ||
| data: | ||
| tokenizer: null | ||
| train_files: placeholder | ||
| val_files: placeholder | ||
| prompt_key: prompt | ||
| max_prompt_length: 256 | ||
| max_response_length: 1024 | ||
| train_batch_size: 256 | ||
| val_batch_size: null | ||
| return_raw_input_ids: False # This should be set to true when the tokenizer between policy and rm differs | ||
| return_raw_chat: False | ||
| shuffle: True | ||
| filter_overlong_prompts: False # for large-scale dataset, filtering overlong prompts could be timeconsuming. You should disable this and set `truncation='left' | ||
| truncation: error | ||
| image_key: images | ||
|
|
||
| actor_rollout_ref: | ||
| hybrid_engine: True | ||
| model: | ||
| path: /PATH/TO/MODEL/ | ||
| external_lib: null | ||
| override_config: { } | ||
| enable_gradient_checkpointing: True | ||
| use_remove_padding: True # False | ||
| actor: | ||
| strategy: fsdp # This is for backward-compatibility | ||
| ppo_mini_batch_size: 128 | ||
| # ppo_micro_batch_size: 8 # will be deprecated, use ppo_micro_batch_size_per_gpu | ||
| ppo_micro_batch_size_per_gpu: 4 | ||
| use_dynamic_bsz: True # False | ||
| ppo_max_token_len_per_gpu: 16384 # n * ${data.max_prompt_length} + ${data.max_response_length} | ||
| grad_clip: 1.0 | ||
| clip_ratio: 0.2 | ||
| entropy_coeff: 0.001 | ||
| use_kl_loss: True # True for GRPO | ||
| kl_loss_coef: 0.001 # for grpo | ||
| kl_loss_type: low_var_kl # for grpo | ||
| ppo_epochs: 1 | ||
| shuffle: False | ||
| ulysses_sequence_parallel_size: 1 # sp size | ||
| optim: | ||
| lr: 1e-5 | ||
| lr_warmup_steps_ratio: 0. # the total steps will be injected during runtime | ||
| # min_lr_ratio: null # only useful for warmup with cosine | ||
| warmup_style: constant # select from constant/cosine | ||
| total_training_steps: -1 # must be override by program | ||
| fsdp_config: | ||
| wrap_policy: | ||
| # transformer_layer_cls_to_wrap: None | ||
| min_num_params: 0 | ||
| param_offload: False | ||
| optimizer_offload: False | ||
| fsdp_size: -1 | ||
| # --- below: opmd --- | ||
| alg_type: ppo # ppo / opmd / pairwise_opmd | ||
| tau: 0.000 # strength of regularization w.r.t. old / ref policy | ||
| opmd_baseline: mean # mean / logavgexp, applicable to opmd | ||
| use_uid: False # True / False, applicable to pairwise_opmd | ||
| ref: | ||
| fsdp_config: | ||
| param_offload: False | ||
| wrap_policy: | ||
| # transformer_layer_cls_to_wrap: None | ||
| min_num_params: 0 | ||
| # log_prob_micro_batch_size: 4 # will be deprecated, use log_prob_micro_batch_size_per_gpu | ||
| log_prob_micro_batch_size_per_gpu: 16 | ||
| log_prob_use_dynamic_bsz: ${actor_rollout_ref.actor.use_dynamic_bsz} | ||
| log_prob_max_token_len_per_gpu: ${actor_rollout_ref.actor.ppo_max_token_len_per_gpu} | ||
| ulysses_sequence_parallel_size: ${actor_rollout_ref.actor.ulysses_sequence_parallel_size} # sp size | ||
| rollout: | ||
| name: vllm | ||
| temperature: 1.0 | ||
| use_fire_sampling: False # https://arxiv.org/abs/2410.21236 | ||
| prompt_length: ${data.max_prompt_length} # not use for opensource | ||
| response_length: ${data.max_response_length} | ||
| # for vllm rollout | ||
| dtype: bfloat16 # should align with FSDP | ||
| gpu_memory_utilization: 0.4 | ||
| ignore_eos: False | ||
| enforce_eager: True | ||
| free_cache_engine: True | ||
| load_format: dummy_dtensor | ||
| tensor_model_parallel_size: 2 | ||
| max_num_batched_tokens: 8192 | ||
| max_model_len: null | ||
| max_num_seqs: 1024 | ||
| # log_prob_micro_batch_size: 8 # will be deprecated, use log_prob_micro_batch_size_per_gpu | ||
| log_prob_micro_batch_size_per_gpu: 4 | ||
| log_prob_use_dynamic_bsz: ${actor_rollout_ref.actor.use_dynamic_bsz} | ||
| log_prob_max_token_len_per_gpu: ${actor_rollout_ref.actor.ppo_max_token_len_per_gpu} | ||
| disable_log_stats: True | ||
| enable_chunked_prefill: True # could get higher throughput | ||
| # for hf rollout | ||
| do_sample: True | ||
| # number of responses (i.e. num sample times) | ||
| n: 8 # > 1 for grpo | ||
|
|
||
| critic: | ||
| strategy: fsdp | ||
| optim: | ||
| lr: 1e-5 | ||
| lr_warmup_steps_ratio: 0. # the total steps will be injected during runtime | ||
| # min_lr_ratio: null # only useful for warmup with cosine | ||
| warmup_style: constant # select from constant/cosine | ||
| total_training_steps: -1 # must be override by program | ||
| model: | ||
| path: /PATH/TO/MODEL/ | ||
| tokenizer_path: ${actor_rollout_ref.model.path} | ||
| override_config: { } | ||
| external_lib: ${actor_rollout_ref.model.external_lib} | ||
| enable_gradient_checkpointing: True | ||
| use_remove_padding: False | ||
| fsdp_config: | ||
| param_offload: False | ||
| optimizer_offload: False | ||
| wrap_policy: | ||
| # transformer_layer_cls_to_wrap: None | ||
| min_num_params: 0 | ||
| fsdp_size: -1 | ||
| ppo_mini_batch_size: ${actor_rollout_ref.actor.ppo_mini_batch_size} | ||
| # ppo_micro_batch_size: 8 # will be deprecated, use ppo_micro_batch_size_per_gpu | ||
| ppo_micro_batch_size_per_gpu: 64 | ||
| forward_micro_batch_size_per_gpu: ${critic.ppo_micro_batch_size_per_gpu} | ||
| use_dynamic_bsz: ${actor_rollout_ref.actor.use_dynamic_bsz} | ||
| ppo_max_token_len_per_gpu: 32768 # (${actor_rollout_ref.actor.ppo_max_token_len_per_gpu}) * 2 | ||
| forward_max_token_len_per_gpu: ${critic.ppo_max_token_len_per_gpu} | ||
| ulysses_sequence_parallel_size: 1 # sp size | ||
| ppo_epochs: ${actor_rollout_ref.actor.ppo_epochs} | ||
| shuffle: ${actor_rollout_ref.actor.shuffle} | ||
| grad_clip: 1.0 | ||
| cliprange_value: 0.5 | ||
|
|
||
| reward_model: | ||
| enable: False | ||
| strategy: fsdp | ||
| model: | ||
| input_tokenizer: ${actor_rollout_ref.model.path} # set this to null if the chat template is identical | ||
| path: ~/models/FsfairX-LLaMA3-RM-v0.1 | ||
| external_lib: ${actor_rollout_ref.model.external_lib} | ||
| use_remove_padding: False | ||
| fsdp_config: | ||
| min_num_params: 0 | ||
| param_offload: False | ||
| fsdp_size: -1 | ||
| # micro_batch_size: null # will be deprecated, use micro_batch_size_per_gpu | ||
| # micro_batch_size_per_gpu: 2 # set a number | ||
| # max_length: null | ||
| ulysses_sequence_parallel_size: 1 # sp size | ||
| use_dynamic_bsz: ${critic.use_dynamic_bsz} | ||
| forward_max_token_len_per_gpu: ${critic.forward_max_token_len_per_gpu} | ||
|
|
||
| custom_reward_function: | ||
| path: null | ||
| name: compute_score | ||
|
|
||
| algorithm: | ||
| gamma: 1.0 | ||
| lam: 1.0 | ||
| adv_estimator: grpo | ||
| kl_penalty: kl # how to estimate kl divergence | ||
| kl_ctrl: | ||
| type: fixed | ||
| kl_coef: 0.001 | ||
|
|
||
| trainer: | ||
| balance_batch: True | ||
| total_epochs: 10 | ||
| # total_training_steps: null | ||
| project_name: rft_example_gsm8k | ||
| experiment_name: cys-qwen2_1.5b_rollout8_grpo_kl0.001_lr1e-5 | ||
| logger: [ 'console','wandb' ] | ||
| val_generations_to_log_to_wandb: 0 | ||
| nnodes: 1 | ||
| n_gpus_per_node: 2 | ||
| save_freq: 100 | ||
| # auto: find the last ckpt to resume. If can't find, start from scratch | ||
| resume_mode: auto # or auto or resume_path if | ||
| test_freq: 5 | ||
| critic_warmup: 0 | ||
| default_hdfs_dir: null | ||
| remove_previous_ckpt_in_save: False | ||
| del_local_ckpt_after_load: False | ||
| default_local_dir: checkpoints/${trainer.project_name}/${trainer.experiment_name} | ||
| val_before_train: False |
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Oops, something went wrong.
Add this suggestion to a batch that can be applied as a single commit.
This suggestion is invalid because no changes were made to the code.
Suggestions cannot be applied while the pull request is closed.
Suggestions cannot be applied while viewing a subset of changes.
Only one suggestion per line can be applied in a batch.
Add this suggestion to a batch that can be applied as a single commit.
Applying suggestions on deleted lines is not supported.
You must change the existing code in this line in order to create a valid suggestion.
Outdated suggestions cannot be applied.
This suggestion has been applied or marked resolved.
Suggestions cannot be applied from pending reviews.
Suggestions cannot be applied on multi-line comments.
Suggestions cannot be applied while the pull request is queued to merge.
Suggestion cannot be applied right now. Please check back later.
Uh oh!
There was an error while loading. Please reload this page.