diff --git a/README.md b/README.md index 09988d6d81..9dc284bce4 100644 --- a/README.md +++ b/README.md @@ -260,7 +260,8 @@ More example config files can be found in `examples`. For more detailed examples about how to use Trinity-RFT, please refer to the following tutorials: + [A quick example with GSM8k](./docs/sphinx_doc/source/tutorial/example_reasoning_basic.md); -+ [Off-policy / asynchronous modes of RFT](./docs/sphinx_doc/source/tutorial/example_reasoning_advanced.md); ++ [Off-policy mode of RFT](./docs/sphinx_doc/source/tutorial/example_reasoning_advanced.md); ++ [Asynchronous mode of RFT](./docs/sphinx_doc/source/tutorial/example_async_mode.md); + [Multi-turn tasks](./docs/sphinx_doc/source/tutorial/example_multi_turn.md); + [Data processing pipelines](./docs/sphinx_doc/source/tutorial/example_data_functionalities.md); + [Offline learning by DPO](./docs/sphinx_doc/source/tutorial/example_dpo.md). diff --git a/docs/sphinx_doc/assets/async-curve.png b/docs/sphinx_doc/assets/async-curve.png new file mode 100644 index 0000000000..74fee57be5 Binary files /dev/null and b/docs/sphinx_doc/assets/async-curve.png differ diff --git a/docs/sphinx_doc/source/main.md b/docs/sphinx_doc/source/main.md index a7e6684219..59fcaf62f1 100644 --- a/docs/sphinx_doc/source/main.md +++ b/docs/sphinx_doc/source/main.md @@ -240,7 +240,8 @@ More example config files can be found in `examples`. For more detailed examples about how to use Trinity-RFT, please refer to the following documents: + [A quick example with GSM8k](tutorial/example_reasoning_basic.md); -+ [Off-policy / asynchronous modes of RFT](tutorial/example_reasoning_advanced.md); ++ [Off-policy mode of RFT](tutorial/example_reasoning_advanced.md); ++ [Asynchronous mode of RFT](tutorial/example_async_mode.md); + [Multi-turn tasks](tutorial/example_multi_turn.md); + [Data processing pipelines](tutorial/example_data_functionalities.md); + [Offline learning by DPO](tutorial/example_dpo.md). diff --git a/docs/sphinx_doc/source/tutorial/example_async_mode.md b/docs/sphinx_doc/source/tutorial/example_async_mode.md new file mode 100644 index 0000000000..a96ff29ff6 --- /dev/null +++ b/docs/sphinx_doc/source/tutorial/example_async_mode.md @@ -0,0 +1,41 @@ +# A quick example for asynchronous mode + +This example shows how to run RFT in asynchronous mode with the GRPO algorithm, Qwen-2.5-1.5B-Instruct model and GSM8K dataset. + +Trinity-RFT supports an asynchronous mode by running the trainer and explorer in separate processes. + +For this purpose, we prepare two main config files: `trainer.yaml` and `explorer.yaml`. +The main difference between them is that in `trainer.yaml` we set `mode=train`, while in `explorer.yaml` we set `mode=explore`. +In addition, we need to configure the following parameters in both files. +The model weights of the explorer and trainer are synchronized once every `sync_iteration_interval * batch_size` tasks. + +```yaml +data: + batch_size: +# The same checkpoint path +model: + checkpoint_path: /PATH/TO/CHECKPOINT + +# The same data_base path +buffer: + train_dataset: + name: gsm8k_buffer + storage_type: queue + path: 'sqlite:///gsm8k.db' + +synchronizer: + sync_method: 'checkpoint' + sync_iteration_interval: +``` + +You may run this example with the following command: + +```bash +bash examples/async_gsm8k/run.sh +``` + +The following plot shows the learning curve of GRPO in the asynchronous mode. +> This result should be regarded merely as a baseline, since GRPO is supposed to be an on-policy algorithm. +> We are continuously investigating other RL algorithms (e.g., [OPMD](./example_reasoning_advanced.md)) in the asynchronous mode. + +![async](../../assets/async-curve.png) diff --git a/docs/sphinx_doc/source/tutorial/example_reasoning_advanced.md b/docs/sphinx_doc/source/tutorial/example_reasoning_advanced.md index d278ed22bc..9307ea9e9f 100644 --- a/docs/sphinx_doc/source/tutorial/example_reasoning_advanced.md +++ b/docs/sphinx_doc/source/tutorial/example_reasoning_advanced.md @@ -1,4 +1,4 @@ -# Example: off-policy / asynchronous RFT mode +# Example: off-policy RFT mode Let's continue with the [previous GSM8k example](./example_reasoning_basic.md) and show some advanced features provided by Trinity-RFT, namely, off-policy or asynchronous RFT mode. @@ -35,17 +35,3 @@ A similar performance boost is shown at step 21, which leads to a converged scor ![opmd](../../assets/opmd-curve.png) - - - - - -## Asynchronous mode - - -Trinity-RFT supports the asynchronous and decoupled mode of RFT, where explorer and trainer act independently and asynchronously. -To run this mode, the explorer and trainer need to be launched separately, with the `mode` parameter in the config file set to `explore` and `train` respectively. - - - -*We are still testing this mode more thoroughly. A concrete example is coming soon!* diff --git a/examples/async_gsm8k/README.md b/examples/async_gsm8k/README.md new file mode 100644 index 0000000000..c73aeeaf59 --- /dev/null +++ b/examples/async_gsm8k/README.md @@ -0,0 +1,13 @@ +# Asynchronous mode on GSM8K dataset + +This example shows the usage of GRPO on the GSM8K dataset in an asynchronous mode. + +For more detailed information, please refer to the [documentation](../../docs/sphinx_doc/source/tutorial/example_async_mode.md). + +The config files are located in [`trainer.yaml`](trainer.yaml), [`explorer.yaml`](explorer.yaml), and [`verl_config.yaml`](verl_config.yaml). + +You can run this example by the following command: + +```bash +bash examples/async_gsm8k/run.sh +``` diff --git a/examples/async_gsm8k/explorer.yaml b/examples/async_gsm8k/explorer.yaml new file mode 100644 index 0000000000..73f8669b10 --- /dev/null +++ b/examples/async_gsm8k/explorer.yaml @@ -0,0 +1,58 @@ +mode: explore +data: + # basic info + dataset_path: /PATH/TO/DATASET/ + subset_name: '' + train_split: 'train' + eval_split: 'test' + format_config: + prompt_key: 'question' + response_key: 'answer' + # downstream loading related + total_epochs: 20 + batch_size: 96 + default_workflow_type: 'math_workflow' +model: + model_path: /PATH/TO/MODEL/ + max_prompt_tokens: 256 + max_response_tokens: 1024 + checkpoint_path: 'checkpoints/qwen2.5-1.5B-gsm8k' +cluster: + node_num: 1 + gpu_per_node: 8 +buffer: + max_retry_times: 3 + max_retry_interval: 1 + train_dataset: + name: gsm8k_buffer + storage_type: queue + path: 'sqlite:///gsm8k.db' +explorer: + engine_type: vllm_async + engine_num: 2 + runner_num: 32 + tensor_parallel_size: 1 + enable_prefix_caching: false + enforce_eager: true + dtype: bfloat16 + temperature: 1.0 + seed: 42 + logprobs: 0 + repeat_times: 8 + use_ray: false + backend: 'nccl' + max_pending_requests: 32 + max_waiting_steps: 4 +synchronizer: + sync_method: 'checkpoint' + sync_iteration_interval: 10 +trainer: + trainer_type: 'verl' + algorithm_type: ppo + trainer_config_path: examples/async_gsm8k/verl_config.yaml + sft_warmup_iteration: 0 # Set to integer to enable sft warmup + eval_interval: 10 +monitor: + cache_root_dir: "" + project: "Trinity-RFT-gsm8k" + name: "async-qwen2.5-1.5B-gsm8k" diff --git a/examples/async_gsm8k/run.sh b/examples/async_gsm8k/run.sh new file mode 100644 index 0000000000..ff9ad66bbc --- /dev/null +++ b/examples/async_gsm8k/run.sh @@ -0,0 +1,4 @@ +#!/bin/bash +trinity run --config examples/async_gsm8k/explorer.yaml 2>&1 | tee explorer.log & +sleep 30 +trinity run --config examples/async_gsm8k/trainer.yaml 2>&1 | tee trainer.log & diff --git a/examples/async_gsm8k/trainer.yaml b/examples/async_gsm8k/trainer.yaml new file mode 100644 index 0000000000..d2dc92503a --- /dev/null +++ b/examples/async_gsm8k/trainer.yaml @@ -0,0 +1,58 @@ +mode: train +data: + # basic info + dataset_path: /PATH/TO/DATASET/ + subset_name: '' + train_split: 'train' + eval_split: 'test' + format_config: + prompt_key: 'question' + response_key: 'answer' + # downstream loading related + total_epochs: 20 + batch_size: 96 + default_workflow_type: 'math_workflow' +model: + model_path: /PATH/TO/MODEL/ + max_prompt_tokens: 256 + max_response_tokens: 1024 + checkpoint_path: "" +cluster: + node_num: 1 + gpu_per_node: 8 +buffer: + max_retry_times: 3 + max_retry_interval: 1 + train_dataset: + name: gsm8k_buffer + storage_type: queue + path: 'sqlite:///gsm8k.db' +explorer: + engine_type: vllm_async + engine_num: 2 + runner_num: 32 + tensor_parallel_size: 1 + enable_prefix_caching: false + enforce_eager: true + dtype: bfloat16 + temperature: 1.0 + seed: 42 + logprobs: 0 + repeat_times: 8 + use_ray: false + backend: 'nccl' + max_pending_requests: 32 + max_waiting_steps: 4 +synchronizer: + sync_method: 'checkpoint' + sync_iteration_interval: 10 +trainer: + trainer_type: 'verl' + algorithm_type: ppo + trainer_config_path: examples/async_gsm8k/verl_config.yaml + sft_warmup_iteration: 0 # Set to integer to enable sft warmup + eval_interval: 10 +monitor: + cache_root_dir: "" + project: "Trinity-RFT-gsm8k" + name: "async-qwen2.5-1.5B-gsm8k" diff --git a/examples/async_gsm8k/verl_config.yaml b/examples/async_gsm8k/verl_config.yaml new file mode 100644 index 0000000000..268d61e0e5 --- /dev/null +++ b/examples/async_gsm8k/verl_config.yaml @@ -0,0 +1,183 @@ +data: + tokenizer: null + train_files: placeholder + val_files: placeholder + prompt_key: prompt + max_prompt_length: 256 + max_response_length: 1024 + train_batch_size: 256 + val_batch_size: null + return_raw_input_ids: False # This should be set to true when the tokenizer between policy and rm differs + return_raw_chat: False + shuffle: True + filter_overlong_prompts: False # for large-scale dataset, filtering overlong prompts could be timeconsuming. You should disable this and set `truncation='left' + truncation: error + image_key: images + +actor_rollout_ref: + hybrid_engine: True + model: + path: /PATH/TO/MODEL/ + external_lib: null + override_config: { } + enable_gradient_checkpointing: True + use_remove_padding: True # False + actor: + strategy: fsdp # This is for backward-compatibility + ppo_mini_batch_size: 128 + # ppo_micro_batch_size: 8 # will be deprecated, use ppo_micro_batch_size_per_gpu + ppo_micro_batch_size_per_gpu: 4 + use_dynamic_bsz: True # False + ppo_max_token_len_per_gpu: 16384 # n * ${data.max_prompt_length} + ${data.max_response_length} + grad_clip: 1.0 + clip_ratio: 0.2 + entropy_coeff: 0.001 + use_kl_loss: True # True for GRPO + kl_loss_coef: 0.001 # for grpo + kl_loss_type: low_var_kl # for grpo + ppo_epochs: 1 + shuffle: False + ulysses_sequence_parallel_size: 1 # sp size + optim: + lr: 1e-5 + lr_warmup_steps_ratio: 0. # the total steps will be injected during runtime + # min_lr_ratio: null # only useful for warmup with cosine + warmup_style: constant # select from constant/cosine + total_training_steps: -1 # must be override by program + fsdp_config: + wrap_policy: + # transformer_layer_cls_to_wrap: None + min_num_params: 0 + param_offload: False + optimizer_offload: False + fsdp_size: -1 + # --- below: opmd --- + tau: 0.000 # strength of regularization w.r.t. old / ref policy + opmd_baseline: mean # mean / logavgexp, applicable to opmd + use_uid: False # True / False, applicable to pairwise_opmd + ref: + fsdp_config: + param_offload: False + wrap_policy: + # transformer_layer_cls_to_wrap: None + min_num_params: 0 + # log_prob_micro_batch_size: 4 # will be deprecated, use log_prob_micro_batch_size_per_gpu + log_prob_micro_batch_size_per_gpu: 16 + log_prob_use_dynamic_bsz: ${actor_rollout_ref.actor.use_dynamic_bsz} + log_prob_max_token_len_per_gpu: ${actor_rollout_ref.actor.ppo_max_token_len_per_gpu} + ulysses_sequence_parallel_size: ${actor_rollout_ref.actor.ulysses_sequence_parallel_size} # sp size + rollout: + name: vllm + temperature: 1.0 + use_fire_sampling: False # https://arxiv.org/abs/2410.21236 + prompt_length: ${data.max_prompt_length} # not use for opensource + response_length: ${data.max_response_length} + # for vllm rollout + dtype: bfloat16 # should align with FSDP + gpu_memory_utilization: 0.4 + ignore_eos: False + enforce_eager: True + free_cache_engine: True + load_format: dummy_dtensor + tensor_model_parallel_size: 2 + max_num_batched_tokens: 8192 + max_model_len: null + max_num_seqs: 1024 + # log_prob_micro_batch_size: 8 # will be deprecated, use log_prob_micro_batch_size_per_gpu + log_prob_micro_batch_size_per_gpu: 4 + log_prob_use_dynamic_bsz: ${actor_rollout_ref.actor.use_dynamic_bsz} + log_prob_max_token_len_per_gpu: ${actor_rollout_ref.actor.ppo_max_token_len_per_gpu} + disable_log_stats: True + enable_chunked_prefill: True # could get higher throughput + # for hf rollout + do_sample: True + # number of responses (i.e. num sample times) + n: 8 # > 1 for grpo + +critic: + strategy: fsdp + optim: + lr: 1e-5 + lr_warmup_steps_ratio: 0. # the total steps will be injected during runtime + # min_lr_ratio: null # only useful for warmup with cosine + warmup_style: constant # select from constant/cosine + total_training_steps: -1 # must be override by program + model: + path: /PATH/TO/MODEL/ + tokenizer_path: ${actor_rollout_ref.model.path} + override_config: { } + external_lib: ${actor_rollout_ref.model.external_lib} + enable_gradient_checkpointing: True + use_remove_padding: False + fsdp_config: + param_offload: False + optimizer_offload: False + wrap_policy: + # transformer_layer_cls_to_wrap: None + min_num_params: 0 + fsdp_size: -1 + ppo_mini_batch_size: ${actor_rollout_ref.actor.ppo_mini_batch_size} + # ppo_micro_batch_size: 8 # will be deprecated, use ppo_micro_batch_size_per_gpu + ppo_micro_batch_size_per_gpu: 64 + forward_micro_batch_size_per_gpu: ${critic.ppo_micro_batch_size_per_gpu} + use_dynamic_bsz: ${actor_rollout_ref.actor.use_dynamic_bsz} + ppo_max_token_len_per_gpu: 32768 # (${actor_rollout_ref.actor.ppo_max_token_len_per_gpu}) * 2 + forward_max_token_len_per_gpu: ${critic.ppo_max_token_len_per_gpu} + ulysses_sequence_parallel_size: 1 # sp size + ppo_epochs: ${actor_rollout_ref.actor.ppo_epochs} + shuffle: ${actor_rollout_ref.actor.shuffle} + grad_clip: 1.0 + cliprange_value: 0.5 + +reward_model: + enable: False + strategy: fsdp + model: + input_tokenizer: ${actor_rollout_ref.model.path} # set this to null if the chat template is identical + path: ~/models/FsfairX-LLaMA3-RM-v0.1 + external_lib: ${actor_rollout_ref.model.external_lib} + use_remove_padding: False + fsdp_config: + min_num_params: 0 + param_offload: False + fsdp_size: -1 + # micro_batch_size: null # will be deprecated, use micro_batch_size_per_gpu + # micro_batch_size_per_gpu: 2 # set a number + # max_length: null + ulysses_sequence_parallel_size: 1 # sp size + use_dynamic_bsz: ${critic.use_dynamic_bsz} + forward_max_token_len_per_gpu: ${critic.forward_max_token_len_per_gpu} + +custom_reward_function: + path: null + name: compute_score + +algorithm: + gamma: 1.0 + lam: 1.0 + adv_estimator: grpo + kl_penalty: kl # how to estimate kl divergence + kl_ctrl: + type: fixed + kl_coef: 0.001 + +trainer: + balance_batch: True + total_epochs: 10 + # total_training_steps: null + project_name: rft_example_gsm8k + experiment_name: cys-qwen2_1.5b_rollout8_grpo_kl0.001_lr1e-5 + logger: [ 'console','wandb' ] + val_generations_to_log_to_wandb: 0 + nnodes: 1 + n_gpus_per_node: 2 + save_freq: 100 + # auto: find the last ckpt to resume. If can't find, start from scratch + resume_mode: auto # or auto or resume_path if + test_freq: 5 + critic_warmup: 0 + default_hdfs_dir: null + remove_previous_ckpt_in_save: False + del_local_ckpt_after_load: False + default_local_dir: checkpoints/${trainer.project_name}/${trainer.experiment_name} + val_before_train: False diff --git a/trinity/buffer/reader/queue_reader.py b/trinity/buffer/reader/queue_reader.py index 8ff4ef4870..519479848c 100644 --- a/trinity/buffer/reader/queue_reader.py +++ b/trinity/buffer/reader/queue_reader.py @@ -21,6 +21,7 @@ def __init__(self, meta: DatasetConfig, config: BufferConfig): self.config = config self.queue = QueueActor.options( name=f"queue-{meta.name}", + namespace=meta.namespace, get_if_exists=True, ).remote(meta, config) diff --git a/trinity/buffer/writer/queue_writer.py b/trinity/buffer/writer/queue_writer.py index 9d4f2a83fd..14fd9e5b28 100644 --- a/trinity/buffer/writer/queue_writer.py +++ b/trinity/buffer/writer/queue_writer.py @@ -20,6 +20,7 @@ def __init__(self, meta: DatasetConfig, config: BufferConfig): self.config = config self.queue = QueueActor.options( name=f"queue-{meta.name}", + namespace=meta.namespace, get_if_exists=True, ).remote(meta, config) diff --git a/trinity/common/config.py b/trinity/common/config.py index 6dc3510ddc..1c795d1e5c 100644 --- a/trinity/common/config.py +++ b/trinity/common/config.py @@ -109,6 +109,7 @@ class DatasetConfig: storage_type: StorageType algorithm_type: AlgorithmType = AlgorithmType.PPO path: Optional[str] = None + namespace: str = "" # automatically generated kwargs: Dict[str, Any] = field(default_factory=dict) @@ -289,7 +290,9 @@ def _check_buffer(self) -> None: if self.buffer.train_dataset is None: raise ValueError("buffer.train_dataset is required when mode is not 'both'") self.buffer.train_dataset.algorithm_type = self.trainer.algorithm_type + self.buffer.train_dataset.namespace = f"{self.monitor.project}-{self.monitor.name}" if self.buffer.sft_warmup_dataset is not None: + self.buffer.sft_warmup_dataset.namespace = f"{self.monitor.project}-{self.monitor.name}" self.buffer.sft_warmup_dataset.algorithm_type = AlgorithmType.SFT self.buffer.read_batch_size = self.data.batch_size * self.explorer.repeat_times