diff --git a/benchmark/config/alfworld-template.yaml b/benchmark/config/alfworld-template.yaml index edd0626624..be62a1bc0b 100644 --- a/benchmark/config/alfworld-template.yaml +++ b/benchmark/config/alfworld-template.yaml @@ -6,10 +6,10 @@ checkpoint_root_dir: placeholder algorithm: algorithm_type: grpo repeat_times: 8 - loss_agg_mode: "seq-mean-token-sum" + loss_agg_mode: "token-mean" optimizer: lr: 1e-6 - sample_strategy: warmup + sample_strategy: default policy_loss_fn: ppo advantage_fn: grpo kl_penalty_fn: none diff --git a/benchmark/config/gsm8k-template.yaml b/benchmark/config/gsm8k-template.yaml index a967589fe9..35db1831e4 100644 --- a/benchmark/config/gsm8k-template.yaml +++ b/benchmark/config/gsm8k-template.yaml @@ -10,7 +10,7 @@ algorithm: lr: 1e-5 lr_warmup_steps_ratio: 0.0 warmup_style: constant - sample_strategy: warmup + sample_strategy: default policy_loss_fn: ppo advantage_fn: grpo kl_penalty_fn: none diff --git a/benchmark/reports/alfworld.md b/benchmark/reports/alfworld.md index e6663478f7..e9328662a6 100644 --- a/benchmark/reports/alfworld.md +++ b/benchmark/reports/alfworld.md @@ -18,7 +18,7 @@ We evaluate the performance of the following methods in Trinity-RFT framework wi Since rLLM does not support ALFWorld environment yet, we implement this task in rLLM for comparison. In Trinity-RFT and rLLM, we respectively evaluate the performance using GRPO algorithm on this task. -We fine-tune a `Qwen2.5-3B-Instruct` model, which has been trained on a SFT dataset, on the training tasks with GRPO and other methods. For all methods, we fix key parameters to `batch_size=32`, `repeat_times=8`, `lr=1e-6`, and `kl_coef=0.001`. +We fine-tune a `Qwen2.5-3B-Instruct` model, which has been trained on a SFT dataset (will be released soon), on the training tasks with GRPO and other methods. For all methods, we fix key parameters to `batch_size=32`, `repeat_times=8`, `lr=1e-6`, and `kl_coef=0.001`. For better efficiency, we use 64 rollout workers in rLLM and set the `explorer.engine_num` to 4 and `explorer.runner_per_model` to 8 in Trinity-RFT. diff --git a/examples/grpo_alfworld/alfworld.yaml b/examples/grpo_alfworld/alfworld.yaml index 77ba65d555..32a491e1d7 100644 --- a/examples/grpo_alfworld/alfworld.yaml +++ b/examples/grpo_alfworld/alfworld.yaml @@ -56,3 +56,25 @@ trainer: use_dynamic_bsz: true max_token_len_per_gpu: 16384 ulysses_sequence_parallel_size: 1 +# If needed, uncomment the following lines to enable SFT warmup before RFT +# stages: +# - stage_name: sft_warmup +# mode: train +# algorithm: +# algorithm_type: sft +# optimizer: +# lr: 5e-6 +# lr_warmup_steps_ratio: 0.0 +# warmup_style: constant +# buffer: +# total_epochs: 1 +# train_batch_size: 32 +# trainer_input: +# experience_buffer: +# name: sft_warmup_dataset +# storage_type: file +# path: ${oc.env:TRINITY_SFT_DATASET_PATH} +# format: +# prompt_type: messages +# messages_key: 'messages' +# - stage_name: rft diff --git a/examples/grpo_gsm8k/gsm8k.yaml b/examples/grpo_gsm8k/gsm8k.yaml index b0640f089c..150cc68497 100644 --- a/examples/grpo_gsm8k/gsm8k.yaml +++ b/examples/grpo_gsm8k/gsm8k.yaml @@ -75,5 +75,9 @@ trainer: # trainer_input: # experience_buffer: # name: sft_warmup_dataset +# storage_type: file # path: ${oc.env:TRINITY_SFT_DATASET_PATH} +# format: +# prompt_type: messages +# messages_key: 'messages' # - stage_name: rft # leave empty to use the original configs for RFT