Skip to content

Commit 9b50fb7

Browse files
authored
[model] feat: add qwen3-4b grpo script on ASCEND NPU A3 (verl-project#4432)
### What does this PR do? add examples/grpo_trainer/run_qwen3-4b_npu.sh ### Test The figure below shows the comparison curve of the critic_reward_mean metric. <img width="1790" height="948" alt="image" src="https://github.com/user-attachments/assets/01df9bed-f888-470d-936c-eb335acd57e9" /> ### API and Usage Example ```sh # install jemalloc sudo apt update sudo apt install libjemalloc2 # run bash bash examples/grpo_trainer/run_qwen3-4b_npu.sh ```
1 parent 615aa67 commit 9b50fb7

File tree

3 files changed

+85
-0
lines changed

3 files changed

+85
-0
lines changed

docs/ascend_tutorial/ascend_quick_start.rst

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -235,6 +235,8 @@ verl 中昇腾暂不支持生态库如下:
235235
+-----------------------+-------------------------+-------------------+-------------------+--------------------------+
236236
| GRPO | Qwen2.5-VL-32B-instruct | FSDP | vllm-ascend | Atlas 200T A2 Box16 |
237237
+-----------------------+-------------------------+-------------------+-------------------+--------------------------+
238+
| GRPO | Qwen3-4B | FSDP | vllm-ascend | Atlas 800T A3 |
239+
+-----------------------+-------------------------+-------------------+-------------------+--------------------------+
238240
| GRPO | Qwen3-8B | FSDP | vllm-ascend | Atlas 200T A2 Box16 |
239241
+-----------------------+-------------------------+-------------------+-------------------+--------------------------+
240242
| GRPO | Qwen3-32B | FSDP | vllm-ascend | Atlas 200T A2 Box16 |
Lines changed: 82 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,82 @@
1+
set -xeuo pipefail
2+
source /usr/local/Ascend/ascend-toolkit/set_env.sh
3+
source /usr/local/Ascend/nnal/atb/set_env.sh
4+
5+
# 使用v1引擎
6+
export VLLM_USE_V1=1
7+
# 指定vllm 版本
8+
export VLLM_VERSION=0.9.1
9+
10+
# 开启二级流水
11+
export TASK_QUEUE_ENABLE=2
12+
# 开启细绑核
13+
export CPU_AFFINITY_CONF=1
14+
# 使用jemalloc优化内存访问(依赖安装jemalloc)
15+
export LD_PRELOAD="/usr/lib/aarch64-linux-gnu/libjemalloc.so.2${LD_PRELOAD:+:$LD_PRELOAD}"
16+
17+
# A3 机器单机8卡
18+
trainer_n_gpus_per_node=16
19+
trainer_nnodes=1
20+
trainer_project_name='verl_grpo_example_gsm8k'
21+
trainer_experiment_name="qwen3_4b_grpo_8npu}"
22+
23+
RAY_DATA_HOME=${RAY_DATA_HOME:-"${HOME}/verl"}
24+
MODEL_PATH=${MODEL_PATH:-"${RAY_DATA_HOME}/models/Qwen3-4B"}
25+
CKPTS_DIR=${CKPTS_DIR:-"${RAY_DATA_HOME}/ckpts/${trainer_project_name}/${trainer_experiment_name}"}
26+
TRAIN_FILE=${TRAIN_FILE:-"${RAY_DATA_HOME}/data/gsm8k/train.parquet"}
27+
TEST_FILE=${TEST_FILE:-"${RAY_DATA_HOME}/data/gsm8k/test.parquet"}
28+
29+
export TENSORBOARD_DIR="${RAY_DATA_HOME}/tensorboard_dir/${trainer_project_name}/${trainer_experiment_name}"
30+
mkdir -p "${RAY_DATA_HOME}/logs/${trainer_project_name}"
31+
LOG_PATH="${RAY_DATA_HOME}/logs/${trainer_project_name}/${trainer_experiment_name}.log"
32+
33+
use_dynamic_bsz=True
34+
35+
python3 -m verl.trainer.main_ppo \
36+
algorithm.adv_estimator=grpo \
37+
data.train_files=${TRAIN_FILE} \
38+
data.val_files=${TEST_FILE} \
39+
data.train_batch_size=512 \
40+
data.max_prompt_length=1024 \
41+
data.max_response_length=1024 \
42+
data.filter_overlong_prompts=True \
43+
data.truncation='error' \
44+
actor_rollout_ref.model.path=${MODEL_PATH} \
45+
actor_rollout_ref.actor.optim.lr=5e-7 \
46+
actor_rollout_ref.model.use_remove_padding=True \
47+
actor_rollout_ref.actor.entropy_coeff=0.001 \
48+
actor_rollout_ref.actor.ppo_mini_batch_size=256 \
49+
actor_rollout_ref.actor.use_kl_loss=True \
50+
actor_rollout_ref.actor.kl_loss_coef=0.001 \
51+
actor_rollout_ref.actor.kl_loss_type=low_var_kl \
52+
actor_rollout_ref.actor.use_torch_compile=False \
53+
actor_rollout_ref.actor.use_dynamic_bsz=${use_dynamic_bsz} \
54+
actor_rollout_ref.actor.ppo_max_token_len_per_gpu=3000 \
55+
actor_rollout_ref.model.enable_gradient_checkpointing=True \
56+
actor_rollout_ref.actor.fsdp_config.param_offload=True \
57+
actor_rollout_ref.actor.fsdp_config.optimizer_offload=True \
58+
actor_rollout_ref.rollout.enforce_eager=True \
59+
actor_rollout_ref.rollout.log_prob_use_dynamic_bsz=${use_dynamic_bsz} \
60+
actor_rollout_ref.rollout.log_prob_max_token_len_per_gpu=4096 \
61+
actor_rollout_ref.rollout.enable_chunked_prefill=False \
62+
actor_rollout_ref.rollout.tensor_model_parallel_size=2 \
63+
actor_rollout_ref.rollout.name=vllm \
64+
actor_rollout_ref.rollout.gpu_memory_utilization=0.9 \
65+
actor_rollout_ref.rollout.n=5 \
66+
actor_rollout_ref.ref.log_prob_use_dynamic_bsz=${use_dynamic_bsz} \
67+
actor_rollout_ref.ref.log_prob_max_token_len_per_gpu=8192 \
68+
actor_rollout_ref.ref.fsdp_config.param_offload=True \
69+
actor_rollout_ref.ref.use_torch_compile=True \
70+
algorithm.kl_ctrl.kl_coef=0.001 \
71+
trainer.critic_warmup=0 \
72+
trainer.project_name=${trainer_project_name} \
73+
trainer.experiment_name=${trainer_experiment_name} \
74+
trainer.logger=['console','tensorboard'] \
75+
trainer.default_local_dir=${CKPTS_DIR} \
76+
trainer.n_gpus_per_node=$trainer_n_gpus_per_node \
77+
trainer.nnodes=$trainer_nnodes \
78+
trainer.save_freq=-1 \
79+
trainer.test_freq=5 \
80+
trainer.total_epochs=15 \
81+
trainer.val_before_train=False \
82+
trainer.device=npu 2>&1 | tee ${LOG_PATH}

verl/models/transformers/npu_patch.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -239,6 +239,7 @@ def forward(self, hidden_states: torch.Tensor) -> torch.Tensor:
239239
# Patches for Qwen3 Model
240240
modeling_qwen3.Qwen3RMSNorm.forward = rms_norm_forward_npu
241241
modeling_qwen3.Qwen3MLP.forward = silu_forward_npu
242+
modeling_qwen3.apply_rotary_pos_emb = apply_rotary_pos_emb_npu
242243

243244
# Patches for Qwen3 MoE Model
244245
modeling_qwen3_moe.Qwen3MoeRMSNorm.forward = rms_norm_forward_npu

0 commit comments

Comments
 (0)