diff --git a/docs/sphinx_doc/source/conf.py b/docs/sphinx_doc/source/conf.py index 605ea1002c..be97e1a131 100644 --- a/docs/sphinx_doc/source/conf.py +++ b/docs/sphinx_doc/source/conf.py @@ -40,7 +40,6 @@ templates_path = ["_templates"] exclude_patterns = ["build"] -autodoc_mock_imports = ["ray"] autodoc_default_options = { "members": True, diff --git a/docs/sphinx_doc/source/tutorial/example_dpo.md b/docs/sphinx_doc/source/tutorial/example_dpo.md index 26451ab982..ab28f9802a 100644 --- a/docs/sphinx_doc/source/tutorial/example_dpo.md +++ b/docs/sphinx_doc/source/tutorial/example_dpo.md @@ -40,13 +40,13 @@ Note that the dataset has the keys `prompt`, `chosen` and `rejected`. If not, pa We use the configurations in [`dpo.yaml`](https://github.com/modelscope/Trinity-RFT/tree/main/examples/dpo_humanlike/dpo.yaml) and [`train_dpo.yaml`](https://github.com/modelscope/Trinity-RFT/tree/main/examples/dpo_humanlike/train_dpo.yaml) for this experiment. Some important setups are listed in the following: -We run the experiment in a train mode, as there is no Explorer. To enable this mode, we config `mode` to `train` and set `sync_method` to `offline`. The value of `sync_iteration_interval` can be set as same of the value of `save_freq`. +We run the experiment in a train mode, as there is no Explorer. To enable this mode, we config `mode` to `train` and set `sync_method` to `checkpoint`. The value of `sync_iteration_interval` can be set as same of the value of `save_interval`. ```yaml # In dpo.yaml mode: train synchronizer: - sync_method: 'offline' + sync_method: 'checkpoint' buffer: train_dataset: storage_type: file @@ -63,7 +63,6 @@ trainer: # In train_dpo.yaml actor_rollout_ref: actor: - alg_type: dpo use_kl_loss: True kl_loss_coef: 0.1 # value of beta in DPO ``` diff --git a/docs/sphinx_doc/source/tutorial/example_reasoning_basic.md b/docs/sphinx_doc/source/tutorial/example_reasoning_basic.md index 884efd6d21..ba0b013cdd 100644 --- a/docs/sphinx_doc/source/tutorial/example_reasoning_basic.md +++ b/docs/sphinx_doc/source/tutorial/example_reasoning_basic.md @@ -42,7 +42,7 @@ We run the experiment in a synchronous mode where the Explorer and Trainer opera ```yaml mode: both synchronizer: - sync_method: 'online' + sync_method: 'nccl' sync_iteration_interval: 2 ``` diff --git a/docs/sphinx_doc/source/tutorial/trinity_configs.md b/docs/sphinx_doc/source/tutorial/trinity_configs.md index 0be1153630..de6f26252d 100644 --- a/docs/sphinx_doc/source/tutorial/trinity_configs.md +++ b/docs/sphinx_doc/source/tutorial/trinity_configs.md @@ -15,17 +15,6 @@ monitor: - `monitor.name`: The name of the experiment. It must be set manually. -## Monitor - -```yaml -monitor: - project: "Trinity-RFT-countdown" - name: "qwen2.5-1.5B-countdown" -``` - -- `monitor.project`: The project name. It must be set manually. -- `monitor.name`: The name of the experiment. It must be set manually. - ## Data @@ -131,8 +120,6 @@ explorer: enforce_eager: true dtype: bfloat16 temperature: 1.0 - top_p: 1.0 - top_k: -1 seed: 42 logprobs: 0 repeat_times: 5 @@ -150,8 +137,6 @@ explorer: - `explorer.enforce_eager`: Whether to enforce eager mode. Default is `True`. - `explorer.dtype`: The data type used in vLLM. Default is `bfloat16`. - `explorer.temperature`: The temperature used in vLLM. Default is `1.0`. -- `explorer.top_p`: The top-p used in vLLM. Default is `1.0`. -- `explorer.top_k`: The top-k used in vLLM. Default is `-1`. - `explorer.seed`: The seed used in vLLM. Default is `42`. - `explorer.logprobs`: The logprobs used in vLLM. Default is `0`. - `explorer.repeat_times`: The number of times to repeat each task, used for GRPO-like algorithms. Default is `5`. @@ -164,12 +149,16 @@ explorer: ```yaml synchronizer: - sync_method: 'online' + sync_method: 'nccl' sync_iteration_interval: 10 + sync_timeout: 1200 ``` -- `synchronizer.sync_method`: The synchronization method, Support `online` and `offline`. Default is `online`. +- `synchronizer.sync_method`: The synchronization method between `trainer` and `explorer`. +Support `nccl` and `checkpoint`, `nccl` represents that model weights in `explorer` will be synchronized from `trainer` through `nccl`, +`checkpoint` represents that `explorer` will load the newest checkpoints saved by `trainer` then update its model weights. Default is `nccl`. - `synchronizer.sync_iteration_interval`: The interval between two synchronizations. Default is `10`. It should be set manually. +- `synchronizer.sync_timeout`: The timeout of the synchronization. Default is `1200`. ## Trainer @@ -180,6 +169,7 @@ trainer: trainer_config_path: 'examples/ppo_countdown/train_countdown.yaml' sft_warmup_iteration: 0 eval_interval: 1000 + save_interval: 100 ``` - `trainer.trainer_type`: The backend of the trainer, Only `verl` is supported. @@ -187,6 +177,7 @@ trainer: - `trainer.trainer_config_path`: The path to the trainer configuration file. It must be set manually. - `trainer.sft_warmup_iteration`: The number of iterations to warm up the model. Default is `0`. - `trainer.eval_interval`: The interval between two evaluations. Default is `1000`. +- `trainer.save_interval`: The interval between two checkpoints. Default is `100`. ### veRL Trainer Configuration @@ -249,7 +240,6 @@ actor_rollout_ref: optimizer_offload: False fsdp_size: -1 # --- below: opmd --- - alg_type: ppo # ppo / opmd / pairwise_opmd tau: 0.000 # strength of regularization w.r.t. old / ref policy opmd_baseline: mean # mean / logavgexp, applicable to opmd use_uid: False # True / False, applicable to pairwise_opmd @@ -403,7 +393,6 @@ trainer: - `actor_rollout_ref.actor.kl_loss_coef`: The coefficient of kl loss. - `actor_rollout_ref.actor.kl_loss_type`: How to compute kl loss, optional value is `kl`, `abs`, `mse` or `low_var_kl`. - `actor_rollout_ref.actor.ulysses_sequence_parallel_size`: Ulysses sequence parallel size. -- `actor_rollout_ref.actor.alg_type`: Used for OPMD, optional value is `ppo`, `opmd` or `pairwise_opmd`. - `actor_rollout_ref.actor.tau`: strength of regularization w.r.t. old / ref policy. - `actor_rollout_ref.actor.opmd_baseline`: mean / logavgexp, applicable to opmd. - `actor_rollout_ref.actor.use_uid`: True / False, applicable to pairwise_opmd. @@ -427,7 +416,6 @@ trainer: - `algorithm`: Training algorithm settings. - `trainer.balance_batch`: Whether to balance batch size between GPUs during training. -- `trainer.save_freq`: Frequency of saving checkpoints. - `trainer.resume_mode`: Resume mode for training. Support `disable`, `auto` and `resume_path`. - `trainer.resume_from_path`: Path to resume from. - `trainer.critic_warmup`: The number of iteration to train the critic model before actual policy learning. diff --git a/examples/dpo_humanlike/dpo.yaml b/examples/dpo_humanlike/dpo.yaml index 29c49bb1a0..5d03d7130c 100644 --- a/examples/dpo_humanlike/dpo.yaml +++ b/examples/dpo_humanlike/dpo.yaml @@ -37,8 +37,6 @@ explorer: enforce_eager: true dtype: bfloat16 temperature: 1.0 - top_p: 1.0 - top_k: -1 seed: 42 logprobs: 0 repeat_times: 1 # NOTE @@ -47,12 +45,14 @@ explorer: max_pending_requests: 32 max_waiting_steps: 4 synchronizer: - sync_method: 'offline' + sync_method: 'checkpoint' sync_iteration_interval: 30 + sync_timeout: 1200 trainer: trainer_type: 'verl' algorithm_type: dpo trainer_config_path: 'examples/dpo_humanlike/train_dpo.yaml' + save_interval: 30 monitor: cache_root_dir: "" project: "dpo_example" diff --git a/examples/dpo_humanlike/train_dpo.yaml b/examples/dpo_humanlike/train_dpo.yaml index 4da9a7ddb5..65d373b4fa 100644 --- a/examples/dpo_humanlike/train_dpo.yaml +++ b/examples/dpo_humanlike/train_dpo.yaml @@ -23,7 +23,6 @@ actor_rollout_ref: enable_gradient_checkpointing: True use_remove_padding: False actor: - alg_type: dpo strategy: fsdp # This is for backward-compatibility ppo_mini_batch_size: 32 # ppo_micro_batch_size: 8 # will be deprecated, use ppo_micro_batch_size_per_gpu @@ -170,7 +169,6 @@ trainer: val_generations_to_log_to_wandb: 0 nnodes: 1 n_gpus_per_node: 2 - save_freq: 30 # auto: find the last ckpt to resume. If can't find, start from scratch resume_mode: auto # or auto or resume_path if test_freq: 5 diff --git a/examples/grpo_alfworld/alfworld.yaml b/examples/grpo_alfworld/alfworld.yaml index 0258bfdbf0..1881f78d36 100644 --- a/examples/grpo_alfworld/alfworld.yaml +++ b/examples/grpo_alfworld/alfworld.yaml @@ -31,8 +31,6 @@ explorer: enforce_eager: true dtype: bfloat16 temperature: 1.0 - top_p: 1.0 - top_k: -1 seed: 42 logprobs: 0 repeat_times: 8 @@ -43,12 +41,14 @@ explorer: gpu_memory_utilization: 0.7 enable_chunked_prefil: true synchronizer: - sync_method: 'online' + sync_method: 'nccl' sync_iteration_interval: 8 + sync_timeout: 1200 trainer: trainer_type: 'verl' algorithm_type: ppo trainer_config_path: 'examples/grpo_alfworld/train_alfworld.yaml' + save_interval: 10 monitor: cache_root_dir: "" project: "ALFWORLD" diff --git a/examples/grpo_alfworld/train_alfworld.yaml b/examples/grpo_alfworld/train_alfworld.yaml index 88f151bdcb..a210c39916 100644 --- a/examples/grpo_alfworld/train_alfworld.yaml +++ b/examples/grpo_alfworld/train_alfworld.yaml @@ -169,7 +169,6 @@ trainer: val_generations_to_log_to_wandb: 0 nnodes: 1 n_gpus_per_node: 2 - save_freq: 1 # auto: find the last ckpt to resume. If can't find, start from scratch resume_mode: auto # or auto or resume_path if test_freq: 100 diff --git a/examples/grpo_gsm8k/gsm8k.yaml b/examples/grpo_gsm8k/gsm8k.yaml index fd6a9b5c44..9dd620c0d7 100644 --- a/examples/grpo_gsm8k/gsm8k.yaml +++ b/examples/grpo_gsm8k/gsm8k.yaml @@ -51,8 +51,6 @@ explorer: enforce_eager: true dtype: bfloat16 temperature: 1.0 - top_p: 1.0 - top_k: -1 seed: 42 logprobs: 0 repeat_times: 8 @@ -61,14 +59,16 @@ explorer: max_pending_requests: 32 max_waiting_steps: 4 synchronizer: - sync_method: 'online' + sync_method: 'nccl' sync_iteration_interval: 2 + sync_timeout: 1200 trainer: trainer_type: 'verl' algorithm_type: ppo trainer_config_path: 'examples/grpo_gsm8k/train_gsm8k.yaml' sft_warmup_iteration: 0 # Set to integer to enable sft warmup eval_interval: 50 + save_interval: 100 # get_exp_strategy: 'LFU' monitor: cache_root_dir: "" diff --git a/examples/grpo_gsm8k/train_gsm8k.yaml b/examples/grpo_gsm8k/train_gsm8k.yaml index 2e8365c6cb..13b195f557 100644 --- a/examples/grpo_gsm8k/train_gsm8k.yaml +++ b/examples/grpo_gsm8k/train_gsm8k.yaml @@ -52,7 +52,6 @@ actor_rollout_ref: optimizer_offload: False fsdp_size: -1 # --- below: opmd --- - alg_type: ppo # ppo / opmd / pairwise_opmd tau: 0.000 # strength of regularization w.r.t. old / ref policy opmd_baseline: mean # mean / logavgexp, applicable to opmd use_uid: False # True / False, applicable to pairwise_opmd @@ -174,7 +173,6 @@ trainer: val_generations_to_log_to_wandb: 0 nnodes: 1 n_gpus_per_node: 2 - save_freq: 100 # auto: find the last ckpt to resume. If can't find, start from scratch resume_mode: auto # or auto or resume_path if test_freq: 5 diff --git a/examples/grpo_math/math.yaml b/examples/grpo_math/math.yaml index 07ea448548..db6a347bc9 100644 --- a/examples/grpo_math/math.yaml +++ b/examples/grpo_math/math.yaml @@ -37,8 +37,6 @@ explorer: enforce_eager: true dtype: bfloat16 temperature: 1.0 - top_p: 1.0 - top_k: -1 seed: 42 logprobs: 0 repeat_times: 8 @@ -47,14 +45,16 @@ explorer: max_pending_requests: 32 max_waiting_steps: 4 synchronizer: - sync_method: 'online' + sync_method: 'nccl' sync_iteration_interval: 2 + sync_timeout: 1200 trainer: trainer_type: 'verl' algorithm_type: ppo trainer_config_path: 'examples/grpo_math/train_math.yaml' sft_warmup_iteration: 0 # Set to integer to enable sft warmup eval_interval: 10 + save_interval: 100 monitor: cache_root_dir: "" project: grpo_math diff --git a/examples/grpo_math/train_math.yaml b/examples/grpo_math/train_math.yaml index 937c19657e..2482ccc785 100644 --- a/examples/grpo_math/train_math.yaml +++ b/examples/grpo_math/train_math.yaml @@ -51,7 +51,6 @@ actor_rollout_ref: optimizer_offload: False fsdp_size: -1 # --- below: opmd --- - alg_type: ppo # ppo / opmd / pairwise_opmd tau: 0.000 # strength of regularization w.r.t. old / ref policy opmd_baseline: mean # mean / logavgexp, applicable to opmd use_uid: False # True / False, applicable to pairwise_opmd @@ -166,7 +165,6 @@ trainer: val_generations_to_log_to_wandb: 0 nnodes: 1 n_gpus_per_node: 2 - save_freq: 100 # auto: find the last ckpt to resume. If can't find, start from scratch resume_mode: auto # or auto or resume_path if test_freq: 5 diff --git a/examples/grpo_sciworld/sciworld.yaml b/examples/grpo_sciworld/sciworld.yaml index 25b5dfa073..53dbdea801 100644 --- a/examples/grpo_sciworld/sciworld.yaml +++ b/examples/grpo_sciworld/sciworld.yaml @@ -31,8 +31,6 @@ explorer: enforce_eager: true dtype: bfloat16 temperature: 1.0 - top_p: 1.0 - top_k: -1 seed: 42 logprobs: 0 repeat_times: 8 @@ -43,12 +41,14 @@ explorer: gpu_memory_utilization: 0.7 enable_chunked_prefil: true synchronizer: - sync_method: 'online' + sync_method: 'nccl' sync_iteration_interval: 8 + sync_timeout: 1200 trainer: trainer_type: 'verl' algorithm_type: ppo trainer_config_path: 'examples/grpo_sciworld/train_sciworld.yaml' + save_interval: 10 monitor: cache_root_dir: "" project: "sciworld" diff --git a/examples/grpo_sciworld/train_sciworld.yaml b/examples/grpo_sciworld/train_sciworld.yaml index 330b659afb..833441142c 100644 --- a/examples/grpo_sciworld/train_sciworld.yaml +++ b/examples/grpo_sciworld/train_sciworld.yaml @@ -164,7 +164,6 @@ trainer: val_generations_to_log_to_wandb: 0 nnodes: 1 n_gpus_per_node: 2 - save_freq: 1 # auto: find the last ckpt to resume. If can't find, start from scratch resume_mode: auto # or auto or resume_path if test_freq: 100 diff --git a/examples/grpo_webshop/train_webshop.yaml b/examples/grpo_webshop/train_webshop.yaml index 0ae8675f50..ac502fec3f 100644 --- a/examples/grpo_webshop/train_webshop.yaml +++ b/examples/grpo_webshop/train_webshop.yaml @@ -169,7 +169,6 @@ trainer: val_generations_to_log_to_wandb: 0 nnodes: 1 n_gpus_per_node: 2 - save_freq: 1 # auto: find the last ckpt to resume. If can't find, start from scratch resume_mode: auto # or auto or resume_path if test_freq: 100 diff --git a/examples/grpo_webshop/webshop.yaml b/examples/grpo_webshop/webshop.yaml index eb9916018d..a301140c07 100644 --- a/examples/grpo_webshop/webshop.yaml +++ b/examples/grpo_webshop/webshop.yaml @@ -31,8 +31,6 @@ explorer: enforce_eager: true dtype: bfloat16 temperature: 1.0 - top_p: 1.0 - top_k: -1 seed: 42 logprobs: 0 repeat_times: 8 @@ -43,12 +41,14 @@ explorer: gpu_memory_utilization: 0.7 enable_chunked_prefil: true synchronizer: - sync_method: 'online' + sync_method: 'nccl' sync_iteration_interval: 8 + sync_timeout: 1200 trainer: trainer_type: 'verl' algorithm_type: ppo trainer_config_path: 'examples/grpo_webshop/train_webshop.yaml' + save_interval: 10 monitor: cache_root_dir: "" project: "WEBSHOP" diff --git a/examples/opmd_gsm8k/opmd_gsm8k.yaml b/examples/opmd_gsm8k/opmd_gsm8k.yaml index 6cde601158..d5d60f7126 100644 --- a/examples/opmd_gsm8k/opmd_gsm8k.yaml +++ b/examples/opmd_gsm8k/opmd_gsm8k.yaml @@ -30,8 +30,6 @@ explorer: enforce_eager: true dtype: bfloat16 temperature: 1.0 - top_p: 1.0 - top_k: -1 seed: 42 logprobs: 0 repeat_times: 8 @@ -40,13 +38,15 @@ explorer: max_pending_requests: 32 max_waiting_steps: 4 synchronizer: - sync_method: 'online' + sync_method: 'nccl' sync_iteration_interval: 10 + sync_timeout: 1200 trainer: trainer_type: 'verl' algorithm_type: opmd trainer_config_path: 'examples/opmd_gsm8k/train_opmd_gsm8k.yaml' sft_warmup_iteration: 0 + save_interval: 100 monitor: cache_root_dir: "" project: "Trinity-RFT-gsm8k-test-opmd" diff --git a/examples/opmd_gsm8k/train_opmd_gsm8k.yaml b/examples/opmd_gsm8k/train_opmd_gsm8k.yaml index 97384e57c3..033405f5c8 100644 --- a/examples/opmd_gsm8k/train_opmd_gsm8k.yaml +++ b/examples/opmd_gsm8k/train_opmd_gsm8k.yaml @@ -1,7 +1,6 @@ # Configs of particular interest for min-opmd and off-policy scenarios: # # parameters specific to min-opmd -# alg_type: opmd # tau: 1.0 # strength of regularization w.r.t. ref policy # opmd_baseline: mean # must be "mean" for min-opmd # use_uid: False # applicable to pairwise-opmd, not min-opmd @@ -20,7 +19,7 @@ # lr: set smaller to account for beta1 = 0.0 # # misc: -# adv_estimator: grpo # merely to disable critic model, doesn't affect adv compute when alg_type is opmd +# adv_estimator: grpo # merely to disable critic model, doesn't affect adv compute when algorithm_type is opmd data: @@ -79,7 +78,6 @@ actor_rollout_ref: optimizer_offload: False fsdp_size: -1 # --- below: opmd --- - alg_type: opmd # ppo / opmd / pairwise_opmd tau: 4.0 # strength of regularization w.r.t. old / ref policy opmd_baseline: mean # mean / logavgexp, applicable to opmd use_uid: False # True / False, applicable to pairwise_opmd @@ -201,7 +199,6 @@ trainer: val_generations_to_log_to_wandb: 0 nnodes: 1 n_gpus_per_node: 2 - save_freq: 100 # auto: find the last ckpt to resume. If can't find, start from scratch resume_mode: auto # or auto or resume_path if test_freq: 100 diff --git a/examples/ppo_countdown/countdown.yaml b/examples/ppo_countdown/countdown.yaml index 9282a7d1a0..acc9c7950e 100644 --- a/examples/ppo_countdown/countdown.yaml +++ b/examples/ppo_countdown/countdown.yaml @@ -33,8 +33,6 @@ explorer: enforce_eager: true dtype: bfloat16 temperature: 1.0 - top_p: 1.0 - top_k: -1 seed: 42 logprobs: 0 repeat_times: 5 @@ -43,14 +41,16 @@ explorer: max_pending_requests: 32 max_waiting_steps: 4 synchronizer: - sync_method: 'online' + sync_method: 'nccl' sync_iteration_interval: 10 + sync_timeout: 1200 trainer: trainer_type: 'verl' algorithm_type: ppo trainer_config_path: 'examples/ppo_countdown/train_countdown.yaml' sft_warmup_iteration: 0 eval_interval: 1000 + save_interval: 100 monitor: cache_root_dir: "" project: "Trinity-RFT-countdown" diff --git a/examples/ppo_countdown/train_countdown.yaml b/examples/ppo_countdown/train_countdown.yaml index 70872fd7f1..291afe452f 100644 --- a/examples/ppo_countdown/train_countdown.yaml +++ b/examples/ppo_countdown/train_countdown.yaml @@ -54,7 +54,6 @@ actor_rollout_ref: optimizer_offload: False fsdp_size: -1 # --- below: opmd --- - alg_type: ppo # ppo / opmd / pairwise_opmd tau: 0.000 # strength of regularization w.r.t. old / ref policy opmd_baseline: mean # mean / logavgexp, applicable to opmd use_uid: False # True / False, applicable to pairwise_opmd @@ -176,7 +175,6 @@ trainer: val_generations_to_log_to_wandb: 0 nnodes: 1 n_gpus_per_node: 2 - save_freq: 100 # auto: find the last ckpt to resume. If can't find, start from scratch resume_mode: auto # or auto or resume_path if test_freq: 100 diff --git a/tests/explorer/runner_pool_test.py b/tests/explorer/runner_pool_test.py index 2ee03fbc2a..7513d0c4d3 100644 --- a/tests/explorer/runner_pool_test.py +++ b/tests/explorer/runner_pool_test.py @@ -53,7 +53,8 @@ def init_process_group( world_size: int, group_name: str, backend: str = "nccl", - offline_update: bool = True, + timeout: int = 1200, + update_with_checkpoint: bool = True, ) -> None: pass diff --git a/tests/template/config.yaml b/tests/template/config.yaml index 0eb84f4fb7..81f679f1fd 100644 --- a/tests/template/config.yaml +++ b/tests/template/config.yaml @@ -28,8 +28,6 @@ explorer: enforce_eager: true dtype: bfloat16 temperature: 1.0 - top_p: 1.0 - top_k: -1 seed: 42 logprobs: 0 backend: nccl @@ -37,10 +35,14 @@ explorer: trainer: trainer_type: verl trainer_config_path: tests/template/verl_config.yaml + sft_warmup_iteration: 0 + eval_interval: 1000 + save_interval: 100 monitor: project: unittest name: test synchronizer: - sync_method: offline + sync_method: checkpoint sync_iteration_interval: 10 + sync_timeout: 1200 wait_for_checkpoint: false diff --git a/tests/template/verl_config.yaml b/tests/template/verl_config.yaml index c902b0d98e..d1e84cb455 100644 --- a/tests/template/verl_config.yaml +++ b/tests/template/verl_config.yaml @@ -37,7 +37,6 @@ actor_rollout_ref: optimizer_offload: False fsdp_size: -1 # --- below: opmd --- - alg_type: ppo # ppo / opmd / pairwise_opmd tau: 0.000 # strength of regularization w.r.t. old / ref policy opmd_baseline: mean # mean / logavgexp, applicable to opmd use_uid: False # True / False, applicable to pairwise_opmd diff --git a/tests/test_data/template.yaml b/tests/test_data/template.yaml index 4a66d9e911..56f6220d75 100644 --- a/tests/test_data/template.yaml +++ b/tests/test_data/template.yaml @@ -9,7 +9,6 @@ cluster: node_num: 1 gpu_per_node: 8 buffer: - read_batch_size: 32 max_retry_times: 3 max_retry_interval: 1 explorer: @@ -21,7 +20,5 @@ explorer: enforce_eager: true dtype: bfloat16 temperature: 1.0 - top_p: 1.0 - top_k: -1 seed: 42 logprobs: 0 diff --git a/tests/trainer/trainer_test.py b/tests/trainer/trainer_test.py index a62d8fe5d1..4bb1f88685 100644 --- a/tests/trainer/trainer_test.py +++ b/tests/trainer/trainer_test.py @@ -15,7 +15,7 @@ get_unittest_dataset_config, ) from trinity.cli.launcher import both -from trinity.common.constants import MonitorType +from trinity.common.constants import MonitorType, SyncMethod class BaseTrainerCase(RayUnittestBase): @@ -30,7 +30,7 @@ def setUp(self): get_checkpoint_path(), f"train-{datetime.now().strftime('%Y%m%d%H%M%S')}" ) self.config.synchronizer.sync_iteration_interval = 2 - self.config.synchronizer.sync_method = "online" + self.config.synchronizer.sync_method = SyncMethod.NCCL self.config.explorer.eval_interval = 4 self.config.trainer.eval_interval = 4 diff --git a/trinity/buffer/reader/file_reader.py b/trinity/buffer/reader/file_reader.py index dbb9565b30..21eba09d04 100644 --- a/trinity/buffer/reader/file_reader.py +++ b/trinity/buffer/reader/file_reader.py @@ -76,6 +76,10 @@ def read(self) -> List: for prompt_messages, response_messages in zip( batch_data[self.prompt_key], batch_data[self.response_key] ): + if not isinstance(prompt_messages, list): + prompt_messages = [prompt_messages] + if not isinstance(response_messages, list): + response_messages = [response_messages] full_messages = prompt_messages + response_messages tokens = self.tokenizer.apply_chat_template( diff --git a/trinity/common/config.py b/trinity/common/config.py index 89693f11d3..6dc3510ddc 100644 --- a/trinity/common/config.py +++ b/trinity/common/config.py @@ -6,7 +6,13 @@ from omegaconf import OmegaConf -from trinity.common.constants import AlgorithmType, MonitorType, PromptType, StorageType +from trinity.common.constants import ( + AlgorithmType, + MonitorType, + PromptType, + StorageType, + SyncMethod, +) from trinity.utils.log import get_logger logger = get_logger(__name__) @@ -110,13 +116,12 @@ class DatasetConfig: class BufferConfig: """Config for experience buffer.""" - db_url: Optional[str] = None + db_url: Optional[str] = None # Is deprecated, please set `buffer.train_dataset.path` instead. read_batch_size: int = 32 max_retry_times: int = 3 max_retry_interval: int = 1 tokenizer_path: Optional[str] = None pad_token_id: Optional[int] = None - reset_consumed: Optional[bool] = False train_dataset: Optional[DatasetConfig] = None sft_warmup_dataset: Optional[DatasetConfig] = None @@ -180,6 +185,7 @@ class TrainerConfig: trainer_type: str = "verl" trainer_config_path: str = "" eval_interval: int = 100 + save_interval: int = 0 enable_preview: bool = True # enable rollout preview in wandb trainer_config: Any = field(default_factory=dict) @@ -209,11 +215,11 @@ class MonitorConfig: class SynchronizerConfig: """Configs for model weight synchronization""" - # only support `offline` for now # TODO: rename to "checkpoint", "nccl", "ipc" - sync_method: str = "offline" + sync_method: SyncMethod = SyncMethod.NCCL # sync weights every `sync_iteration_interval` iterations sync_iteration_interval: int = 1 + sync_timeout: int = 1200 # wait for the lastest checkpoint to be ready wait_for_checkpoint: bool = False master_address: Optional[str] = None @@ -246,6 +252,10 @@ def _check_buffer(self) -> None: raise ValueError( "buffer.sft_warmup_dataset is required when trainer.sft_warmup_iteration > 0" ) + if self.buffer.db_url: + raise ValueError( + "`buffer.db_url` is deprecated, please set `buffer.train_dataset.path` instead." + ) if self.buffer.pad_token_id is None: from transformers import AutoTokenizer @@ -265,10 +275,17 @@ def _check_buffer(self) -> None: name="experience_buffer", storage_type=StorageType.QUEUE, algorithm_type=self.trainer.algorithm_type, - path=self.buffer.db_url, ) - logger.info(f"Auto set buffer.train_dataset to {self.buffer.train_dataset}") - else: + logger.info(f"Auto set `buffer.train_dataset` to {self.buffer.train_dataset}") + else: # TODO: to be check + if self.mode == "train" and self.trainer.algorithm_type == AlgorithmType.DPO: + if self.buffer.train_dataset is None and self.data.dataset_path.strip(): + self.buffer.train_dataset = DatasetConfig( + name="dpo_train_dataset", + storage_type=StorageType.FILE, + algorithm_type=self.trainer.algorithm_type, + ) + logger.info(f"Auto set `buffer.train_dataset` to {self.buffer.train_dataset}") if self.buffer.train_dataset is None: raise ValueError("buffer.train_dataset is required when mode is not 'both'") self.buffer.train_dataset.algorithm_type = self.trainer.algorithm_type @@ -276,30 +293,13 @@ def _check_buffer(self) -> None: self.buffer.sft_warmup_dataset.algorithm_type = AlgorithmType.SFT self.buffer.read_batch_size = self.data.batch_size * self.explorer.repeat_times - def check_and_update(self) -> None: + def check_and_update(self) -> None: # noqa: C901 """Check and update the config.""" - if self.trainer.trainer_type == "verl": - if self.trainer.trainer_config: - from trinity.common.verl_config import veRLConfig - - trainer_config_schema = OmegaConf.structured(veRLConfig) - trainer_config = OmegaConf.merge(trainer_config_schema, self.trainer.trainer_config) - self.trainer.trainer_config = OmegaConf.to_object(trainer_config) - else: - if os.path.isfile(self.trainer.trainer_config_path): - from trinity.common.verl_config import load_config - - self.trainer.trainer_config = load_config(self.trainer.trainer_config_path) - else: - raise ValueError( - f"Invalid trainer config path: {self.trainer.trainer_config_path}" - ) - else: - raise ValueError(f"Invalid trainer type: {self.trainer_type}") - # check mode if self.mode not in ["explore", "train", "both"]: raise ValueError(f"Invalid mode: {self.mode}") + if self.trainer.algorithm_type == AlgorithmType.DPO and self.mode == "both": + raise ValueError("DPO does not support `both` mode") # check model path if not os.path.isabs(self.model.model_path): @@ -315,8 +315,8 @@ def check_and_update(self) -> None: self.explorer.engine_num * self.explorer.tensor_parallel_size ) self.synchronizer.backend = self.explorer.backend - if self.synchronizer.sync_method == "online" and self.mode != "both": - raise ValueError("Online synchronization is only supported in both mode") + if self.synchronizer.sync_method == SyncMethod.NCCL and self.mode != "both": + raise ValueError("`nccl` synchronization is only supported in both mode.") # check eval_interval if self.trainer.eval_interval % self.synchronizer.sync_iteration_interval != 0: @@ -332,6 +332,12 @@ def check_and_update(self) -> None: f"Warning: explorer.eval_interval is not equal to trainer.eval_interval; adjusted to the same value={self.trainer.eval_interval}." ) + # check save_interval + if self.synchronizer.sync_method == SyncMethod.CHECKPOINT: + self.trainer.save_interval = ( + self.synchronizer.sync_iteration_interval + ) # TODO: not proper for DPO + # check monitor if not self.monitor.cache_root_dir: # create a cache dir in /.cache @@ -352,6 +358,26 @@ def check_and_update(self) -> None: self._check_buffer() # check and update trainer if self.mode != "explore": + if self.trainer.trainer_type == "verl": + if self.trainer.trainer_config: + from trinity.common.verl_config import veRLConfig + + trainer_config_schema = OmegaConf.structured(veRLConfig) + trainer_config = OmegaConf.merge( + trainer_config_schema, self.trainer.trainer_config + ) + self.trainer.trainer_config = OmegaConf.to_object(trainer_config) + else: + if os.path.isfile(self.trainer.trainer_config_path): + from trinity.common.verl_config import load_config + + self.trainer.trainer_config = load_config(self.trainer.trainer_config_path) + else: + raise ValueError( + f"Invalid trainer config path: {self.trainer.trainer_config_path}" + ) + else: + raise ValueError(f"Invalid trainer type: {self.trainer_type}") self.trainer.trainer_config.synchronize_config(self) else: self.trainer.trainer_config = None diff --git a/trinity/common/constants.py b/trinity/common/constants.py index bb0d967e6b..38d9a9c162 100644 --- a/trinity/common/constants.py +++ b/trinity/common/constants.py @@ -66,6 +66,7 @@ class AlgorithmType(CaseInsensitiveEnum): PPO = "ppo" GRPO = "grpo" OPMD = "opmd" + PAIRWISE_OPMD = "pairwise_opmd" DPO = "dpo" def is_rft(self) -> bool: @@ -74,6 +75,7 @@ def is_rft(self) -> bool: AlgorithmType.PPO, AlgorithmType.GRPO, AlgorithmType.OPMD, + AlgorithmType.PAIRWISE_OPMD, ] def is_sft(self) -> bool: @@ -94,3 +96,25 @@ class MonitorType(CaseInsensitiveEnum): WANDB = "wandb" TENSORBOARD = "tensorboard" + + +class SyncMethodEnumMeta(CaseInsensitiveEnumMeta): + def __call__(cls, value, *args, **kwargs): + if value == "online": + print("SyncMethod `online` is deprecated, use `nccl` instead.") + value = "nccl" + elif value == "offline": + print("SyncMethod `offline` is deprecated, use `checkpoint` instead.") + value = "checkpoint" + try: + return super().__call__(value, *args, **kwargs) + except Exception as e: + print("Error parsing SyncMethod:", e) + raise ValueError(f"Invalid SyncMethod: {value}") + + +class SyncMethod(CaseInsensitiveEnum, metaclass=SyncMethodEnumMeta): + """Sync Method.""" + + NCCL = "nccl" + CHECKPOINT = "checkpoint" diff --git a/trinity/common/models/model.py b/trinity/common/models/model.py index dfabe64ec3..5cc4d41cca 100644 --- a/trinity/common/models/model.py +++ b/trinity/common/models/model.py @@ -63,7 +63,8 @@ def init_process_group( world_size: int, group_name: str, backend: str = "nccl", - offline_update: bool = True, + timeout: int = 1200, + update_with_checkpoint: bool = True, ) -> None: """Init the process group for model weights sync.""" diff --git a/trinity/common/models/vllm_async_model.py b/trinity/common/models/vllm_async_model.py index babd84ee5d..4ad86d86ae 100644 --- a/trinity/common/models/vllm_async_model.py +++ b/trinity/common/models/vllm_async_model.py @@ -47,8 +47,6 @@ def __init__( self.default_sampling_params = vllm.SamplingParams( n=config.explorer.repeat_times, temperature=config.explorer.temperature, - top_p=config.explorer.top_p, - top_k=config.explorer.top_k, max_tokens=config.model.max_response_tokens, min_tokens=1, truncate_prompt_tokens=config.model.max_prompt_tokens, @@ -260,7 +258,8 @@ def init_process_group( world_size: int, group_name: str, backend: str = "nccl", - offline_update: bool = True, + timeout: int = 1200, + update_with_checkpoint: bool = True, ): return self.async_llm.engine.model_executor.collective_rpc( "init_process_group", @@ -271,7 +270,8 @@ def init_process_group( world_size, group_name, backend, - offline_update, + timeout, + update_with_checkpoint, ), ) diff --git a/trinity/common/models/vllm_model.py b/trinity/common/models/vllm_model.py index ca205d7617..95aa9e805c 100644 --- a/trinity/common/models/vllm_model.py +++ b/trinity/common/models/vllm_model.py @@ -37,8 +37,6 @@ def __init__(self, config: Config, **kwargs): self.default_sampling_params = SamplingParams( n=config.explorer.repeat_times, temperature=config.explorer.temperature, - top_p=config.explorer.top_p, - top_k=config.explorer.top_k, max_tokens=config.model.max_response_tokens, min_tokens=1, truncate_prompt_tokens=config.model.max_prompt_tokens, @@ -89,7 +87,8 @@ def init_process_group( world_size: int, group_name: str, backend: str = "nccl", - offline_update: bool = True, + timeout: int = 1200, + update_with_checkpoint: bool = True, ): return self.llm.collective_rpc( "init_process_group", @@ -100,7 +99,8 @@ def init_process_group( world_size, group_name, backend, - offline_update, + timeout, + update_with_checkpoint, ), ) diff --git a/trinity/common/models/vllm_worker.py b/trinity/common/models/vllm_worker.py index f42d843c3f..52e0ade475 100644 --- a/trinity/common/models/vllm_worker.py +++ b/trinity/common/models/vllm_worker.py @@ -26,20 +26,21 @@ def init_process_group( world_size: int, group_name: str, backend: str = "nccl", - offline_update: bool = True, + timeout: int = 1200, + update_with_checkpoint: bool = True, ): """Init torch process group for model weights update""" assert torch.distributed.is_initialized(), "default torch process group must be initialized" assert group_name != "", "group name must not be empty" - self._offline_update = offline_update - if self._offline_update: + self._update_with_checkpoint = update_with_checkpoint + if self._update_with_checkpoint: logger.info( - f"init_process_group (offline): address={master_address}:{master_port}, rank={torch.distributed.get_rank()}, rank_offset={rank_offset}, world_size={world_size}" + f"init_process_group (checkpoint): address={master_address}:{master_port}, rank={torch.distributed.get_rank()}, rank_offset={rank_offset}, world_size={world_size}" ) self._weight_update_rank = torch.distributed.get_rank() + rank_offset else: logger.info( - f"init_process_group (online): rank={torch.distributed.get_rank()}, rank_offset={rank_offset}, world_size={world_size}" + f"init_process_group (nccl): rank={torch.distributed.get_rank()}, rank_offset={rank_offset}, world_size={world_size}" ) self._weight_update_rank = torch.distributed.get_rank() + rank_offset @@ -52,6 +53,7 @@ def init_process_group( self._model_update_group = init_process_group( backend=backend, init_method=init_method, + timeout=timeout, world_size=world_size, rank=self._weight_update_rank, group_name=group_name, diff --git a/trinity/common/verl_config.py b/trinity/common/verl_config.py index f7598bbea0..c46fefb849 100644 --- a/trinity/common/verl_config.py +++ b/trinity/common/verl_config.py @@ -4,6 +4,7 @@ from omegaconf import OmegaConf from trinity.common.config import BufferConfig, Config, SynchronizerConfig +from trinity.common.constants import AlgorithmType @dataclass @@ -86,7 +87,7 @@ class Actor: checkpoint: Checkpoint = field(default_factory=Checkpoint) optim: Optim = field(default_factory=Optim) fsdp_config: FSDPConfig = field(default_factory=FSDPConfig) - alg_type: str = "ppo" # ppo / opmd / pairwise_opmd + algorithm_type: AlgorithmType = AlgorithmType.PPO tau: float = 0.001 # strength of regularization w.r.t. old / ref policy opmd_baseline: str = "mean" # mean / logavgexp, applicable to opmd use_uid: bool = False # True / False, applicable to pairwise_opmd @@ -243,7 +244,6 @@ class Trainer: val_before_train: bool = False training_rollout_mode: str = "parallel" enable_exp_buffer: bool = True - steps_per_epoch: int = 1280 sync_freq: int = 0 sft_warmup_iteration: int = 0 max_actor_ckpt_to_keep: Optional[int] = None @@ -279,10 +279,7 @@ def synchronize_config(self, config: Config) -> None: # for multi-node scenarios, some nodes for rollout, others for training self.trainer.n_gpus_per_node = config.cluster.gpu_per_node self.trainer.sync_freq = config.synchronizer.sync_iteration_interval - if config.synchronizer.sync_method == "offline": - self.trainer.save_freq = ( - config.synchronizer.sync_iteration_interval - ) # TODO: not proper for DPO + self.trainer.save_freq = config.trainer.save_interval self.synchronizer = config.synchronizer self.actor_rollout_ref.synchronizer = config.synchronizer self.buffer = config.buffer @@ -296,37 +293,19 @@ def synchronize_config(self, config: Config) -> None: self.buffer.pad_token_id = config.buffer.pad_token_id self.trainer.project_name = config.monitor.project self.trainer.experiment_name = config.monitor.name - self.data.train_batch_size = self.buffer.read_batch_size + self.data.train_batch_size = config.data.batch_size self.trainer.default_local_dir = config.model.checkpoint_path self.trainer.sft_warmup_iteration = config.trainer.sft_warmup_iteration self.actor_rollout_ref.actor.ppo_mini_batch_size = config.data.batch_size self.actor_rollout_ref.rollout.temperature = config.explorer.temperature self.actor_rollout_ref.rollout.n = config.explorer.repeat_times - batch_size_per_gpu = self.buffer.read_batch_size // world_size - self.actor_rollout_ref.actor.alg_type = ( - config.trainer.algorithm_type.value - ) # TODO: refactor `alg_type` - # print(f"using algorithm type: {self.actor_rollout_ref.actor.alg_type}") + self.actor_rollout_ref.actor.algorithm_type = config.trainer.algorithm_type - if self.actor_rollout_ref.actor.alg_type == "dpo": # for DPO + if self.actor_rollout_ref.actor.algorithm_type.is_dpo(): # for DPO print("Warning: DPO micro batch size is doubled for computing loss.") + self.actor_rollout_ref.actor.ppo_mini_batch_size *= 2 self.actor_rollout_ref.actor.ppo_micro_batch_size_per_gpu *= 2 # type: ignore self.actor_rollout_ref.ref.log_prob_micro_batch_size_per_gpu *= 2 # type: ignore - if batch_size_per_gpu % self.actor_rollout_ref.actor.ppo_micro_batch_size_per_gpu != 0: # type: ignore - raise ValueError( - f"batch_size_per_gpu ({batch_size_per_gpu}) must be divisible by " - f"actor.ppo_micro_batch_size_per_gpu ({self.actor_rollout_ref.actor.ppo_micro_batch_size_per_gpu})" - ) - if batch_size_per_gpu % self.actor_rollout_ref.ref.log_prob_micro_batch_size_per_gpu != 0: # type: ignore - raise ValueError( - f"batch_size_per_gpu ({batch_size_per_gpu}) must be divisible by " - f"ref.log_prob_micro_batch_size_per_gpu ({self.actor_rollout_ref.ref.log_prob_micro_batch_size_per_gpu})" - ) - if batch_size_per_gpu % self.critic.ppo_micro_batch_size_per_gpu != 0: # type: ignore - raise ValueError( - f"batch_size_per_gpu ({batch_size_per_gpu}) must be divisible by " - f"critic.ppo_micro_batch_size_per_gpu ({self.critic.ppo_micro_batch_size_per_gpu})" - ) # TODO: check other fields self.enable_preview = config.trainer.enable_preview diff --git a/trinity/explorer/explorer.py b/trinity/explorer/explorer.py index fc3a34ca55..a52dc2b14b 100644 --- a/trinity/explorer/explorer.py +++ b/trinity/explorer/explorer.py @@ -10,7 +10,11 @@ from trinity.buffer import get_buffer_writer from trinity.common.config import Config -from trinity.common.constants import ROLLOUT_WEIGHT_SYNC_GROUP_NAME, TaskType +from trinity.common.constants import ( + ROLLOUT_WEIGHT_SYNC_GROUP_NAME, + SyncMethod, + TaskType, +) from trinity.common.models import create_rollout_models from trinity.common.models.utils import ( get_checkpoint_dir_with_iteration, @@ -59,15 +63,17 @@ def __init__(self, config: Config): self.update_interval = ( self.config.synchronizer.sync_iteration_interval * self.config.data.batch_size ) - self.use_offline_weights_update = self.config.synchronizer.sync_method == "offline" + self.use_checkpoint_weights_update = ( + self.config.synchronizer.sync_method == SyncMethod.CHECKPOINT + ) - # For offline weights update + # For checkpoint weights update # Use explorer to periodically load the latest model weights and # boradcast to all rollout models - if self.use_offline_weights_update: + if self.use_checkpoint_weights_update: self.old_checkpoint = None self.state_dict = {} - else: # online mode + else: # nccl mode self.state_dict_meta = [] self.logger.info("Finished initializing Explorer.") @@ -75,8 +81,8 @@ def __init__(self, config: Config): def setup_weight_sync_group( self, master_address: str, master_port: int, state_dict_meta: List = None ): - # In offline mode, we use explorer to store the model weights which has no rank - base_offset = 0 if self.use_offline_weights_update else 1 + # In checkpoint mode, we use explorer to store the model weights which has no rank + base_offset = 0 if self.use_checkpoint_weights_update else 1 world_size = len(self.models) * self.config.explorer.tensor_parallel_size + base_offset self.logger.info( f"Initialize process group for weight synchronization, " @@ -92,7 +98,8 @@ def setup_weight_sync_group( world_size=world_size, group_name=ROLLOUT_WEIGHT_SYNC_GROUP_NAME, backend=self.config.explorer.backend, - offline_update=self.use_offline_weights_update, + timeout=self.config.synchronizer.sync_timeout, + update_with_checkpoint=self.use_checkpoint_weights_update, ) for i, model in enumerate(self.models) ] @@ -122,7 +129,7 @@ def _update_model_weight(self, state_dict: dict) -> None: ray.get([model.sync_model.remote(update_weight_args_list) for model in self.models]) self.state_dict.clear() - def _offline_weights_update(self, iteration_num: Optional[int] = None) -> None: + def _checkpoint_weights_update(self, iteration_num: Optional[int] = None) -> None: # TODO: support more checkpoint types try: checkpoint_dir = get_checkpoint_dir_with_iteration( @@ -138,18 +145,18 @@ def _offline_weights_update(self, iteration_num: Optional[int] = None) -> None: except Exception as e: self.logger.error(f"Error when loading state_dict: {e}") - def _online_weights_update(self): + def _nccl_weights_update(self): ray.get([model.sync_model.remote(self.state_dict_meta) for model in self.models]) def prepare(self) -> None: """Preparation before running.""" - if self.use_offline_weights_update: + if self.use_checkpoint_weights_update: master_address, master_port = ray.get(self.models[0].get_address.remote()) self.setup_weight_sync_group(master_address, master_port) @ray.method(concurrency_group="get_weight") def get_weight(self, name: str) -> torch.Tensor: - """Get the weight of the loaded model (For offline weights update).""" + """Get the weight of the loaded model (For checkpoint weights update).""" return self.state_dict[name] def explore(self) -> None: @@ -257,10 +264,10 @@ def eval(self) -> bool: def sync_weight(self) -> None: """Synchronize model weights.""" # call this method before training start to load the latest model weights - if self.use_offline_weights_update: - self._offline_weights_update() - else: # online weights update - self._online_weights_update() + if self.use_checkpoint_weights_update: + self._checkpoint_weights_update() + else: # nccl weights update + self._nccl_weights_update() def flush_log(self, step: int) -> None: """Flush the log of the current step.""" diff --git a/trinity/manager/config_manager.py b/trinity/manager/config_manager.py index 13ade4b011..e1652c60bf 100644 --- a/trinity/manager/config_manager.py +++ b/trinity/manager/config_manager.py @@ -5,7 +5,13 @@ import streamlit as st import yaml -from trinity.common.constants import AlgorithmType, MonitorType, StorageType +from trinity.common.constants import ( + AlgorithmType, + MonitorType, + PromptType, + StorageType, + SyncMethod, +) from trinity.common.rewards import REWARD_FUNCTIONS from trinity.common.workflows.workflow import WORKFLOWS from trinity.trainer.verl.ray_trainer import AdvantageEstimator @@ -48,9 +54,10 @@ def _init_default_config(self): "trainer_gpu_num": 6, "max_prompt_tokens": 1024, "max_response_tokens": 1024, - # Data and Buffer Configs + # Data Configs "total_epochs": 20, - "task_num_per_batch": 6, + "_train_batch_size_per_gpu": 16, + "train_batch_size": 96, "dataset_path": "", "subset_name": None, "train_split": "train", @@ -59,22 +66,35 @@ def _init_default_config(self): "response_key": "answer", "default_workflow_type": "math_workflow", "default_reward_fn_type": "math_reward", + # Buffer Configs + "_is_dpo_storage_type": StorageType.FILE.value, + "_not_dpo_storage_type": StorageType.QUEUE.value, "storage_type": StorageType.QUEUE.value, - "db_url": "", + "train_dataset_path": "", "max_retry_times": 3, "max_retry_interval": 1, + "dpo_dataset_train_split": "train", + "dpo_dataset_prompt_type": PromptType.MESSAGES.value, + "dpo_dataset_prompt_key": "prompt", + "dpo_dataset_chosen_key": "chosen", + "dpo_dataset_rejected_key": "rejected", "sft_warmup_dataset_path": "", "sft_warmup_train_split": "train", - "sft_warmup_eval_split": "", - "sft_warmup_prompt_key": "question", - "sft_warmup_response_key": "answer", + "sft_warmup_prompt_type": PromptType.MESSAGES.value, + "sft_warmup_messages_key": "messages", + "sft_warmup_prompt_key": "prompt", + "sft_warmup_response_key": "response", # Explorer and Sync Configs "engine_type": "vllm_async", "engine_num": 2, "tensor_parallel_size": 1, + "_grouped_adv_repeat_times": 2, + "_not_grouped_adv_repeat_times": 1, "repeat_times": 1, - "sync_method": "online", + "_not_dpo_sync_method": SyncMethod.NCCL.value, + "sync_method": SyncMethod.NCCL.value, "sync_iteration_interval": 10, + "sync_timeout": 1200, "runner_num": 32, "max_pending_requests": 32, "max_waiting_steps": 4, @@ -92,6 +112,8 @@ def _init_default_config(self): "algorithm_type": AlgorithmType.PPO.value, "sft_warmup_iteration": 0, "eval_interval": 1000, + "_nccl_save_interval": 100, + "save_interval": 100, # veRL Trainer Configs "training_args": [ "balance_batch", @@ -99,7 +121,6 @@ def _init_default_config(self): "remove_padding", "dynamic_bsz", ], - "save_freq": 100, "training_strategy": "fsdp", "param_offload": False, "optimizer_offload": False, @@ -145,7 +166,7 @@ def _init_default_config(self): "critic_cliprange_value": 0.5, "critic_ppo_micro_batch_size_per_gpu": 8, "critic_ulysses_sequence_parallel_size": 1, - "training_mode": "PPO", + "critic_checkpoint": ["model", "optimizer", "extra"], } def reset_session_state(self): @@ -224,29 +245,37 @@ def _set_total_epochs(self): st.number_input("Total Epochs", key="total_epochs", min_value=1) @property - def _str_for_task_num_per_batch(self): + def _str_for_train_batch_size(self): return ( - f"Please ensure that `task_num_per_batch` can be divided by " + f"Please ensure that `train_batch_size` can be divided by " f"`gpu_per_node * node_num - engine_num * tensor_parallel_size` " f"= {st.session_state['trainer_gpu_num']}" ) - def _set_task_num_per_batch(self): + def _set_train_batch_size(self): trainer_gpu_num = st.session_state["trainer_gpu_num"] - if st.session_state["task_num_per_batch"] < trainer_gpu_num: - st.session_state["task_num_per_batch"] = trainer_gpu_num + st.session_state["train_batch_size"] = ( + st.session_state["_train_batch_size_per_gpu"] * st.session_state["trainer_gpu_num"] + ) + + def on_change(): + st.session_state["_train_batch_size_per_gpu"] = max( + st.session_state["train_batch_size"] // st.session_state["trainer_gpu_num"], 1 + ) + st.number_input( - "Task Num Per Batch", - key="task_num_per_batch", + "Train Batch Size", + key="train_batch_size", min_value=trainer_gpu_num, step=trainer_gpu_num, - help=self._str_for_task_num_per_batch, + help=self._str_for_train_batch_size, + on_change=on_change, ) - def _check_task_num_per_batch(self): - if st.session_state["task_num_per_batch"] % st.session_state["trainer_gpu_num"] != 0: - self.unfinished_fields.add("task_num_per_batch") - st.warning(self._str_for_task_num_per_batch) + def _check_train_batch_size(self): + if st.session_state["train_batch_size"] % st.session_state["trainer_gpu_num"] != 0: + self.unfinished_fields.add("train_batch_size") + st.warning(self._str_for_train_batch_size) def _set_dataset_path(self): st.text_input("Dataset Path", key="dataset_path") @@ -291,18 +320,39 @@ def _set_default_reward_fn_type(self): ) def _set_storage_type(self): + if st.session_state["algorithm_type"] == AlgorithmType.DPO.value: + st.session_state["storage_type"] = st.session_state["_is_dpo_storage_type"] + storage_candidates = [StorageType.FILE.value, StorageType.SQL.value] + else: + st.session_state["storage_type"] = st.session_state["_not_dpo_storage_type"] + storage_candidates = [StorageType.QUEUE.value, StorageType.SQL.value] + + def on_change(): + if st.session_state["algorithm_type"] == AlgorithmType.DPO.value: + st.session_state["_is_dpo_storage_type"] = st.session_state["storage_type"] + else: + st.session_state["_not_dpo_storage_type"] = st.session_state["storage_type"] + st.selectbox( "Storage Type", - [storage_type.value for storage_type in StorageType], + storage_candidates, key="storage_type", + on_change=on_change, ) - def _set_db_url(self): + def _set_train_dataset_path(self): # TODO st.text_input( - "DB URL", - key="db_url", - help=r"Default to `sqlite:///{os.path.join(checkpoint_path, '.cache', project_name, experiment_name)}/data.db`", + "Train Dataset Path", + key="train_dataset_path", + help=r"This path is used for `trainer`, " + r"if `storage_type == StorageType.QUEUE`, default to `None`, " + r"if `storage_type == StorageType.FILE`, this should be a path to a file, " + r"if `storage_type == StorageType.SQL`, default to `sqlite:///{os.path.join(checkpoint_path, '.cache', project_name, experiment_name)}/data.db`.", ) + if st.session_state["storage_type"] == StorageType.FILE.value: + if not st.session_state["train_dataset_path"].strip(): + self.unfinished_fields.add("train_dataset_path") + st.warning("Please input train dataset path.") def _set_max_retry_times(self): st.number_input("Max Retry Times", key="max_retry_times", min_value=1) @@ -310,6 +360,32 @@ def _set_max_retry_times(self): def _set_max_retry_interval(self): st.number_input("Max Retry Interval", key="max_retry_interval", min_value=1) + def _set_dpo_dataset_kwargs(self): + dpo_dataset_train_split_col, dpo_dataset_prompt_type_col = st.columns(2) + dpo_dataset_train_split_col.text_input( + "DPO Dataset Train Split", key="dpo_dataset_train_split" + ) + dpo_dataset_prompt_type_col.selectbox( + "DPO Dataset Prompt Type", + [prompt_type.value for prompt_type in PromptType], + key="dpo_dataset_prompt_type", + ) + + ( + dpo_dataset_prompt_key_col, + dpo_dataset_chosen_key_col, + dpo_dataset_rejected_key_col, + ) = st.columns(3) + dpo_dataset_prompt_key_col.text_input( + "DPO Dataset Prompt Key", key="dpo_dataset_prompt_key" + ) + dpo_dataset_chosen_key_col.text_input( + "DPO Dataset Chosen Key", key="dpo_dataset_chosen_key" + ) + dpo_dataset_rejected_key_col.text_input( + "DPO Dataset Rejected Key", key="dpo_dataset_rejected_key" + ) + def _check_sft_warmup_dataset_path(self): if st.session_state["sft_warmup_iteration"]: if not st.session_state["sft_warmup_dataset_path"].strip(): @@ -329,12 +405,22 @@ def _set_sft_warmup_dataset_args(self): ): # TODO ( sft_warmup_train_split_col, - sft_warmup_eval_split_col, + sft_warmup_prompt_type_col, + ) = st.columns(2) + sft_warmup_train_split_col.text_input("SFT Train Split", key="sft_warmup_train_split") + sft_warmup_prompt_type_col.selectbox( + "SFT Prompt Type", + [prompt_type.value for prompt_type in PromptType], + key="sft_warmup_prompt_type", + ) + ( + sft_warmup_messages_key_col, sft_warmup_prompt_key_col, sft_warmup_response_key_col, - ) = st.columns(4) - sft_warmup_train_split_col.text_input("SFT Train Split", key="sft_warmup_train_split") - sft_warmup_eval_split_col.text_input("SFT Eval Split", key="sft_warmup_eval_split") + ) = st.columns(3) + sft_warmup_messages_key_col.text_input( + "SFT Messages Key", key="sft_warmup_messages_key" + ) sft_warmup_prompt_key_col.text_input("SFT Prompt Key", key="sft_warmup_prompt_key") sft_warmup_response_key_col.text_input( "SFT Response Key", key="sft_warmup_response_key" @@ -402,34 +488,54 @@ def _check_engine_num_and_tp_size(self): "Please ensure that `engine_num * tensor_parallel_size` can be divided by `gpu_per_node` when `node_num > 1`." ) - def _set_repeat_times(self): - if st.session_state["algorithm_type"] == AlgorithmType.OPMD.value or st.session_state[ - "adv_estimator" - ] in [ - AdvantageEstimator.GRPO.value, - AdvantageEstimator.RLOO.value, - ]: + def _set_repeat_times(self): # TODO + grouped_adv_algorithms = [ + AlgorithmType.GRPO.value, + AlgorithmType.OPMD.value, # TODO: may add rloo + ] + if st.session_state["algorithm_type"] in grouped_adv_algorithms: min_repeat_times = 2 + st.session_state["repeat_times"] = st.session_state["_grouped_adv_repeat_times"] else: min_repeat_times = 1 - if st.session_state["repeat_times"] < min_repeat_times: - st.session_state["repeat_times"] = min_repeat_times + st.session_state["repeat_times"] = st.session_state["_not_grouped_adv_repeat_times"] + + def on_change(): + if st.session_state["algorithm_type"] in grouped_adv_algorithms: + st.session_state["_grouped_adv_repeat_times"] = st.session_state["repeat_times"] + else: + st.session_state["_not_grouped_adv_repeat_times"] = st.session_state["repeat_times"] + st.number_input( "Repeat Times", key="repeat_times", min_value=min_repeat_times, help="`repeat_times` is used to set how many experiences each task can generate, " "and it must be greater than `1` when `algorithm_type` is `opmd` or `grpo`.", + on_change=on_change, ) def _set_sync_method(self): + if st.session_state["algorithm_type"] == AlgorithmType.DPO.value: + st.session_state["sync_method"] = SyncMethod.CHECKPOINT.value + disabled = True + else: + st.session_state["sync_method"] = st.session_state["_not_dpo_sync_method"] + disabled = False + + def on_change(): + if st.session_state["algorithm_type"] != AlgorithmType.DPO.value: + st.session_state["_not_dpo_sync_method"] = st.session_state["sync_method"] + st.selectbox( "Sync Method", - ["online", "offline"], + [sync_method.value for sync_method in SyncMethod], key="sync_method", - help="""`online`: the explorer and trainer sync model weights once every `sync_iteration_interval` steps. + help="""`nccl`: the explorer and trainer sync model weights once every `sync_iteration_interval` steps. -`offline`: the trainer saves the model checkpoint, and the explorer loads it at `sync_iteration_interval`.""", +`checkpoint`: the trainer saves the model checkpoint, and the explorer loads it at `sync_iteration_interval`.""", + disabled=disabled, + on_change=on_change, ) def _set_sync_iteration_interval(self): @@ -440,6 +546,14 @@ def _set_sync_iteration_interval(self): help="""The iteration interval at which the `explorer` and `trainer` synchronize model weight.""", ) + def _set_sync_timeout(self): + st.number_input( + "Sync Timeout", + key="sync_timeout", + min_value=1, + help="The timeout value for the synchronization operation.", + ) + def _set_runner_num(self): st.number_input("Runner Num", key="runner_num", min_value=1) @@ -488,8 +602,14 @@ def _set_trainer_type(self): def _set_algorithm_type(self): st.selectbox( "Algorithm Type", - [AlgorithmType.PPO.value, AlgorithmType.DPO.value, AlgorithmType.OPMD.value], + [ + AlgorithmType.PPO.value, + AlgorithmType.GRPO.value, + AlgorithmType.DPO.value, + AlgorithmType.OPMD.value, + ], key="algorithm_type", + on_change=self._set_adv_estimator, ) def _set_sft_warmup_iteration(self): @@ -510,18 +630,25 @@ def _set_training_args(self): key="training_args", ) - def _set_save_freq(self): - if st.session_state["sync_method"] == "online": - freeze_save_freq = False + def _set_save_interval(self): + if st.session_state["sync_method"] == SyncMethod.NCCL.value: + st.session_state["save_interval"] = st.session_state["_nccl_save_interval"] + freeze_save_interval = False else: - st.session_state["save_freq"] = st.session_state["sync_iteration_interval"] - freeze_save_freq = True + st.session_state["save_interval"] = st.session_state["sync_iteration_interval"] + freeze_save_interval = True + + def on_change(): + if st.session_state["sync_method"] == SyncMethod.NCCL.value: + st.session_state["_nccl_save_interval"] = st.session_state["save_interval"] + st.number_input( - "Save Freq", - key="save_freq", + "Save Interval", + key="save_interval", min_value=1, - help="Set to `sync_iteration_interval` when `sync_method` is `offline`", - disabled=freeze_save_freq, + help="Set to `sync_iteration_interval` when `sync_method` is `checkpoint`", + disabled=freeze_save_interval, + on_change=on_change, ) def _set_training_strategy(self): @@ -579,11 +706,16 @@ def _set_lam(self): st.number_input("Lambda", key="lam") def _set_adv_estimator(self): - st.selectbox( - "Advantage Estimator", - [member.value for member in AdvantageEstimator], - key="adv_estimator", - ) + if st.session_state["algorithm_type"] == AlgorithmType.PPO.value: + st.session_state["adv_estimator"] = AdvantageEstimator.GAE.value + elif st.session_state["algorithm_type"] == AlgorithmType.GRPO.value: + st.session_state["adv_estimator"] = AdvantageEstimator.GRPO.value + elif st.session_state["algorithm_type"] == AlgorithmType.DPO.value: + st.session_state["adv_estimator"] = AdvantageEstimator.GRPO.value + elif st.session_state["algorithm_type"] == AlgorithmType.OPMD.value: + st.session_state["adv_estimator"] = AdvantageEstimator.GRPO.value + else: # TODO: add more algorithms + pass def _set_norm_adv_by_std_in_grpo(self): st.checkbox("Norm Adv by Std in GRPO", key="norm_adv_by_std_in_grpo") @@ -607,17 +739,27 @@ def _set_target_kl(self): st.number_input("Target KL", key="target_kl", format="%.1e") def _set_actor_ppo_micro_batch_size_per_gpu(self): + st.session_state["actor_ppo_micro_batch_size_per_gpu"] = min( + st.session_state["actor_ppo_micro_batch_size_per_gpu"], + st.session_state["_train_batch_size_per_gpu"], + ) st.number_input( "Micro Batch Size Per GPU for Actor", key="actor_ppo_micro_batch_size_per_gpu", min_value=1, + max_value=st.session_state["_train_batch_size_per_gpu"], ) def _set_ref_log_prob_micro_batch_size_per_gpu(self): + st.session_state["ref_log_prob_micro_batch_size_per_gpu"] = min( + st.session_state["ref_log_prob_micro_batch_size_per_gpu"], + st.session_state["_train_batch_size_per_gpu"], + ) st.number_input( "Micro Batch Size Per GPU for Ref", key="ref_log_prob_micro_batch_size_per_gpu", min_value=1, + max_value=st.session_state["_train_batch_size_per_gpu"], ) def _set_actor_ulysses_sequence_parallel_size(self): @@ -712,10 +854,15 @@ def _set_actor_checkpoint(self): ) def _set_critic_ppo_micro_batch_size_per_gpu(self): + st.session_state["critic_ppo_micro_batch_size_per_gpu"] = min( + st.session_state["critic_ppo_micro_batch_size_per_gpu"], + st.session_state["_train_batch_size_per_gpu"], + ) st.number_input( "Micro Batch Size Per GPU for Critic", key="critic_ppo_micro_batch_size_per_gpu", min_value=1, + max_value=st.session_state["_train_batch_size_per_gpu"], ) def _set_critic_ulysses_sequence_parallel_size(self): @@ -766,21 +913,12 @@ def _set_critic_cliprange_value(self): max_value=1.0, ) - def _set_training_mode(self): - st.selectbox("Training Mode", ["PPO", "GRPO", "DPO", "OPMD"], key="training_mode") - - if st.session_state["training_mode"] == "PPO": - st.session_state["algorithm_type"] = AlgorithmType.PPO.value - st.session_state["adv_estimator"] = "gae" - elif st.session_state["training_mode"] == "GRPO": - st.session_state["algorithm_type"] = AlgorithmType.PPO.value - st.session_state["adv_estimator"] = "grpo" - elif st.session_state["training_mode"] == "DPO": - st.session_state["algorithm_type"] = AlgorithmType.DPO.value - st.session_state["adv_estimator"] = "grpo" - elif st.session_state["training_mode"] == "OPMD": - st.session_state["algorithm_type"] = AlgorithmType.OPMD.value - st.session_state["adv_estimator"] = "grpo" + def _set_critic_checkpoint(self): + st.multiselect( + "Checkpoint", + ["model", "hf_model", "optimizer", "extra"], + key="critic_checkpoint", + ) def _set_configs_with_st_columns( self, config_names: List[str], columns_config: List[int] = None @@ -802,7 +940,9 @@ def beginner_mode(self): self._set_dataset_path() - self._set_configs_with_st_columns(["training_mode", "sft_warmup_iteration", "monitor_type"]) + self._set_configs_with_st_columns( + ["algorithm_type", "sft_warmup_iteration", "monitor_type"] + ) if st.session_state["sft_warmup_iteration"] > 0: self._set_sft_warmup_dataset_path() @@ -813,11 +953,14 @@ def beginner_mode(self): self._check_engine_num_and_tp_size() self._set_configs_with_st_columns( - ["total_epochs", "task_num_per_batch", "max_prompt_tokens", "max_response_tokens"] + ["total_epochs", "train_batch_size", "max_prompt_tokens", "max_response_tokens"] ) - self._check_task_num_per_batch() + self._check_train_batch_size() - self._set_dataset_args() + if st.session_state["algorithm_type"] != AlgorithmType.DPO.value: + self._set_dataset_args() + else: + self._set_dpo_dataset_kwargs() if st.session_state["sft_warmup_iteration"] > 0: self._set_sft_warmup_dataset_args() @@ -826,7 +969,9 @@ def beginner_mode(self): ["default_workflow_type", "default_reward_fn_type", "repeat_times"] ) - self._set_configs_with_st_columns(["sync_iteration_interval", "eval_interval", "save_freq"]) + self._set_configs_with_st_columns( + ["sync_iteration_interval", "eval_interval", "save_interval"] + ) self._set_actor_use_kl_loss() if st.session_state["actor_use_kl_loss"]: @@ -840,7 +985,9 @@ def beginner_mode(self): ] ) - use_critic = st.session_state["adv_estimator"] == "gae" # TODO: may apply to expert mode + use_critic = ( + st.session_state["adv_estimator"] == AdvantageEstimator.GAE.value + ) # TODO: may apply to expert mode if use_critic: self._set_configs_with_st_columns(["critic_ppo_micro_batch_size_per_gpu", "critic_lr"]) @@ -856,12 +1003,15 @@ def _expert_model_part(self): self._set_configs_with_st_columns(["max_prompt_tokens", "max_response_tokens"]) def _expert_buffer_part(self): - self._set_configs_with_st_columns(["total_epochs", "task_num_per_batch"]) - self._check_task_num_per_batch() + self._set_configs_with_st_columns(["total_epochs", "train_batch_size"]) + self._check_train_batch_size() self._set_dataset_path() - self._set_dataset_args() + if st.session_state["algorithm_type"] != AlgorithmType.DPO.value: + self._set_dataset_args() + else: + self._set_dpo_dataset_kwargs() self._set_configs_with_st_columns( ["default_workflow_type", "default_reward_fn_type", "storage_type"] @@ -869,8 +1019,6 @@ def _expert_buffer_part(self): self.buffer_advanced_tab = st.expander("Advanced Config") with self.buffer_advanced_tab: - self._set_db_url() - self._set_configs_with_st_columns(["max_retry_times", "max_retry_interval"]) self._set_sft_warmup_dataset_path() @@ -882,22 +1030,22 @@ def _expert_connector_part(self): ) self._check_engine_num_and_tp_size() - self._set_configs_with_st_columns(["sync_method", "sync_iteration_interval"]) + self._set_configs_with_st_columns( + ["sync_method", "sync_iteration_interval", "sync_timeout"] + ) with st.expander("Advanced Config"): self._set_configs_with_st_columns( ["runner_num", "max_pending_requests", "max_waiting_steps", "dtype"] ) - self._set_configs_with_st_columns( - ["backend", "temperature", "top_p", "top_k", "seed", "logprobs"] - ) + self._set_configs_with_st_columns(["backend", "temperature", "seed", "logprobs"]) self._set_configs_with_st_columns(["enable_prefix_caching", "enforce_eager"]) def _expert_trainer_part(self): - self._set_configs_with_st_columns( - ["trainer_type", "algorithm_type", "sft_warmup_iteration", "eval_interval"] + self._set_configs_with_st_columns( # TODO: may add `trainer_type` + ["algorithm_type", "sft_warmup_iteration", "eval_interval", "save_interval"] ) self._check_sft_warmup_dataset_path() @@ -917,7 +1065,7 @@ def _expert_verl_trainer_part(self): st.subheader("RL Training Config") self._set_training_args() - self._set_configs_with_st_columns(["save_freq", "training_strategy", "resume_mode"]) + self._set_configs_with_st_columns(["training_strategy", "resume_mode"]) if st.session_state["training_strategy"] == "fsdp": self._set_configs_with_st_columns(["param_offload", "optimizer_offload"]) @@ -938,7 +1086,7 @@ def _expert_verl_trainer_part(self): with rl_algorithm_tab: st.subheader("RL Algorithm Config") - self._set_configs_with_st_columns(["gamma", "lam", "adv_estimator"]) + self._set_configs_with_st_columns(["gamma", "lam"]) self._set_configs_with_st_columns(["norm_adv_by_std_in_grpo", "use_kl_in_reward"]) self._set_configs_with_st_columns(["kl_penalty", "kl_ctrl_type", "kl_ctrl_coef"]) self._set_configs_with_st_columns(["horizon", "target_kl"]) @@ -983,6 +1131,7 @@ def _expert_verl_trainer_part(self): ) self._set_configs_with_st_columns(["critic_grad_clip", "critic_cliprange_value"]) + self._set_critic_checkpoint() def expert_mode(self): model_tab, buffer_tab, connector_tab, trainer_tab = st.tabs( @@ -1036,7 +1185,7 @@ def _generate_verl_config(self, trainer_nnodes: int = 1, trainer_n_gpus_per_node "prompt_key": "placeholder", "max_prompt_length": st.session_state["max_prompt_tokens"], "max_response_length": st.session_state["max_response_tokens"], - "train_batch_size": st.session_state["task_num_per_batch"] + "train_batch_size": st.session_state["train_batch_size"] * st.session_state["repeat_times"], "val_batch_size": None, "return_raw_input_ids": False, @@ -1057,7 +1206,7 @@ def _generate_verl_config(self, trainer_nnodes: int = 1, trainer_n_gpus_per_node }, "actor": { "strategy": st.session_state["training_strategy"], - "ppo_mini_batch_size": st.session_state["task_num_per_batch"], + "ppo_mini_batch_size": st.session_state["train_batch_size"], "ppo_micro_batch_size_per_gpu": st.session_state[ "actor_ppo_micro_batch_size_per_gpu" ], @@ -1084,7 +1233,6 @@ def _generate_verl_config(self, trainer_nnodes: int = 1, trainer_n_gpus_per_node else st.session_state["total_training_steps"], }, "fsdp_config": copy.deepcopy(fsdp_config), - "alg_type": st.session_state["algorithm_type"], "tau": st.session_state["actor_tau"], "opmd_baseline": st.session_state["actor_opmd_baseline"], "use_uid": st.session_state["actor_use_uid"], @@ -1146,7 +1294,7 @@ def _generate_verl_config(self, trainer_nnodes: int = 1, trainer_n_gpus_per_node "use_remove_padding": use_remove_padding, "fsdp_config": copy.deepcopy(fsdp_config), }, - "ppo_mini_batch_size": st.session_state["task_num_per_batch"], + "ppo_mini_batch_size": st.session_state["train_batch_size"], "ppo_micro_batch_size_per_gpu": st.session_state[ "critic_ppo_micro_batch_size_per_gpu" ], @@ -1163,6 +1311,7 @@ def _generate_verl_config(self, trainer_nnodes: int = 1, trainer_n_gpus_per_node "shuffle": False, "grad_clip": st.session_state["critic_grad_clip"], "cliprange_value": st.session_state["critic_cliprange_value"], + "checkpoint": {"contents": st.session_state["critic_checkpoint"]}, }, "reward_model": { "enable": False, @@ -1203,7 +1352,7 @@ def _generate_verl_config(self, trainer_nnodes: int = 1, trainer_n_gpus_per_node "val_generations_to_log_to_wandb": 0, "nnodes": trainer_nnodes, "n_gpus_per_node": trainer_n_gpus_per_node, - "save_freq": st.session_state["save_freq"], + "save_freq": st.session_state["save_interval"], "resume_mode": st.session_state["resume_mode"], "resume_from_path": st.session_state["resume_from_path"], "test_freq": 100, @@ -1235,11 +1384,17 @@ def generate_config(self): else: trainer_n_gpus_per_node = st.session_state["gpu_per_node"] - db_url = ( - st.session_state["db_url"] - if st.session_state["db_url"].strip() - else f"sqlite:///{os.path.join(st.session_state['checkpoint_path'], '.cache', st.session_state['project'], st.session_state['exp_name'])}/data.db" - ) + if st.session_state["algorithm_type"] == AlgorithmType.DPO.value: + train_dataset_path = ( + st.session_state["train_dataset_path"].strip() + if st.session_state["train_dataset_path"].strip() + else st.session_state["dataset_path"].strip() + ) + else: # not dpo algorithms + train_dataset_path = st.session_state["train_dataset_path"].strip() + if not train_dataset_path and st.session_state["storage_type"] == StorageType.SQL.value: + train_dataset_path = f"sqlite:///{os.path.join(st.session_state['checkpoint_path'], '.cache', st.session_state['project'], st.session_state['exp_name'])}/data.db" + sft_storage_type = ( StorageType.SQL.value if "://" in st.session_state["sft_warmup_dataset_path"] @@ -1268,7 +1423,7 @@ def generate_config(self): config = { "data": { "total_epochs": st.session_state["total_epochs"], - "batch_size": st.session_state["task_num_per_batch"], + "batch_size": st.session_state["train_batch_size"], "dataset_path": st.session_state["dataset_path"], "default_workflow_type": st.session_state["default_workflow_type"], "default_reward_fn_type": st.session_state["default_reward_fn_type"], @@ -1290,22 +1445,26 @@ def generate_config(self): "gpu_per_node": st.session_state["gpu_per_node"], }, "buffer": { - "db_url": db_url, - "read_batch_size": st.session_state["task_num_per_batch"] - * st.session_state["repeat_times"], "max_retry_times": st.session_state["max_retry_times"], "max_retry_interval": st.session_state["max_retry_interval"], "train_dataset": { "name": "experience_buffer", # TODO "storage_type": st.session_state["storage_type"], "algorithm_type": st.session_state["algorithm_type"], - "path": db_url, + "path": train_dataset_path, }, "sft_warmup_dataset": { "name": "sft_warmup_dataset", "storage_type": sft_storage_type, "algorithm_type": AlgorithmType.SFT.value, "path": st.session_state["sft_warmup_dataset_path"], + "kwargs": { + "train_split": st.session_state["sft_warmup_train_split"], + "prompt_type": st.session_state["sft_warmup_prompt_type"], + "messages_key": st.session_state["sft_warmup_messages_key"], + "prompt_key": st.session_state["sft_warmup_prompt_key"], + "response_key": st.session_state["sft_warmup_response_key"], + }, }, }, "explorer": { @@ -1317,8 +1476,6 @@ def generate_config(self): "enforce_eager": st.session_state["enforce_eager"], "dtype": st.session_state["dtype"], "temperature": st.session_state["temperature"], - "top_p": st.session_state["top_p"], - "top_k": st.session_state["top_k"], "seed": st.session_state["seed"], "logprobs": st.session_state["logprobs"], "repeat_times": st.session_state["repeat_times"], @@ -1329,6 +1486,7 @@ def generate_config(self): "synchronizer": { "sync_method": st.session_state["sync_method"], "sync_iteration_interval": st.session_state["sync_iteration_interval"], + "sync_timeout": st.session_state["sync_timeout"], }, "trainer": { "trainer_type": st.session_state["trainer_type"], @@ -1336,6 +1494,7 @@ def generate_config(self): "trainer_config": trainer_config, "sft_warmup_iteration": st.session_state["sft_warmup_iteration"], "eval_interval": st.session_state["eval_interval"], + "save_interval": st.session_state["save_interval"], }, "monitor": { "project": st.session_state["project"], @@ -1343,6 +1502,15 @@ def generate_config(self): "monitor_type": st.session_state["monitor_type"], }, } + + if st.session_state["algorithm_type"] == AlgorithmType.DPO.value: + config["buffer"]["train_dataset"]["kwargs"] = { + "dpo_dataset_train_split": st.session_state["dpo_dataset_train_split"], + "dpo_dataset_prompt_type": st.session_state["dpo_dataset_prompt_type"], + "dpo_dataset_prompt_key": st.session_state["dpo_dataset_prompt_key"], + "dpo_dataset_chosen_key": st.session_state["dpo_dataset_chosen_key"], + "dpo_dataset_rejected_key": st.session_state["dpo_dataset_rejected_key"], + } st.header("Generated Config File") st.subheader("Config File") yaml_config = yaml.dump(config, allow_unicode=True, sort_keys=False) diff --git a/trinity/trainer/trainer.py b/trinity/trainer/trainer.py index a0c136faa1..4749a57b5c 100644 --- a/trinity/trainer/trainer.py +++ b/trinity/trainer/trainer.py @@ -13,7 +13,7 @@ from trinity.buffer import get_buffer_reader from trinity.common.config import Config -from trinity.common.constants import AlgorithmType, ReadStrategy +from trinity.common.constants import AlgorithmType, ReadStrategy, SyncMethod from trinity.common.experience import Experiences from trinity.utils.log import get_logger @@ -110,7 +110,7 @@ def train_iteration(self, algo_type: AlgorithmType = AlgorithmType.PPO) -> Tuple def sync_weight(self) -> None: """Sync the model weight.""" - if self.config.synchronizer.sync_method == "online": + if self.config.synchronizer.sync_method == SyncMethod.NCCL: self.engine.sync_weight() def flush_log(self, step: int) -> None: diff --git a/trinity/trainer/verl/core_algos.py b/trinity/trainer/verl/core_algos.py index 98aeb62dc2..20cffc9962 100644 --- a/trinity/trainer/verl/core_algos.py +++ b/trinity/trainer/verl/core_algos.py @@ -24,6 +24,8 @@ import torch.nn.functional as F import verl.utils.torch_functional as verl_F +from trinity.common.constants import AlgorithmType + class KLController(ABC): @abstractmethod @@ -353,20 +355,9 @@ def compute_rewards(token_level_scores, old_log_prob, ref_log_prob, kl_ratio): def compute_policy_loss(old_log_prob, log_prob, eos_mask, **kwargs): """Compute policy loss for PPO / OPMD / pairwise OPMD""" - alg_type = kwargs.get("alg_type", "ppo") - - if alg_type == "ppo": - advantages = kwargs.get("advantages") - cliprange = kwargs.get("cliprange") - return compute_policy_loss_ppo( - old_log_prob=old_log_prob, - log_prob=log_prob, - advantages=advantages, - eos_mask=eos_mask, - cliprange=cliprange, - ) + algorithm_type: AlgorithmType = kwargs.get("algorithm_type", AlgorithmType.PPO) - elif alg_type == "opmd": + if algorithm_type == AlgorithmType.OPMD: advantages = kwargs.get("advantages") tau = kwargs.get("tau") return compute_policy_loss_opmd( @@ -377,7 +368,7 @@ def compute_policy_loss(old_log_prob, log_prob, eos_mask, **kwargs): tau=tau, ) - elif alg_type == "pairwise_opmd": + elif algorithm_type == AlgorithmType.PAIRWISE_OPMD: token_level_scores = kwargs.get("token_level_scores") index = kwargs.get("index") tau = kwargs.get("tau") @@ -390,8 +381,19 @@ def compute_policy_loss(old_log_prob, log_prob, eos_mask, **kwargs): tau=tau, ) + elif algorithm_type.is_rft(): + advantages = kwargs.get("advantages") + cliprange = kwargs.get("cliprange") + return compute_policy_loss_ppo( + old_log_prob=old_log_prob, + log_prob=log_prob, + advantages=advantages, + eos_mask=eos_mask, + cliprange=cliprange, + ) + else: - raise NotImplementedError(f"Get invalid alg_type '{alg_type}'.") + raise NotImplementedError(f"Get invalid algorithm_type '{algorithm_type}'.") def compute_policy_loss_dpo( diff --git a/trinity/trainer/verl/dp_actor.py b/trinity/trainer/verl/dp_actor.py index 544f4b60f0..246cd1f21c 100644 --- a/trinity/trainer/verl/dp_actor.py +++ b/trinity/trainer/verl/dp_actor.py @@ -309,7 +309,7 @@ def update_policy(self, data: DataProto): # noqa: C901 "temperature" ] # temperature must be in the data.meta_info to avoid slient error - alg_type = self.config.get("alg_type", "ppo") + algorithm_type: AlgorithmType = self.config.get("algorithm_type", AlgorithmType.PPO) if self.algorithm_type.is_rft(): select_keys = [ "responses", @@ -323,7 +323,7 @@ def update_policy(self, data: DataProto): # noqa: C901 if self.config.use_kl_loss: select_keys.append("ref_log_prob") - if alg_type == "pairwise_opmd": + if algorithm_type == AlgorithmType.PAIRWISE_OPMD: select_keys.append("token_level_scores") elif self.algorithm_type.is_dpo(): select_keys = [ @@ -349,15 +349,15 @@ def update_policy(self, data: DataProto): # noqa: C901 # Split to make minibatch iterator for updating the actor # See PPO paper for details. https://arxiv.org/abs/1707.06347 - if has_multi_modal_inputs or ((alg_type == "pairwise_opmd") and use_uid): - # TODO: for now, we treat alg_type == "pairwise_opmd" in the same way that + if has_multi_modal_inputs or ((algorithm_type == AlgorithmType.PAIRWISE_OPMD) and use_uid): + # TODO: for now, we treat algorithm_type == AlgorithmType.PAIRWISE_OPMD in the same way that # has_multi_modal_inputs was treated originally (to handle non_tensor_select_keys); # need to double check if this is the best approach. num_mini_batches = data.batch.batch_size[0] // self.config.ppo_mini_batch_size non_tensor_select_keys = [] if has_multi_modal_inputs: non_tensor_select_keys.append("multi_modal_inputs") - if (alg_type == "pairwise_opmd") and use_uid: + if (algorithm_type == AlgorithmType.PAIRWISE_OPMD) and use_uid: non_tensor_select_keys.append("uid") dataloader = data.select(select_keys, non_tensor_select_keys).chunk(num_mini_batches) else: @@ -373,7 +373,9 @@ def update_policy(self, data: DataProto): # noqa: C901 for batch_idx, data in enumerate(dataloader): # split batch into micro_batches mini_batch = data - if has_multi_modal_inputs or ((alg_type == "pairwise_opmd") and use_uid): + if has_multi_modal_inputs or ( + (algorithm_type == AlgorithmType.PAIRWISE_OPMD) and use_uid + ): self.gradient_accumulation = ( self.config.ppo_mini_batch_size // self.config.ppo_micro_batch_size_per_gpu ) @@ -456,7 +458,7 @@ def update_policy(self, data: DataProto): # noqa: C901 tau = self.config.get("tau", 1.0) token_level_scores = None index = None - if alg_type == "pairwise_opmd": + if algorithm_type == AlgorithmType.PAIRWISE_OPMD: token_level_scores = data["token_level_scores"] if use_uid: index = data["uid"] @@ -470,7 +472,7 @@ def update_policy(self, data: DataProto): # noqa: C901 old_log_prob=old_log_prob, log_prob=log_prob, eos_mask=response_mask, - alg_type=alg_type, + algorithm_type=algorithm_type, advantages=advantages, cliprange=clip_ratio, # for opmd / pairwise_opmd diff --git a/trinity/trainer/verl/fsdp_workers.py b/trinity/trainer/verl/fsdp_workers.py index 0273fd0544..57e9849f9d 100644 --- a/trinity/trainer/verl/fsdp_workers.py +++ b/trinity/trainer/verl/fsdp_workers.py @@ -50,7 +50,7 @@ from verl.utils.model import compute_position_id_with_mask from verl.workers.sharding_manager.fsdp_ulysses import FSDPUlyssesShardingManager -from trinity.common.constants import AlgorithmType +from trinity.common.constants import AlgorithmType, SyncMethod from trinity.utils.distributed import init_process_group, is_ipv6_address logger = logging.getLogger(__file__) @@ -560,7 +560,7 @@ def init_model(self): def setup_weight_sync_group(self): if ( hasattr(self.config, "synchronizer") - and getattr(self.config.synchronizer, "sync_method", None) == "online" + and getattr(self.config.synchronizer, "sync_method", None) == SyncMethod.NCCL ): model = self.actor_module_fsdp self.named_modules = [] @@ -597,10 +597,12 @@ def setup_weight_sync_group(self): init_method = f"tcp://[{master_address}]:{master_port}" else: init_method = f"tcp://{master_address}:{master_port}" + timeout = self.config.synchronizer.sync_timeout self._model_update_group = init_process_group( backend=backend, init_method=init_method, + timeout=timeout, world_size=world_size, rank=0, group_name=group_name, diff --git a/trinity/trainer/verl/ray_trainer.py b/trinity/trainer/verl/ray_trainer.py index fc85c171f1..7073319db0 100644 --- a/trinity/trainer/verl/ray_trainer.py +++ b/trinity/trainer/verl/ray_trainer.py @@ -56,6 +56,7 @@ from verl.utils.torch_functional import masked_mean from verl.utils.tracking import ValidationGenerationsLogger +from trinity.common.constants import AlgorithmType from trinity.trainer.verl import core_algos WorkerType = Type[Worker] @@ -208,23 +209,9 @@ def compute_response_mask(data: DataProto): def compute_advantage(data: DataProto, **kwargs): """Extend verl's original compute_advantage with OPMD""" - alg_type = kwargs.get("alg_type", "ppo") + algorithm_type: AlgorithmType = kwargs.get("algorithm_type", AlgorithmType.PPO) - if alg_type == "ppo": - adv_estimator = kwargs.get("adv_estimator", None) - gamma = kwargs.get("gamma", 1.0) - lam = kwargs.get("lam", 1.0) - num_repeat = kwargs.get("num_repeat", 1) - - return compute_advantage_ppo( - data=data, - adv_estimator=adv_estimator, - gamma=gamma, - lam=lam, - num_repeat=num_repeat, - ) - - elif alg_type == "opmd": + if algorithm_type == AlgorithmType.OPMD: tau = kwargs.get("tau", 1.0) opmd_baseline = kwargs.get("opmd_baseline", "mean") @@ -234,13 +221,27 @@ def compute_advantage(data: DataProto, **kwargs): opmd_baseline=opmd_baseline, ) - elif alg_type == "pairwise_opmd": + elif algorithm_type == AlgorithmType.PAIRWISE_OPMD: data.batch["advantages"] = None data.batch["returns"] = None return data + elif algorithm_type.is_rft(): + adv_estimator = kwargs.get("adv_estimator", None) + gamma = kwargs.get("gamma", 1.0) + lam = kwargs.get("lam", 1.0) + num_repeat = kwargs.get("num_repeat", 1) + + return compute_advantage_ppo( + data=data, + adv_estimator=adv_estimator, + gamma=gamma, + lam=lam, + num_repeat=num_repeat, + ) + else: - raise ValueError(f"alg_type must be 'ppo' or 'opmd', get '{alg_type}'.") + raise ValueError(f"Get invalid algorithm_type '{algorithm_type}'.") def compute_advantage_opmd(data: DataProto, tau=1.0, opmd_baseline="mean"): @@ -367,7 +368,10 @@ def __init__( else: self.kl_ctrl = core_algos.FixedKLController(kl_coef=0.0) - if self.config.actor_rollout_ref.actor.get("alg_type", "ppo") != "ppo": + if ( + self.config.actor_rollout_ref.actor.get("algorithm_type", AlgorithmType.PPO) + != AlgorithmType.PPO + ): self.use_critic = False elif self.config.algorithm.adv_estimator == AdvantageEstimator.GAE: self.use_critic = True @@ -1082,14 +1086,16 @@ def fit(self): # noqa: C901 batch.batch["token_level_rewards"] = batch.batch["token_level_scores"] # compute advantages, executed on the driver process - alg_type = self.config.actor_rollout_ref.actor.get("alg_type", "ppo") + algorithm_type = self.config.actor_rollout_ref.actor.get( + "algorithm_type", AlgorithmType.PPO + ) tau = self.config.actor_rollout_ref.actor.get("tau", 1.0) opmd_baseline = self.config.actor_rollout_ref.actor.get( "opmd_baseline", "mean" ) batch = compute_advantage( batch, - alg_type=alg_type, + algorithm_type=algorithm_type, adv_estimator=self.config.algorithm.adv_estimator, gamma=self.config.algorithm.gamma, lam=self.config.algorithm.lam, diff --git a/trinity/trainer/verl_trainer.py b/trinity/trainer/verl_trainer.py index ea7c632b35..ffb1d2be12 100644 --- a/trinity/trainer/verl_trainer.py +++ b/trinity/trainer/verl_trainer.py @@ -40,7 +40,6 @@ class _InternalDataLoader: def __init__(self, config): self.config = config - self.length = config.trainer.steps_per_epoch self.dataset = None self.index = 0 self.experience_buffer = None @@ -383,12 +382,14 @@ def train_rft_iteration(self, experiences: Experiences) -> Tuple[bool, int]: # compute advantages, executed on the driver process kwargs = {} - alg_type = self.config.actor_rollout_ref.actor.get("alg_type", "ppo") - if alg_type == "opmd": + algorithm_type = self.config.actor_rollout_ref.actor.get( + "algorithm_type", AlgorithmType.PPO + ) + if algorithm_type == AlgorithmType.OPMD: tau = self.config.actor_rollout_ref.actor.get("tau", 0.0) opmd_baseline = self.config.actor_rollout_ref.actor.get("opmd_baseline", "mean") kwargs = { - "alg_type": alg_type, + "algorithm_type": algorithm_type, "tau": tau, "opmd_baseline": opmd_baseline, } @@ -428,7 +429,6 @@ def train_rft_iteration(self, experiences: Experiences) -> Tuple[bool, int]: # val_metrics: dict = self._validate() # metrics.update(val_metrics) - # TODO save_checkpoint too frequently, a method for updating parameters online needs to be added if ( self.config.trainer.save_freq > 0 and self.global_steps % self.config.trainer.save_freq == 0 diff --git a/trinity/utils/distributed.py b/trinity/utils/distributed.py index 5111b41449..35933168fa 100644 --- a/trinity/utils/distributed.py +++ b/trinity/utils/distributed.py @@ -27,7 +27,7 @@ def is_ipv6_address(ip_str: str) -> bool: def init_process_group( backend: Union[str, Backend] = None, init_method: Optional[str] = None, - timeout: Optional[timedelta] = None, + timeout: Optional[float] = None, world_size: int = -1, rank: int = -1, store: Optional[Store] = None, @@ -49,6 +49,8 @@ def init_process_group( if timeout is None: timeout = default_pg_timeout + else: + timeout = timedelta(seconds=timeout) # backward compatible API if store is None: