Skip to content

Commit b23fcd1

Browse files
authored
[WIP][FSDP] Support FSDP for Qwen3Next (#1116)
1 parent bc0e70a commit b23fcd1

File tree

3 files changed

+481
-0
lines changed

3 files changed

+481
-0
lines changed
Lines changed: 108 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,108 @@
1+
# 8xH100 训练 Qwen3-30B-A3B
2+
3+
## 环境准备
4+
5+
搭建环境、下载模型、数据与 ckpt 转换均与 Qwen3-4B 模型相同,可以参考 [示例:Qwen3-4B](./qwen3-4B.md),将文中 Qwen3-4B 的部分转换为
6+
Qwen3-next-80B-A3B-Instruct 即可。
7+
8+
可以用如下完整方法把 huggingface checkpoint 转化为 torch_dist 格式:
9+
10+
```bash
11+
export BASE_FOLDER=./models/
12+
# 下载模型权重 (Qwen3-Next-80B-A3B-Thinking)
13+
hf download Qwen/Qwen3-Next-80B-A3B-Thinking --local-dir ${BASE_FOLDER}/Qwen3-Next-80B-A3B-Thinking
14+
```
15+
16+
```shell
17+
cd slime/
18+
pip install -e .
19+
20+
# (for acceleration)
21+
cd .. # and find a proper folder
22+
git clone https://github.com/fla-org/flash-linear-attention
23+
cd flash-linear-attention
24+
git checkout 9714c595
25+
pip install -e .
26+
27+
wget https://github.com/Dao-AILab/causal-conv1d/releases/download/v1.5.4/causal_conv1d-1.5.4+cu12torch2.8cxx11abiTRUE-cp312-cp312-linux_x86_64.whl
28+
pip install ./causal_conv1d-1.5.4+cu12torch2.8cxx11abiTRUE-cp312-cp312-linux_x86_64.whl
29+
```
30+
31+
## [Optional] Fix a bug in triton compilation on Blackwell (sm100)
32+
33+
see discussion here https://github.com/triton-lang/triton/issues/8695
34+
and https://github.com/fla-org/flash-linear-attention/issues/638
35+
36+
We need to apply a patch to fix the bug.
37+
Go to the flash-linear-attention folder you just installed, and apply the following patch:
38+
39+
```diff
40+
diff --git a/fla/ops/gated_delta_rule/wy_fast.py b/fla/ops/gated_delta_rule/wy_fast.py
41+
index c5119dcf..838f5e4e 100644
42+
--- a/fla/ops/gated_delta_rule/wy_fast.py
43+
+++ b/fla/ops/gated_delta_rule/wy_fast.py
44+
@@ -198,7 +198,14 @@ def prepare_wy_repr_bwd_kernel(
45+
b_A += tl.dot(b_kb, tl.trans(b_k))
46+
b_dkb = tl.dot(b_dA, b_k)
47+
b_db += tl.sum(b_dkb * b_k, 1)
48+
- b_dk += tl.dot(tl.trans(b_dA), b_kb)
49+
+ b_dk += tl.inline_asm_elementwise(
50+
+ asm="mov.f32 $0, $1;",
51+
+ constraints="=r,r",
52+
+ args=[tl.dot(tl.trans(b_dA), b_kb)],
53+
+ dtype=tl.float32,
54+
+ is_pure=True,
55+
+ pack=1,
56+
+ )
57+
b_dk += b_dkb * b_b[:, None]
58+
tl.store(p_dk, b_dk.to(p_dk.dtype.element_ty), boundary_check=(0, 1))
59+
tl.store(p_db, b_db.to(p_db.dtype.element_ty), boundary_check=(0,))
60+
61+
```
62+
63+
save it as `patch.diff` (Please remember to copy the last empty line to the file!) and do `git apply patch.diff`
64+
65+
## 执行训练 (Megatron)
66+
67+
**当前暂不支持Blackwell**
68+
69+
转换模型权重:
70+
71+
```bash
72+
source scripts/models/qwen3-next-80B-A3B.sh
73+
PYTHONPATH=/root/Megatron-LM/ torchrun --nproc-per-node 8 \
74+
tools/convert_hf_to_torch_dist.py \
75+
${MODEL_ARGS[@]} \
76+
--hf-checkpoint /root/Qwen3-Next-80B-A3B-Thinking/ \
77+
--save /root/Qwen3-Next-80B-A3B-Thinking_torch_dist/
78+
```
79+
80+
单机8卡
81+
82+
```bash
83+
cd /root/slime
84+
export BASE_FOLDER=/root
85+
export MASTER_ADDR=127.0.0.1
86+
bash scripts/run-qwen3-next-80B-A3B-8gpus.sh
87+
```
88+
如果显存不够,考虑disable `--accumulate-allreduce-grads-in-fp32`,enable `--grad-reduce-in-bf16`
89+
90+
91+
多机(4x8)
92+
93+
```bash
94+
cd /root/slime
95+
export BASE_FOLDER=/root
96+
export MASTER_ADDR=your_master_addr
97+
bash scripts/run-qwen3-next-80B-A3B.sh
98+
```
99+
100+
## 执行训练 (FSDP)
101+
102+
```bash
103+
export BASE_FOLDER=./models/
104+
export MASTER_ADDR=127.0.0.1
105+
106+
bash scripts/run-qwen3-next-80B-A3B-fsdp.sh
107+
```
108+
Lines changed: 192 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,192 @@
1+
#!/bin/bash
2+
3+
# for rerun the task
4+
pkill -9 sglang
5+
sleep 3
6+
ray stop --force
7+
pkill -9 ray
8+
pkill -9 python
9+
sleep 3
10+
pkill -9 ray
11+
pkill -9 python
12+
13+
set -ex
14+
15+
# if base folder not set raise error
16+
if [ -z "${BASE_FOLDER}" ]; then
17+
echo "BASE_FOLDER is not set. Please set it to the base directory of your checkpoints."
18+
exit 1
19+
fi
20+
21+
if [ -z "${MASTER_ADDR}" ]; then
22+
echo "MASTER_ADDR is not set. Please set it to the master node address."
23+
exit 1
24+
fi
25+
26+
# will prevent ray from buffering stdout/stderr
27+
export PYTHONBUFFERED=16
28+
29+
NVLINK_COUNT=$(nvidia-smi topo -m 2>/dev/null | grep -o 'NV[0-9][0-9]*' | wc -l)
30+
if [ "$NVLINK_COUNT" -gt 0 ]; then
31+
HAS_NVLINK=1
32+
else
33+
HAS_NVLINK=0
34+
fi
35+
echo "HAS_NVLINK: $HAS_NVLINK (detected $NVLINK_COUNT NVLink references)"
36+
37+
SCRIPT_DIR="$(cd -- "$(dirname -- "${BASH_SOURCE[0]}")" &>/dev/null && pwd)"
38+
source "${SCRIPT_DIR}/models/qwen3-next-80B-A3B.sh"
39+
40+
CKPT_ARGS=(
41+
--hf-checkpoint ${BASE_FOLDER}/Qwen3-Next-80B-A3B-Thinking
42+
--ref-load ${BASE_FOLDER}/Qwen3-Next-80B-A3B-Thinking_torch_dist
43+
--load ${BASE_FOLDER}/Qwen3-Next-80B-A3B-Thinking_slime/
44+
--save ${BASE_FOLDER}/Qwen3-Next-80B-A3B-Thinking_slime/
45+
--save-interval 20
46+
)
47+
48+
ROLLOUT_ARGS=(
49+
--prompt-data ${BASE_FOLDER}/dapo-math-17k/dapo-math-17k.jsonl
50+
--input-key prompt
51+
--label-key label
52+
--apply-chat-template
53+
--rollout-shuffle
54+
--rm-type deepscaler
55+
--num-rollout 300
56+
--rollout-batch-size 16
57+
--n-samples-per-prompt 4
58+
--rollout-max-response-len 8192
59+
--rollout-temperature 0.8
60+
61+
--global-batch-size 64
62+
--balance-data
63+
)
64+
65+
EVAL_ARGS=(
66+
--eval-interval 20
67+
--eval-prompt-data aime ${BASE_FOLDER}/aime-2024/aime-2024.jsonl
68+
--n-samples-per-eval-prompt 2
69+
--eval-max-response-len 16384
70+
--eval-top-p 0.7
71+
)
72+
73+
PERF_ARGS=(
74+
--tensor-model-parallel-size 1
75+
--sequence-parallel
76+
--pipeline-model-parallel-size 6
77+
--context-parallel-size 1
78+
--expert-model-parallel-size 1
79+
--expert-tensor-parallel-size 1
80+
81+
--recompute-granularity full
82+
--recompute-method uniform
83+
--recompute-num-layers 1
84+
85+
# --micro-batch-size 1
86+
--use-dynamic-batch-size
87+
--max-tokens-per-gpu 2048
88+
)
89+
90+
GRPO_ARGS=(
91+
--advantage-estimator gspo
92+
#--use-kl-loss
93+
--kl-loss-coef 0.00
94+
--kl-loss-type low_var_kl
95+
--kl-coef 0.00
96+
--entropy-coef 0.00
97+
--eps-clip 4e-4
98+
)
99+
100+
OPTIMIZER_ARGS=(
101+
--optimizer adam
102+
--lr 1e-6
103+
--lr-decay-style constant
104+
--weight-decay 0.1
105+
--adam-beta1 0.9
106+
--adam-beta2 0.98
107+
108+
--optimizer-cpu-offload
109+
--overlap-cpu-optimizer-d2h-h2d
110+
--use-precision-aware-optimizer
111+
)
112+
113+
WANDB_ARGS=(
114+
# --use-wandb
115+
# --wandb-project slime-dev
116+
# --wandb-group qwen3-next-80B-A3B-test
117+
# --wandb-key ${WANDB_KEY}
118+
)
119+
120+
SGLANG_ARGS=(
121+
--rollout-num-gpus-per-engine 2
122+
--rollout-num-gpus 2
123+
--sglang-mem-fraction-static 0.8
124+
--sglang-ep-size 1
125+
126+
--sglang-cuda-graph-bs 1 2 4 8 $(seq 16 8 128)
127+
128+
# mtp
129+
# --sglang-speculative-algorithm EAGLE
130+
# --sglang-speculative-num-steps 2
131+
# --sglang-speculative-eagle-topk 1
132+
# --sglang-speculative-num-draft-tokens 3
133+
# --sglang-enable-draft-weights-cpu-backup
134+
#
135+
# --sglang-max-running-requests 512
136+
)
137+
138+
MISC_ARGS=(
139+
# default dropout in megatron is 0.1
140+
--attention-dropout 0.0
141+
--hidden-dropout 0.0
142+
# should be good for model performance
143+
--accumulate-allreduce-grads-in-fp32
144+
# --grad-reduce-in-bf16
145+
--attention-softmax-in-fp32
146+
# need to comment this when using model with MLA
147+
--attention-backend flash
148+
149+
--moe-token-dispatcher-type alltoall
150+
# --moe-enable-deepep
151+
# --debug-rollout-only
152+
)
153+
154+
# launch the master node of ray in container
155+
export no_proxy="127.0.0.1,${MASTER_ADDR}"
156+
ray start --head --node-ip-address ${MASTER_ADDR} --num-gpus 8 --disable-usage-stats --dashboard-host=0.0.0.0 --dashboard-port=8265
157+
for WORKER_IP in $(awk '{print $1}' /root/mpi_rack_hostfile); do
158+
if [[ "$WORKER_IP" == "$MLP_WORKER_0_HOST" ]]; then
159+
continue
160+
fi
161+
echo "Starting Ray worker on ${WORKER_IP}"
162+
ssh root@"${WORKER_IP}" \
163+
"pkill -9 sglang ; ray stop --force ; pkill -9 python ; ray start --address=${MASTER_ADDR}:6379 --num-gpus 8 --node-ip-address ${WORKER_IP} --disable-usage-stats --dashboard-host=0.0.0.0 --dashboard-port=8265" &
164+
done
165+
wait
166+
167+
# Build the runtime environment JSON with proper variable substitution
168+
RUNTIME_ENV_JSON="{
169+
\"env_vars\": {
170+
\"PYTHONPATH\": \"/root/Megatron-LM/\",
171+
\"CUDA_DEVICE_MAX_CONNECTIONS\": \"1\",
172+
\"NCCL_NVLS_ENABLE\": \"${HAS_NVLINK}\",
173+
\"no_proxy\": \"${no_proxy}\",
174+
\"MASTER_ADDR\": \"${MASTER_ADDR}\"
175+
}
176+
}"
177+
178+
ray job submit --address="http://127.0.0.1:8265" \
179+
--runtime-env-json="${RUNTIME_ENV_JSON}" \
180+
-- python3 train.py \
181+
--actor-num-nodes 1 \
182+
--actor-num-gpus-per-node 6 \
183+
${MODEL_ARGS[@]} \
184+
${CKPT_ARGS[@]} \
185+
${ROLLOUT_ARGS[@]} \
186+
${OPTIMIZER_ARGS[@]} \
187+
${GRPO_ARGS[@]} \
188+
${WANDB_ARGS[@]} \
189+
${PERF_ARGS[@]} \
190+
${EVAL_ARGS[@]} \
191+
${SGLANG_ARGS[@]} \
192+
${MISC_ARGS[@]}

0 commit comments

Comments
 (0)