diff --git a/docs/sphinx_doc/source/tutorial/example_async_mode.md b/docs/sphinx_doc/source/tutorial/example_async_mode.md index ed0fc87596..a1b09d873f 100644 --- a/docs/sphinx_doc/source/tutorial/example_async_mode.md +++ b/docs/sphinx_doc/source/tutorial/example_async_mode.md @@ -7,17 +7,15 @@ Trinity-RFT supports an asynchronous mode by running the trainer and explorer in For this purpose, we prepare two main config files: `trainer.yaml` and `explorer.yaml`. The main difference between them is that in `trainer.yaml` we set `mode=train`, while in `explorer.yaml` we set `mode=explore`. In addition, we need to configure the following parameters in both files. -The model weights of the explorer and trainer are synchronized once every `sync_iteration_interval * batch_size` tasks. +The model weights of the explorer and trainer are synchronized once every `sync_interval * batch_size` tasks. ```yaml -global_config: - batch_size: -# The same checkpoint path -model: - checkpoint_path: /PATH/TO/CHECKPOINT +project: tutorial +name: async_mode_example +checkpoint_root_dir: /PATH/TO/CHECKPOINT -# The same data_base path buffer: + batch_size: trainer_input: experience_buffer: name: gsm8k_buffer @@ -26,7 +24,7 @@ buffer: synchronizer: sync_method: 'checkpoint' - sync_iteration_interval: + sync_interval: ``` You may run this example with the following command: diff --git a/docs/sphinx_doc/source/tutorial/example_dpo.md b/docs/sphinx_doc/source/tutorial/example_dpo.md index 8bbf3d9199..953af399e5 100644 --- a/docs/sphinx_doc/source/tutorial/example_dpo.md +++ b/docs/sphinx_doc/source/tutorial/example_dpo.md @@ -45,6 +45,8 @@ We run the experiment in a train mode, as there is no Explorer. To enable this m ```yaml # In dpo.yaml mode: train +algorithm: + algorithm_type: dpo synchronizer: sync_method: 'checkpoint' buffer: @@ -56,14 +58,9 @@ buffer: prompt_key: chosen_key: rejected_key: -global_config: - algorithm_type: dpo - -# In train_dpo.yaml -actor_rollout_ref: - actor: - use_kl_loss: True - kl_loss_coef: 0.1 # value of beta in DPO +trainer: + actor_use_kl_loss: True + actor_kl_loss_coef: 0.1 # value of beta in DPO ``` ### Run the Experiment diff --git a/docs/sphinx_doc/source/tutorial/example_reasoning_basic.md b/docs/sphinx_doc/source/tutorial/example_reasoning_basic.md index de785d18fe..1cda68fc50 100644 --- a/docs/sphinx_doc/source/tutorial/example_reasoning_basic.md +++ b/docs/sphinx_doc/source/tutorial/example_reasoning_basic.md @@ -53,16 +53,13 @@ We use the configurations in [`gsm8k.yaml`](https://github.com/modelscope/Trinit ```yaml # In gsm8k.yaml -explorer: +algorithm: + algorithm_type: grpo / ppo repeat_times: {number of rollouts for each task} -# In train_gsm8k.yaml -actor_rollout_ref: - actor: - use_kl_loss: True (fro GRPO) / False (for PPO) - kl_loss_coef: 0.001 -algorithm: - adv_estimator: grpo (fro GRPO) / gae (for PPO) +trainer: + actor_use_kl_loss: True (fro GRPO) / False (for PPO) + actort_kl_loss_coef: 0.001 ``` ### Run the Experiment @@ -76,20 +73,20 @@ trinity run --config examples/grpo_gsm8k/gsm8k.yaml ## Optional: RFT with SFT Warmup -Before RFT, we may use SFT as a warmup step. We need to set `trainer.sft_warmup_steps > 0` and prepare the SFT data to `buffer.train_dataset.path=$DATASET_PATH/{sft_data}`. +Before RFT, we may use SFT as a warmup step. We need to set `buffer.trainer_input.sft_warmup_steps > 0` and prepare the SFT data to `buffer.trainer_input.sft_warmup_dataset.path=$DATASET_PATH/{sft_data}`. ```yaml # Properly set the following configs in gsm8k.yaml buffer: - sft_warmup_dataset: - storage_type: file - path: <$DATASET_PATH/{sft_data}> - format: - prompt_type: # messages/plaintext/chatpair - prompt_key: - response_key: -trainer: - sft_warmup_steps: 10 + trainer_input: + sft_warmup_dataset: + storage_type: file + path: <$DATASET_PATH/{sft_data}> + format: + prompt_type: # messages/plaintext/chatpair + prompt_key: + response_key: + sft_warmup_steps: 10 ``` The following command runs SFT and RFT in sequence: diff --git a/docs/sphinx_doc/source/tutorial/trinity_configs.md b/docs/sphinx_doc/source/tutorial/trinity_configs.md index ec9372cc34..255bb58018 100644 --- a/docs/sphinx_doc/source/tutorial/trinity_configs.md +++ b/docs/sphinx_doc/source/tutorial/trinity_configs.md @@ -5,33 +5,36 @@ The following is the main config file for Trinity-RFT. Take `countdown.yaml` as ## Global Config ```yaml +project: Trinity-RFT +name: example mode: both -global_config: - algorithm_type: ppo - total_epochs: 1 - batch_size: 96 - eval_interval: 1000 - eval_on_latest_ckp: true +checkpoint_root_dir: /PATH/TO/CHECKPOINT ``` +- `project`: The name of the project. +- `name`: The name of the experiment. - `mode`: The mode of the experiment, chosen from `both`, `train`, `explore` or `bench`. `both` means both trainer and explorer are launched; `train` means only trainer is launched; `explore` means only explorer is launched; `bench` conducts benchmark evaluation. Default is `both`. -- `global_config.algorithm_type`: The type of the algorithm, Support `ppo`, `grpo`, `opmd` and `dpo`. -- `global_config.total_epochs`: The total number of epochs. It should be checked manually. -- `global_config.batch_size`: The batch size used for training. It should be checked manually. -- `global_config.eval_interval`: The interval steps between two evaluations. Default is `1000`. -- `global_config.eval_on_latest_ckp`: Whether to evaluate on only the latest checkpoint or all the checkpoints in the path. Only valid in `bench` mode. Default is `true`. +- `checkpoint_root_dir`: The root directory to save the checkpoints. Sepcifically, the generated checkpoints will be saved in `///. +## Algorithm + +```yaml +algorithm: + algorithm_type: grpo + repeat_times: 1 +``` + +- `algorithm.algorithm_type`: The type of the algorithm. Support `ppo`, `grpo`, `opmd` and `dpo`. +- `algorithm.repeat_times`: The number of times to repeat each task. Used for GRPO-like algorithm. Default is `1`. ## Monitor ```yaml monitor: - project: "Trinity-RFT-countdown" - name: "qwen2.5-1.5B-countdown" + monitor_type: MonitorType.WANDB ``` -- `monitor.project`: The project name. It must be set manually. -- `monitor.name`: The name of the experiment. It must be set manually. +- `monitor.monitor_type`: The type of the monitor. For now, `MonitorType.WANDB` and `MonitorType.TENSORBOARD` are supported. ## Data Processing @@ -69,16 +72,11 @@ The `model` configuration specifies the model used for training. It includes the model: model_path: '/PATH/TO/MODEL/CHECKPOINT/' critic_model_path: '' - max_prompt_tokens: 256 - max_response_tokens: 1024 - checkpoint_path: 'checkpoints/qwen2.5-1.5B-countdown' ``` - `model.model_path`: The path to the model checkpoint. It must be set manually. - `model.critic_model_path`: The path to the critic model checkpoint. If not set, the `model.critic_model_path` will be set to `model.model_path`. -- `model.max_prompt_tokens`: The maximum number of tokens in the prompt. Default is `2048`. It should be set manually. -- `model.max_response_tokens`: The maximum number of tokens in the response. Default is `2048`. It should be set manually. -- `model.checkpoint_path`: The path to the checkpoint of the model. It must be set manually. + ## Cluster @@ -108,7 +106,7 @@ buffer: prompt_key: 'question' response_key: 'answer' rollout_args: - repeat_times: 1 + n: 1 temperature: 1.0 logprobs: 0 eval_tasksets: [] @@ -129,7 +127,7 @@ buffer: - `buffer.explorer_input.taskset.path`: The path to the taskset. - `buffer.explorer_input.taskset.split`: The split name of the taskset used for training. Default is `train`. - `buffer.explorer_input.taskset.format`: The format of the taskset. It includes `prompt_key`, `response_key`, `workflow_key` and `reward_fn_key`. -- `buffer.explorer_input.taskset.rollout_args.repeat_times`: The number of times to repeat each task, used for GRPO-like algorithms. Default is `1`. +- `buffer.explorer_input.taskset.rollout_args.n`: The number of times to repeat each task. This field is automatically set to `algorithm.repeat_times`. - `buffer.explorer_input.taskset.rollout_args.temperature`: The temperature used in vLLM. Default is `1.0`. - `buffer.explorer_input.taskset.rollout_args.logprobs`: The logprobs used in vLLM. Default is `0`. - `buffer.explorer_input.eval_tasksets`: The configuration of the eval tasksets. It is a list of tasksets which will be used for evaluation. And it is empty by default. @@ -143,22 +141,19 @@ buffer: ## Explorer -The `explorer` configuration specifies the explorer configuration. It includes the type of the engine, the number of engines, the number of workflow runners, the tensor parallel size, whether to enable prefix caching, whether to enforce eager mode, the data type, the `temperature`, the `top-p`, the `top-k`, the `seed`, the `logprobs`, the number of times to repeat each task, whether to use Ray, the backend, the maximum number of pending requests, and the maximum number of waitingsteps. +The `explorer` configuration specifies the explorer configuration. It includes the type of the engine, the number of engines, the number of workflow runners, the tensor parallel size, whether to enable prefix caching, whether to enforce eager mode, the data type, the `temperature`, the `top-p`, the `top-k`, the `seed`, the `logprobs`, the number of times to repeat each task, the maximum number of pending requests, and the maximum number of waitingsteps. ```yaml explorer: - engine_type: vllm_async - engine_num: 2 runner_num: 32 - tensor_parallel_size: 1 - enable_prefix_caching: false - enforce_eager: true - dtype: bfloat16 - seed: 42 - use_ray: false - backend: 'nccl' - max_pending_requests: 32 - max_waiting_steps: 4 + rollout_model: + engine_type: vllm_async + engine_num: 2 + tensor_parallel_size: 1 + enable_prefix_caching: false + enforce_eager: true + dtype: bfloat16 + seed: 42 ``` - `explorer.engine_type`: The type of the engine, Support `vllm_async` and `vllm_sync`. Default is `vllm_async`. @@ -169,10 +164,8 @@ explorer: - `explorer.enforce_eager`: Whether to enforce eager mode. Default is `True`. - `explorer.dtype`: The data type used in vLLM. Default is `bfloat16`. - `explorer.seed`: The seed used in vLLM. Default is `42`. -- `explorer.use_ray`: Whether to use Ray. Default is `False`. -- `explorer.backend`: The backend used in vLLM. Default is `nccl`. -- `explorer.max_pending_requests`: The maximum number of pending requests. Default is `32`. -- `explorer.max_waiting_steps`: The maximum number of waiting steps. Default is `4`. +- `explorer.rollout_model.max_prompt_tokens`: The maximum number of tokens in the prompt. Default is `2048`. It should be set manually. +- `explorer.rollout_model.max_response_tokens`: The maximum number of tokens in the response. Default is `2048`. It should be set manually. ## Synchronizer @@ -195,15 +188,11 @@ Support `nccl` and `checkpoint`, `nccl` represents that model weights in `explor trainer: trainer_type: 'verl' trainer_config_path: 'examples/ppo_countdown/train_countdown.yaml' - sft_warmup_steps: 0 - eval_interval: 1000 save_interval: 100 ``` - `trainer.trainer_type`: The backend of the trainer, Only `verl` is supported. - `trainer.trainer_config_path`: The path to the trainer configuration file. It must be set manually. -- `trainer.sft_warmup_steps`: The number of steps to warm up the model. Default is `0`. -- `trainer.eval_interval`: The interval steps between two evaluations. Default is `1000`. - `trainer.save_interval`: The interval steps between two checkpoints. Default is `100`. ### veRL Trainer Configuration diff --git a/docs/sphinx_doc/source/tutorial/trinity_programming_guide.md b/docs/sphinx_doc/source/tutorial/trinity_programming_guide.md index dc8e7e676f..7b7cc1dc76 100644 --- a/docs/sphinx_doc/source/tutorial/trinity_programming_guide.md +++ b/docs/sphinx_doc/source/tutorial/trinity_programming_guide.md @@ -116,7 +116,7 @@ class ExampleWorkflow(Workflow): "content": f"Question:\n{self.question}", } ], - n=self.task.rollout_args.repeat_times, + n=self.task.rollout_args.n, temperature=self.task.rollout_args.temperature, ) reward: float = self.calculate_reward(response.response_text, self.answer) diff --git a/examples/async_gsm8k/explorer.yaml b/examples/async_gsm8k/explorer.yaml index 8402e4ced2..a05b2ebfcf 100644 --- a/examples/async_gsm8k/explorer.yaml +++ b/examples/async_gsm8k/explorer.yaml @@ -1,18 +1,20 @@ +project: "Trinity-RFT-gsm8k" +name: "async-qwen2.5-1.5B-gsm8k" mode: explore -global_config: - total_epochs: 20 - batch_size: 96 - eval_interval: 10 +checkpoint_root_dir: '/PATH/TO/CHECKPOINT/' +algorithm: algorithm_type: grpo + repeat_times: 8 model: model_path: /PATH/TO/MODEL/ max_prompt_tokens: 256 max_response_tokens: 1024 - checkpoint_path: 'checkpoints/qwen2.5-1.5B-gsm8k' cluster: node_num: 1 gpu_per_node: 8 buffer: + total_epochs: 20 + batch_size: 96 max_retry_times: 3 max_retry_interval: 1 explorer_input: @@ -25,7 +27,6 @@ buffer: prompt_key: 'question' response_key: 'answer' rollout_args: - repeat_times: 8 temperature: 1.0 logprobs: 0 default_workflow_type: 'math_workflow' @@ -35,26 +36,19 @@ buffer: storage_type: queue path: 'sqlite:///gsm8k.db' explorer: - engine_type: vllm_async - engine_num: 2 + eval_interval: 10 runner_num: 32 - tensor_parallel_size: 1 - enable_prefix_caching: false - enforce_eager: true - dtype: bfloat16 - seed: 42 - use_ray: false - backend: 'nccl' - max_pending_requests: 32 - max_waiting_steps: 4 + rollout_model: + engine_type: vllm_async + engine_num: 2 + tensor_parallel_size: 1 + enable_prefix_caching: false + enforce_eager: true + dtype: bfloat16 + seed: 42 synchronizer: sync_method: 'checkpoint' - sync_iteration_interval: 10 + sync_interval: 10 trainer: trainer_type: 'verl' trainer_config_path: examples/async_gsm8k/verl_config.yaml - sft_warmup_steps: 0 # Set to integer to enable sft warmup -monitor: - cache_root_dir: "" - project: "Trinity-RFT-gsm8k" - name: "async-qwen2.5-1.5B-gsm8k" diff --git a/examples/async_gsm8k/trainer.yaml b/examples/async_gsm8k/trainer.yaml index 79a50337d2..d259cb7ca0 100644 --- a/examples/async_gsm8k/trainer.yaml +++ b/examples/async_gsm8k/trainer.yaml @@ -1,18 +1,20 @@ +project: "Trinity-RFT-gsm8k" +name: "async-qwen2.5-1.5B-gsm8k" mode: train -global_config: - total_epochs: 20 - batch_size: 96 - eval_interval: 10 +checkpoint_root_dir: /PATH/TO/CHECKPOINT +algorithm: algorithm_type: grpo + repeat_times: 8 model: - model_path: /PATH/TO/MODEL/ + model_path: /PATH/TO/MODEL max_prompt_tokens: 256 max_response_tokens: 1024 - checkpoint_path: "" cluster: node_num: 1 gpu_per_node: 8 buffer: + total_epochs: 20 + batch_size: 96 max_retry_times: 3 max_retry_interval: 1 explorer_input: @@ -24,7 +26,7 @@ buffer: prompt_key: 'question' response_key: 'answer' rollout_args: - repeat_times: 8 + n: 8 temperature: 1.0 logprobs: 0 default_workflow_type: 'math_workflow' @@ -34,26 +36,19 @@ buffer: storage_type: queue path: 'sqlite:///gsm8k.db' explorer: - engine_type: vllm_async - engine_num: 2 + eval_interval: 10 runner_num: 32 - tensor_parallel_size: 1 - enable_prefix_caching: false - enforce_eager: true - dtype: bfloat16 - seed: 42 - use_ray: false - backend: 'nccl' - max_pending_requests: 32 - max_waiting_steps: 4 + rollout_model: + engine_type: vllm_async + engine_num: 2 + tensor_parallel_size: 1 + enable_prefix_caching: false + enforce_eager: true + dtype: bfloat16 + seed: 42 synchronizer: sync_method: 'checkpoint' - sync_iteration_interval: 10 + sync_interval: 10 trainer: trainer_type: 'verl' trainer_config_path: examples/async_gsm8k/verl_config.yaml - sft_warmup_steps: 0 # Set to integer to enable sft warmup -monitor: - cache_root_dir: "" - project: "Trinity-RFT-gsm8k" - name: "async-qwen2.5-1.5B-gsm8k" diff --git a/examples/dpo_humanlike/dpo.yaml b/examples/dpo_humanlike/dpo.yaml index 825788a792..4648d6e493 100644 --- a/examples/dpo_humanlike/dpo.yaml +++ b/examples/dpo_humanlike/dpo.yaml @@ -1,17 +1,19 @@ +project: "dpo_example" +name: "trinity_dpo" mode: train -global_config: - total_epochs: 20 - batch_size: 32 # NOTE +algorithm: algorithm_type: dpo +checkpoint_root_dir: /PATH/TO/CHECKPOINT model: - model_path: '/PATH/TO/MODEL/CHECKPOINT/' # NOTE + model_path: '/PATH/TO/MODEL' max_prompt_tokens: 1792 max_response_tokens: 256 - checkpoint_path: 'checkpoints/trinity_dpo' cluster: node_num: 1 gpu_per_node: 8 buffer: + total_epochs: 20 + batch_size: 32 max_retry_times: 3 max_retry_interval: 1 trainer_input: @@ -32,7 +34,5 @@ trainer: trainer_type: 'verl' trainer_config_path: 'examples/dpo_humanlike/train_dpo.yaml' save_interval: 30 -monitor: - cache_root_dir: "" - project: "dpo_example" - name: "trinity_dpo" + actor_use_kl_loss: True + actor_kl_loss_coef: 0.1 diff --git a/examples/dpo_humanlike/train_dpo.yaml b/examples/dpo_humanlike/train_dpo.yaml index 65d373b4fa..09327877f9 100644 --- a/examples/dpo_humanlike/train_dpo.yaml +++ b/examples/dpo_humanlike/train_dpo.yaml @@ -32,7 +32,7 @@ actor_rollout_ref: grad_clip: 1.0 clip_ratio: 0.2 entropy_coeff: 0.001 - use_kl_loss: True # NOTE + use_kl_loss: True kl_loss_coef: 0.1 # NOTE: beta for DPO kl_loss_type: low_var_kl # for grpo ppo_epochs: 1 diff --git a/examples/grpo_alfworld/alfworld.yaml b/examples/grpo_alfworld/alfworld.yaml index 08a9fd8e42..8323ef8591 100644 --- a/examples/grpo_alfworld/alfworld.yaml +++ b/examples/grpo_alfworld/alfworld.yaml @@ -1,16 +1,19 @@ -global_config: - total_epochs: 20 - batch_size: 4 +project: "ALFWORLD" +name: "ALFWORLD_RFT" +checkpoint_root_dir: /PATH/TO/CHECKPOINT/ALFWORLD_RFT/ +algorithm: algorithm_type: grpo + repeat_times: 8 model: - model_path: '/PATH/TO/MODEL/CHECKPOINT/' + model_path: /PATH/TO/MODEL/ max_prompt_tokens: 4096 max_response_tokens: 16384 - checkpoint_path: 'checkpoints/ALFWORLD_RFT' cluster: node_num: 1 gpu_per_node: 8 buffer: + total_epochs: 20 + batch_size: 4 max_retry_times: 3 max_retry_interval: 1 explorer_input: @@ -21,7 +24,6 @@ buffer: format: prompt_key: 'game_file' rollout_args: - repeat_times: 8 temperature: 1.0 logprobs: 0 default_workflow_type: 'alfworld_workflow' @@ -31,20 +33,17 @@ buffer: storage_type: queue path: 'sqlite:///alfworld.db' explorer: - engine_type: vllm_async - engine_num: 2 runner_num: 32 - tensor_parallel_size: 2 - enable_prefix_caching: false - enforce_eager: true - dtype: bfloat16 - seed: 42 - use_ray: false - backend: 'nccl' - max_pending_requests: 32 - max_waiting_steps: 4 - gpu_memory_utilization: 0.7 - enable_chunked_prefill: true + rollout_model: + engine_type: vllm_async + engine_num: 2 + tensor_parallel_size: 2 + enable_prefix_caching: false + enforce_eager: true + dtype: bfloat16 + seed: 42 + gpu_memory_utilization: 0.7 + enable_chunked_prefill: true synchronizer: sync_method: 'nccl' sync_interval: 8 @@ -53,7 +52,3 @@ trainer: trainer_type: 'verl' trainer_config_path: 'examples/grpo_alfworld/train_alfworld.yaml' save_interval: 10 -monitor: - cache_root_dir: "" - project: "ALFWORLD" - name: "ALFWORLD_RFT" diff --git a/examples/grpo_gsm8k/gsm8k.yaml b/examples/grpo_gsm8k/gsm8k.yaml index 71630adaac..de2c1d2d9e 100644 --- a/examples/grpo_gsm8k/gsm8k.yaml +++ b/examples/grpo_gsm8k/gsm8k.yaml @@ -1,3 +1,9 @@ +project: "Trinity-RFT-gsm8k" +name: "qwen2.5-1.5B-gsm8k" +checkpoint_root_dir: /PATH/TO/CHECKPOINT/ +algorithm: + algorithm_type: grpo + repeat_times: 8 data_processor: # basic info source_data_path: 'openai/gsm8k' @@ -11,20 +17,17 @@ data_processor: clean_strategy: 'iterative' # db related db_url: '' -global_config: - total_epochs: 1 - batch_size: 96 - eval_interval: 50 - algorithm_type: grpo + model: model_path: '/PATH/TO/MODEL/' max_prompt_tokens: 256 max_response_tokens: 1024 - checkpoint_path: "" cluster: node_num: 1 gpu_per_node: 8 buffer: + total_epochs: 1 + batch_size: 96 max_retry_times: 3 max_retry_interval: 1 explorer_input: @@ -38,7 +41,7 @@ buffer: prompt_key: 'question' response_key: 'answer' rollout_args: - repeat_times: 8 + n: 8 temperature: 1.0 logprobs: 0 eval_tasksets: @@ -56,25 +59,22 @@ buffer: name: gsm8k_buffer storage_type: queue path: 'sqlite:///gsm8k.db' + # sft_warmup_steps: 0 # sft_warmup_dataset: # Uncomment these to enable sft warmup # name: warmup_data # storage_type: file # path: '/PATH/TO/WARMUP_DATA/' - # kwargs: - # prompt_type: plaintext explorer: - engine_type: vllm_async - engine_num: 2 + eval_interval: 50 runner_num: 32 - tensor_parallel_size: 1 - enable_prefix_caching: false - enforce_eager: true - dtype: bfloat16 - seed: 42 - use_ray: false - backend: 'nccl' - max_pending_requests: 32 - max_waiting_steps: 4 + rollout_model: + engine_type: vllm_async + engine_num: 2 + tensor_parallel_size: 1 + enable_prefix_caching: false + enforce_eager: true + dtype: bfloat16 + seed: 42 synchronizer: sync_method: 'nccl' sync_interval: 2 @@ -82,10 +82,4 @@ synchronizer: trainer: trainer_type: 'verl' trainer_config_path: 'examples/grpo_gsm8k/train_gsm8k.yaml' - sft_warmup_steps: 0 # Set to integer to enable sft warmup save_interval: 100 - # get_exp_strategy: 'LFU' -monitor: - cache_root_dir: "" - project: "Trinity-RFT-gsm8k" - name: "qwen2.5-1.5B-gsm8k" diff --git a/examples/grpo_math/math.yaml b/examples/grpo_math/math.yaml index a1527d9255..daa1fd5fb7 100644 --- a/examples/grpo_math/math.yaml +++ b/examples/grpo_math/math.yaml @@ -1,17 +1,17 @@ -global_config: - total_epochs: 20 - batch_size: 288 - eval_interval: 10 - algorithm_type: grpo +project: grpo_math +name: grpo_math_example +checkpoint_root_dir: /PATH/TO/CHECKPOINT/ model: model_path: /PATH/TO/MODEL/ - max_prompt_tokens: 1024 - max_response_tokens: 3072 - checkpoint_path: /PATH/TO/CHECKPOINT/ +algorithm: + algorithm_type: grpo + repeat_times: 8 cluster: node_num: 1 gpu_per_node: 8 buffer: + total_epochs: 20 + batch_size: 288 max_retry_times: 3 max_retry_interval: 1 explorer_input: @@ -23,7 +23,7 @@ buffer: prompt_key: 'question' response_key: 'gt_answer' rollout_args: - repeat_times: 8 + n: 8 temperature: 1.0 logprobs: 0 default_workflow_type: 'math_workflow' @@ -33,18 +33,18 @@ buffer: storage_type: queue path: 'sqlite:///math.db' explorer: - engine_type: vllm_async - engine_num: 2 + eval_interval: 10 runner_num: 32 - tensor_parallel_size: 1 - enable_prefix_caching: false - enforce_eager: true - dtype: bfloat16 - seed: 42 - use_ray: false - backend: 'nccl' - max_pending_requests: 32 - max_waiting_steps: 4 + rollout_model: + engine_type: vllm_async + engine_num: 2 + tensor_parallel_size: 1 + enable_prefix_caching: false + enforce_eager: true + dtype: bfloat16 + max_prompt_tokens: 1024 + max_response_tokens: 3072 + seed: 42 synchronizer: sync_method: 'nccl' sync_interval: 2 @@ -52,9 +52,4 @@ synchronizer: trainer: trainer_type: 'verl' trainer_config_path: 'examples/grpo_math/train_math.yaml' - sft_warmup_steps: 0 # Set to integer to enable sft warmup save_interval: 100 -monitor: - cache_root_dir: "" - project: grpo_math - name: grpo_math_example diff --git a/examples/grpo_sciworld/sciworld.yaml b/examples/grpo_sciworld/sciworld.yaml index 43f7a13af3..799b2df800 100644 --- a/examples/grpo_sciworld/sciworld.yaml +++ b/examples/grpo_sciworld/sciworld.yaml @@ -1,16 +1,19 @@ -global_config: - total_epochs: 20 - batch_size: 4 +project: "sciworld" +name: "sciworld_RFT" +checkpoint_root_dir: /PATH/TO/CHECKPOINT/ +algorithm: algorithm_type: grpo + repeat_times: 8 model: - model_path: '/PATH/TO/MODEL/CHECKPOINT/' + model_path: /PATH/TO/MODEL/ max_prompt_tokens: 4096 max_response_tokens: 16384 - checkpoint_path: 'checkpoints/sciworld_RFT' cluster: node_num: 1 gpu_per_node: 8 buffer: + total_epochs: 20 + batch_size: 4 max_retry_times: 3 max_retry_interval: 1 explorer_input: @@ -21,7 +24,6 @@ buffer: format: prompt_key: 'task_desc' rollout_args: - repeat_times: 8 temperature: 1.0 logprobs: 0 default_workflow_type: 'sciworld_workflow' @@ -31,20 +33,17 @@ buffer: storage_type: queue path: 'sqlite:///sciworld.db' explorer: - engine_type: vllm_async - engine_num: 2 runner_num: 32 - tensor_parallel_size: 2 - enable_prefix_caching: false - enforce_eager: true - dtype: bfloat16 - seed: 42 - use_ray: false - backend: 'nccl' - max_pending_requests: 32 - max_waiting_steps: 4 - gpu_memory_utilization: 0.7 - enable_chunked_prefill: true + rollout_model: + engine_type: vllm_async + engine_num: 2 + tensor_parallel_size: 2 + enable_prefix_caching: false + enforce_eager: true + dtype: bfloat16 + seed: 42 + gpu_memory_utilization: 0.7 + enable_chunked_prefill: true synchronizer: sync_method: 'nccl' sync_interval: 8 @@ -53,7 +52,3 @@ trainer: trainer_type: 'verl' trainer_config_path: 'examples/grpo_sciworld/train_sciworld.yaml' save_interval: 10 -monitor: - cache_root_dir: "" - project: "sciworld" - name: "sciworld_RFT" diff --git a/examples/grpo_webshop/webshop.yaml b/examples/grpo_webshop/webshop.yaml index 0ae7563db2..a5ea2a310e 100644 --- a/examples/grpo_webshop/webshop.yaml +++ b/examples/grpo_webshop/webshop.yaml @@ -1,16 +1,19 @@ -global_config: - total_epochs: 20 - batch_size: 4 +project: "WEBSHOP" +name: "WEBSHOP_RFT" +checkpoint_root_dir: /PATH/TO/CHECKPOINT/ +algorithm: algorithm_type: grpo + repeat_times: 8 model: - model_path: '/PATH/TO/MODEL/CHECKPOINT/' + model_path: /PATH/TO/MODEL/ max_prompt_tokens: 4096 max_response_tokens: 16384 - checkpoint_path: 'checkpoints/WEBSHOP_RFT' cluster: node_num: 1 gpu_per_node: 8 buffer: + total_epochs: 20 + batch_size: 4 max_retry_times: 3 max_retry_interval: 1 explorer_input: @@ -21,7 +24,6 @@ buffer: format: prompt_key: 'task_id' rollout_args: - repeat_times: 8 temperature: 1.0 logprobs: 0 default_workflow_type: 'webshop_workflow' @@ -31,20 +33,17 @@ buffer: storage_type: queue path: 'sqlite:///webshop.db' explorer: - engine_type: vllm_async - engine_num: 2 runner_num: 8 - tensor_parallel_size: 2 - enable_prefix_caching: false - enforce_eager: true - dtype: bfloat16 - seed: 42 - use_ray: false - backend: 'nccl' - max_pending_requests: 32 - max_waiting_steps: 4 - gpu_memory_utilization: 0.7 - enable_chunked_prefill: true + rollout_model: + engine_type: vllm_async + engine_num: 2 + tensor_parallel_size: 2 + enable_prefix_caching: false + enforce_eager: true + dtype: bfloat16 + seed: 42 + gpu_memory_utilization: 0.7 + enable_chunked_prefill: true synchronizer: sync_method: 'nccl' sync_interval: 8 @@ -53,7 +52,3 @@ trainer: trainer_type: 'verl' trainer_config_path: 'examples/grpo_webshop/train_webshop.yaml' save_interval: 10 -monitor: - cache_root_dir: "" - project: "WEBSHOP" - name: "WEBSHOP_RFT" diff --git a/examples/opmd_gsm8k/opmd_gsm8k.yaml b/examples/opmd_gsm8k/opmd_gsm8k.yaml index 7cc502eff2..388c823dd5 100644 --- a/examples/opmd_gsm8k/opmd_gsm8k.yaml +++ b/examples/opmd_gsm8k/opmd_gsm8k.yaml @@ -1,16 +1,19 @@ -global_config: - total_epochs: 1 - batch_size: 96 +project: "Trinity-RFT-gsm8k-test-opmd" +name: "opmd_test" +checkpoint_root_dir: /PATH/TO/CHECKPOINT/ +algorithm: algorithm_type: opmd + repeat_times: 8 model: - model_path: '{path to models}/Qwen2.5-1.5B-Inst' - max_prompt_tokens: 256 - max_response_tokens: 1024 - checkpoint_path: '{path to checkpoints}/test-opmd-gsm8k/qwen2.5-1.5B-gsm8k-opmd-kl_0.001-entropy_0-tau_4-beta1_0.0-beta2_0.95-lr_2e-6-sync10' + model_path: /PATH/TO/MODEL/ + max_prompt_tokens: 4096 + max_response_tokens: 16384 cluster: node_num: 1 gpu_per_node: 8 buffer: + total_epochs: 1 + batch_size: 96 max_retry_times: 3 max_retry_interval: 1 explorer_input: @@ -22,7 +25,7 @@ buffer: prompt_key: 'question' response_key: 'answer' rollout_args: - repeat_times: 8 + n: 8 temperature: 1.0 logprobs: 0 default_workflow_type: 'math_workflow' @@ -32,18 +35,15 @@ buffer: storage_type: queue path: 'sqlite:///gsm8k_opmd.db' explorer: - engine_type: vllm_async - engine_num: 2 runner_num: 32 - tensor_parallel_size: 1 - enable_prefix_caching: false - enforce_eager: true - dtype: bfloat16 - seed: 42 - use_ray: false - backend: 'nccl' - max_pending_requests: 32 - max_waiting_steps: 4 + rollout_model: + engine_type: vllm_async + engine_num: 2 + tensor_parallel_size: 1 + enable_prefix_caching: false + enforce_eager: true + dtype: bfloat16 + seed: 42 synchronizer: sync_method: 'nccl' sync_interval: 10 @@ -51,9 +51,4 @@ synchronizer: trainer: trainer_type: 'verl' trainer_config_path: 'examples/opmd_gsm8k/train_opmd_gsm8k.yaml' - sft_warmup_steps: 0 save_interval: 100 -monitor: - cache_root_dir: "" - project: "Trinity-RFT-gsm8k-test-opmd" - name: "qwen2.5-1.5B-gsm8k-opmd-kl_0.001-entropy_0-tau_4-beta1_0.0-beta2_0.95-lr_2e-6-sync10" diff --git a/examples/ppo_countdown/countdown.yaml b/examples/ppo_countdown/countdown.yaml index c428a167b4..47b33fe136 100644 --- a/examples/ppo_countdown/countdown.yaml +++ b/examples/ppo_countdown/countdown.yaml @@ -1,17 +1,19 @@ -global_config: - total_epochs: 20 - batch_size: 96 - eval_interval: 1000 +project: "Trinity-RFT-countdown" +name: "qwen2.5-1.5B-countdown" +checkpoint_root_dir: /PATH/TO/CHECKPOINT/ +algorithm: algorithm_type: ppo + repeat_times: 5 model: model_path: '/PATH/TO/MODEL/CHECKPOINT/' max_prompt_tokens: 256 max_response_tokens: 1024 - checkpoint_path: 'checkpoints/qwen2.5-1.5B-countdown' cluster: node_num: 1 gpu_per_node: 8 buffer: + total_epochs: 20 + batch_size: 96 max_retry_times: 3 max_retry_interval: 1 explorer_input: @@ -23,7 +25,6 @@ buffer: prompt_key: 'question' response_key: 'answer' rollout_args: - repeat_times: 5 temperature: 1.0 logprobs: 0 default_workflow_type: 'math_workflow' @@ -34,18 +35,16 @@ buffer: storage_type: queue path: 'sqlite:///countdown.db' explorer: - engine_type: vllm_async - engine_num: 2 + eval_interval: 100 runner_num: 32 - tensor_parallel_size: 1 - enable_prefix_caching: false - enforce_eager: true - dtype: bfloat16 - seed: 42 - use_ray: false - backend: 'nccl' - max_pending_requests: 32 - max_waiting_steps: 4 + rollout_model: + engine_type: vllm_async + engine_num: 2 + tensor_parallel_size: 1 + enable_prefix_caching: false + enforce_eager: true + dtype: bfloat16 + seed: 42 synchronizer: sync_method: 'nccl' sync_interval: 10 @@ -53,9 +52,4 @@ synchronizer: trainer: trainer_type: 'verl' trainer_config_path: 'examples/ppo_countdown/train_countdown.yaml' - sft_warmup_steps: 0 save_interval: 100 -monitor: - cache_root_dir: "" - project: "Trinity-RFT-countdown" - name: "qwen2.5-1.5B-countdown" diff --git a/tests/common/config_test.py b/tests/common/config_test.py index 001a592c42..35b9a4f9c7 100644 --- a/tests/common/config_test.py +++ b/tests/common/config_test.py @@ -4,19 +4,33 @@ import unittest from tests.tools import get_template_config -from trinity.common.config import load_config +from trinity.common.config import InferenceModelConfig, load_config class TestConfig(unittest.TestCase): def test_load_default_config(self): config = get_template_config() + config.buffer.batch_size = 8 + config.algorithm.repeat_times = 10 + config.model.model_path = "Qwen/Qwen3-1.7B" + config.cluster.gpu_per_node = 8 + config.cluster.node_num = 2 + config.explorer.rollout_model.engine_num = 2 + config.explorer.rollout_model.tensor_parallel_size = 2 + config.explorer.auxiliary_models.append( + InferenceModelConfig(model_path="Qwen/Qwen3-32B", tensor_parallel_size=4, engine_num=1), + ) config.check_and_update() self.assertIsNotNone(config.trainer.trainer_config) - self.assertEqual(config.trainer.trainer_config.trainer.n_gpus_per_node, 2) + self.assertEqual(config.trainer.trainer_config.trainer.n_gpus_per_node, 8) self.assertEqual(config.trainer.trainer_config.trainer.nnodes, 1) - self.assertEqual(config.trainer.trainer_config.trainer.project_name, config.monitor.project) - self.assertEqual(config.trainer.trainer_config.trainer.experiment_name, config.monitor.name) - self.assertEqual(config.trainer.trainer_config.trainer.project_name, config.monitor.project) + self.assertEqual(config.trainer.trainer_config.trainer.project_name, config.project) + self.assertEqual(config.trainer.trainer_config.trainer.experiment_name, config.name) + self.assertEqual( + config.buffer.explorer_input.taskset.rollout_args.n, config.algorithm.repeat_times + ) + self.assertEqual(config.model.model_path, config.model.critic_model_path) + self.assertEqual(config.model.model_path, config.explorer.rollout_model.model_path) self.assertEqual( config.trainer.trainer_config.trainer.save_freq, config.synchronizer.sync_interval, diff --git a/tests/common/vllm_test.py b/tests/common/vllm_test.py index f33dc06903..0146eb075c 100644 --- a/tests/common/vllm_test.py +++ b/tests/common/vllm_test.py @@ -85,9 +85,9 @@ def get_model_path() -> str: class BaseTestModelWrapper: def test_generate(self): prompts = ["Hello, world!", "Hello, my name is"] - repeat_times = self.config.buffer.explorer_input.taskset.rollout_args.repeat_times - results = self.model_wrapper.generate(prompts, n=repeat_times, temperature=1.0) - self.assertEqual(len(results), len(prompts) * repeat_times) + n = self.config.algorithm.repeat_times + results = self.model_wrapper.generate(prompts, n=n, temperature=1.0) + self.assertEqual(len(results), len(prompts) * n) messages = [ {"role": "system", "content": "You are a helpful assistant."}, {"role": "user", "content": "What's the weather like today?"}, @@ -97,8 +97,8 @@ def test_generate(self): }, {"role": "user", "content": "OK, thanks!"}, ] - results = self.model_wrapper.chat(messages, n=repeat_times, temperature=1.0) - self.assertEqual(len(results), repeat_times) + results = self.model_wrapper.chat(messages, n=n, temperature=1.0) + self.assertEqual(len(results), n) for result in results: input_logprobs = result.logprobs[: result.prompt_length] output_logprobs = result.logprobs[result.prompt_length :] @@ -133,13 +133,15 @@ def test_generate(self): class TestModelWrapperSyncV0(BaseTestModelWrapper, RayUnittestBase): def setUp(self): self.config = get_template_config() + self.config.mode = "explore" self.config.model.model_path = get_model_path() - self.config.explorer.engine_type = "vllm" - self.config.explorer.tensor_parallel_size = 1 - self.config.explorer.engine_num = 2 - self.config.buffer.explorer_input.taskset.rollout_args.repeat_times = 2 - self.config.explorer.use_v1 = False - self.config.explorer.chat_template = CHAT_TEMPLATE + self.config.explorer.rollout_model.engine_type = "vllm" + self.config.explorer.rollout_model.tensor_parallel_size = 1 + self.config.explorer.rollout_model.engine_num = 2 + self.config.explorer.rollout_model.use_v1 = False + self.config.explorer.rollout_model.chat_template = CHAT_TEMPLATE + self.config.algorithm.repeat_times = 2 + self.config.check_and_update() self.engines, self.auxiliary_engines = create_inference_models(self.config) self.model_wrapper = ModelWrapper(self.engines[0], model_type="vllm") @@ -147,13 +149,15 @@ def setUp(self): class TestModelWrapperAsyncV0(BaseTestModelWrapper, RayUnittestBase): def setUp(self): self.config = get_template_config() + self.config.mode = "explore" self.config.model.model_path = get_model_path() - self.config.explorer.engine_type = "vllm_async" - self.config.explorer.engine_num = 2 - self.config.explorer.tensor_parallel_size = 1 - self.config.buffer.explorer_input.taskset.rollout_args.repeat_times = 2 - self.config.explorer.use_v1 = False - self.config.explorer.chat_template = CHAT_TEMPLATE + self.config.explorer.rollout_model.engine_type = "vllm_async" + self.config.explorer.rollout_model.engine_num = 2 + self.config.explorer.rollout_model.tensor_parallel_size = 1 + self.config.explorer.rollout_model.use_v1 = False + self.config.explorer.rollout_model.chat_template = CHAT_TEMPLATE + self.config.algorithm.repeat_times = 2 + self.config.check_and_update() self.engines, self.auxiliary_engines = create_inference_models(self.config) self.model_wrapper = ModelWrapper(self.engines[0], model_type="vllm_async") @@ -161,12 +165,14 @@ def setUp(self): class TestModelWrapperAsyncTPV0(BaseTestModelWrapper, RayUnittestBase): def setUp(self): self.config = get_template_config() + self.config.mode = "explore" self.config.model.model_path = get_model_path() - self.config.explorer.engine_type = "vllm_async" - self.config.explorer.engine_num = 2 - self.config.explorer.tensor_parallel_size = 2 - self.config.explorer.use_v1 = False - self.config.explorer.chat_template = CHAT_TEMPLATE + self.config.explorer.rollout_model.engine_type = "vllm_async" + self.config.explorer.rollout_model.engine_num = 2 + self.config.explorer.rollout_model.tensor_parallel_size = 2 + self.config.explorer.rollout_model.use_v1 = False + self.config.explorer.rollout_model.chat_template = CHAT_TEMPLATE + self.config.check_and_update() self.engines, self.auxiliary_engines = create_inference_models(self.config) self.model_wrapper = ModelWrapper(self.engines[0], model_type="vllm_async") @@ -174,13 +180,15 @@ def setUp(self): class TestModelWrapperAsyncTPV1(BaseTestModelWrapper, RayUnittestBase): def setUp(self): self.config = get_template_config() + self.config.mode = "explore" self.config.model.model_path = get_model_path() - self.config.explorer.engine_type = "vllm_async" - self.config.explorer.engine_num = 2 - self.config.explorer.tensor_parallel_size = 2 - self.config.buffer.explorer_input.taskset.rollout_args.repeat_times = 2 - self.config.explorer.use_v1 = True - self.config.explorer.chat_template = CHAT_TEMPLATE + self.config.explorer.rollout_model.engine_type = "vllm_async" + self.config.explorer.rollout_model.engine_num = 2 + self.config.explorer.rollout_model.tensor_parallel_size = 2 + self.config.explorer.rollout_model.use_v1 = True + self.config.explorer.rollout_model.chat_template = CHAT_TEMPLATE + self.config.algorithm.repeat_times = 2 + self.config.check_and_update() self.engines, self.auxiliary_engines = create_inference_models(self.config) self.model_wrapper = ModelWrapper(self.engines[0], model_type="vllm_async") @@ -188,12 +196,14 @@ def setUp(self): class TestModelWrapperAsyncV1(BaseTestModelWrapper, RayUnittestBase): def setUp(self): self.config = get_template_config() + self.config.mode = "explore" self.config.model.model_path = get_model_path() - self.config.explorer.engine_type = "vllm_async" - self.config.explorer.engine_num = 2 - self.config.explorer.tensor_parallel_size = 1 - self.config.explorer.use_v1 = True - self.config.explorer.chat_template = CHAT_TEMPLATE + self.config.explorer.rollout_model.engine_type = "vllm_async" + self.config.explorer.rollout_model.engine_num = 2 + self.config.explorer.rollout_model.tensor_parallel_size = 1 + self.config.explorer.rollout_model.use_v1 = True + self.config.explorer.rollout_model.chat_template = CHAT_TEMPLATE + self.config.check_and_update() self.engines, self.auxiliary_engines = create_inference_models(self.config) self.model_wrapper = ModelWrapper(self.engines[0], model_type="vllm_async") @@ -201,13 +211,15 @@ def setUp(self): class TestAPIServer(RayUnittestBase): def setUp(self): self.config = get_template_config() + self.config.mode = "explore" self.config.model.model_path = get_model_path() - self.config.explorer.engine_type = "vllm_async" - self.config.explorer.engine_num = 1 - self.config.explorer.tensor_parallel_size = 1 - self.config.explorer.use_v1 = True - self.config.explorer.chat_template = CHAT_TEMPLATE - self.config.explorer.enable_openai_api = True + self.config.explorer.rollout_model.engine_type = "vllm_async" + self.config.explorer.rollout_model.engine_num = 1 + self.config.explorer.rollout_model.tensor_parallel_size = 1 + self.config.explorer.rollout_model.use_v1 = True + self.config.explorer.rollout_model.chat_template = CHAT_TEMPLATE + self.config.explorer.rollout_model.enable_openai_api = True + self.config.check_and_update() self.engines, self.auxiliary_engines = create_inference_models(self.config) self.model_wrapper = ModelWrapper(self.engines[0], model_type="vllm_async") diff --git a/tests/explorer/explorer_test.py b/tests/explorer/explorer_test.py index f4570571a1..74b5d400e5 100644 --- a/tests/explorer/explorer_test.py +++ b/tests/explorer/explorer_test.py @@ -18,16 +18,16 @@ class BaseExplorerCase(RayUnittestBase): def setUp(self): self.config = get_template_config() - self.config.global_config.total_epochs = 2 - self.config.global_config.batch_size = 4 + self.config.buffer.total_epochs = 2 + self.config.buffer.batch_size = 4 self.config.model.model_path = get_model_path() - self.config.explorer.engine_type = "vllm_async" - self.config.buffer.explorer_input.taskset.rollout_args.repeat_times = 2 + self.config.explorer.rollout_model.engine_type = "vllm_async" + self.config.algorithm.repeat_times = 2 self.config.monitor.monitor_type = MonitorType.TENSORBOARD - self.config.monitor.project = "Trinity-unittest" - self.config.model.checkpoint_path = get_checkpoint_path() + self.config.project = "Trinity-unittest" + self.config.checkpoint_root_dir = get_checkpoint_path() self.config.synchronizer.sync_interval = 2 - self.config.global_config.eval_interval = 4 + self.config.explorer.eval_interval = 4 @abstractmethod def test_explorer(self): @@ -40,11 +40,11 @@ def test_explorer(self): self.config.buffer.explorer_input.eval_tasksets.append( get_unittest_dataset_config("countdown", "test") ) - self.config.monitor.name = f"explore-eval-{datetime.now().strftime('%Y%m%d%H%M%S')}" - self.config.explorer.use_v1 = True + self.config.name = f"explore-eval-{datetime.now().strftime('%Y%m%d%H%M%S')}" + self.config.explorer.rollout_model.use_v1 = True self.config.check_and_update() explore(self.config) - parser = TensorBoardParser(os.path.join(self.config.monitor.job_dir, "tensorboard")) + parser = TensorBoardParser(os.path.join(self.config.monitor.cache_dir, "tensorboard")) rollout_metrics = parser.metric_list("rollout") self.assertTrue(len(rollout_metrics) > 0) eval_metrics = parser.metric_list("eval") @@ -56,11 +56,11 @@ def test_explorer(self): class TestExplorerCountdownNoEval(BaseExplorerCase): def test_explorer(self): self.config.buffer.explorer_input.taskset = get_unittest_dataset_config("countdown") - self.config.monitor.name = f"explore-no-eval-{datetime.now().strftime('%Y%m%d%H%M%S')}" - self.config.explorer.use_v1 = False + self.config.name = f"explore-no-eval-{datetime.now().strftime('%Y%m%d%H%M%S')}" + self.config.explorer.rollout_model.use_v1 = False self.config.check_and_update() explore(self.config) - parser = TensorBoardParser(os.path.join(self.config.monitor.job_dir, "tensorboard")) + parser = TensorBoardParser(os.path.join(self.config.monitor.cache_dir, "tensorboard")) rollout_metrics = parser.metric_list("rollout") self.assertTrue(len(rollout_metrics) > 0) eval_metrics = parser.metric_list("eval") diff --git a/tests/explorer/workflow_test.py b/tests/explorer/workflow_test.py index 3c00733e54..8cce2f9e85 100644 --- a/tests/explorer/workflow_test.py +++ b/tests/explorer/workflow_test.py @@ -113,6 +113,7 @@ def test_gsm8k_workflow(self) -> None: MockResponse(" balabalabala 99 \n 36 "), MockResponse(" 36.0 "), MockResponse("Kim's total points are 6 + 30 = 36 "), + MockResponse(" balalaba 35.00 "), ] taskset_config = get_unittest_dataset_config("countdown") task = Task( @@ -131,3 +132,21 @@ def test_gsm8k_workflow(self) -> None: self.assertEqual(experiences[0].reward, 1.1) self.assertEqual(experiences[1].reward, 0.9) self.assertEqual(experiences[2].reward, 0.9) + self.assertEqual(experiences[3].reward, 0.1) + task_new = Task( + workflow=MathWorkflow, + format_args=taskset_config.format, + rollout_args=taskset_config.rollout_args, + is_eval=False, + raw_task={ + taskset_config.format.prompt_key: "", + taskset_config.format.response_key: r"35", + }, + ) + workflow.reset(task_new) + workflow_new = task_new.to_workflow(model=model) + experiences = workflow_new.run() + self.assertEqual(experiences[0].reward, 0.1) + self.assertEqual(experiences[1].reward, -0.1) + self.assertEqual(experiences[2].reward, -0.1) + self.assertEqual(experiences[3].reward, 1.1) diff --git a/tests/template/config.yaml b/tests/template/config.yaml index 0cf3ba3cf9..09b6f9ca0d 100644 --- a/tests/template/config.yaml +++ b/tests/template/config.yaml @@ -1,17 +1,20 @@ +project: unittest +name: test mode: both -global_config: - total_epochs: 1 - batch_size: 4 - eval_interval: 1000 +checkpoint_root_dir: '' +algorithm: + algorithm_type: ppo + repeat_times: 1 model: model_path: '' max_prompt_tokens: 2048 max_response_tokens: 2048 - checkpoint_path: '' cluster: # 2 for explorer, 2 for trainer node_num: 1 gpu_per_node: 4 buffer: + total_epochs: 1 + batch_size: 4 max_retry_times: 3 max_retry_interval: 1 explorer_input: @@ -23,25 +26,21 @@ buffer: default_workflow_type: '' default_reward_fn_type: '' explorer: - engine_type: vllm_async - engine_num: 2 + eval_interval: 100 runner_num: 4 - tensor_parallel_size: 1 - enable_prefix_caching: false - enforce_eager: true - dtype: bfloat16 - seed: 42 - backend: nccl - use_ray: false - use_v1: true + rollout_model: + engine_type: vllm_async + engine_num: 2 + tensor_parallel_size: 1 + enable_prefix_caching: false + enforce_eager: true + dtype: bfloat16 + seed: 42 + use_v1: true trainer: trainer_type: verl trainer_config_path: tests/template/verl_config.yaml - sft_warmup_steps: 0 save_interval: 100 -monitor: - project: unittest - name: test synchronizer: sync_method: checkpoint sync_interval: 10 diff --git a/tests/test_data/template.yaml b/tests/test_data/template.yaml index 21fe407842..018bc6ccea 100644 --- a/tests/test_data/template.yaml +++ b/tests/test_data/template.yaml @@ -1,13 +1,8 @@ -global_config: - batch_size: 32 -model: - max_prompt_tokens: 2048 - max_response_tokens: 2048 - checkpoint_path: '' cluster: node_num: 1 gpu_per_node: 8 buffer: + batch_size: 32 max_retry_times: 3 max_retry_interval: 1 explorer_input: @@ -18,11 +13,12 @@ buffer: default_workflow_type: '' default_reward_fn_type: '' explorer: - engine_type: vllm - engine_num: 2 runner_num: 8 - tensor_parallel_size: 2 - enable_prefix_caching: false - enforce_eager: true - dtype: bfloat16 - seed: 42 + rollout_model: + engine_type: vllm + engine_num: 2 + tensor_parallel_size: 2 + enable_prefix_caching: false + enforce_eager: true + dtype: bfloat16 + seed: 42 diff --git a/tests/tools.py b/tests/tools.py index a650638a0d..2e34438d66 100644 --- a/tests/tools.py +++ b/tests/tools.py @@ -52,7 +52,7 @@ def get_unittest_dataset_config( response_key="answer", ), rollout_args=GenerationConfig( - repeat_times=1, + n=1, temperature=1.0, logprobs=0, ), diff --git a/tests/trainer/trainer_test.py b/tests/trainer/trainer_test.py index b0ee7cc089..ac73e46c8d 100644 --- a/tests/trainer/trainer_test.py +++ b/tests/trainer/trainer_test.py @@ -22,20 +22,19 @@ class BaseTrainerCase(RayUnittestBase): def setUp(self): ray.init(ignore_reinit_error=True) self.config = get_template_config() - self.config.global_config.total_epochs = 2 - self.config.global_config.batch_size = 4 + self.config.buffer.total_epochs = 2 + self.config.buffer.batch_size = 4 self.config.model.model_path = get_model_path() - self.config.explorer.engine_type = "vllm_async" - self.config.buffer.explorer_input.taskset.rollout_args.repeat_times = 3 - self.config.explorer.use_v1 = False - self.config.monitor.name = f"trainer-{datetime.now().strftime('%Y%m%d%H%M%S')}" + self.config.explorer.rollout_model.engine_type = "vllm_async" + self.config.algorithm.repeat_times = 3 + self.config.explorer.rollout_model.use_v1 = False + self.config.project = "Trainer-unittest" + self.config.name = f"trainer-{datetime.now().strftime('%Y%m%d%H%M%S')}" self.config.monitor.monitor_type = MonitorType.TENSORBOARD - self.config.model.checkpoint_path = os.path.join( - get_checkpoint_path(), f"trainer-{datetime.now().strftime('%Y%m%d%H%M%S')}" - ) + self.config.checkpoint_root_dir = get_checkpoint_path() self.config.synchronizer.sync_interval = 2 self.config.synchronizer.sync_method = SyncMethod.NCCL - self.config.global_config.eval_interval = 4 + self.config.explorer.eval_interval = 4 @abstractmethod def test_trainer(self): @@ -58,7 +57,7 @@ def test_trainer(self): self.config.trainer.trainer_config.trainer.max_actor_ckpt_to_keep = 2 self.config.trainer.trainer_config.trainer.max_critic_ckpt_to_keep = 2 both(self.config) - parser = TensorBoardParser(os.path.join(self.config.monitor.job_dir, "tensorboard")) + parser = TensorBoardParser(os.path.join(self.config.monitor.cache_dir, "tensorboard")) rollout_metrics = parser.metric_list("rollout") self.assertTrue(len(rollout_metrics) > 0) self.assertEqual(parser.metric_max_step(rollout_metrics[0]), 8) @@ -76,12 +75,12 @@ def test_trainer(self): from trinity.common.models.utils import get_checkpoint_dir_with_step_num checkpoint_step_4 = get_checkpoint_dir_with_step_num( - checkpoint_root_path=self.config.model.checkpoint_path, + checkpoint_root_path=self.config.checkpoint_job_dir, trainer_type=self.config.trainer.trainer_type, step_num=4, ) checkpoint_step_8 = get_checkpoint_dir_with_step_num( - checkpoint_root_path=self.config.model.checkpoint_path, + checkpoint_root_path=self.config.checkpoint_job_dir, trainer_type=self.config.trainer.trainer_type, step_num=8, ) @@ -92,10 +91,10 @@ def test_trainer(self): # test bench mode self.config.mode = "bench" self.config.synchronizer.sync_method = SyncMethod.CHECKPOINT - self.config.global_config.eval_on_latest_ckp = False + self.config.explorer.eval_on_latest_checkpoint = False self.config.check_and_update() bench(self.config) - parser = TensorBoardParser(os.path.join(self.config.monitor.job_dir, "tensorboard")) + parser = TensorBoardParser(os.path.join(self.config.monitor.cache_dir, "tensorboard")) countdown_metrics = parser.metric_list("eval/countdown") copy_countdown_metrics = parser.metric_list("eval/copy_countdown") self.assertTrue(len(countdown_metrics) > 0) @@ -109,4 +108,4 @@ def test_trainer(self): def tearDown(self): # remove dir only when the test passed - shutil.rmtree(self.config.model.checkpoint_path) + shutil.rmtree(self.config.checkpoint_job_dir) diff --git a/trinity/cli/launcher.py b/trinity/cli/launcher.py index ce36299e84..9dfe4df8ee 100644 --- a/trinity/cli/launcher.py +++ b/trinity/cli/launcher.py @@ -1,6 +1,9 @@ """Launch the trainer""" import argparse +import os import sys +from pathlib import Path +from pprint import pprint import ray @@ -46,17 +49,18 @@ def train(config: Config) -> None: trainer = Trainer.remote(config) ray.get(trainer.prepare.remote()) - if config.trainer.sft_warmup_steps > 0: + if config.buffer.trainer_input.sft_warmup_steps > 0: while True: train_continue, train_step_num = ray.get( trainer.train_one_period.remote(AlgorithmType.SFT) ) - logger.info(f"SFT warmup step {train_step_num} finished.") + if train_step_num <= config.buffer.trainer_input.sft_warmup_steps: + logger.info(f"SFT warmup step {train_step_num} finished.") if not train_continue: logger.info("SFT warmup finished.") break - algo_type = config.global_config.algorithm_type + algo_type = config.algorithm.algorithm_type try: ray.get(trainer.train.remote(algo_type)) logger.info("Train finished.") @@ -89,18 +93,19 @@ def both(config: Config) -> None: # sync weight before training start ray.get([explorer.sync_weight.remote(), trainer.sync_weight.remote()]) - if config.trainer.sft_warmup_steps > 0: + if config.buffer.trainer_input.sft_warmup_steps > 0: while True: train_continue, train_step_num = ray.get( trainer.train_one_period.remote(AlgorithmType.SFT) ) - logger.info(f"SFT warmup step {train_step_num} finished.") + if train_step_num <= config.buffer.trainer_input.sft_warmup_steps: + logger.info(f"SFT warmup step {train_step_num} finished.") if not train_continue: logger.info("SFT warmup finished.") break ray.get([explorer.sync_weight.remote(), trainer.sync_weight.remote()]) - algo_type = config.global_config.algorithm_type + algo_type = config.algorithm.algorithm_type while True: try: ref_explore = explorer.explore_one_period.remote() @@ -123,7 +128,7 @@ def both(config: Config) -> None: logger.error(e) logger.error("Training stopped due to exception.") raise e - if explore_step_num % config.global_config.eval_interval == 0: + if explore_step_num % config.explorer.eval_interval == 0: try: ray.get(explorer.eval.remote()) logger.info("Evaluation finished.") @@ -155,18 +160,23 @@ def activate_data_module(data_workflow_url: str, config_path: str): def run(config_path: str, dlc: bool = False): config = load_config(config_path) config.check_and_update() + pprint(config) # try to activate data module data_processor_config = config.data_processor if data_processor_config.data_workflow_url and ( data_processor_config.dj_config_path or data_processor_config.dj_process_desc ): activate_data_module(data_processor_config.data_workflow_url, config_path) - ray_namespace = f"{config.monitor.project}-{config.monitor.name}" + ray_namespace = f"{config.project}-{config.name}" if dlc: from trinity.utils.dlc_utils import setup_ray_cluster setup_ray_cluster(namespace=ray_namespace) else: + from trinity.utils.dlc_utils import is_running + + if not is_running: + raise RuntimeError("Ray is not running, please start it by `ray start --head`.") ray.init(namespace=ray_namespace, ignore_reinit_error=True) if config.mode == "explore": explore(config) @@ -181,10 +191,13 @@ def run(config_path: str, dlc: bool = False): def studio(port: int = 8501): from streamlit.web import cli as stcli + current_dir = Path(__file__).resolve().parent.parent + config_manager_path = os.path.join(current_dir, "manager", "config_manager.py") + sys.argv = [ "streamlit", "run", - "trinity/manager/config_manager.py", + config_manager_path, "--server.port", str(port), "--server.fileWatcherType", diff --git a/trinity/common/config.py b/trinity/common/config.py index c168bb48b1..b2703d4d2d 100644 --- a/trinity/common/config.py +++ b/trinity/common/config.py @@ -10,6 +10,7 @@ AlgorithmType, MonitorType, PromptType, + ReadStrategy, StorageType, SyncMethod, TaskType, @@ -52,13 +53,15 @@ class FormatConfig: @dataclass class GenerationConfig: - # repeat each task for `repeat_times` times (for GPRO-like algorithms) - repeat_times: int = 1 - temperature: float = 1.0 top_p: float = 1.0 top_k: int = -1 logprobs: int = 0 # vLLM return `logprobs + 1` elements + # repeat each task for `n` times (for GPRO-like algorithms) + # this field will be automatically set to `algorithm.repeat_times` in + # `buffer.explorer_input.taskset.rollout_args` + # ! DO NOT SET in `buffer.explorer_input.taskset.rollout_args` + n: int = 1 @dataclass @@ -67,7 +70,6 @@ class StorageConfig: name: str = "" storage_type: StorageType = StorageType.FILE - algorithm_type: Optional[AlgorithmType] = None # automatically set path: Optional[str] = None # used for StorageType.FILE @@ -76,13 +78,20 @@ class StorageConfig: format: FormatConfig = field(default_factory=FormatConfig) index: int = 0 - # used for algorithm_type is None - task_type: TaskType = TaskType.EXPLORE # automatically set + # used for rollout tasks default_workflow_type: Optional[str] = None default_reward_fn_type: Optional[str] = None - total_epochs: int = 1 # automatically set rollout_args: GenerationConfig = field(default_factory=GenerationConfig) + # ! DO NOT SET, automatically set from algorithm.algorithm_type + algorithm_type: Optional[AlgorithmType] = None + + # ! DO NOT SET, automatically set from buffer.total_epochs + total_epochs: int = 1 # automatically set + + # ! DO NOT SET, automatically set corresponding to train/eval + task_type: TaskType = TaskType.EXPLORE + @dataclass class DataProcessorConfig: @@ -113,47 +122,60 @@ class DataProcessorConfig: max_retry_interval: int = 1 -@dataclass -class GlobalConfig: - # downstream loading related - total_epochs: int = 1 - batch_size: int = 1 - eval_interval: int = 100 - eval_on_latest_ckp: bool = True - algorithm_type: AlgorithmType = AlgorithmType.PPO - - @dataclass class ModelConfig: - # TODO: add more # source model path model_path: str = "" critic_model_path: str = "" - max_prompt_tokens: int = 2048 - max_response_tokens: int = 2048 - # The checkpoint directory, contains a latest dir link and multiple checkpoint dirs. - checkpoint_path: str = "" - # for models support both thinking and non-thinking mode, e.g., Qwen3 - enable_thinking: bool = False + max_prompt_tokens: Optional[int] = None + max_response_tokens: Optional[int] = None @dataclass class InferenceModelConfig: - # TODO: support setting engine_num + # ! DO NOT SET in explorer.rollout_model, automatically set from config.model.model_path model_path: str = "" + + # support `vllm` or `vllm_async`, + engine_type: str = "vllm_async" + engine_num: int = 1 tensor_parallel_size: int = 1 use_v1: bool = True - max_prompt_tokens: int = 2048 - max_response_tokens: int = 2048 - enable_thinking: bool = False enforce_eager: bool = True enable_prefix_caching: bool = False enable_chunked_prefill: bool = False gpu_memory_utilization: float = 0.9 dtype: str = "bfloat16" seed: int = 42 + + # if not set, use `model.max_prompt_tokens` + max_prompt_tokens: Optional[int] = None + # if not set, use `model.max_response_tokens` + max_response_tokens: Optional[int] = None + + # override chat template in model chat_template: Optional[str] = None - bundle_indices: str = "" # DO NOT SET this field + + # For Qwen3 + enable_thinking: bool = False + + # For OpenAI API + enable_openai_api: bool = False + + # ! DO NOT SET + bundle_indices: str = "" + + +@dataclass +class AlgorithmConfig: + """Config for algorithm.""" + + algorithm_type: AlgorithmType = AlgorithmType.PPO + # for GRPO-like algorithms, repeat each task for `repeat_times` times + repeat_times: int = 1 + gamma: float = 1.0 + lam: float = 1.0 + # TODO: add more algorithm params here @dataclass @@ -183,65 +205,56 @@ class TrainerInput: experience_buffer: Optional[StorageConfig] = None sft_warmup_dataset: Optional[StorageConfig] = None + read_experience_strategy: Optional[ReadStrategy] = None + sft_warmup_steps: int = 0 @dataclass class BufferConfig: - """Config for experience buffer.""" + """Config for buffer.""" - read_batch_size: int = 32 - max_retry_times: int = 3 - max_retry_interval: int = 1 - tokenizer_path: Optional[str] = None # automatically set - pad_token_id: Optional[int] = None # automatically set + batch_size: int = 1 + total_epochs: int = 1 + # for explorer explorer_input: ExplorerInput = field(default_factory=ExplorerInput) explorer_output: Optional[StorageConfig] = None # currently do not set + + # for trainer trainer_input: TrainerInput = field(default_factory=TrainerInput) + # for storage connection + max_retry_times: int = 3 + max_retry_interval: int = 1 + + # ! DO NOT SET FOLLOWING FIELDS + read_batch_size: int = 1 # automatically set + tokenizer_path: Optional[str] = None # automatically set + pad_token_id: Optional[int] = None # automatically set + @dataclass class ExplorerConfig: """Config for explorer.""" - # rollout engine type, `vllm` or `vllm_async` - engine_type: str = "vllm_async" - - # number of rollout engines - engine_num: int = 1 - + # for workflow runner # number of workflow runners. # For sync engine (vllm), it should be equal to `engine_num`. # For async engine (vllm_async), it can be larger than `engine_num`, e.g. 16 * `engine_num` runner_num: int = 1 - - # for rollout tokneize - chat_template: Optional[str] = None - - # TODO: move vllm rollout model related args into - # `explorer.rollout_model: InferenceModelConfig` - tensor_parallel_size: int = 1 - enable_prefix_caching: bool = False - enforce_eager: bool = True - dtype: str = "bfloat16" - seed: int = 42 - backend: str = "nccl" - use_ray: bool = False - gpu_memory_utilization: float = 0.9 - enable_chunked_prefill: bool = False - use_v1: bool = True - enable_openai_api: bool = False - bundle_indices: str = "" # DO NOT SET this field - - # for workflow runner - max_pending_requests: int = 5 - max_waiting_steps: int = 1 max_timeout: int = 900 # wait each task for 15 minutes max_retry_times: int = 2 # retry each task for 2 times if it fails or timeout + # for inference models + # for rollout model + rollout_model: InferenceModelConfig = field(default_factory=InferenceModelConfig) # for other models used in the custom workflows auxiliary_models: List[InferenceModelConfig] = field(default_factory=list) + # for evaluation + eval_interval: int = 100 + eval_on_latest_checkpoint: bool = False + @dataclass class TrainerConfig: @@ -249,28 +262,24 @@ class TrainerConfig: trainer_config_path: str = "" save_interval: int = 0 enable_preview: bool = True # enable rollout preview in wandb - trainer_config: Any = field(default_factory=dict) - # train algorithm - get_exp_strategy: Optional[str] = None + # trainer configs + actor_use_kl_loss: bool = False + actor_kl_loss_coef: float = 0.001 + actor_entropy_coef: float = 0.001 + actor_grad_clip: float = 1.0 + actor_clip_ratio: float = 0.2 + # TODO: extract more train-related params from underlying trainer engine - # warmup config - sft_warmup_steps: int = 0 - sft_warmup_iteration: Optional[int] = None # deprecated + trainer_config: Any = field(default_factory=dict) @dataclass class MonitorConfig: - # TODO: add more - project: str = "trinity" - name: str = "rft" + # TODO: support multiple monitors (List[MonitorType]) monitor_type: MonitorType = MonitorType.WANDB - - # ! DO NOT SET - # the root directory for cache and meta files, automatically generated - cache_root_dir: Optional[str] = None - # directory path for current job, automatically generated - job_dir: Optional[str] = None + # ! DO NOT SET, automatically generated as checkpoint_job_dir/monitor + cache_dir: str = "" @dataclass @@ -281,15 +290,13 @@ class SynchronizerConfig: sync_method: SyncMethod = SyncMethod.NCCL # sync weights every `sync_interval` steps sync_interval: int = 1 - # `sync_iteration_interval` is deprecated, use `sync_interval` instead - sync_iteration_interval: Optional[int] = None + # waiting for `sync_timeout` seconds before timeout in `nccl` method sync_timeout: int = 1200 # wait for the lastest checkpoint to be ready wait_for_checkpoint: bool = False - master_address: Optional[str] = None - master_port: Optional[int] = None + + # ! DO NOT SET, automatically calculated explorer_world_size: Optional[int] = None - backend: str = "nccl" @dataclass @@ -297,8 +304,15 @@ class Config: """Global Configuration""" mode: str = "both" # `explore`, `train`, `both` or `bench` + project: str = "Trinity-RFT" + name: str = "rft" + # the root dir for checkpoints + checkpoint_root_dir: str = "" + # ! DO NOT SET, automatically generated as `checkpoint_root_dir/project/name` + checkpoint_job_dir: str = "" + + algorithm: AlgorithmConfig = field(default_factory=AlgorithmConfig) data_processor: DataProcessorConfig = field(default_factory=DataProcessorConfig) - global_config: GlobalConfig = field(default_factory=GlobalConfig) model: ModelConfig = field(default_factory=ModelConfig) cluster: ClusterConfig = field(default_factory=ClusterConfig) buffer: BufferConfig = field(default_factory=BufferConfig) @@ -313,19 +327,7 @@ def save(self, config_path: str) -> None: OmegaConf.save(self, f) def _check_deprecated(self) -> None: - if self.synchronizer.sync_iteration_interval is not None: - logger.warning( - f"`synchronizer.sync_iteration_interval` is deprecated, please use `synchronizer.sync_interval` instead. " - f"And `synchronizer.sync_interval` will set to {self.synchronizer.sync_iteration_interval} instead." - ) - self.synchronizer.sync_interval = self.synchronizer.sync_iteration_interval - - if self.trainer.sft_warmup_iteration is not None: - logger.warning( - f"`trainer.sft_warmup_iteration` is deprecated, please use `trainer.sft_warmup_steps` instead. " - f"And `trainer.sft_warmup_steps` will be set to {self.trainer.sft_warmup_iteration} instead." - ) - self.trainer.sft_warmup_steps = self.trainer.sft_warmup_iteration + pass def _check_interval(self) -> None: assert self.synchronizer.sync_interval > 0 @@ -333,25 +335,25 @@ def _check_interval(self) -> None: # check eval_interval if ( self.mode != "bench" - and self.global_config.algorithm_type != AlgorithmType.DPO - and self.global_config.eval_interval % self.synchronizer.sync_interval != 0 + and self.algorithm.algorithm_type != AlgorithmType.DPO + and self.explorer.eval_interval % self.synchronizer.sync_interval != 0 ): - self.global_config.eval_interval = ( - max(self.global_config.eval_interval // self.synchronizer.sync_interval, 1) + self.buffer.eval_interval = ( + max(self.explorer.eval_interval // self.synchronizer.sync_interval, 1) ) * self.synchronizer.sync_interval logger.warning( - f"`eval_interval` is not a multiple of `sync_interval`; adjusted to the nearest integer={self.global_config.eval_interval}." + f"`eval_interval` is not a multiple of `sync_interval`; adjusted to the nearest integer={self.explorer.eval_interval}." ) # check save_interval if ( self.mode != "bench" - and self.global_config.algorithm_type != AlgorithmType.DPO + and self.algorithm.algorithm_type != AlgorithmType.DPO and self.synchronizer.sync_method == SyncMethod.CHECKPOINT ): if self.trainer.save_interval != self.synchronizer.sync_interval: logger.warning( - f"When `global_config.algorithm_type` != `DPO` and `synchronizer.sync_method` == `checkpoint`, " + f"When `algorithm.algorithm_type` != `DPO` and `synchronizer.sync_method` == `checkpoint`, " f"`trainer.save_interval` will be set to " f"`synchronizer.sync_interval = {self.synchronizer.sync_interval}`." ) @@ -365,8 +367,13 @@ def _check_buffer(self) -> None: # noqa: C901 ) if not self.buffer.explorer_input.taskset.name: self.buffer.explorer_input.taskset.name = "taskset" + self.buffer.explorer_input.taskset.rollout_args.n = self.algorithm.repeat_times + logger.info( + "`buffer.explorer_input.taskset.rollout_args.n` is set to `algorithm.repeat_times`" + f" (={self.algorithm.repeat_times})." + ) self.buffer.explorer_input.taskset.task_type = TaskType.EXPLORE - self.buffer.explorer_input.taskset.total_epochs = self.global_config.total_epochs + self.buffer.explorer_input.taskset.total_epochs = self.buffer.total_epochs if self.buffer.explorer_input.taskset.default_workflow_type is None: self.buffer.explorer_input.taskset.default_workflow_type = ( self.buffer.explorer_input.default_workflow_type @@ -414,41 +421,38 @@ def _check_buffer(self) -> None: # noqa: C901 f"Auto set `buffer.trainer_input.experience_buffer` to {self.buffer.trainer_input.experience_buffer}" ) elif self.mode == "train": # TODO: to be check - if self.global_config.algorithm_type.is_dpo(): + if self.algorithm.algorithm_type.is_dpo(): if ( self.buffer.trainer_input.experience_buffer is None or not self.buffer.trainer_input.experience_buffer.path ): raise ValueError( - "`buffer.trainer_input.experience_buffer.path` is required when `global_config.algorithm_type == AlgorithmType.DPO`" + "`buffer.trainer_input.experience_buffer.path` is required when `algorithm.algorithm_type == AlgorithmType.DPO`" ) if self.buffer.trainer_input.experience_buffer is not None: self.buffer.trainer_input.experience_buffer.algorithm_type = ( - self.global_config.algorithm_type + self.algorithm.algorithm_type ) # set buffer.explorer_output if self.buffer.explorer_output is None: self.buffer.explorer_output = self.buffer.trainer_input.experience_buffer else: - self.buffer.explorer_output.algorithm_type = self.global_config.algorithm_type + self.buffer.explorer_output.algorithm_type = self.algorithm.algorithm_type # check trainer_input.sft_warmup_dataset if ( - self.trainer.sft_warmup_steps > 0 + self.buffer.trainer_input.sft_warmup_steps > 0 and self.buffer.trainer_input.sft_warmup_dataset is None ): raise ValueError( - "buffer.trainer_input.sft_warmup_dataset is required when trainer.sft_warmup_steps > 0" + "buffer.trainer_input.sft_warmup_dataset is required when buffer.trainer_input.sft_warmup_steps > 0" ) if self.buffer.trainer_input.sft_warmup_dataset is not None: self.buffer.trainer_input.sft_warmup_dataset.algorithm_type = AlgorithmType.SFT # set read_batch_size / pad_token_id / tokenizer_path - self.buffer.read_batch_size = ( - self.global_config.batch_size - * self.buffer.explorer_input.taskset.rollout_args.repeat_times - ) + self.buffer.read_batch_size = self.buffer.batch_size * self.algorithm.repeat_times if self.buffer.pad_token_id is None: from transformers import AutoTokenizer @@ -468,24 +472,38 @@ def check_and_update(self) -> None: # noqa: C901 # check mode if self.mode not in ["explore", "train", "both", "bench"]: raise ValueError(f"Invalid mode: {self.mode}") - if self.global_config.algorithm_type == AlgorithmType.DPO and self.mode == "both": + if self.algorithm.algorithm_type == AlgorithmType.DPO and self.mode == "both": raise ValueError("DPO does not support `both` mode") - # check model path - if not os.path.isabs(self.model.checkpoint_path): - self.model.checkpoint_path = os.path.join(os.getcwd(), self.model.checkpoint_path) + # prepare for the checkpoint directory + if not os.path.isabs(self.checkpoint_root_dir): + self.checkpoint_root_dir = os.path.join(os.getcwd(), self.checkpoint_root_dir) + # create a job dir at checkpoint_root_dir/project/name + self.checkpoint_job_dir = os.path.join(self.checkpoint_root_dir, self.project, self.name) + os.makedirs(self.checkpoint_job_dir, exist_ok=True) + + # check and update model path + if self.explorer is not None: + self.explorer.rollout_model.model_path = self.model.model_path if not self.model.critic_model_path: self.model.critic_model_path = self.model.model_path # check explorer - if self.explorer.engine_type != "vllm_asyc" and self.explorer.enable_openai_api: + if ( + self.explorer.rollout_model.engine_type != "vllm_async" + and self.explorer.rollout_model.enable_openai_api + ): raise ValueError("OpenAI API server only support `vllm_async` engine.") + if self.explorer.rollout_model.max_prompt_tokens is None: + self.explorer.rollout_model.max_prompt_tokens = self.model.max_prompt_tokens + if self.explorer.rollout_model.max_response_tokens is None: + self.explorer.rollout_model.max_response_tokens = self.model.max_response_tokens # check synchronizer self.synchronizer.explorer_world_size = ( - self.explorer.engine_num * self.explorer.tensor_parallel_size + self.explorer.rollout_model.engine_num + * self.explorer.rollout_model.tensor_parallel_size ) - self.synchronizer.backend = self.explorer.backend if ( self.mode in ["train", "explore", "bench"] and self.synchronizer.sync_method != SyncMethod.CHECKPOINT @@ -495,30 +513,27 @@ def check_and_update(self) -> None: # noqa: C901 f"`{self.mode}` mode only supports checkpoint synchronization, set `synchronizer.sync_method` to `checkpoint`." ) if ( - self.global_config.algorithm_type == AlgorithmType.DPO + self.algorithm.algorithm_type == AlgorithmType.DPO and self.synchronizer.sync_method != SyncMethod.CHECKPOINT ): self.synchronizer.sync_method = SyncMethod.CHECKPOINT logger.warning( "DPO only supports checkpoint synchronization, set `synchronizer.sync_method` to `checkpoint`." ) + if self.algorithm.algorithm_type == AlgorithmType.DPO and self.algorithm.repeat_times != 2: + self.algorithm.repeat_times = 2 + logger.warning("DPO only supports 2 repeat times, set `algorithm.repeat_times` to 2.") self._check_interval() - # check monitor - if not self.monitor.cache_root_dir: - # create a cache dir in /.cache - self.monitor.cache_root_dir = os.path.join(self.model.checkpoint_path, ".cache") - # create a job dir in /.cache// - self.monitor.job_dir = os.path.join( - self.monitor.cache_root_dir, self.monitor.project, self.monitor.name - ) + # create a job dir in /monitor + self.monitor.cache_dir = os.path.join(self.checkpoint_job_dir, "monitor") try: - os.makedirs(self.monitor.job_dir, exist_ok=True) + os.makedirs(self.monitor.cache_dir, exist_ok=True) except Exception: logger.warning( - "Failed to create cache dir, please check " - f"your checkpoint path: {self.model.checkpoint_path}" + f"Failed to create monitor dir {self.monitor.cache_dir}, please check " + f"your checkpoint directory: {self.checkpoint_root_dir}" ) # check buffer diff --git a/trinity/common/models/__init__.py b/trinity/common/models/__init__.py index 00faa71165..7324ff2a47 100644 --- a/trinity/common/models/__init__.py +++ b/trinity/common/models/__init__.py @@ -1,7 +1,7 @@ from collections import defaultdict from typing import List, Tuple -from trinity.common.config import Config, InferenceModelConfig +from trinity.common.config import Config from trinity.common.models.model import InferenceModel from trinity.utils.log import get_logger @@ -46,27 +46,34 @@ def create_inference_models( from trinity.common.models.vllm_async_model import vLLMAysncRolloutModel from trinity.common.models.vllm_model import vLLMRolloutModel - engine_num = config.explorer.engine_num - tensor_parallel_size = config.explorer.tensor_parallel_size - is_multi_process = config.explorer.tensor_parallel_size > 1 + engine_num = config.explorer.rollout_model.engine_num + tensor_parallel_size = config.explorer.rollout_model.tensor_parallel_size - if config.explorer.enable_openai_api and config.explorer.engine_type != "vllm_async": + if ( + config.explorer.rollout_model.enable_openai_api + and config.explorer.rollout_model.engine_type != "vllm_async" + ): raise ValueError("OpenAI API is only supported for vllm_async engine") rollout_engines = [] - if config.explorer.engine_type == "vllm": + if config.explorer.rollout_model.engine_type == "vllm": engine_cls = vLLMRolloutModel - elif config.explorer.engine_type == "vllm_async": + elif config.explorer.rollout_model.engine_type == "vllm_async": engine_cls = vLLMAysncRolloutModel else: - raise ValueError(f"Unknown engine type: {config.explorer.engine_type}") + raise ValueError(f"Unknown engine type: {config.explorer.rollout_model.engine_type}") main_bundles = [{"GPU": 1, "CPU": 1} for _ in range(engine_num * tensor_parallel_size)] auxiliary_bundles = [ {"GPU": 1, "CPU": 1} for _ in range( - sum([model.tensor_parallel_size for model in config.explorer.auxiliary_models]) + sum( + [ + model.engine_num * model.tensor_parallel_size + for model in config.explorer.auxiliary_models + ] + ) ) ] pg = placement_group(main_bundles + auxiliary_bundles, strategy="PACK") @@ -84,57 +91,47 @@ def create_inference_models( allocator = _BundleAllocator(node_bundle_map) # create rollout models - for _ in range(config.explorer.engine_num): - bundles_for_engine = allocator.allocate(config.explorer.tensor_parallel_size) - model_config = InferenceModelConfig( - model_path=config.model.model_path, - tensor_parallel_size=config.explorer.tensor_parallel_size, - use_v1=config.explorer.use_v1, - max_prompt_tokens=config.model.max_prompt_tokens, - max_response_tokens=config.model.max_response_tokens, - enforce_eager=config.explorer.enforce_eager, - enable_prefix_caching=config.explorer.enable_prefix_caching, - enable_chunked_prefill=config.explorer.enable_chunked_prefill, - enable_thinking=config.model.enable_thinking, - gpu_memory_utilization=config.explorer.gpu_memory_utilization, - dtype=config.explorer.dtype, - seed=config.explorer.seed, - chat_template=config.explorer.chat_template, - bundle_indices=",".join([str(bid) for bid in bundles_for_engine]), + for _ in range(config.explorer.rollout_model.engine_num): + bundles_for_engine = allocator.allocate(config.explorer.rollout_model.tensor_parallel_size) + config.explorer.rollout_model.bundle_indices = ",".join( + [str(bid) for bid in bundles_for_engine] ) rollout_engines.append( ray.remote(engine_cls) .options( num_cpus=0, - num_gpus=0 if is_multi_process else 1, + num_gpus=0 if config.explorer.rollout_model.tensor_parallel_size > 1 else 1, scheduling_strategy=PlacementGroupSchedulingStrategy( placement_group=pg, placement_group_bundle_index=bundles_for_engine[0], ), ) .remote( - config=model_config, + config=config.explorer.rollout_model, ) ) - if config.explorer.enable_openai_api: + if config.explorer.rollout_model.enable_openai_api: for engine in rollout_engines: engine.run_api_server.remote() # create auxiliary models for model_config in config.explorer.auxiliary_models: - bundles_for_engine = allocator.allocate(model_config.tensor_parallel_size) - auxiliary_engines.append( - ray.remote(vLLMAysncRolloutModel) - .options( - num_cpus=0, - num_gpus=0 if is_multi_process else 1, - scheduling_strategy=PlacementGroupSchedulingStrategy( - placement_group=pg, - placement_group_bundle_index=bundles_for_engine[0], - ), + for _ in range(model_config.engine_num): + bundles_for_engine = allocator.allocate(model_config.tensor_parallel_size) + model_config.enable_openai_api = True + model_config.engine_type = "vllm_async" + auxiliary_engines.append( + ray.remote(vLLMAysncRolloutModel) + .options( + num_cpus=0, + num_gpus=0 if model_config.tensor_parallel_size > 1 else 1, + scheduling_strategy=PlacementGroupSchedulingStrategy( + placement_group=pg, + placement_group_bundle_index=bundles_for_engine[0], + ), + ) + .remote(config=model_config) ) - .remote(config=model_config) - ) # all auxiliary engines run api server for engine in auxiliary_engines: engine.run_api_server.remote() diff --git a/trinity/common/models/vllm_async_model.py b/trinity/common/models/vllm_async_model.py index ae5c4db9c1..02ea52ec58 100644 --- a/trinity/common/models/vllm_async_model.py +++ b/trinity/common/models/vllm_async_model.py @@ -64,6 +64,9 @@ def __init__( ) self.enable_thinking = config.enable_thinking self.request_id = 0 + max_model_len = None + if config.max_prompt_tokens is not None and config.max_response_tokens is not None: + max_model_len = config.max_prompt_tokens + config.max_response_tokens engine_args = vllm.AsyncEngineArgs( model=config.model_path, enforce_eager=config.enforce_eager, @@ -71,7 +74,7 @@ def __init__( tensor_parallel_size=config.tensor_parallel_size, seed=config.seed, distributed_executor_backend=("uni" if config.tensor_parallel_size == 1 else "ray"), - max_model_len=config.max_prompt_tokens + config.max_response_tokens, + max_model_len=max_model_len, enable_prefix_caching=config.enable_prefix_caching, dtype=config.dtype, trust_remote_code=True, @@ -314,7 +317,7 @@ async def run_api_server(self): ) async def has_api_server(self) -> bool: - return self.api_server_host is not None and self.api_server_port is not None + return self.config.enable_openai_api async def api_server_ready(self) -> Optional[str]: """Check if the OpenAI API server is ready. diff --git a/trinity/common/models/vllm_model.py b/trinity/common/models/vllm_model.py index 32ab98fe8a..9459cd7511 100644 --- a/trinity/common/models/vllm_model.py +++ b/trinity/common/models/vllm_model.py @@ -51,6 +51,9 @@ def __init__(self, config: InferenceModelConfig): include_stop_str_in_output=False, logprobs=0, ) + max_model_len = None + if config.max_prompt_tokens is not None and config.max_response_tokens is not None: + max_model_len = config.max_prompt_tokens + config.max_response_tokens self.llm = LLM( # TODO: check checkpoint path model=config.model_path, @@ -59,7 +62,7 @@ def __init__(self, config: InferenceModelConfig): tensor_parallel_size=config.tensor_parallel_size, seed=config.seed, distributed_executor_backend=("uni" if config.tensor_parallel_size == 1 else "ray"), - max_model_len=config.max_prompt_tokens + config.max_response_tokens, + max_model_len=max_model_len, enable_prefix_caching=config.enable_prefix_caching, dtype=config.dtype, trust_remote_code=True, @@ -149,7 +152,7 @@ def generate(self, prompts: List[str], **kwargs) -> List: Example: - >>> # config.buffer.explorer_input.taskset.rollout_args.repeat_times == 2 or kwargs["repeat_times"] == 2 + >>> # config.algorithm.repeat_times == 2 or kwargs["n"] == 2 >>> >>> prompts = [ >>> "Hello, world!", diff --git a/trinity/common/verl_config.py b/trinity/common/verl_config.py index b988161723..dd896a23f1 100644 --- a/trinity/common/verl_config.py +++ b/trinity/common/verl_config.py @@ -270,9 +270,19 @@ class veRLConfig: def synchronize_config(self, config: Config) -> None: """Synchronize config.""" - rollout_gpu_num = config.explorer.tensor_parallel_size * config.explorer.engine_num + sum( - [model.tensor_parallel_size for model in config.explorer.auxiliary_models] - ) + if config.mode != "train": + rollout_gpu_num = ( + config.explorer.rollout_model.tensor_parallel_size + * config.explorer.rollout_model.engine_num + + sum( + [ + model.tensor_parallel_size * model.engine_num + for model in config.explorer.auxiliary_models + ] + ) + ) + else: + rollout_gpu_num = 0 rollout_node_num = rollout_gpu_num // config.cluster.gpu_per_node self.trainer.nnodes = config.cluster.node_num - rollout_node_num self.actor_rollout_ref.model.path = config.model.model_path @@ -291,35 +301,42 @@ def synchronize_config(self, config: Config) -> None: self.actor_rollout_ref.synchronizer = config.synchronizer self.buffer = config.buffer world_size = self.trainer.nnodes * self.trainer.n_gpus_per_node - if config.global_config.batch_size % world_size != 0: + if config.buffer.batch_size % world_size != 0: raise ValueError( - f"batch_size ({config.global_config.batch_size}) must be divisible by ({world_size})" + f"batch_size ({config.buffer.batch_size}) must be divisible by ({world_size})" ) # TODO: use dynamic read_batch_size to support multi-round scenarios # Get the experiences of one explore step - self.trainer.project_name = config.monitor.project - self.trainer.experiment_name = config.monitor.name - self.data.train_batch_size = config.global_config.batch_size - self.trainer.default_local_dir = config.model.checkpoint_path - self.trainer.sft_warmup_steps = config.trainer.sft_warmup_steps - self.actor_rollout_ref.actor.ppo_mini_batch_size = config.global_config.batch_size + self.trainer.project_name = config.project + self.trainer.experiment_name = config.name + self.data.train_batch_size = config.buffer.batch_size + self.trainer.default_local_dir = config.checkpoint_job_dir + self.trainer.sft_warmup_steps = config.buffer.trainer_input.sft_warmup_steps + self.actor_rollout_ref.actor.ppo_mini_batch_size = config.buffer.batch_size self.actor_rollout_ref.rollout.temperature = ( config.buffer.explorer_input.taskset.rollout_args.temperature ) - self.actor_rollout_ref.rollout.n = ( - config.buffer.explorer_input.taskset.rollout_args.repeat_times - ) - self.critic.ppo_mini_batch_size = config.global_config.batch_size + self.actor_rollout_ref.rollout.n = config.algorithm.repeat_times + self.critic.ppo_mini_batch_size = config.buffer.batch_size self.critic.rollout_n = self.actor_rollout_ref.rollout.n - self.actor_rollout_ref.actor.algorithm_type = config.global_config.algorithm_type - if config.global_config.algorithm_type == AlgorithmType.PPO: + self.actor_rollout_ref.actor.algorithm_type = config.algorithm.algorithm_type + if config.algorithm.algorithm_type == AlgorithmType.PPO: logger.info("Using GAE `adv_estimator` for PPO") self.algorithm.adv_estimator = AdvantageEstimator.GAE.value - elif config.global_config.algorithm_type == AlgorithmType.GRPO: + elif config.algorithm.algorithm_type == AlgorithmType.GRPO: logger.info("Using GRPO `adv_estimator` for GRPO") self.algorithm.adv_estimator = AdvantageEstimator.GRPO.value + # copy trainer related config from global config + self.algorithm.gamma = config.algorithm.gamma + self.algorithm.lam = config.algorithm.lam + self.actor_rollout_ref.actor.use_kl_loss = config.trainer.actor_use_kl_loss + self.actor_rollout_ref.actor.kl_loss_coef = config.trainer.actor_kl_loss_coef + self.actor_rollout_ref.actor.entropy_coeff = config.trainer.actor_entropy_coef + self.actor_rollout_ref.actor.grad_clip = config.trainer.actor_grad_clip + self.actor_rollout_ref.actor.clip_ratio = config.trainer.actor_clip_ratio + if self.actor_rollout_ref.actor.algorithm_type.is_dpo(): # for DPO if not self.actor_rollout_ref.actor.use_kl_loss: self.actor_rollout_ref.actor.use_kl_loss = True diff --git a/trinity/common/workflows/envs/alfworld/alfworld_workflow.py b/trinity/common/workflows/envs/alfworld/alfworld_workflow.py index 9b3b0d79d4..39171d5561 100644 --- a/trinity/common/workflows/envs/alfworld/alfworld_workflow.py +++ b/trinity/common/workflows/envs/alfworld/alfworld_workflow.py @@ -107,11 +107,11 @@ def __init__( task=task, ) self.task_desc = task.task_desc or "0" - self.repeat_times = task.rollout_args.repeat_times + self.repeat_times = task.rollout_args.n self.max_env_steps = 30 def get_model_response(self, messages): - responses = self.model.chat(messages, repeat_times=1) + responses = self.model.chat(messages, n=1) return responses def get_model_response_text(self, messages): diff --git a/trinity/common/workflows/envs/sciworld/sciworld_workflow.py b/trinity/common/workflows/envs/sciworld/sciworld_workflow.py index 60bc6c4d81..12e1e97146 100644 --- a/trinity/common/workflows/envs/sciworld/sciworld_workflow.py +++ b/trinity/common/workflows/envs/sciworld/sciworld_workflow.py @@ -70,11 +70,11 @@ def __init__( task=task, ) self.task_desc = task.task_desc or "0" - self.repeat_times = task.rollout_args.repeat_times + self.repeat_times = task.rollout_args.n self.max_env_steps = 30 # should be less than 100 def get_model_response(self, messages): - responses = self.model.chat(messages, repeat_times=1) + responses = self.model.chat(messages, n=1) return responses def get_model_response_text(self, messages): diff --git a/trinity/common/workflows/envs/webshop/webshop_workflow.py b/trinity/common/workflows/envs/webshop/webshop_workflow.py index 0003217e35..59eb7d7c94 100644 --- a/trinity/common/workflows/envs/webshop/webshop_workflow.py +++ b/trinity/common/workflows/envs/webshop/webshop_workflow.py @@ -215,10 +215,10 @@ def resettable(self): def reset(self, task: Task): self.task_desc = task.task_desc or "0" - self.repeat_times = task.rollout_args.repeat_times + self.repeat_times = task.rollout_args.n def get_model_response(self, messages): - responses = self.model.chat(messages, repeat_times=1) + responses = self.model.chat(messages, n=1) return responses def get_model_response_text(self, messages): diff --git a/trinity/common/workflows/workflow.py b/trinity/common/workflows/workflow.py index 603bd1ced4..1a0daadb2b 100644 --- a/trinity/common/workflows/workflow.py +++ b/trinity/common/workflows/workflow.py @@ -51,11 +51,13 @@ def to_workflow( auxiliary_models=auxiliary_models, ) + # Deprecated property, will be removed in the future @property def task_desc(self) -> Union[str, None]: prompt_key = self.format_args.prompt_key return self.raw_task[prompt_key] if prompt_key in self.raw_task else None # type: ignore + # Deprecated property, will be removed in the future @property def truth(self) -> Union[str, None]: response_key = self.format_args.response_key @@ -154,6 +156,7 @@ def __init__( super().__init__( model=model, task=task, + auxiliary_models=auxiliary_models, ) self.reset(task) @@ -177,7 +180,6 @@ def reset(self, task: Task): raise ValueError("`reward_fn` must be a subclass of `RewardFn`") # Rollout args rollout_args = asdict(task.rollout_args) - rollout_args["n"] = rollout_args["repeat_times"] self.rollout_args = rollout_args self.is_eval = task.is_eval @@ -231,7 +233,15 @@ def __init__( reasoning process here answer here . """ - super().__init__( - model=model, - task=task, - ) + super().__init__(model=model, task=task, auxiliary_models=auxiliary_models) + + def reset(self, task: Task): + if task.reward_fn is None: + task.reward_fn = MathRewardFn + if task.reward_fn == MathRewardFn and task.format_args.system_prompt is None: + task.format_args.system_prompt = """A conversation between User and Assistant. The user asks a question, and the Assistant solves it. The assistant first thinks about the reasoning process in the mind and then provides the user with the answer. The reasoning process and answer are enclosed within and tags, respectively, i.e., + reasoning process here + answer here . +""" + # call the SimpleWorkflow.reset + super().reset(task) diff --git a/trinity/explorer/explorer.py b/trinity/explorer/explorer.py index f2b2490e1d..f5f3466675 100644 --- a/trinity/explorer/explorer.py +++ b/trinity/explorer/explorer.py @@ -48,16 +48,14 @@ def __init__(self, config: Config): self.eval_tasksets.append(get_buffer_reader(eval_taskset_config, self.config.buffer)) self.runner_pool = self._init_runner_pool() self.monitor = Monitor( - project=self.config.monitor.project, - name=self.config.monitor.name, + project=self.config.project, + name=self.config.name, role="explorer", config=config, ) - self.max_pending_task_num = self.config.explorer.runner_num - self.max_waiting_steps = max(1, int(self.config.explorer.max_waiting_steps)) - self.batch_size = config.global_config.batch_size + self.batch_size = config.buffer.batch_size self.update_interval = ( - self.config.synchronizer.sync_interval * self.config.global_config.batch_size + self.config.synchronizer.sync_interval * self.config.buffer.batch_size ) self.use_checkpoint_weights_update = ( self.config.synchronizer.sync_method == SyncMethod.CHECKPOINT @@ -79,7 +77,9 @@ def setup_weight_sync_group( ): # In checkpoint mode, we use explorer to store the model weights which has no rank base_offset = 0 if self.use_checkpoint_weights_update else 1 - world_size = len(self.models) * self.config.explorer.tensor_parallel_size + base_offset + world_size = ( + len(self.models) * self.config.explorer.rollout_model.tensor_parallel_size + base_offset + ) self.logger.info( f"Initialize process group for weight synchronization, " f"master_address={master_address}, master_port={master_port}, " @@ -90,10 +90,10 @@ def setup_weight_sync_group( model.init_process_group.remote( master_address=master_address, master_port=master_port, - rank_offset=i * self.config.explorer.tensor_parallel_size + base_offset, + rank_offset=i * self.config.explorer.rollout_model.tensor_parallel_size + + base_offset, world_size=world_size, group_name=ROLLOUT_WEIGHT_SYNC_GROUP_NAME, - backend=self.config.explorer.backend, timeout=self.config.synchronizer.sync_timeout, update_with_checkpoint=self.use_checkpoint_weights_update, ) @@ -102,14 +102,14 @@ def setup_weight_sync_group( ray.get(refs) def _init_runner_pool(self) -> RunnerPool: - if self.config.explorer.engine_type != "vllm_async": + if self.config.explorer.rollout_model.engine_type != "vllm_async": # sync model requires the same number of runners as the number of models - self.config.explorer.runner_num = self.config.explorer.engine_num + self.config.explorer.runner_num = self.config.explorer.rollout_model.engine_num self.logger.info( "Sync vLLM model requires the same number of runners as the number of models" ) - if self.config.explorer.runner_num < self.config.explorer.engine_num: - self.config.explorer.runner_num = self.config.explorer.engine_num + if self.config.explorer.runner_num < self.config.explorer.rollout_model.engine_num: + self.config.explorer.runner_num = self.config.explorer.rollout_model.engine_num self.logger.info( f"Number of Runners is less than number of models, set to {self.config.explorer.runner_num}" ) @@ -129,7 +129,7 @@ def _checkpoint_weights_update(self, step_num: Optional[int] = None) -> None: # TODO: support more checkpoint types try: checkpoint_dir = get_checkpoint_dir_with_step_num( - checkpoint_root_path=self.config.model.checkpoint_path, + checkpoint_root_path=self.config.checkpoint_job_dir, trainer_type=self.config.trainer.trainer_type, step_num=step_num, ) @@ -162,7 +162,7 @@ def explore(self) -> None: if not explore_status: break self.sync_weight() - if explore_iter % self.config.global_config.eval_interval == 0: + if explore_iter % self.config.explorer.eval_interval == 0: self.eval() self.logger.info("Evaluation finished.") self.logger.info("Explorer finished.") @@ -177,9 +177,7 @@ def explore_one_period(self) -> Tuple[bool, int]: explore_status: whether there are more tasks to explore. explore_step_num: the number of explore steps """ - task_num_per_period = ( - self.config.synchronizer.sync_interval * self.config.global_config.batch_size - ) + task_num_per_period = self.config.synchronizer.sync_interval * self.config.buffer.batch_size st = time.time() all_metrics = defaultdict(list) @@ -265,7 +263,7 @@ def wait(): def benchmark(self) -> bool: """Benchmark the model checkpoints.""" # benchmark on the latest checkpoint - if self.config.global_config.eval_on_latest_ckp: + if self.config.explorer.eval_on_latest_checkpoint: self._checkpoint_weights_update() self.eval() return True @@ -274,8 +272,8 @@ def benchmark(self) -> bool: all_ckp_steps = sorted( [ int(ckp.split("global_step_")[-1]) - for ckp in os.listdir(self.config.model.checkpoint_path) - if os.path.isdir(os.path.join(self.config.model.checkpoint_path, ckp)) + for ckp in os.listdir(self.config.checkpoint_job_dir) + if os.path.isdir(os.path.join(self.config.checkpoint_job_dir, ckp)) and ckp.startswith("global_step_") ] ) diff --git a/trinity/explorer/runner_pool.py b/trinity/explorer/runner_pool.py index c7b3f39f1f..e58e87c124 100644 --- a/trinity/explorer/runner_pool.py +++ b/trinity/explorer/runner_pool.py @@ -43,7 +43,7 @@ def __init__(self, config: Config, models: List): self._pending_submits = [] # create new actors - self.engine_status = [0] * config.explorer.engine_num + self.engine_status = [0] * config.explorer.rollout_model.engine_num self._idle_actors = list() self.actor_to_engine_index = {} self._create_actors(config.explorer.runner_num) @@ -202,7 +202,10 @@ def get_next(self) -> Status: # TODO: balance the model self._return_actor( WorkflowRunner.remote( - self.config, self.models[random.randint(0, self.config.explorer.engine_num - 1)] + self.config, + self.models[ + random.randint(0, self.config.explorer.rollout_model.engine_num - 1) + ], ) ) return_status = Status( diff --git a/trinity/explorer/workflow_runner.py b/trinity/explorer/workflow_runner.py index f5a1c2dc6a..3b36423ff6 100644 --- a/trinity/explorer/workflow_runner.py +++ b/trinity/explorer/workflow_runner.py @@ -44,7 +44,7 @@ def __init__( self.model = model self.model_wrapper = ModelWrapper( model, - config.explorer.engine_type, + config.explorer.rollout_model.engine_type, ) self.auxiliary_models = [] if auxiliary_models is not None: diff --git a/trinity/manager/config_manager.py b/trinity/manager/config_manager.py index f56b8eb363..21b0e57348 100644 --- a/trinity/manager/config_manager.py +++ b/trinity/manager/config_manager.py @@ -117,12 +117,8 @@ def _init_default_config(self): "top_k": -1, "seed": 42, "logprobs": 0, - "backend": "nccl", - "use_ray": False, "gpu_memory_utilization": 0.9, "enable_chunked_prefill": False, - "max_pending_requests": 32, - "max_waiting_steps": 4, "max_timeout": 900, "explorer_max_retry_times": 2, # Synchronizer Configs @@ -176,7 +172,7 @@ def _init_default_config(self): "actor_use_uid": False, "actor_grad_clip": 1.0, "actor_clip_ratio": 0.2, - "actor_entropy_coeff": 0.001, + "actor_entropy_coef": 0.001, "_not_dpo_actor_use_kl_loss": True, "actor_use_kl_loss": True, "actor_kl_loss_coef": 0.001, @@ -624,33 +620,6 @@ def _check_engine_num_and_tp_size(self): "Please ensure that `engine_num * tensor_parallel_size` can be divided by `gpu_per_node` when `node_num > 1`." ) - def _set_repeat_times(self): # TODO - grouped_adv_algorithms = [ - AlgorithmType.GRPO.value, - AlgorithmType.OPMD.value, # TODO: may add rloo - ] - if st.session_state["algorithm_type"] in grouped_adv_algorithms: - min_repeat_times = 2 - st.session_state["repeat_times"] = st.session_state["_grouped_adv_repeat_times"] - else: - min_repeat_times = 1 - st.session_state["repeat_times"] = st.session_state["_not_grouped_adv_repeat_times"] - - def on_change(): - if st.session_state["algorithm_type"] in grouped_adv_algorithms: - st.session_state["_grouped_adv_repeat_times"] = st.session_state["repeat_times"] - else: - st.session_state["_not_grouped_adv_repeat_times"] = st.session_state["repeat_times"] - - st.number_input( - "Repeat Times", - key="repeat_times", - min_value=min_repeat_times, - help="`repeat_times` is used to set how many experiences each task can generate, " - "and it must be greater than `1` when `algorithm_type` is `opmd` or `grpo`.", - on_change=on_change, - ) - def _set_sync_method(self): if st.session_state["algorithm_type"] == AlgorithmType.DPO.value: st.session_state["sync_method"] = SyncMethod.CHECKPOINT.value @@ -693,18 +662,9 @@ def _set_sync_timeout(self): def _set_runner_num(self): st.number_input("Runner Num", key="runner_num", min_value=1) - def _set_max_pending_requests(self): - st.number_input("Max Pending Requests", key="max_pending_requests", min_value=1) - - def _set_max_waiting_steps(self): - st.number_input("Max Waiting Steps", key="max_waiting_steps", min_value=1) - def _set_dtype(self): st.selectbox("Dtype", ["float16", "bfloat16", "float32"], key="dtype") - def _set_backend(self): - st.selectbox("Backend", ["nccl"], key="backend") - def _set_temperature(self): st.number_input("Temperature", key="temperature", min_value=0.0, max_value=2.0) @@ -732,9 +692,6 @@ def _set_enable_prefix_caching(self): def _set_enforce_eager(self): st.checkbox("Enforce Eager", key="enforce_eager") - def _set_use_ray(self): - st.checkbox("Use Ray", key="use_ray") - def _set_gpu_memory_utilization(self): st.number_input( "GPU Memory Utilization", key="gpu_memory_utilization", min_value=0.0, max_value=1.0 @@ -830,6 +787,33 @@ def on_change(): def _set_ppo_epochs(self): st.number_input("PPO Epochs", key="ppo_epochs", min_value=1) + def _set_repeat_times(self): # TODO + grouped_adv_algorithms = [ + AlgorithmType.GRPO.value, + AlgorithmType.OPMD.value, # TODO: may add rloo + ] + if st.session_state["algorithm_type"] in grouped_adv_algorithms: + min_repeat_times = 2 + st.session_state["repeat_times"] = st.session_state["_grouped_adv_repeat_times"] + else: + min_repeat_times = 1 + st.session_state["repeat_times"] = st.session_state["_not_grouped_adv_repeat_times"] + + def on_change(): + if st.session_state["algorithm_type"] in grouped_adv_algorithms: + st.session_state["_grouped_adv_repeat_times"] = st.session_state["repeat_times"] + else: + st.session_state["_not_grouped_adv_repeat_times"] = st.session_state["repeat_times"] + + st.number_input( + "Repeat Times", + key="repeat_times", + min_value=min_repeat_times, + help="`repeat_times` is used to set how many experiences each task can generate, " + "and it must be greater than `1` when `algorithm_type` is `opmd` or `grpo`.", + on_change=on_change, + ) + def _set_training_strategy(self): st.selectbox( "Training Strategy", @@ -978,10 +962,10 @@ def _set_actor_clip_ratio(self): max_value=1.0, ) - def _set_actor_entropy_coeff(self): + def _set_actor_entropy_coef(self): st.number_input( "Entropy Coeff", - key="actor_entropy_coeff", + key="actor_entropy_coef", min_value=0.0, max_value=1.0, format="%.1e", @@ -1241,18 +1225,16 @@ def _expert_explorer_part(self): ["runner_num", "temperature", "top_p", "top_k", "seed", "logprobs"] ) - self._set_configs_with_st_columns(["dtype", "backend", "gpu_memory_utilization"]) + self._set_configs_with_st_columns(["dtype", "gpu_memory_utilization"]) self._set_configs_with_st_columns( [ - "max_pending_requests", - "max_waiting_steps", "max_timeout", "explorer_max_retry_times", ] ) self._set_configs_with_st_columns( - ["enable_prefix_caching", "enforce_eager", "use_ray", "enable_chunked_prefill"] + ["enable_prefix_caching", "enforce_eager", "enable_chunked_prefill"] ) def _expert_trainer_part(self): @@ -1318,7 +1300,7 @@ def _expert_verl_trainer_part(self): ) self._set_configs_with_st_columns( - ["actor_grad_clip", "actor_clip_ratio", "actor_entropy_coeff"] + ["actor_grad_clip", "actor_clip_ratio", "actor_entropy_coef"] ) self._set_actor_use_kl_loss() @@ -1423,7 +1405,7 @@ def _generate_verl_config(self, trainer_nnodes: int = 1, trainer_n_gpus_per_node "ppo_max_token_len_per_gpu": ppo_max_token_len_per_gpu, "grad_clip": st.session_state["actor_grad_clip"], "clip_ratio": st.session_state["actor_clip_ratio"], - "entropy_coeff": st.session_state["actor_entropy_coeff"], + "entropy_coeff": st.session_state["actor_entropy_coef"], "use_kl_loss": st.session_state["actor_use_kl_loss"], "kl_loss_coef": st.session_state["actor_kl_loss_coef"], "kl_loss_type": st.session_state["actor_kl_loss_type"], @@ -1640,23 +1622,25 @@ def generate_config(self): if st.session_state.config_generated: config = { "mode": st.session_state["mode"], - "global_config": { - "total_epochs": st.session_state["total_epochs"], - "batch_size": st.session_state["train_batch_size"], - "eval_interval": st.session_state["eval_interval"], + "project": st.session_state["project"], + "name": st.session_state["name"], + "checkpoint_root_dir": st.session_state["checkpoint_path"], + "algorithm": { "algorithm_type": st.session_state["algorithm_type"], + "repeat_times": st.session_state["repeat_times"], }, "model": { "model_path": st.session_state["model_path"], "max_prompt_tokens": st.session_state["max_prompt_tokens"], "max_response_tokens": st.session_state["max_response_tokens"], - "checkpoint_path": st.session_state["checkpoint_path"], }, "cluster": { "node_num": st.session_state["node_num"], "gpu_per_node": st.session_state["gpu_per_node"], }, "buffer": { + "total_epochs": st.session_state["total_epochs"], + "batch_size": st.session_state["train_batch_size"], "max_retry_times": st.session_state["buffer_max_retry_times"], "max_retry_interval": st.session_state["max_retry_interval"], "explorer_input": { @@ -1671,7 +1655,7 @@ def generate_config(self): "response_key": st.session_state["taskset_response_key"], }, "rollout_args": { - "repeat_times": st.session_state["repeat_times"], + "n": st.session_state["repeat_times"], "temperature": st.session_state["temperature"], "top_p": st.session_state["top_p"], "top_k": st.session_state["top_k"], @@ -1690,9 +1674,11 @@ def generate_config(self): "storage_type": st.session_state["storage_type"], "path": experience_buffer_path, }, + "sft_warmup_steps": st.session_state["sft_warmup_steps"], }, }, "explorer": { + "eval_interval": st.session_state["eval_interval"], "engine_type": st.session_state["engine_type"], "engine_num": st.session_state["engine_num"], "runner_num": st.session_state["runner_num"], @@ -1702,13 +1688,9 @@ def generate_config(self): "enforce_eager": st.session_state["enforce_eager"], "dtype": st.session_state["dtype"], "seed": st.session_state["seed"], - "backend": st.session_state["backend"], - "use_ray": st.session_state["use_ray"], "gpu_memory_utilization": st.session_state["gpu_memory_utilization"], "enable_chunked_prefill": st.session_state["enable_chunked_prefill"], "use_v1": True, - "max_pending_requests": st.session_state["max_pending_requests"], - "max_waiting_steps": st.session_state["max_waiting_steps"], "max_timeout": st.session_state["max_timeout"], "max_retry_times": st.session_state["explorer_max_retry_times"], }, @@ -1720,12 +1702,9 @@ def generate_config(self): "trainer": { "trainer_type": st.session_state["trainer_type"], "trainer_config": trainer_config, - "sft_warmup_steps": st.session_state["sft_warmup_steps"], "save_interval": st.session_state["save_interval"], }, "monitor": { - "project": st.session_state["project"], - "name": st.session_state["exp_name"], "monitor_type": st.session_state["monitor_type"], }, } diff --git a/trinity/manager/manager.py b/trinity/manager/manager.py index 8d49c4cb1a..3c148cbe12 100644 --- a/trinity/manager/manager.py +++ b/trinity/manager/manager.py @@ -13,7 +13,7 @@ class CacheManager: """A Manager class for managing the cache dir.""" def __init__(self, config: Config, check_config: bool = False): - self.cache_dir = config.monitor.job_dir # type: ignore + self.cache_dir = config.monitor.cache_dir # type: ignore self.explorer_meta_path = os.path.join(self.cache_dir, "explorer_meta.json") # type: ignore self.trainer_meta_path = os.path.join(self.cache_dir, "trainer_meta.json") # type: ignore if check_config: diff --git a/trinity/trainer/trainer.py b/trinity/trainer/trainer.py index b98c47e729..36d23e7628 100644 --- a/trinity/trainer/trainer.py +++ b/trinity/trainer/trainer.py @@ -14,7 +14,7 @@ from trinity.buffer import get_buffer_reader from trinity.common.config import Config -from trinity.common.constants import AlgorithmType, ReadStrategy, SyncMethod +from trinity.common.constants import AlgorithmType, SyncMethod from trinity.common.experience import Experiences from trinity.utils.log import get_logger @@ -35,7 +35,7 @@ def __init__(self, config: Config) -> None: self.config.buffer.trainer_input.sft_warmup_dataset, # type: ignore self.config.buffer, ) - if self.config.trainer.sft_warmup_steps > 0 + if self.config.buffer.trainer_input.sft_warmup_steps > 0 else None ) self.engine = get_trainer_wrapper(config) @@ -74,8 +74,8 @@ def train_step(self, algo_type: AlgorithmType = AlgorithmType.PPO) -> Tuple[bool bool: Whether to continue training. """ self.engine.set_mode(algo_type) - if algo_type.is_rft() and self.config.trainer.get_exp_strategy: - strategy = ReadStrategy(self.config.trainer.get_exp_strategy) + if algo_type.is_rft() and self.config.buffer.trainer_input.read_experience_strategy: + strategy = self.config.buffer.trainer_input.read_experience_strategy else: strategy = None try: @@ -123,7 +123,7 @@ def flush_log(self, step: int) -> None: def shutdown(self) -> None: # if checkpoint not saved, save the last checkpoint step_num = self.engine.global_steps - 1 - path = os.path.join(self.config.model.checkpoint_path, f"global_step_{step_num}") + path = os.path.join(self.config.checkpoint_job_dir, f"global_step_{step_num}") if not os.path.isdir(path) or len(os.listdir(path)) == 0: self.engine.save_checkpoint() self.engine.logger.close() diff --git a/trinity/trainer/verl/fsdp_workers.py b/trinity/trainer/verl/fsdp_workers.py index 3cd1a53e13..26b640e871 100644 --- a/trinity/trainer/verl/fsdp_workers.py +++ b/trinity/trainer/verl/fsdp_workers.py @@ -585,7 +585,6 @@ def setup_weight_sync_group(self): master_address, master_port = self.get_availale_master_addr_port() world_size = self.config.synchronizer.explorer_world_size + 1 - backend = self.config.synchronizer.backend print(f"Trainer init_process_group {master_address}:{master_port} ({world_size}).") explorer = ray.get_actor("explorer") group_name = "rollout_weight_sync" @@ -600,7 +599,7 @@ def setup_weight_sync_group(self): timeout = self.config.synchronizer.sync_timeout self._model_update_group = init_process_group( - backend=backend, + backend="nccl", init_method=init_method, timeout=timeout, world_size=world_size, diff --git a/trinity/trainer/verl_trainer.py b/trinity/trainer/verl_trainer.py index 1f4f9ddfb8..090a5ff881 100644 --- a/trinity/trainer/verl_trainer.py +++ b/trinity/trainer/verl_trainer.py @@ -73,7 +73,6 @@ def __init__( global_config: Config, ): train_config = global_config.trainer - pprint(train_config.trainer_config) config = OmegaConf.structured(train_config.trainer_config) # download the checkpoint from hdfs local_path = copy_local_path_from_hdfs(config.actor_rollout_ref.model.path) @@ -305,7 +304,7 @@ def train_sft_step(self, experiences: Experiences) -> Tuple[bool, int]: if self.sft_warmup_step_num == self.config.trainer.sft_warmup_steps: self.logger.log( data={"sft_warmup_steps": self.sft_warmup_step_num}, - step=self.global_steps, + step=self.global_steps - 1, ) with _timer("save_checkpoint", timing_raw): self._save_checkpoint() diff --git a/trinity/utils/monitor.py b/trinity/utils/monitor.py index f4c0db6372..3044c6dcc8 100644 --- a/trinity/utils/monitor.py +++ b/trinity/utils/monitor.py @@ -1,12 +1,13 @@ """Monitor""" import os -from typing import Any, List, Optional, Union +from typing import List, Optional, Union import numpy as np import pandas as pd import wandb from torch.utils.tensorboard import SummaryWriter +from trinity.common.config import Config from trinity.common.constants import MonitorType from trinity.utils.log import get_logger @@ -19,7 +20,7 @@ def __init__( project: str, name: str, role: str, - config: Any = None, + config: Config = None, # pass the global Config for recording ) -> None: if config.monitor.monitor_type == MonitorType.WANDB: self.logger = WandbLogger(project, name, role, config) @@ -59,8 +60,8 @@ def close(self) -> None: class TensorboardLogger: - def __init__(self, project: str, name: str, role: str, config: Any = None) -> None: - self.tensorboard_dir = os.path.join(config.monitor.job_dir, "tensorboard") + def __init__(self, project: str, name: str, role: str, config: Config = None) -> None: + self.tensorboard_dir = os.path.join(config.monitor.cache_dir, "tensorboard") os.makedirs(self.tensorboard_dir, exist_ok=True) self.logger = SummaryWriter(self.tensorboard_dir) self.console_logger = get_logger(__name__) @@ -81,7 +82,7 @@ def __del__(self) -> None: class WandbLogger: - def __init__(self, project: str, name: str, role: str, config: Any = None) -> None: + def __init__(self, project: str, name: str, role: str, config: Config = None) -> None: self.logger = wandb.init( project=project, group=name,