Skip to content
Merged
Show file tree
Hide file tree
Changes from 1 commit
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 2 additions & 2 deletions benchmark/config/alfworld-template.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -6,10 +6,10 @@ checkpoint_root_dir: placeholder
algorithm:
algorithm_type: grpo
repeat_times: 8
loss_agg_mode: "seq-mean-token-sum"
loss_agg_mode: "token-mean"
optimizer:
lr: 1e-6
sample_strategy: warmup
sample_strategy: default
policy_loss_fn: ppo
advantage_fn: grpo
kl_penalty_fn: none
Expand Down
2 changes: 1 addition & 1 deletion benchmark/config/gsm8k-template.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -10,7 +10,7 @@ algorithm:
lr: 1e-5
lr_warmup_steps_ratio: 0.0
warmup_style: constant
sample_strategy: warmup
sample_strategy: default
policy_loss_fn: ppo
advantage_fn: grpo
kl_penalty_fn: none
Expand Down
2 changes: 1 addition & 1 deletion benchmark/reports/alfworld.md
Original file line number Diff line number Diff line change
Expand Up @@ -18,7 +18,7 @@ We evaluate the performance of the following methods in Trinity-RFT framework wi
Since rLLM does not support ALFWorld environment yet, we implement this task in rLLM for comparison.

In Trinity-RFT and rLLM, we respectively evaluate the performance using GRPO algorithm on this task.
We fine-tune a `Qwen2.5-3B-Instruct` model, which has been trained on a SFT dataset, on the training tasks with GRPO and other methods. For all methods, we fix key parameters to `batch_size=32`, `repeat_times=8`, `lr=1e-6`, and `kl_coef=0.001`.
We fine-tune a `Qwen2.5-3B-Instruct` model, which has been trained on a SFT dataset (will be released soon), on the training tasks with GRPO and other methods. For all methods, we fix key parameters to `batch_size=32`, `repeat_times=8`, `lr=1e-6`, and `kl_coef=0.001`.

For better efficiency, we use 64 rollout workers in rLLM and set the `explorer.engine_num` to 4 and `explorer.runner_per_model` to 8 in Trinity-RFT.

Expand Down
22 changes: 22 additions & 0 deletions examples/grpo_alfworld/alfworld.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -56,3 +56,25 @@ trainer:
use_dynamic_bsz: true
max_token_len_per_gpu: 16384
ulysses_sequence_parallel_size: 1
# If needed, uncomment the following lines to enable SFT warmup before RFT
# stages:
# - stage_name: sft_warmup
# mode: train
# algorithm:
# algorithm_type: sft
# optimizer:
# lr: 5e-6
# lr_warmup_steps_ratio: 0.0
# warmup_style: constant
# buffer:
# total_epochs: 1
# train_batch_size: 32
# trainer_input:
# experience_buffer:
# name: sft_warmup_dataset
# storage_type: file
# path: ${oc.env:TRINITY_SFT_DATASET_PATH}
# format:
# prompt_type: messages
# messages_key: 'messages'
# - stage_name: rft
1 change: 1 addition & 0 deletions examples/grpo_gsm8k/gsm8k.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -75,5 +75,6 @@ trainer:
# trainer_input:
# experience_buffer:
# name: sft_warmup_dataset
# storage_type: file
# path: ${oc.env:TRINITY_SFT_DATASET_PATH}
# - stage_name: rft # leave empty to use the original configs for RFT