Skip to content

Commit 187fde9

Browse files
authored
feat(robo2vlm_sft): support qwen3 vl model sft in rlinf (RLinf#781)
Signed-off-by: FxxxxU <fu18801374388@163.com>
1 parent 08330be commit 187fde9

File tree

13 files changed

+354
-55
lines changed

13 files changed

+354
-55
lines changed

.github/workflows/sft-e2e-tests.yml

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -48,6 +48,7 @@ jobs:
4848
export UV_PYTHON_INSTALL_DIR=/workspace/dataset/.uv_python
4949
export MEGATRON_PATH=/workspace/dataset/Megatron-LM
5050
bash requirements/install.sh agentic
51+
uv pip install transformers==4.57.1
5152
5253
- name: SFT Robo2vlm train test
5354
timeout-minutes: 20

docs/source-en/rst_source/examples/embodied/sft_vlm.rst

Lines changed: 9 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -1,16 +1,16 @@
1-
VLM Supervised Fine-Tuning (SFT)
1+
VLM Supervised Fine-Tuning
22
================================
33

44
This document explains how to run **full-parameter supervised fine-tuning (Full-parameter SFT)** for VLM models in RLinf.
55

66
This tutorial mainly focuses on two files:
77

88
- Launch script: ``examples/sft/run_vlm_sft.sh``
9-
- Training config: ``examples/sft/config/custom_sft_vlm.yaml``
9+
- Training config: ``examples/sft/config/qwen2_5_sft_vlm.yaml``
1010

1111
Launch Script: ``examples/sft/run_vlm_sft.sh``
1212

13-
- The script uses ``examples/sft/config/custom_sft_vlm.yaml`` by default.
13+
- The script uses ``examples/sft/config/qwen2_5_sft_vlm.yaml`` by default.
1414
- Logs are redirected to: ``<repo>/logs/<timestamp>/``
1515
- Actual command:
1616

@@ -21,7 +21,7 @@ Launch Script: ``examples/sft/run_vlm_sft.sh``
2121
--config-name <your_config_name> \
2222
runner.logger.log_path=<auto_generated_log_dir>
2323
24-
Config Template: ``examples/sft/config/custom_sft_vlm.yaml``
24+
Config Template: ``examples/sft/config/qwen2_5_sft_vlm.yaml``
2525

2626
The VLM config structure is similar to other RLinf training configs.
2727
You mainly need to adapt ``data`` and ``actor.model`` for your VLM use case.
@@ -35,11 +35,11 @@ Preparation Before Running
3535
``https://huggingface.co/Qwen/Qwen2.5-VL-3B-Instruct``.
3636
3. Prepare Robo2VLM dataset:
3737
``https://huggingface.co/datasets/keplerccc/Robo2VLM-1``.
38-
4. Edit ``examples/sft/config/custom_sft_vlm.yaml`` and run
38+
4. Edit ``examples/sft/config/qwen2_5_sft_vlm.yaml`` and run
3939
``examples/sft/run_vlm_sft.sh``.
4040

41-
Example YAML
42-
------------
41+
Example of Qwen2_5_VL_4B SFT
42+
----------------------------
4343

4444
Important note: after downloading Robo2VLM, train and eval parquet files are mixed in one directory
4545
(e.g., ``train-00000-of-00262.parquet`` and ``test-0000X-of-00003.parquet``).
@@ -153,7 +153,7 @@ Run from repository root:
153153
154154
Notes:
155155

156-
- If no argument is provided, the script uses ``custom_sft_vlm`` by default.
156+
- If no argument is provided, the script uses ``qwen2_5_sft_vlm`` by default.
157157
- If your config name is different (e.g., ``my_vlm_config.yaml``), pass it as an argument:
158158

159159
.. code:: bash
@@ -230,7 +230,7 @@ Update these fields first:
230230
- ``convertor.ckpt_path``: path to ``full_weights.pt``
231231
- ``convertor.save_path``: output HF model directory
232232
- ``model.model_path``: base model path
233-
- ``model.model_type``: model type (e.g., ``qwen2.5_vl``)
233+
- ``model.model_type``: model type (e.g., ``qwen2.5_vl`` , ``qwen3_vl`` or ``qwen3_vl_moe``)
234234

235235
Run:
236236

docs/source-zh/rst_source/examples/embodied/sft_vlm.rst

Lines changed: 8 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -6,13 +6,13 @@ VLM模型监督微调训练
66
本教程重点需要关注两个文件:
77

88
- 启动脚本:``examples/sft/run_vlm_sft.sh``
9-
- 训练配置:``examples/sft/config/custom_sft_vlm.yaml``
9+
- 训练配置:``examples/sft/config/qwen2_5_sft_vlm.yaml``
1010

1111
----------------------
1212

1313
启动脚本:``examples/sft/run_vlm_sft.sh``
1414

15-
- 当前脚本默认使用配置yaml文件 ``examples/sft/config/custom_sft_vlm.yaml``
15+
- 当前脚本默认使用配置yaml文件 ``examples/sft/config/qwen2_5_sft_vlm.yaml``
1616
- 重定向文件的输出在:``<repo>/logs/<timestamp>/``
1717
- 实际执行命令:
1818

@@ -23,7 +23,7 @@ VLM模型监督微调训练
2323
--config-name <你的配置名> \
2424
runner.logger.log_path=<自动生成的日志目录>
2525
26-
配置模板:``examples/sft/config/custom_sft_vlm.yaml``
26+
配置模板:``examples/sft/config/qwen2_5_sft_vlm.yaml``
2727

2828
VLM 配置与 RLinf 中的其他 RL 训练文件结构基本一样,其中 ``data`` 和 ``actor.model`` 的具体值改为 VLM 场景。
2929

@@ -33,9 +33,10 @@ VLM模型监督微调训练
3333
1. 准备好环境,下载 RLinf 官方镜像 ``rlinf/rlinf:math-rlinf0.1-torch2.6.0-sglang0.4.6.post5-vllm0.8.5-megatron0.13.0-te2.1``
3434
2. 准备好模型权重目录,下载网址 ``https://huggingface.co/Qwen/Qwen2.5-VL-3B-Instruct``
3535
3. 准备好 Robo2VLM 数据集目录 ``https://huggingface.co/datasets/keplerccc/Robo2VLM-1``
36-
4. 修改 ``examples/sft/config/custom_sft_vlm.yaml`` 文件,运行脚本 ``examples/sft/run_vlm_sft.sh``
36+
4. 修改 ``examples/sft/config/qwen2_5_sft_vlm.yaml`` 文件,运行脚本 ``examples/sft/run_vlm_sft.sh``
3737

38-
下面是实例 yaml 文件
38+
下面是 Qwen2.5-Vl-4B sft 的例子
39+
--------------------------------
3940

4041
请注意,Robo2VLM数据集下载后由于它将 train 数据和 evaluate 数据放在一起,命名方式为 ``train-00000-of-00262.parquet`` 和 ``test-0000X-of-00003.parquet``,所以需要将它们分开,并分别放在不同的文件夹下,否则 RLinf 会直接读取整个数据集。
4142

@@ -148,7 +149,7 @@ VLM模型监督微调训练
148149
149150
说明:
150151

151-
- 不传参数时,脚本默认 ``custom_sft_vlm``
152+
- 不传参数时,脚本默认 ``qwen2_5_sft_vlm``
152153
- 如果你文件名不同,比如 ``my_vlm_config.yaml``,就传参数:
153154

154155
.. code:: bash
@@ -228,7 +229,7 @@ loss 曲线:
228229
- ``convertor.ckpt_path``:指向 ``full_weights.pt``
229230
- ``convertor.save_path``:输出 HF 权重目录
230231
- ``model.model_path``:原始基座模型路径
231-
- ``model.model_type``:对应模型类型(如 qwen2.5_vl)
232+
- ``model.model_type``:对应模型类型(如 ``qwen2.5_vl`` , ``qwen3_vl`` 或 ``qwen3_vl_moe``
232233

233234
运行命令:
234235

examples/sft/config/custom_sft_openpi.yaml

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -63,7 +63,7 @@ actor:
6363
# Override the default values in training_backend/fsdp
6464
fsdp_config:
6565
strategy: "fsdp"
66-
sharding_strategy: "no_shard"
66+
sharding_strategy: "full_shard"
6767
use_orig_params: False
6868
gradient_checkpointing: False # for openpi, gradient checkpointing is not supported, please do not change this value
6969
mixed_precision:
Lines changed: 2 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -75,12 +75,11 @@ actor:
7575
total_training_steps: ${runner.max_epochs}
7676
lr_warmup_steps: 200
7777

78-
# Override the default values in training_backend/fsdp
7978
fsdp_config:
8079
strategy: "fsdp"
81-
sharding_strategy: "no_shard"
80+
sharding_strategy: "full_shard"
8281
use_orig_params: False
83-
gradient_checkpointing: False # for openpi, gradient checkpointing is not supported, please do not change this value
82+
gradient_checkpointing: False
8483
mixed_precision:
8584
param_dtype: bf16
8685
reduce_dtype: fp32
Lines changed: 92 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,92 @@
1+
defaults:
2+
- override hydra/job_logging: stdout
3+
4+
hydra:
5+
run:
6+
dir: .
7+
output_subdir: null
8+
9+
cluster:
10+
num_nodes: 1
11+
component_placement:
12+
actor: all
13+
14+
runner:
15+
task_type: sft
16+
logger:
17+
log_path: ${runner.output_dir}/${runner.experiment_name}
18+
project_name: rlinf
19+
experiment_name: ${runner.experiment_name}
20+
logger_backends: ["tensorboard"] # wandb, swanlab
21+
22+
# sft runner use the len(dataset) as one epoch
23+
max_epochs: 6000
24+
max_steps: -1
25+
# eval the model intenval
26+
val_check_interval: 1000
27+
# save the model interval
28+
save_interval: 1000
29+
experiment_name: qwen3_vl_sft
30+
output_dir: ../results
31+
resume_dir: null
32+
33+
data:
34+
type: vlm
35+
dataset_name: "robo2vlmsft"
36+
apply_chat_template: True
37+
use_chat_template: True
38+
# if train_data_paths is not None, the sft code will just eval the model
39+
train_data_paths: "/path/to/Robo2VLM-1"
40+
eval_data_paths: "/path/to/Robo2VLM-1"
41+
prompt_key: "question"
42+
choice_key: "choices"
43+
answer_key: "correct_answer"
44+
image_keys: ["image"]
45+
max_prompt_length: 1024
46+
lazy_loading: false
47+
num_workers: 16
48+
answer_separator: ""
49+
50+
algorithm:
51+
adv_type: gae
52+
53+
actor:
54+
group_name: "ActorGroup"
55+
training_backend: "fsdp"
56+
micro_batch_size: 4
57+
eval_batch_size: 4
58+
global_batch_size: 256
59+
seed: 42
60+
61+
model:
62+
model_type: "qwen3_vl"
63+
precision: fp32
64+
model_path: "/path/to/Qwen3-VL-4B-Instruct"
65+
is_lora: False
66+
67+
optim:
68+
lr: 1e-5
69+
adam_beta1: 0.9
70+
adam_beta2: 0.999
71+
adam_eps: 1.0e-08
72+
weight_decay: 0.01
73+
clip_grad: 1.0
74+
lr_scheduler: "cosine"
75+
total_training_steps: ${runner.max_epochs}
76+
lr_warmup_steps: 200
77+
78+
fsdp_config:
79+
strategy: "fsdp"
80+
sharding_strategy: "full_shard"
81+
use_orig_params: False
82+
gradient_checkpointing: False
83+
mixed_precision:
84+
param_dtype: bf16
85+
reduce_dtype: fp32
86+
buffer_dtype: bf16
87+
88+
reward:
89+
use_reward_model: False
90+
91+
critic:
92+
use_critic_model: False

examples/sft/config/robotwin_sft_openpi.yaml

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -72,7 +72,7 @@ actor:
7272
# Override the default values in training_backend/fsdp
7373
fsdp_config:
7474
strategy: "fsdp"
75-
sharding_strategy: "no_shard"
75+
sharding_strategy: "full_shard"
7676
use_orig_params: False
7777
gradient_checkpointing: False # for openpi, gradient checkpointing is not supported, please do not change this value
7878
mixed_precision:

examples/sft/run_vlm_sft.sh

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -7,7 +7,7 @@ export SRC_FILE="${VLM_PATH}/train_vlm_sft.py"
77
export PYTHONPATH=${REPO_PATH}:${LIBERO_REPO_PATH}:$PYTHONPATH
88

99
if [ -z "$1" ]; then
10-
CONFIG_NAME="custom_sft_vlm"
10+
CONFIG_NAME="qwen2_5_sft_vlm"
1111
else
1212
CONFIG_NAME=$1
1313
fi

rlinf/hybrid_engines/fsdp/fsdp_model_manager.py

Lines changed: 22 additions & 18 deletions
Original file line numberDiff line numberDiff line change
@@ -193,28 +193,32 @@ def _optimize_with_liger_kernel(self, model: torch.nn.Module) -> None:
193193
from liger_kernel.transformers import (
194194
apply_liger_kernel_to_qwen2,
195195
apply_liger_kernel_to_qwen2_5_vl,
196+
apply_liger_kernel_to_qwen3_moe,
197+
apply_liger_kernel_to_qwen3_vl,
198+
apply_liger_kernel_to_qwen3_vl_moe,
196199
)
197200

201+
LIGER_COMMON_KWARGS = {
202+
"rope": True,
203+
"rms_norm": True,
204+
"swiglu": True,
205+
"fused_linear_cross_entropy": True,
206+
}
207+
208+
_liger_func_by_model = {
209+
SupportedModel.QWEN2_5: apply_liger_kernel_to_qwen2,
210+
SupportedModel.QWEN2_5_VL: apply_liger_kernel_to_qwen2_5_vl,
211+
SupportedModel.QWEN2_5_VL_SFT: apply_liger_kernel_to_qwen2_5_vl,
212+
SupportedModel.QWEN3_VL_SFT: apply_liger_kernel_to_qwen3_vl,
213+
SupportedModel.QWEN3_MOE: apply_liger_kernel_to_qwen3_moe,
214+
SupportedModel.QWEN3_VL_MOE_SFT: apply_liger_kernel_to_qwen3_vl_moe,
215+
}
216+
198217
MODEL_LIGER_KERNEL_APPLY_FUNC = {
199-
SupportedModel.QWEN2_5: (
200-
apply_liger_kernel_to_qwen2,
201-
{
202-
"rope": True,
203-
"rms_norm": True,
204-
"swiglu": True,
205-
"fused_linear_cross_entropy": True,
206-
},
207-
),
208-
SupportedModel.QWEN2_5_VL: (
209-
apply_liger_kernel_to_qwen2_5_vl,
210-
{
211-
"rope": True,
212-
"rms_norm": True,
213-
"swiglu": True,
214-
"fused_linear_cross_entropy": True,
215-
},
216-
),
218+
model_type: (apply_fn, dict(LIGER_COMMON_KWARGS))
219+
for model_type, apply_fn in _liger_func_by_model.items()
217220
}
221+
218222
model_type = get_supported_model(
219223
self._cfg.model.get("model_type", "").lower()
220224
)

0 commit comments

Comments
 (0)