Skip to content

Commit 1b6dd53

Browse files
committed
[feat] add ppo_train scripts for webshop and alfworld
1 parent 477132f commit 1b6dd53

File tree

2 files changed

+142
-0
lines changed

2 files changed

+142
-0
lines changed
Lines changed: 71 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,71 @@
1+
set -x
2+
ENGINE=${1:-vllm}
3+
export VLLM_ATTENTION_BACKEND=XFORMERS
4+
export WANDB_API_KEY=
5+
export WANDB_BASE_URL=https://api.bandw.top
6+
7+
8+
visible_devices="1,2,3,4"
9+
export CUDA_VISIBLE_DEVICES="$visible_devices"
10+
11+
12+
train_data_size=128
13+
val_data_size=128
14+
15+
python3 -m verl.trainer.main_ppo \
16+
algorithm.adv_estimator=gae \
17+
data.train_files=/data1/user/muxin/verl-agent/text/train.parquet \ # TODO: change to the correct path
18+
data.val_files=/data1/user/muxin/verl-agent/text/test.parquet \ # TODO: change to the correct path
19+
data.train_batch_size=$train_data_size \
20+
data.val_batch_size=$val_data_size \
21+
data.max_prompt_length=2048 \
22+
data.max_response_length=512 \
23+
data.filter_overlong_prompts=True \
24+
data.truncation='error' \
25+
data.return_raw_chat=True \
26+
actor_rollout_ref.model.path=/home/user/models/Qwen/Qwen3-4B \
27+
actor_rollout_ref.actor.optim.lr=1e-6 \
28+
actor_rollout_ref.model.use_remove_padding=True \
29+
actor_rollout_ref.actor.ppo_mini_batch_size=256 \
30+
actor_rollout_ref.actor.ppo_micro_batch_size_per_gpu=4 \
31+
actor_rollout_ref.actor.use_kl_loss=True \
32+
actor_rollout_ref.actor.kl_loss_coef=0.01 \
33+
actor_rollout_ref.actor.kl_loss_type=low_var_kl \
34+
actor_rollout_ref.model.enable_gradient_checkpointing=True \
35+
actor_rollout_ref.actor.fsdp_config.param_offload=False \
36+
actor_rollout_ref.actor.fsdp_config.optimizer_offload=False \
37+
actor_rollout_ref.rollout.log_prob_micro_batch_size_per_gpu=4 \
38+
actor_rollout_ref.rollout.tensor_model_parallel_size=1 \
39+
actor_rollout_ref.rollout.name=$ENGINE \
40+
actor_rollout_ref.rollout.gpu_memory_utilization=0.3 \
41+
actor_rollout_ref.rollout.enable_chunked_prefill=False \
42+
actor_rollout_ref.rollout.enforce_eager=False \
43+
actor_rollout_ref.rollout.free_cache_engine=False \
44+
actor_rollout_ref.rollout.val_kwargs.temperature=0.4 \
45+
actor_rollout_ref.rollout.val_kwargs.do_sample=True \
46+
actor_rollout_ref.ref.log_prob_micro_batch_size_per_gpu=4 \
47+
actor_rollout_ref.ref.fsdp_config.param_offload=True \
48+
actor_rollout_ref.actor.use_invalid_action_penalty=True \
49+
actor_rollout_ref.actor.invalid_action_penalty_coef=0.1 \
50+
critic.optim.lr=1e-5 \
51+
critic.model.use_remove_padding=True \
52+
critic.model.path=/home/user/models/Qwen/Qwen3-4B \
53+
critic.model.enable_gradient_checkpointing=True \
54+
critic.ppo_micro_batch_size_per_gpu=4 \
55+
critic.ppo_mini_batch_size=128 \
56+
critic.model.fsdp_config.param_offload=False \
57+
critic.model.fsdp_config.optimizer_offload=False \
58+
algorithm.use_kl_in_reward=False \
59+
env.env_name=alfworld/AlfredTWEnv \
60+
env.seed=0 \
61+
env.max_steps=50 \
62+
trainer.critic_warmup=0 \
63+
trainer.logger=['console','wandb'] \
64+
trainer.project_name='openmanus-rl_ppo_alfworld' \
65+
trainer.experiment_name='ppo_qwen3_4b' \
66+
trainer.n_gpus_per_node=4 \
67+
trainer.nnodes=1 \
68+
trainer.save_freq=-1 \
69+
trainer.test_freq=5 \
70+
trainer.total_epochs=150 \
71+
trainer.val_before_train=True $@

scripts/ppo_train/train_webshop.sh

Lines changed: 71 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,71 @@
1+
set -x
2+
ENGINE=${1:-vllm}
3+
export VLLM_ATTENTION_BACKEND=XFORMERS
4+
export WANDB_API_KEY=
5+
export WANDB_BASE_URL=https://api.bandw.top
6+
7+
8+
visible_devices="1,2,3,4"
9+
export CUDA_VISIBLE_DEVICES="$visible_devices"
10+
11+
12+
train_data_size=128
13+
val_data_size=128
14+
15+
python3 -m verl.trainer.main_ppo \
16+
algorithm.adv_estimator=gae \
17+
data.train_files= \ # TODO: change to the correct path
18+
data.val_files= \ # TODO: change to the correct path
19+
data.train_batch_size=$train_data_size \
20+
data.val_batch_size=$val_data_size \
21+
data.max_prompt_length=4096 \
22+
data.max_response_length=512 \
23+
data.filter_overlong_prompts=True \
24+
data.truncation='error' \
25+
data.return_raw_chat=True \
26+
actor_rollout_ref.model.path=Qwen/Qwen2.5-1.5B-Instruct \
27+
actor_rollout_ref.actor.optim.lr=1e-6 \
28+
actor_rollout_ref.model.use_remove_padding=True \
29+
actor_rollout_ref.actor.ppo_mini_batch_size=64 \
30+
actor_rollout_ref.actor.ppo_micro_batch_size_per_gpu=8 \
31+
actor_rollout_ref.actor.use_kl_loss=True \
32+
actor_rollout_ref.actor.kl_loss_coef=0.01 \
33+
actor_rollout_ref.actor.kl_loss_type=low_var_kl \
34+
actor_rollout_ref.model.enable_gradient_checkpointing=True \
35+
actor_rollout_ref.actor.fsdp_config.param_offload=False \
36+
actor_rollout_ref.actor.fsdp_config.optimizer_offload=False \
37+
actor_rollout_ref.rollout.log_prob_micro_batch_size_per_gpu=16 \
38+
actor_rollout_ref.rollout.tensor_model_parallel_size=2 \
39+
actor_rollout_ref.rollout.name=$ENGINE \
40+
actor_rollout_ref.rollout.gpu_memory_utilization=0.6 \
41+
actor_rollout_ref.rollout.enable_chunked_prefill=False \
42+
actor_rollout_ref.rollout.enforce_eager=False \
43+
actor_rollout_ref.rollout.free_cache_engine=False \
44+
actor_rollout_ref.rollout.val_kwargs.temperature=0.4 \
45+
actor_rollout_ref.rollout.val_kwargs.do_sample=True \
46+
actor_rollout_ref.ref.log_prob_micro_batch_size_per_gpu=16 \
47+
actor_rollout_ref.ref.fsdp_config.param_offload=True \
48+
actor_rollout_ref.actor.use_invalid_action_penalty=True \
49+
actor_rollout_ref.actor.invalid_action_penalty_coef=0.1 \
50+
critic.optim.lr=1e-5 \
51+
critic.model.use_remove_padding=True \
52+
critic.model.path=Qwen/Qwen2.5-1.5B-Instruct \
53+
critic.model.enable_gradient_checkpointing=True \
54+
critic.ppo_micro_batch_size_per_gpu=4 \
55+
critic.model.fsdp_config.param_offload=False \
56+
critic.model.fsdp_config.optimizer_offload=False \
57+
algorithm.use_kl_in_reward=False \
58+
env.env_name=Webshop \
59+
env.seed=0 \
60+
env.max_steps=15 \
61+
trainer.critic_warmup=0 \
62+
trainer.logger=['console','wandb'] \
63+
trainer.project_name='openmanus-rl_ppo_webshop' \
64+
trainer.experiment_name='ppo_qwen2.5_1.5b' \
65+
trainer.n_gpus_per_node=2 \
66+
trainer.nnodes=1 \
67+
trainer.save_freq=-1 \
68+
trainer.test_freq=5 \
69+
trainer.total_epochs=150 \
70+
trainer.val_before_train=True $@
71+

0 commit comments

Comments
 (0)