Skip to content

Commit 5f9d5dd

Browse files
authored
Create run_ppo_test.sh
Signed-off-by: erfgss <97771661+erfgss@users.noreply.github.com>
1 parent f257544 commit 5f9d5dd

File tree

1 file changed

+101
-0
lines changed

1 file changed

+101
-0
lines changed

.buildkite/scripts/run_ppo_test.sh

Lines changed: 101 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,101 @@
1+
#!/bin/bash
2+
# SPDX-License-Identifier: Apache-2.0
3+
# Setup Verl + vLLM environment, run GSM8K Qwen0.5B ppo example, then test with vLLM
4+
5+
set -euo pipefail
6+
SCRIPT_DIR="$(cd "$(dirname "${BASH_SOURCE[0]}")" && pwd)"
7+
REPO_ROOT="$(cd "${SCRIPT_DIR}/../.." && pwd)"
8+
9+
VERL_REPO="https://github.com/volcengine/verl.git"
10+
VERL_BRANCH="main"
11+
VERL_DIR="${REPO_ROOT}/verl"
12+
TARGET_DIR="${VERL_DIR}/examples/data_preprocess"
13+
MODEL_ID=Qwen/Qwen2.5-0.5B-Instruct
14+
MODEL_DIR="${VERL_DIR}/models/Qwen2.5-0.5B-Instruct"
15+
train_epochs=2
16+
data_dir="${VERL_DIR}/gsm8k"
17+
n_gpus_per_node=8
18+
nnodes=1
19+
20+
echo "VERL_REPO=${VERL_REPO}"
21+
echo "VERL_BRANCH=${VERL_BRANCH}"
22+
echo "VERL_DIR=${VERL_DIR}"
23+
echo "TARGET_DIR=${TARGET_DIR}"
24+
echo "MODEL_ID=${MODEL_ID}"
25+
echo "MODEL_DIR=${MODEL_DIR}"
26+
echo "train_epochs=${train_epochs}"
27+
echo "data_dir=${data_dir}"
28+
echo "n_gpus_per_node=${n_gpus_per_node}"
29+
echo "nnodes=${nnodes}"
30+
31+
echo "===== Setting up Verl environment ====="
32+
33+
if [ -d "${VERL_DIR}" ]; then
34+
echo "Verl exists, skip clone"
35+
else
36+
git clone --branch "${VERL_BRANCH}" --single-branch "${VERL_REPO}" "${VERL_DIR}"
37+
fi
38+
39+
echo "Entering ${VERL_DIR} ..."
40+
cd "${VERL_DIR}"
41+
uv pip install --no-deps -e .
42+
uv pip install -e .[vllm]
43+
44+
echo "Entering ${TARGET_DIR} ..."
45+
cd "${TARGET_DIR}"
46+
echo "Running gsm8k.py "
47+
python3 gsm8k.py --local_save_dir "${data_dir}";
48+
49+
echo "===== gsm8k.py preprocessing completed! ====="
50+
51+
echo "===== Downloading model: ${MODEL_ID} ====="
52+
echo "Target directory: ${MODEL_DIR}"
53+
huggingface-cli download "${MODEL_ID}" --resume-download --local-dir "${MODEL_DIR}"
54+
echo "===== Downloading model: ${MODEL_ID} completed! ====="
55+
echo "===== Starting PPO Training ====="
56+
python3 -m verl.trainer.main_ppo \
57+
data.train_files="${data_dir}/train.parquet" \
58+
data.val_files="${data_dir}/train.parquet" \
59+
data.train_batch_size=256 \
60+
data.max_prompt_length=512 \
61+
data.max_response_length=512 \
62+
actor_rollout_ref.model.path="${MODEL_DIR}" \
63+
actor_rollout_ref.actor.optim.lr=1e-6 \
64+
actor_rollout_ref.actor.ppo_mini_batch_size=64 \
65+
actor_rollout_ref.actor.ppo_micro_batch_size_per_gpu=4 \
66+
actor_rollout_ref.rollout.name=vllm \
67+
actor_rollout_ref.rollout.log_prob_micro_batch_size_per_gpu=8 \
68+
actor_rollout_ref.rollout.tensor_model_parallel_size=1 \
69+
actor_rollout_ref.rollout.gpu_memory_utilization=0.4 \
70+
actor_rollout_ref.ref.log_prob_micro_batch_size_per_gpu=4 \
71+
critic.optim.lr=1e-5 \
72+
critic.model.path="${MODEL_DIR}" \
73+
critic.ppo_micro_batch_size_per_gpu=4 \
74+
algorithm.kl_ctrl.kl_coef=0.001 \
75+
trainer.logger=tensorboard \
76+
trainer.val_before_train=False \
77+
trainer.n_gpus_per_node="${n_gpus_per_node}" \
78+
trainer.nnodes="${nnodes}" \
79+
trainer.save_freq=10 \
80+
trainer.test_freq=10 \
81+
trainer.total_epochs="${train_epochs}"
82+
echo "===== End PPO Training ====="
83+
echo "===== Model Restoration ====="
84+
# steps_per_epoch = 7473 samples(GSM8K: ~7473 samples) / 256 global batch size ≈ 29
85+
step=$((29 * train_epochs))
86+
merge_LOCAL_DIR="${TARGET_DIR}/checkpoints/verl_examples/gsm8k/global_step_${step}/actor"
87+
merge_TARGET_DIR="${TARGET_DIR}/checkpoints/verl_examples/gsm8k/global_step_${step}/actor_hf"
88+
89+
python "${VERL_DIR}/scripts/legacy_model_merger.py" merge \
90+
--backend fsdp \
91+
--local_dir "${merge_LOCAL_DIR}" \
92+
--target_dir "${merge_TARGET_DIR}"
93+
94+
CUDA_VISIBLE_DEVICES=0,1,2,3 lm_eval --model hf \
95+
--model_args pretrained="${merge_TARGET_DIR}",trust_remote_code=True \
96+
--tasks gsm8k \
97+
--batch_size auto \
98+
--apply_chat_template True \
99+
--output_path results_ppo.json
100+
101+
echo "=====Test completed! ====="

0 commit comments

Comments
 (0)