diff --git a/README.md b/README.md index 560632fc0e..1f6f5bdd45 100644 --- a/README.md +++ b/README.md @@ -200,7 +200,7 @@ For more details about dataset downloading, please refer to [Huggingface](https: ### Step 3: configurations -You may customize the configurations in `scripts/config/{config_name}.yaml`and `scripts/config/{train_config_name}.yaml`. For example, the model and dataset are specified as: +You may customize the configurations in [`examples`](examples/). For example, the model and dataset are specified as: ```yaml model: @@ -208,12 +208,9 @@ model: data: dataset_path: $DATASET_PATH/{dataset_name} - -trainer: - trainer_config_path: scripts/config/{train_config_name}.yaml ``` -You may use the default configurations located in the directory `scripts/config`. Please refer to `examples` for more details. +Please refer to [`examples`](examples/) for more details. @@ -252,12 +249,12 @@ trinity run --config For example, below is the command for fine-tuning Qwen-2.5-1B-Instruct on GSM8k dataset using GRPO algorithm: ```shell -trinity run --config scripts/config/gsm8k.yaml +trinity run --config examples/grpo_gsm8k/gsm8k.yaml ``` -More example config files can be found in `scripts/config`. +More example config files can be found in `examples`. diff --git a/docs/sphinx_doc/source/main.md b/docs/sphinx_doc/source/main.md index 3caa99f6f6..c277e9b116 100644 --- a/docs/sphinx_doc/source/main.md +++ b/docs/sphinx_doc/source/main.md @@ -180,7 +180,7 @@ For more details about dataset downloading, please refer to [Huggingface](https: ### Step 3: configurations -You may customize the configurations in `scripts/config/{config_name}.yaml`and `scripts/config/{train_config_name}.yaml`. For example, the model and dataset are specified as: +You may customize the configurations in [`examples`](https://github.com/modelscope/Trinity-RFT/tree/main/examples/). For example, the model and dataset are specified as: ```yaml model: @@ -188,12 +188,9 @@ model: data: dataset_path: $DATASET_PATH/{dataset_name} - -trainer: - trainer_config_path: scripts/config/{train_config_name}.yaml ``` -You may use the default configurations located in the directory `scripts/config`. Please refer to `examples` for more details. +Please refer to [`examples`](https://github.com/modelscope/Trinity-RFT/tree/main/examples/) for more details. @@ -232,12 +229,12 @@ trinity run --config For example, below is the command for fine-tuning Qwen-2.5-1B-Instruct on GSM8k dataset using GRPO algorithm: ```shell -trinity run --config scripts/config/gsm8k.yaml +trinity run --config examples/grpo_gsm8k/gsm8k.yaml ``` -More example config files can be found in `scripts/config`. +More example config files can be found in `examples`. diff --git a/docs/sphinx_doc/source/tutorial/example_data_functionalities.md b/docs/sphinx_doc/source/tutorial/example_data_functionalities.md index c3067bb2a1..fb9a82873b 100644 --- a/docs/sphinx_doc/source/tutorial/example_data_functionalities.md +++ b/docs/sphinx_doc/source/tutorial/example_data_functionalities.md @@ -133,7 +133,7 @@ And you can set the `clean_strategy` to 'iterative' to get a better dataset. -All config items in the `data` section can be found [here](trinity_configs.md). A prepared config file for this example of GSM-8K can be found in [the config file of gsm8k](../../../../scripts/config/gsm8k.yaml). +All config items in the `data` section can be found [here](trinity_configs.md). A prepared config file for this example of GSM-8K can be found in [the config file of gsm8k](https://github.com/modelscope/Trinity-RFT/tree/main/examples/grpo_gsm8k/gsm8k.yaml). diff --git a/docs/sphinx_doc/source/tutorial/example_dpo.md b/docs/sphinx_doc/source/tutorial/example_dpo.md index 5ae0ee824c..448cfd67fe 100644 --- a/docs/sphinx_doc/source/tutorial/example_dpo.md +++ b/docs/sphinx_doc/source/tutorial/example_dpo.md @@ -38,12 +38,12 @@ Note that the dataset has the keys `prompt`, `chosen` and `rejected`. If not, pa ### Configuration -We use the configurations in `scripts/config/dpo.yaml`and `scripts/config/train_dpo.yaml` for this experiment. Some important setups are listed in the following: +We use the configurations in [`dpo.yaml`](https://github.com/modelscope/Trinity-RFT/tree/main/examples/dpo_humanlike/dpo.yaml) and [`train_dpo.yaml`](https://github.com/modelscope/Trinity-RFT/tree/main/examples/dpo_humanlike/train_dpo.yaml) for this experiment. Some important setups are listed in the following: We run the experiment in a train mode, as there is no Explorer. To enable this mode, we config `mode` to `train` and set `sync_method` to `offline`. The value of `sync_iteration_interval` can be set as same of the value of `save_freq`. ```yaml -# scripts/config/dpo.yaml +# In dpo.yaml mode: train synchronizer: sync_method: 'offline' @@ -60,7 +60,7 @@ buffer: trainer: algorithm_type: dpo -# scripts/config/train_dpo.yaml +# In train_dpo.yaml actor_rollout_ref: actor: alg_type: dpo @@ -73,5 +73,5 @@ actor_rollout_ref: Run RFT process with the following command: ```shell -trinity run --config scripts/config/dpo.yaml +trinity run --config examples/dpo_humanlike/dpo.yaml ``` diff --git a/docs/sphinx_doc/source/tutorial/example_multi_turn.md b/docs/sphinx_doc/source/tutorial/example_multi_turn.md index 28a580bc98..9002fffadd 100644 --- a/docs/sphinx_doc/source/tutorial/example_multi_turn.md +++ b/docs/sphinx_doc/source/tutorial/example_multi_turn.md @@ -36,15 +36,15 @@ The task is described as an environment instead of a single prompt. ## Step 2: Config preparation and run the experiment -You can refer to `example_reasoning_basic` to setup the config and others. The default config files are `scripts/config/alfworld.yaml` and `scripts/config/webshop.yaml`, respectively. +You can refer to `example_reasoning_basic` to setup the config and others. The default config files are [`alfworld.yaml`](https://github.com/modelscope/Trinity-RFT/tree/main/examples/grpo_alfworld/alfworld.yaml) and [`webshop.yaml`](https://github.com/modelscope/Trinity-RFT/tree/main/examples/grpo_webshop/webshop.yaml), respectively. You may revise the configurations properly and run the experiment! ```bash # For ALFworld env -trinity run --config scripts/config/alfworld.yaml +trinity run --config examples/grpo_alfworld/alfworld.yaml # For WebShop env -trinity run --config scripts/config/webshop.yaml +trinity run --config examples/grpo_webshop/webshop.yaml ``` ## Advance: How to build your own environment diff --git a/docs/sphinx_doc/source/tutorial/example_reasoning_advanced.md b/docs/sphinx_doc/source/tutorial/example_reasoning_advanced.md index f9638ad94d..d278ed22bc 100644 --- a/docs/sphinx_doc/source/tutorial/example_reasoning_advanced.md +++ b/docs/sphinx_doc/source/tutorial/example_reasoning_advanced.md @@ -17,11 +17,11 @@ The algorithm design and analysis can be found in this [technical report](../../ To try out the OPMD algorithm: ```shell -trinity run --config scripts/config/gsm8k_opmd.yaml +trinity run --config examples/opmd_gsm8k/opmd_gsm8k.yaml ``` Note that in this config file, `sync_iteration_interval` is set to 10, i.e., the model weights of explorer and trainer are synchronized only once every 10 training steps, which leads to a challenging off-policy scenario (potentially with abrupt distribution shift during the RFT process). -Other configurations of particular interest are explained at the beginning of `scripts/config/train_gsm8k_opmd.yaml`. +Other configurations of particular interest are explained at the beginning of [`train_opmd_gsm8k.yaml`](https://github.com/modelscope/Trinity-RFT/tree/main/examples/opmd_gsm8k/train_opmd_gsm8k.yaml). diff --git a/docs/sphinx_doc/source/tutorial/example_reasoning_basic.md b/docs/sphinx_doc/source/tutorial/example_reasoning_basic.md index e2ef856faf..6893528ad4 100644 --- a/docs/sphinx_doc/source/tutorial/example_reasoning_basic.md +++ b/docs/sphinx_doc/source/tutorial/example_reasoning_basic.md @@ -48,15 +48,15 @@ synchronizer: ### Use GRPO or PPO Algorithm -We use the configurations in `scripts/config/gsm8k.yaml`and `scripts/config/train_gsm8k.yaml` for this experiment. Some important setups are listed in the following: +We use the configurations in [`gsm8k.yaml`](https://github.com/modelscope/Trinity-RFT/tree/main/examples/grpo_gsm8k/gsm8k.yaml) and [`train_gsm8k.yaml`](https://github.com/modelscope/Trinity-RFT/tree/main/examples/grpo_gsm8k/train_gsm8k.yaml) for this experiment. Some important setups are listed in the following: ```yaml -# scripts/config/gsm8k.yaml +# In gsm8k.yaml explorer: repeat_times: {number of rollouts for each task} -# scripts/config/train_gsm8k.yaml +# In train_gsm8k.yaml actor_rollout_ref: actor: use_kl_loss: True (fro GRPO) / False (for PPO) @@ -69,7 +69,7 @@ algorithm: Run the RFT process with the following command: ```bash -trinity run --config scripts/config/gsm8k.yaml +trinity run --config examples/grpo_gsm8k/gsm8k.yaml ``` @@ -79,14 +79,14 @@ trinity run --config scripts/config/gsm8k.yaml Before RFT, we may use SFT as a warmup step. We need to set `trainer.sft_warmup_iteration > 0` and prepare the SFT data to `buffer.train_dataset.path=$DATASET_PATH/{sft_data}`. ```yaml -# Properly set the following configs in scripts/config/gsm8k.yaml +# Properly set the following configs in gsm8k.yaml buffer: sft_warmup_dataset: storage_type: file algorithm_type: sft path: <$DATASET_PATH/{sft_data}> kwargs: - prompt_type: # messages/plaintext + prompt_type: # messages/plaintext/chatpair prompt_key: response_key: trainer: @@ -95,5 +95,5 @@ trainer: The following command runs SFT and RFT in sequence: ```bash -trinity run --config scripts/config/gsm8k.yaml +trinity run --config examples/grpo_gsm8k/gsm8k.yaml ``` diff --git a/docs/sphinx_doc/source/tutorial/trinity_configs.md b/docs/sphinx_doc/source/tutorial/trinity_configs.md index 02a24c7c0a..0c84e72b46 100644 --- a/docs/sphinx_doc/source/tutorial/trinity_configs.md +++ b/docs/sphinx_doc/source/tutorial/trinity_configs.md @@ -1,6 +1,6 @@ # Trinity-RFT Configuration -The following is the main config file for Trinity-RFT. Take `scripts/config/countdown.yaml` as an example. +The following is the main config file for Trinity-RFT. Take `countdown.yaml` as an example. ## Monitor @@ -165,7 +165,7 @@ synchronizer: trainer: trainer_type: 'verl' algorithm_type: ppo - trainer_config_path: 'scripts/config/train_countdown.yaml' + trainer_config_path: 'examples/ppo_countdown/train_countdown.yaml' sft_warmup_iteration: 0 eval_interval: 1000 ``` diff --git a/examples/dpo_humanlike/README.md b/examples/dpo_humanlike/README.md new file mode 100644 index 0000000000..2102ec14d3 --- /dev/null +++ b/examples/dpo_humanlike/README.md @@ -0,0 +1,7 @@ +# DPO on HumanLike Dataset + +This example shows the usage of DPO on the HumanLike dataset. + +For more detailed information, please refer to the [documentation](../../docs/sphinx_doc/source/tutorial/example_dpo.md). + +The config files are located in [`dpo.yaml`](dpo.yaml) and [`train_dpo.yaml`](train_dpo.yaml). diff --git a/scripts/config/dpo.yaml b/examples/dpo_humanlike/dpo.yaml similarity index 95% rename from scripts/config/dpo.yaml rename to examples/dpo_humanlike/dpo.yaml index 6579228bef..991c662138 100644 --- a/scripts/config/dpo.yaml +++ b/examples/dpo_humanlike/dpo.yaml @@ -53,7 +53,7 @@ synchronizer: trainer: trainer_type: 'verl' algorithm_type: dpo - trainer_config_path: 'scripts/config/train_dpo.yaml' + trainer_config_path: 'examples/dpo_humanlike/train_dpo.yaml' monitor: cache_root_dir: "" project: "dpo_example" diff --git a/scripts/config/train_dpo.yaml b/examples/dpo_humanlike/train_dpo.yaml similarity index 100% rename from scripts/config/train_dpo.yaml rename to examples/dpo_humanlike/train_dpo.yaml diff --git a/examples/grpo_alfworld/README.md b/examples/grpo_alfworld/README.md new file mode 100644 index 0000000000..8a1fa57ce0 --- /dev/null +++ b/examples/grpo_alfworld/README.md @@ -0,0 +1,7 @@ +# GRPO on ALFWorld Dataset + +This example shows the usage of GRPO on the ALFWorld dataset. + +For more detailed information, please refer to the [documentation](../../docs/sphinx_doc/source/tutorial/example_multi_turn.md). + +The config files are located in [`alfworld.yaml`](alfworld.yaml) and [`train_alfworld.yaml`](train_alfworld.yaml). diff --git a/scripts/config/alfworld.yaml b/examples/grpo_alfworld/alfworld.yaml similarity index 94% rename from scripts/config/alfworld.yaml rename to examples/grpo_alfworld/alfworld.yaml index 96ca477c72..100420acd8 100644 --- a/scripts/config/alfworld.yaml +++ b/examples/grpo_alfworld/alfworld.yaml @@ -49,7 +49,7 @@ synchronizer: trainer: trainer_type: 'verl' algorithm_type: ppo - trainer_config_path: 'scripts/config/train_alfworld.yaml' + trainer_config_path: 'examples/grpo_alfworld/train_alfworld.yaml' monitor: cache_root_dir: "" project: "ALFWORLD" diff --git a/scripts/config/train_alfworld.yaml b/examples/grpo_alfworld/train_alfworld.yaml similarity index 100% rename from scripts/config/train_alfworld.yaml rename to examples/grpo_alfworld/train_alfworld.yaml diff --git a/examples/grpo_gsm8k/README.md b/examples/grpo_gsm8k/README.md new file mode 100644 index 0000000000..29caa5319a --- /dev/null +++ b/examples/grpo_gsm8k/README.md @@ -0,0 +1,7 @@ +# GRPO on GSM8K dataset + +This example shows the usage of GRPO on the GSM8K dataset. + +For more detailed information, please refer to the [documentation](../../docs/sphinx_doc/source/tutorial/example_reasoning_basic.md). + +The config files are located in [`gsm8k.yaml`](gsm8k.yaml) and [`train_gsm8k.yaml`](train_gsm8k.yaml). diff --git a/scripts/config/gsm8k.yaml b/examples/grpo_gsm8k/gsm8k.yaml similarity index 96% rename from scripts/config/gsm8k.yaml rename to examples/grpo_gsm8k/gsm8k.yaml index c4bd8e05f2..e2ca38cd63 100644 --- a/scripts/config/gsm8k.yaml +++ b/examples/grpo_gsm8k/gsm8k.yaml @@ -67,7 +67,7 @@ synchronizer: trainer: trainer_type: 'verl' algorithm_type: ppo - trainer_config_path: 'scripts/config/train_gsm8k.yaml' + trainer_config_path: 'examples/grpo_gsm8k/train_gsm8k.yaml' sft_warmup_iteration: 0 # Set to integer to enable sft warmup eval_interval: 50 monitor: diff --git a/scripts/config/train_gsm8k.yaml b/examples/grpo_gsm8k/train_gsm8k.yaml similarity index 100% rename from scripts/config/train_gsm8k.yaml rename to examples/grpo_gsm8k/train_gsm8k.yaml diff --git a/examples/grpo_math/README.md b/examples/grpo_math/README.md new file mode 100644 index 0000000000..649cc5272f --- /dev/null +++ b/examples/grpo_math/README.md @@ -0,0 +1,7 @@ +# Example: PPO on MATH dataset + +This example shows the usage of PPO on the MATH dataset. + +For more detailed information, please refer to the [documentation](../../docs/sphinx_doc/source/tutorial/example_reasoning_basic.md). + +The config files are located in [`math.yaml`](math.yaml) and [`train_math.yaml`](train_math.yaml). diff --git a/examples/grpo_math/math.yaml b/examples/grpo_math/math.yaml new file mode 100644 index 0000000000..d7468c1cb7 --- /dev/null +++ b/examples/grpo_math/math.yaml @@ -0,0 +1,63 @@ +data: + # basic info + dataset_path: /PATH/TO/DATASET/ + # dataset_config: + train_split: train + eval_split: test + format_config: + prompt_key: 'question' + response_key: 'gt_answer' + # db related + db_url: '' + # downstream loading related + total_epoch: 20 + batch_size: 288 + default_workflow_type: 'math_workflow' +model: + model_path: /PATH/TO/MODEL/ + max_prompt_tokens: 1024 + max_response_tokens: 3072 + checkpoint_path: /PATH/TO/CHECKPOINT/ + load_checkpoint: true +cluster: + node_num: 1 + gpu_per_node: 8 +buffer: + max_retry_times: 3 + max_retry_interval: 1 + train_dataset: + name: math_buffer + storage_type: queue + algorithm_type: ppo + path: 'sqlite:////math.db' +explorer: + engine_type: vllm_async + engine_num: 2 + runner_num: 32 + tensor_parallel_size: 1 + enable_prefix_caching: false + enforce_eager: true + dtype: bfloat16 + temperature: 1.0 + top_p: 1.0 + top_k: -1 + seed: 42 + logprobs: 0 + repeat_times: 8 + use_ray: false + backend: 'nccl' + max_pending_requests: 32 + max_waiting_steps: 4 +synchronizer: + sync_method: 'online' + sync_iteration_interval: 2 +trainer: + trainer_type: 'verl' + algorithm_type: ppo + trainer_config_path: 'examples/grpo_math/train_math.yaml' + sft_warmup_iteration: 0 # Set to integer to enable sft warmup + eval_interval: 10 +monitor: + cache_root_dir: "" + project: grpo_math + name: grpo_math_example diff --git a/examples/grpo_math/train_math.yaml b/examples/grpo_math/train_math.yaml new file mode 100644 index 0000000000..0c457281ee --- /dev/null +++ b/examples/grpo_math/train_math.yaml @@ -0,0 +1,184 @@ +data: + tokenizer: null + train_files: train_example.parquet + val_files: test_example.parquet + prompt_key: prompt + max_prompt_length: 1024 + max_response_length: 2048 + # train_batch_size: 256 + val_batch_size: null + return_raw_input_ids: False # This should be set to true when the tokenizer between policy and rm differs + return_raw_chat: False + shuffle: True + filter_overlong_prompts: False # for large-scale dataset, filtering overlong prompts could be timeconsuming. You should disable this and set `truncation='left' + truncation: error + image_key: images + +actor_rollout_ref: + hybrid_engine: True + model: + path: /PATH/TO/MODEL/ + external_lib: null + override_config: { } + enable_gradient_checkpointing: True + use_remove_padding: True # False + actor: + strategy: fsdp # This is for backward-compatibility + ppo_mini_batch_size: 128 + # ppo_micro_batch_size: 8 # will be deprecated, use ppo_micro_batch_size_per_gpu + ppo_micro_batch_size_per_gpu: 4 + use_dynamic_bsz: True # False + ppo_max_token_len_per_gpu: 16384 # n * ${data.max_prompt_length} + ${data.max_response_length} + grad_clip: 1.0 + clip_ratio: 0.2 + entropy_coeff: 0.001 + use_kl_loss: True # True for GRPO + kl_loss_coef: 0.0001 # for grpo + kl_loss_type: low_var_kl # for grpo + ppo_epochs: 1 + shuffle: False + ulysses_sequence_parallel_size: 1 # sp size + optim: + lr: 5e-7 + lr_warmup_steps_ratio: 0. # the total steps will be injected during runtime + # min_lr_ratio: null # only useful for warmup with cosine + warmup_style: constant # select from constant/cosine + total_training_steps: -1 # must be override by program + fsdp_config: + wrap_policy: + # transformer_layer_cls_to_wrap: None + min_num_params: 0 + param_offload: False + optimizer_offload: False + fsdp_size: -1 + # --- below: opmd --- + alg_type: ppo # ppo / opmd / pairwise_opmd + tau: 0.000 # strength of regularization w.r.t. old / ref policy + opmd_baseline: mean # mean / logavgexp, applicable to opmd + use_uid: False # True / False, applicable to pairwise_opmd + ref: + fsdp_config: + param_offload: False + wrap_policy: + # transformer_layer_cls_to_wrap: None + min_num_params: 0 + # log_prob_micro_batch_size: 4 # will be deprecated, use log_prob_micro_batch_size_per_gpu + log_prob_micro_batch_size_per_gpu: 16 + log_prob_use_dynamic_bsz: ${actor_rollout_ref.actor.use_dynamic_bsz} + log_prob_max_token_len_per_gpu: ${actor_rollout_ref.actor.ppo_max_token_len_per_gpu} + ulysses_sequence_parallel_size: ${actor_rollout_ref.actor.ulysses_sequence_parallel_size} # sp size + rollout: + name: vllm + temperature: 1.0 + top_k: -1 # 0 for hf rollout, -1 for vllm rollout + top_p: 1 + use_fire_sampling: False # https://arxiv.org/abs/2410.21236 + prompt_length: ${data.max_prompt_length} # not use for opensource + response_length: ${data.max_response_length} + # for vllm rollout + dtype: bfloat16 # should align with FSDP + gpu_memory_utilization: 0.4 + ignore_eos: False + enforce_eager: True + free_cache_engine: True + load_format: dummy_dtensor + tensor_model_parallel_size: 2 + max_num_batched_tokens: 8192 + max_model_len: null + max_num_seqs: 1024 + # log_prob_micro_batch_size: 8 # will be deprecated, use log_prob_micro_batch_size_per_gpu + log_prob_micro_batch_size_per_gpu: 4 + log_prob_use_dynamic_bsz: ${actor_rollout_ref.actor.use_dynamic_bsz} + log_prob_max_token_len_per_gpu: ${actor_rollout_ref.actor.ppo_max_token_len_per_gpu} + disable_log_stats: True + enable_chunked_prefill: True # could get higher throughput + # for hf rollout + do_sample: True + # number of responses (i.e. num sample times) + n: 8 # > 1 for grpo + +critic: + strategy: fsdp + optim: + lr: 1e-5 + lr_warmup_steps_ratio: 0. # the total steps will be injected during runtime + # min_lr_ratio: null # only useful for warmup with cosine + warmup_style: constant # select from constant/cosine + total_training_steps: -1 # must be override by program + model: + path: /PATH/TO/MODEL/ + tokenizer_path: ${actor_rollout_ref.model.path} + override_config: { } + external_lib: ${actor_rollout_ref.model.external_lib} + enable_gradient_checkpointing: True + use_remove_padding: False + fsdp_config: + param_offload: False + optimizer_offload: False + wrap_policy: + # transformer_layer_cls_to_wrap: None + min_num_params: 0 + fsdp_size: -1 + ppo_mini_batch_size: ${actor_rollout_ref.actor.ppo_mini_batch_size} + # ppo_micro_batch_size: 8 # will be deprecated, use ppo_micro_batch_size_per_gpu + ppo_micro_batch_size_per_gpu: 64 + forward_micro_batch_size: ${critic.ppo_micro_batch_size} + forward_micro_batch_size_per_gpu: ${critic.ppo_micro_batch_size_per_gpu} + use_dynamic_bsz: ${actor_rollout_ref.actor.use_dynamic_bsz} + ppo_max_token_len_per_gpu: 32768 # (${actor_rollout_ref.actor.ppo_max_token_len_per_gpu}) * 2 + forward_max_token_len_per_gpu: ${critic.ppo_max_token_len_per_gpu} + ulysses_sequence_parallel_size: 1 # sp size + ppo_epochs: ${actor_rollout_ref.actor.ppo_epochs} + shuffle: ${actor_rollout_ref.actor.shuffle} + grad_clip: 1.0 + cliprange_value: 0.5 + +reward_model: + enable: False + strategy: fsdp + model: + input_tokenizer: ${actor_rollout_ref.model.path} # set this to null if the chat template is identical + path: ~/models/FsfairX-LLaMA3-RM-v0.1 + external_lib: ${actor_rollout_ref.model.external_lib} + use_remove_padding: False + fsdp_config: + min_num_params: 0 + param_offload: False + fsdp_size: -1 + ulysses_sequence_parallel_size: 1 # sp size + use_dynamic_bsz: ${critic.use_dynamic_bsz} + forward_max_token_len_per_gpu: ${critic.forward_max_token_len_per_gpu} + +custom_reward_function: + path: null + name: compute_score + +algorithm: + gamma: 1.0 + lam: 1.0 + adv_estimator: grpo + kl_penalty: kl # how to estimate kl divergence + kl_ctrl: + type: fixed + kl_coef: 0.0001 + +trainer: + balance_batch: True + total_epochs: 20 + project_name: grpo_math + experiment_name: grpo_math_example + logger: [ 'console','wandb' ] + val_generations_to_log_to_wandb: 0 + nnodes: 1 + n_gpus_per_node: 2 + save_freq: 100 + # auto: find the last ckpt to resume. If can't find, start from scratch + resume_mode: auto # or auto or resume_path if + resume_from_path: False + test_freq: 5 + critic_warmup: 0 + default_hdfs_dir: null + remove_previous_ckpt_in_save: False + del_local_ckpt_after_load: False + default_local_dir: checkpoints/${trainer.project_name}/${trainer.experiment_name} + val_before_train: False diff --git a/examples/grpo_sciworld/README.md b/examples/grpo_sciworld/README.md new file mode 100644 index 0000000000..6921411055 --- /dev/null +++ b/examples/grpo_sciworld/README.md @@ -0,0 +1,7 @@ +# Example: GRPO on SciWorld + +This example shows the usage of GRPO on the SciWorld dataset. + +For more detailed information, please refer to the [documentation](../../docs/sphinx_doc/source/tutorial/example_multi_turn.md). + +The config files are located in [`sciworld.yaml`](sciworld.yaml) and [`train_sciworld.yaml`](train_sciworld.yaml). diff --git a/scripts/config/sciworld.yaml b/examples/grpo_sciworld/sciworld.yaml similarity index 94% rename from scripts/config/sciworld.yaml rename to examples/grpo_sciworld/sciworld.yaml index 62a9344cbd..d036fc54fc 100644 --- a/scripts/config/sciworld.yaml +++ b/examples/grpo_sciworld/sciworld.yaml @@ -49,7 +49,7 @@ synchronizer: trainer: trainer_type: 'verl' algorithm_type: ppo - trainer_config_path: 'scripts/config/train_sciworld.yaml' + trainer_config_path: 'examples/grpo_sciworld/train_sciworld.yaml' monitor: cache_root_dir: "" project: "sciworld" diff --git a/scripts/config/train_sciworld.yaml b/examples/grpo_sciworld/train_sciworld.yaml similarity index 100% rename from scripts/config/train_sciworld.yaml rename to examples/grpo_sciworld/train_sciworld.yaml diff --git a/examples/grpo_webshop/README.md b/examples/grpo_webshop/README.md new file mode 100644 index 0000000000..54bdfdb5ac --- /dev/null +++ b/examples/grpo_webshop/README.md @@ -0,0 +1,7 @@ +# Example: GRPO on Webshop dataset + +This example shows the usage of GRPO on the Webshop dataset. + +For more detailed information, please refer to the [documentation](../../docs/sphinx_doc/source/tutorial/example_multi_turn.md). + +The config files are located in [`webshop.yaml`](webshop.yaml) and [`train_webshop.yaml`](train_webshop.yaml). diff --git a/scripts/config/train_webshop.yaml b/examples/grpo_webshop/train_webshop.yaml similarity index 100% rename from scripts/config/train_webshop.yaml rename to examples/grpo_webshop/train_webshop.yaml diff --git a/scripts/config/webshop.yaml b/examples/grpo_webshop/webshop.yaml similarity index 94% rename from scripts/config/webshop.yaml rename to examples/grpo_webshop/webshop.yaml index 86e79d8e52..ff5c2c57bf 100644 --- a/scripts/config/webshop.yaml +++ b/examples/grpo_webshop/webshop.yaml @@ -49,7 +49,7 @@ synchronizer: trainer: trainer_type: 'verl' algorithm_type: ppo - trainer_config_path: 'scripts/config/train_webshop.yaml' + trainer_config_path: 'examples/grpo_webshop/train_webshop.yaml' monitor: cache_root_dir: "" project: "WEBSHOP" diff --git a/examples/opmd_gsm8k/README.md b/examples/opmd_gsm8k/README.md new file mode 100644 index 0000000000..4fcd75c2de --- /dev/null +++ b/examples/opmd_gsm8k/README.md @@ -0,0 +1,7 @@ +# Example: OPMD on GSM8K dataset + +This example shows the usage of OPMD on the GSM8K dataset. + +For more detailed information, please refer to the [documentation](../../docs/sphinx_doc/source/tutorial/example_reasoning_advanced.md). + +The config files are located in [`gsm8k_opmd.yaml`](gsm8k_opmd.yaml) and [`train_gsm8k_opmd.yaml`](train_gsm8k_opmd.yaml). diff --git a/scripts/config/gsm8k_opmd.yaml b/examples/opmd_gsm8k/opmd_gsm8k.yaml similarity index 95% rename from scripts/config/gsm8k_opmd.yaml rename to examples/opmd_gsm8k/opmd_gsm8k.yaml index 6dd05a8c2c..5a458b24e8 100644 --- a/scripts/config/gsm8k_opmd.yaml +++ b/examples/opmd_gsm8k/opmd_gsm8k.yaml @@ -46,7 +46,7 @@ synchronizer: trainer: trainer_type: 'verl' algorithm_type: opmd - trainer_config_path: 'scripts/config/train_gsm8k_opmd.yaml' + trainer_config_path: 'examples/opmd_gsm8k/train_opmd_gsm8k.yaml' sft_warmup_iteration: 0 monitor: cache_root_dir: "" diff --git a/scripts/config/train_gsm8k_opmd.yaml b/examples/opmd_gsm8k/train_opmd_gsm8k.yaml similarity index 100% rename from scripts/config/train_gsm8k_opmd.yaml rename to examples/opmd_gsm8k/train_opmd_gsm8k.yaml diff --git a/examples/ppo_countdown/README.md b/examples/ppo_countdown/README.md new file mode 100644 index 0000000000..fa08b375a7 --- /dev/null +++ b/examples/ppo_countdown/README.md @@ -0,0 +1,7 @@ +# Example: PPO on Countdown dataset + +This example shows the usage of PPO on the Countdown dataset. + +For more detailed information, please refer to the [documentation](../../docs/sphinx_doc/source/tutorial/example_reasoning_basic.md). + +The config files are located in [`countdown.yaml`](countdown.yaml) and [`train_countdown.yaml`](train_countdown.yaml). diff --git a/scripts/config/countdown.yaml b/examples/ppo_countdown/countdown.yaml similarity index 94% rename from scripts/config/countdown.yaml rename to examples/ppo_countdown/countdown.yaml index 1dd64b3a6d..a531eebda0 100644 --- a/scripts/config/countdown.yaml +++ b/examples/ppo_countdown/countdown.yaml @@ -49,7 +49,7 @@ synchronizer: trainer: trainer_type: 'verl' algorithm_type: ppo - trainer_config_path: 'scripts/config/train_countdown.yaml' + trainer_config_path: 'examples/ppo_countdown/train_countdown.yaml' sft_warmup_iteration: 0 eval_interval: 1000 monitor: diff --git a/scripts/config/train_countdown.yaml b/examples/ppo_countdown/train_countdown.yaml similarity index 100% rename from scripts/config/train_countdown.yaml rename to examples/ppo_countdown/train_countdown.yaml diff --git a/tests/common/config_test.py b/tests/common/config_test.py index f3febee0d7..3e210df993 100644 --- a/tests/common/config_test.py +++ b/tests/common/config_test.py @@ -22,8 +22,7 @@ def test_load_default_config(self): config.synchronizer.sync_iteration_interval, ) - def test_all_examples_are_valid(self): - # ../../scripts/config + def test_all_examples_are_valid(self): # TODO: useless example_dir = os.path.join(os.path.dirname(__file__), "..", "..", "scripts", "config") for filename in ["countdown", "gsm8k"]: if filename.endswith(".yaml"): diff --git a/trinity/common/config.py b/trinity/common/config.py index bcc5ed99dc..d0026ed1e6 100644 --- a/trinity/common/config.py +++ b/trinity/common/config.py @@ -173,7 +173,7 @@ class ExplorerConfig: class TrainerConfig: trainer_type: str = "verl" trainer_data_type: str = "RFT" - trainer_config_path: str = "scripts/config/train_countdown.yaml" + trainer_config_path: str = "examples/ppo_countdown/train_countdown.yaml" eval_interval: int = 100 enable_preview: bool = True # enable rollout preview in wandb trainer_config: Any = None diff --git a/trinity/data/client.py b/trinity/data/client.py index 0808bb41bd..5c5ca9e530 100644 --- a/trinity/data/client.py +++ b/trinity/data/client.py @@ -34,6 +34,6 @@ def request(url, **kwargs): if __name__ == "__main__": res = request( url=LOCAL_SERVER_URL, - configPath="scripts/config/gsm8k.yaml", + configPath="examples/grpo_gsm8k/gsm8k.yaml", ) print(res)