Skip to content

Commit 0a72680

Browse files
authored
Fix sft warmup yaml (#435)
1 parent f3f0846 commit 0a72680

File tree

5 files changed

+30
-4
lines changed

5 files changed

+30
-4
lines changed

benchmark/config/alfworld-template.yaml

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -6,10 +6,10 @@ checkpoint_root_dir: placeholder
66
algorithm:
77
algorithm_type: grpo
88
repeat_times: 8
9-
loss_agg_mode: "seq-mean-token-sum"
9+
loss_agg_mode: "token-mean"
1010
optimizer:
1111
lr: 1e-6
12-
sample_strategy: warmup
12+
sample_strategy: default
1313
policy_loss_fn: ppo
1414
advantage_fn: grpo
1515
kl_penalty_fn: none

benchmark/config/gsm8k-template.yaml

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -10,7 +10,7 @@ algorithm:
1010
lr: 1e-5
1111
lr_warmup_steps_ratio: 0.0
1212
warmup_style: constant
13-
sample_strategy: warmup
13+
sample_strategy: default
1414
policy_loss_fn: ppo
1515
advantage_fn: grpo
1616
kl_penalty_fn: none

benchmark/reports/alfworld.md

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -18,7 +18,7 @@ We evaluate the performance of the following methods in Trinity-RFT framework wi
1818
Since rLLM does not support ALFWorld environment yet, we implement this task in rLLM for comparison.
1919

2020
In Trinity-RFT and rLLM, we respectively evaluate the performance using GRPO algorithm on this task.
21-
We fine-tune a `Qwen2.5-3B-Instruct` model, which has been trained on a SFT dataset, on the training tasks with GRPO and other methods. For all methods, we fix key parameters to `batch_size=32`, `repeat_times=8`, `lr=1e-6`, and `kl_coef=0.001`.
21+
We fine-tune a `Qwen2.5-3B-Instruct` model, which has been trained on a SFT dataset (will be released soon), on the training tasks with GRPO and other methods. For all methods, we fix key parameters to `batch_size=32`, `repeat_times=8`, `lr=1e-6`, and `kl_coef=0.001`.
2222

2323
For better efficiency, we use 64 rollout workers in rLLM and set the `explorer.engine_num` to 4 and `explorer.runner_per_model` to 8 in Trinity-RFT.
2424

examples/grpo_alfworld/alfworld.yaml

Lines changed: 22 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -56,3 +56,25 @@ trainer:
5656
use_dynamic_bsz: true
5757
max_token_len_per_gpu: 16384
5858
ulysses_sequence_parallel_size: 1
59+
# If needed, uncomment the following lines to enable SFT warmup before RFT
60+
# stages:
61+
# - stage_name: sft_warmup
62+
# mode: train
63+
# algorithm:
64+
# algorithm_type: sft
65+
# optimizer:
66+
# lr: 5e-6
67+
# lr_warmup_steps_ratio: 0.0
68+
# warmup_style: constant
69+
# buffer:
70+
# total_epochs: 1
71+
# train_batch_size: 32
72+
# trainer_input:
73+
# experience_buffer:
74+
# name: sft_warmup_dataset
75+
# storage_type: file
76+
# path: ${oc.env:TRINITY_SFT_DATASET_PATH}
77+
# format:
78+
# prompt_type: messages
79+
# messages_key: 'messages'
80+
# - stage_name: rft

examples/grpo_gsm8k/gsm8k.yaml

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -75,5 +75,9 @@ trainer:
7575
# trainer_input:
7676
# experience_buffer:
7777
# name: sft_warmup_dataset
78+
# storage_type: file
7879
# path: ${oc.env:TRINITY_SFT_DATASET_PATH}
80+
# format:
81+
# prompt_type: messages
82+
# messages_key: 'messages'
7983
# - stage_name: rft # leave empty to use the original configs for RFT

0 commit comments

Comments
 (0)