Skip to content

Commit a9cfd75

Browse files
authored
[FSDP, VLM] feat: true on policy for VLM (#1056)
1 parent aed7a41 commit a9cfd75

File tree

7 files changed

+677
-97
lines changed

7 files changed

+677
-97
lines changed

docker/patch/latest/sglang.patch

Lines changed: 491 additions & 47 deletions
Large diffs are not rendered by default.

examples/geo3k_vlm/run_geo3k_vlm.py

Lines changed: 8 additions & 47 deletions
Original file line numberDiff line numberDiff line change
@@ -1,8 +1,7 @@
1-
import json
21
import os
3-
import subprocess
42

53
import slime.utils.misc as U
4+
from slime.utils.external_utils.command_utils import execute_train, get_default_wandb_args
65

76
MODEL_NAME = os.environ.get("SLIME_SCRIPT_MODEL_NAME", "Qwen3-VL-2B-Instruct")
87
assert MODEL_NAME in {"Qwen2.5-VL-3B-Instruct", "Qwen3-VL-2B-Instruct", "Qwen3-VL-4B-Instruct", "Qwen3-VL-8B-Instruct"}
@@ -12,19 +11,6 @@
1211
MASTER_ADDR = os.environ.get("MASTER_ADDR", "127.0.0.1")
1312

1413

15-
def detect_nvlink():
16-
"""Detect if NVLink is available on the system."""
17-
try:
18-
result = subprocess.run(["nvidia-smi"], capture_output=True, text=True, check=True)
19-
nvlink_count = result.stdout.count("NVLink")
20-
has_nvlink = 1 if nvlink_count > 0 else 0
21-
print(f"HAS_NVLINK: {has_nvlink} (detected {nvlink_count} NVLink references)")
22-
return has_nvlink
23-
except Exception as e:
24-
print(f"Failed to detect NVLink: {e}")
25-
return 0
26-
27-
2814
def prepare():
2915
U.exec_command("mkdir -p /root/models /root/datasets")
3016
U.exec_command(f"hf download Qwen/{MODEL_NAME} --local-dir /root/models/{MODEL_NAME}")
@@ -34,8 +20,6 @@ def prepare():
3420

3521

3622
def execute():
37-
# Detect NVLink for optimized NCCL settings
38-
has_nvlink = detect_nvlink()
3923

4024
ckpt_args = f"--hf-checkpoint /root/models/{MODEL_NAME} "
4125

@@ -57,7 +41,7 @@ def execute():
5741

5842
eval_args = (
5943
# "--eval-interval 20 "
60-
"--eval-prompt-data geo3k-test /root/datasets/geo3k_imgurl/test.parquet "
44+
"--eval-prompt-data geo3k /root/datasets/geo3k_imgurl/test.parquet "
6145
"--n-samples-per-eval-prompt 1 "
6246
"--eval-max-response-len 4096 "
6347
"--eval-top-k 1 "
@@ -100,14 +84,6 @@ def execute():
10084
"--attn-implementation flash_attention_3 "
10185
)
10286

103-
wandb_args = (
104-
"--use-wandb "
105-
"--wandb-project geo3k-vlm "
106-
"--wandb-group geo3k-vlm "
107-
"--wandb-key ${WANDB_API_KEY} "
108-
"--disable-wandb-random-suffix "
109-
)
110-
11187
misc_args = "--actor-num-nodes 1 " f"--actor-num-gpus-per-node {NUM_GPUS} " "--colocate "
11288

11389
# misc_args += (
@@ -139,7 +115,7 @@ def execute():
139115
f"{fsdp_args} "
140116
f"{eval_args} "
141117
f"{misc_args} "
142-
f"{wandb_args} "
118+
f"{get_default_wandb_args(__file__)} "
143119
# f"{true_on_policy_args} "
144120
)
145121

@@ -164,27 +140,12 @@ def execute():
164140
f"ray start --head --node-ip-address {MASTER_ADDR} --num-gpus {NUM_GPUS} "
165141
f"--disable-usage-stats --dashboard-host=0.0.0.0 --dashboard-port=8265"
166142
)
167-
168-
# Prepare runtime environment
169-
runtime_env_json = json.dumps(
170-
{
171-
"env_vars": {
172-
"CUDA_DEVICE_MAX_CONNECTIONS": "1",
173-
"NCCL_NVLS_ENABLE": str(has_nvlink),
174-
# **true_on_policy_envs,
175-
# "SGLANG_DUMPER_ENABLE": "0",
176-
# "SGLANG_TEMP_UTILS_ENABLE_DEBUG_PRINT": "0",
177-
}
178-
}
179-
)
180-
181143
# Submit Ray job
182-
U.exec_command(
183-
f"export no_proxy=127.0.0.1 && export PYTHONBUFFERED=16 && "
184-
f'ray job submit --address="http://127.0.0.1:8265" '
185-
f"--runtime-env-json='{runtime_env_json}' "
186-
f"-- python3 /root/slime/train.py "
187-
f"{train_args}"
144+
execute_train(
145+
train_args=train_args,
146+
num_gpus_per_node=NUM_GPUS,
147+
megatron_model_type=None,
148+
extra_env_vars={},
188149
)
189150

190151

Lines changed: 9 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,9 @@
1+
# True On-Policy between Training and Inference for VLM
2+
3+
This example demonstrates true on-policy training with Qwen3-VL dense model on FSDP. The core concepts and expected observations are the same as [true_on_policy](../true_on_policy/README.md).
4+
5+
## Usage
6+
7+
```bash
8+
python examples/true_on_policy_vlm/run_simple.py
9+
```
Lines changed: 165 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,165 @@
1+
import os
2+
3+
import slime.utils.misc as U
4+
from slime.utils.external_utils.command_utils import execute_train, get_default_wandb_args
5+
6+
MODEL_NAME = os.environ.get("SLIME_SCRIPT_MODEL_NAME", "Qwen3-VL-2B-Instruct")
7+
assert MODEL_NAME in {"Qwen2.5-VL-3B-Instruct", "Qwen3-VL-2B-Instruct", "Qwen3-VL-4B-Instruct", "Qwen3-VL-8B-Instruct"}
8+
9+
NUM_GPUS = int(os.environ.get("SLIME_SCRIPT_NUM_GPUS", "1"))
10+
EXTERNAL_RAY = int(os.environ.get("SLIME_SCRIPT_EXTERNAL_RAY", "0"))
11+
MASTER_ADDR = os.environ.get("MASTER_ADDR", "127.0.0.1")
12+
13+
14+
def prepare():
15+
U.exec_command("mkdir -p /root/models /root/datasets")
16+
U.exec_command(f"hf download Qwen/{MODEL_NAME} --local-dir /root/models/{MODEL_NAME}")
17+
dataset_name = "chenhegu/geo3k_imgurl"
18+
_, partial_name = dataset_name.split("/")
19+
U.exec_command(f"hf download --repo-type dataset {dataset_name} --local-dir /root/datasets/{partial_name}")
20+
21+
22+
def execute():
23+
ckpt_args = f"--hf-checkpoint /root/models/{MODEL_NAME} "
24+
25+
rollout_args = (
26+
"--prompt-data /root/datasets/geo3k_imgurl/train.parquet "
27+
"--input-key problem "
28+
"--label-key answer "
29+
'--multimodal-keys \'{"image": "images"}\' '
30+
"--apply-chat-template "
31+
"--rollout-shuffle "
32+
"--rm-type math "
33+
"--num-rollout 3000 "
34+
"--rollout-batch-size 64 "
35+
"--n-samples-per-prompt 8 "
36+
"--rollout-max-response-len 4096 "
37+
"--rollout-temperature 0.8 "
38+
"--global-batch-size 512 "
39+
)
40+
41+
eval_args = (
42+
# "--eval-interval 20 "
43+
"--eval-prompt-data geo3k /root/datasets/geo3k_imgurl/test.parquet "
44+
"--n-samples-per-eval-prompt 1 "
45+
"--eval-max-response-len 4096 "
46+
"--eval-top-k 1 "
47+
)
48+
49+
grpo_args = (
50+
"--advantage-estimator grpo "
51+
# "--use-kl-loss "
52+
"--kl-loss-coef 0.00 "
53+
"--kl-loss-type low_var_kl "
54+
"--kl-coef 0.00 "
55+
"--entropy-coef 0.00 "
56+
"--eps-clip 0.2 "
57+
"--eps-clip-high 0.28 "
58+
)
59+
60+
optimizer_args = (
61+
"--optimizer adam "
62+
"--lr 1e-6 "
63+
"--lr-decay-style constant "
64+
"--weight-decay 0.1 "
65+
"--adam-beta1 0.9 "
66+
"--adam-beta2 0.98 "
67+
)
68+
69+
sglang_args = (
70+
"--rollout-num-gpus-per-engine 1 "
71+
"--sglang-mem-fraction-static 0.6 "
72+
f"--sglang-cuda-graph-bs {' '.join(map(str, [1, 2, 4, 8] + list(range(16, 257, 8))))} "
73+
)
74+
75+
fsdp_args = (
76+
# Set to true for FULL_STATE_DICT mode, false for SHARDED_STATE_DICT mode (default)
77+
# "--fsdp-full-params " # Uncomment this line to enable full params mode
78+
# Set the bucket size for weight update
79+
"--update-weight-buffer-size 536870912 " # 512MB
80+
"--train-backend fsdp "
81+
"--gradient-checkpointing "
82+
"--sglang-attention-backend fa3 "
83+
"--attn-implementation flash_attention_3 "
84+
)
85+
86+
ci_args = (
87+
"--ci-test "
88+
"--ci-disable-kl-checker "
89+
"--ci-metric-checker-key eval/geo3k "
90+
"--ci-metric-checker-threshold 0.5 " # loose threshold at 60 step
91+
)
92+
93+
misc_args = "--actor-num-nodes 1 " f"--actor-num-gpus-per-node {NUM_GPUS} " "--colocate "
94+
95+
# misc_args += (
96+
# "--use-dynamic-batch-size "
97+
# # TODO pick a good value
98+
# "--max-tokens-per-gpu 2048 "
99+
# )
100+
101+
true_on_policy_args = (
102+
"--sglang-enable-deterministic-inference "
103+
"--sglang-rl-on-policy-target fsdp "
104+
"--deterministic-mode "
105+
"--true-on-policy-mode "
106+
)
107+
true_on_policy_envs = {
108+
# TODO note: "Ring" in original RL PR, "allreduce:tree" in SGLang
109+
# "NCCL_ALGO": "Ring",
110+
"NCCL_ALGO": "allreduce:tree",
111+
"NVTE_ALLOW_NONDETERMINISTIC_ALGO": "0",
112+
"CUBLAS_WORKSPACE_CONFIG": ":4096:8",
113+
"SGLANG_VLM_CACHE_SIZE_MB": "0",
114+
}
115+
116+
train_args = (
117+
f"{ckpt_args} "
118+
f"{rollout_args} "
119+
f"{optimizer_args} "
120+
f"{grpo_args} "
121+
f"{sglang_args} "
122+
f"{fsdp_args} "
123+
f"{ci_args} "
124+
f"{eval_args} "
125+
f"{misc_args} "
126+
f"{get_default_wandb_args(__file__)} "
127+
f"{true_on_policy_args} "
128+
)
129+
130+
# Kill existing processes
131+
U.exec_command(
132+
"pkill -9 sglang; "
133+
"sleep 3; "
134+
f"{'' if EXTERNAL_RAY else 'ray stop --force; '}"
135+
f"{'' if EXTERNAL_RAY else 'pkill -9 ray; '}"
136+
"pkill -9 slime; "
137+
"sleep 3; "
138+
f"{'' if EXTERNAL_RAY else 'pkill -9 ray; '}"
139+
"pkill -9 slime; "
140+
"pkill -9 redis; "
141+
"true; "
142+
)
143+
144+
if not EXTERNAL_RAY:
145+
# Start Ray
146+
U.exec_command(
147+
f"export PYTHONBUFFERED=16 && "
148+
f"ray start --head --node-ip-address {MASTER_ADDR} --num-gpus {NUM_GPUS} "
149+
f"--disable-usage-stats --dashboard-host=0.0.0.0 --dashboard-port=8265"
150+
)
151+
152+
# Submit Ray job
153+
execute_train(
154+
train_args=train_args,
155+
num_gpus_per_node=NUM_GPUS,
156+
megatron_model_type=None,
157+
extra_env_vars={
158+
**true_on_policy_envs,
159+
},
160+
)
161+
162+
163+
if __name__ == "__main__":
164+
prepare()
165+
execute()

slime/backends/fsdp_utils/actor.py

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -147,7 +147,8 @@ def init(self, args: Namespace, role: str, with_ref: bool = False) -> int: # ty
147147
return int(getattr(self.args, "start_rollout_id", 0))
148148

149149
def get_model_cls(self):
150-
if self.args.multimodal_keys:
150+
# Vision models have `vision_config` in the config
151+
if hasattr(self.hf_config, "vision_config"):
151152
from transformers import AutoModelForVision2Seq
152153

153154
return AutoModelForVision2Seq

slime/utils/data.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -73,7 +73,7 @@ def _build_messages(data: dict, prompt_key: str, multimodal_keys: dict = None):
7373
for type_name, data_key in multimodal_keys.items():
7474
mt = MultimodalTypes.get(type_name)
7575
if mt:
76-
multimodals[mt.placeholder] = (mt, data.get(data_key).tolist())
76+
multimodals[mt.placeholder] = (mt, list(data.get(data_key)))
7777

7878
pattern = "(" + "|".join(re.escape(p) for p in multimodals.keys()) + ")"
7979

slime/utils/processing_utils.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -72,5 +72,5 @@ def encode_image_for_rollout_engine(image) -> str:
7272
buffer = io.BytesIO()
7373
if image.mode != "RGB":
7474
image = image.convert("RGB")
75-
image.save(buffer, format="JPEG")
75+
image.save(buffer, format="PNG")
7676
return base64.b64encode(buffer.getvalue()).decode("utf-8")

0 commit comments

Comments
 (0)