Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 2 additions & 2 deletions benchmark/config/alfworld-template.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -6,10 +6,10 @@ checkpoint_root_dir: placeholder
algorithm:
algorithm_type: grpo
repeat_times: 8
loss_agg_mode: "seq-mean-token-sum"
loss_agg_mode: "token-mean"
optimizer:
lr: 1e-6
sample_strategy: warmup
sample_strategy: default
policy_loss_fn: ppo
advantage_fn: grpo
kl_penalty_fn: none
Expand Down
2 changes: 1 addition & 1 deletion benchmark/config/gsm8k-template.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -10,7 +10,7 @@ algorithm:
lr: 1e-5
lr_warmup_steps_ratio: 0.0
warmup_style: constant
sample_strategy: warmup
sample_strategy: default
policy_loss_fn: ppo
advantage_fn: grpo
kl_penalty_fn: none
Expand Down
2 changes: 1 addition & 1 deletion benchmark/reports/alfworld.md
Original file line number Diff line number Diff line change
Expand Up @@ -18,7 +18,7 @@ We evaluate the performance of the following methods in Trinity-RFT framework wi
Since rLLM does not support ALFWorld environment yet, we implement this task in rLLM for comparison.

In Trinity-RFT and rLLM, we respectively evaluate the performance using GRPO algorithm on this task.
We fine-tune a `Qwen2.5-3B-Instruct` model, which has been trained on a SFT dataset, on the training tasks with GRPO and other methods. For all methods, we fix key parameters to `batch_size=32`, `repeat_times=8`, `lr=1e-6`, and `kl_coef=0.001`.
We fine-tune a `Qwen2.5-3B-Instruct` model, which has been trained on a SFT dataset (will be released soon), on the training tasks with GRPO and other methods. For all methods, we fix key parameters to `batch_size=32`, `repeat_times=8`, `lr=1e-6`, and `kl_coef=0.001`.

For better efficiency, we use 64 rollout workers in rLLM and set the `explorer.engine_num` to 4 and `explorer.runner_per_model` to 8 in Trinity-RFT.

Expand Down
22 changes: 22 additions & 0 deletions examples/grpo_alfworld/alfworld.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -56,3 +56,25 @@ trainer:
use_dynamic_bsz: true
max_token_len_per_gpu: 16384
ulysses_sequence_parallel_size: 1
# If needed, uncomment the following lines to enable SFT warmup before RFT
# stages:
# - stage_name: sft_warmup
# mode: train
# algorithm:
# algorithm_type: sft
# optimizer:
# lr: 5e-6
# lr_warmup_steps_ratio: 0.0
# warmup_style: constant
# buffer:
# total_epochs: 1
# train_batch_size: 32
# trainer_input:
# experience_buffer:
# name: sft_warmup_dataset
# storage_type: file
# path: ${oc.env:TRINITY_SFT_DATASET_PATH}
# format:
# prompt_type: messages
# messages_key: 'messages'
# - stage_name: rft
4 changes: 4 additions & 0 deletions examples/grpo_gsm8k/gsm8k.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -75,5 +75,9 @@ trainer:
# trainer_input:
# experience_buffer:
# name: sft_warmup_dataset
# storage_type: file
# path: ${oc.env:TRINITY_SFT_DATASET_PATH}
# format:
# prompt_type: messages
# messages_key: 'messages'
# - stage_name: rft # leave empty to use the original configs for RFT