Skip to content

Commit 80e0528

Browse files
authored
Megatron VLM Support w/ SFT (2/N) (#1150)
1 parent 9bf1e36 commit 80e0528

File tree

4 files changed

+274
-6
lines changed

4 files changed

+274
-6
lines changed

examples/geo3k_vlm/README.md

Lines changed: 35 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -6,6 +6,38 @@ Training VLMs with FSDP or Megatron on single-turn reasoning task using GRPO on
66
<img src="fsdp_vs_megatron.png" alt="FSDP vs Megatron Reward Plot" width="800">
77
</p>
88

9+
## Data Preparation (For SFT Training)
10+
11+
The [geo3k_imgurl](https://huggingface.co/datasets/chenhegu/geo3k_imgurl) dataset contains:
12+
- `problem`: The math problem text (string)
13+
- `answer`: The answer (string, e.g., "270")
14+
- `images`: Image data (list)
15+
16+
For SFT training, we need to format the `answer` field for `\boxed{}` format and the messages. You can use the following script to format the answer field:
17+
18+
```python
19+
from datasets import load_dataset
20+
import pandas as pd
21+
22+
ds = load_dataset("chenhegu/geo3k_imgurl", split="train")
23+
24+
def format_answer(answer: str) -> str:
25+
"""Format answer to include \\boxed{} format."""
26+
return f"Answer: \\boxed{{{answer}}}"
27+
28+
def process_sample(sample):
29+
formatted_answer = f"Answer: \\boxed{{{sample['answer']}}}"
30+
31+
sample["messages"] = [
32+
{"role": "user", "content": sample["problem"]},
33+
{"role": "assistant", "content": formatted_answer}
34+
]
35+
return sample
36+
37+
ds = ds.map(process_sample)
38+
ds.to_parquet("/root/datasets/geo3k_imgurl/train_formatted.parquet")
39+
```
40+
941
## Reproduce
1042

1143
```bash
@@ -19,6 +51,9 @@ SLIME_SCRIPT_TRAIN_BACKEND=fsdp ./examples/geo3k_vlm/run_geo3k_vlm.sh
1951

2052
# With different model
2153
SLIME_SCRIPT_MODEL_NAME=Qwen3-VL-4B-Instruct ./examples/geo3k_vlm/run_geo3k_vlm.sh
54+
55+
# SFT
56+
./examples/geo_3k_vlm/run_geo3k_vlm_sft.sh
2257
```
2358

2459
### Configuration
Lines changed: 186 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,186 @@
1+
TRAIN_BACKEND=${SLIME_SCRIPT_TRAIN_BACKEND:-"megatron"}
2+
MODEL_NAME=${SLIME_SCRIPT_MODEL_NAME:-"Qwen3-VL-8B-Instruct"}
3+
DATASET_NAME=${SLIME_SCRIPT_DATASET_NAME:-"chenhegu/geo3k_imgurl"}
4+
NUM_GPUS=${SLIME_SCRIPT_NUM_GPUS:-8}
5+
DATASET_LOCAL_NAME=$(basename "$DATASET_NAME")
6+
7+
# Validate MODEL_NAME
8+
VALID_MODELS="
9+
Qwen3-VL-2B-Instruct
10+
Qwen3-VL-4B-Instruct
11+
Qwen3-VL-8B-Instruct
12+
Qwen3-VL-2B-Thinking
13+
Qwen3-VL-4B-Thinking
14+
Qwen3-VL-8B-Thinking
15+
Qwen3-VL-30B-A3B-Instruct
16+
Qwen3-VL-235B-A22B-Instruct
17+
Qwen3-VL-30B-A3B-Thinking
18+
Qwen3-VL-235B-A22B-Thinking
19+
"
20+
if ! echo "$VALID_MODELS" | grep -qw "$MODEL_NAME"; then
21+
echo "Error: MODEL_NAME must be one of: $VALID_MODELS"
22+
exit 1
23+
fi
24+
25+
MODEL_NAME_LOWER=$(echo "$MODEL_NAME" | tr '[:upper:]' '[:lower:]')
26+
27+
# External Ray flag
28+
if [ -z "$SLIME_SCRIPT_EXTERNAL_RAY" ] || [ "$SLIME_SCRIPT_EXTERNAL_RAY" = "0" ]; then
29+
USE_EXTERNAL_RAY=0
30+
else
31+
USE_EXTERNAL_RAY=1
32+
fi
33+
34+
# Cleanup
35+
pkill -9 sglang
36+
sleep 3
37+
if [ "$USE_EXTERNAL_RAY" = "0" ]; then
38+
ray stop --force
39+
pkill -9 ray
40+
fi
41+
pkill -9 slime
42+
sleep 3
43+
if [ "$USE_EXTERNAL_RAY" = "0" ]; then
44+
pkill -9 ray
45+
fi
46+
pkill -9 slime
47+
pkill -9 redis
48+
49+
set -ex
50+
51+
export PYTHONBUFFERED=16
52+
53+
# Detect NVLink
54+
NVLINK_COUNT=$(nvidia-smi topo -m 2>/dev/null | grep -o 'NV[0-9][0-9]*' | wc -l)
55+
if [ "$NVLINK_COUNT" -gt 0 ]; then
56+
HAS_NVLINK=1
57+
else
58+
HAS_NVLINK=0
59+
fi
60+
echo "HAS_NVLINK: $HAS_NVLINK (detected $NVLINK_COUNT NVLink references)"
61+
62+
# Download model and dataset
63+
mkdir -p /root/models /root/datasets
64+
if [ ! -d "/root/models/${MODEL_NAME}" ]; then
65+
hf download Qwen/${MODEL_NAME} --local-dir /root/models/${MODEL_NAME}
66+
fi
67+
if [ ! -d "/root/datasets/${DATASET_LOCAL_NAME}" ]; then
68+
hf download --repo-type dataset ${DATASET_NAME} --local-dir /root/datasets/${DATASET_LOCAL_NAME}
69+
fi
70+
71+
# Common args
72+
CKPT_ARGS=(
73+
--hf-checkpoint /root/models/${MODEL_NAME}
74+
--load /root/models/${MODEL_NAME}
75+
--rotary-base 5000000
76+
)
77+
78+
SFT_ARGS=(
79+
--rollout-function-path slime.rollout.sft_rollout.generate_rollout
80+
--prompt-data /root/datasets/${DATASET_LOCAL_NAME}/train_formatted.parquet
81+
--input-key messages
82+
--apply-chat-template
83+
--rollout-shuffle
84+
--num-epoch 3000
85+
--rollout-batch-size 128
86+
--global-batch-size 128
87+
88+
--loss-type sft_loss
89+
--calculate-per-token-loss
90+
--disable-compute-advantages-and-returns
91+
--debug-train-only
92+
)
93+
94+
# required for vlm datasets
95+
MULTIMODAL_KEYS='{"image": "images"}'
96+
97+
98+
OPTIMIZER_ARGS=(
99+
--optimizer adam
100+
--lr 1e-5
101+
--lr-decay-style cosine
102+
--min-lr 1e-6
103+
--lr-warmup-fraction 0.1
104+
--weight-decay 0.1
105+
--adam-beta1 0.9
106+
--adam-beta2 0.95
107+
)
108+
109+
if [ -n "$WANDB_API_KEY" ]; then
110+
WANDB_ARGS=(
111+
--use-wandb
112+
--wandb-project slime-geo3k-vlm-sft
113+
--wandb-group ${MODEL_NAME_LOWER}-${TRAIN_BACKEND}
114+
--wandb-key ${WANDB_API_KEY}
115+
--disable-wandb-random-suffix
116+
)
117+
else
118+
WANDB_ARGS=()
119+
fi
120+
121+
# Backend-specific args
122+
if [ "$TRAIN_BACKEND" = "fsdp" ]; then
123+
BACKEND_ARGS=(
124+
--train-backend fsdp
125+
--gradient-checkpointing
126+
--attn-implementation flash_attention_3
127+
--update-weight-buffer-size 536870912
128+
)
129+
else
130+
# megatron backend (default)
131+
BACKEND_ARGS=(
132+
--train-backend megatron
133+
--tensor-model-parallel-size 4
134+
--sequence-parallel
135+
--pipeline-model-parallel-size 1
136+
--context-parallel-size 1
137+
--expert-model-parallel-size 1
138+
--expert-tensor-parallel-size 1
139+
--recompute-granularity full
140+
--recompute-method uniform
141+
--recompute-num-layers 1
142+
--use-dynamic-batch-size
143+
--max-tokens-per-gpu 4096
144+
--attention-dropout 0.0
145+
--hidden-dropout 0.0
146+
--accumulate-allreduce-grads-in-fp32
147+
--attention-softmax-in-fp32
148+
--attention-backend flash
149+
--megatron-to-hf-mode bridge
150+
)
151+
152+
# get MODEL_ARGS from scripts/models for megatron backend
153+
SLIME_DIR="$(cd -- "$(dirname -- "${BASH_SOURCE[0]}")/../.." &>/dev/null && pwd)"
154+
MODEL_ARGS_FILE=$(echo "$MODEL_NAME" | sed 's/-Instruct//g; s/-Thinking//g; s/Qwen3-VL-/qwen3-/g; s/-2B/-1.7B/g')
155+
source "${SLIME_DIR}/scripts/models/${MODEL_ARGS_FILE}.sh"
156+
fi
157+
158+
# Start Ray if not using external Ray
159+
if [ "$USE_EXTERNAL_RAY" = "0" ]; then
160+
export MASTER_ADDR=${MASTER_ADDR:-"127.0.0.1"}
161+
export no_proxy="127.0.0.1,${MASTER_ADDR}"
162+
ray start --head --node-ip-address ${MASTER_ADDR} --num-gpus ${NUM_GPUS} --disable-usage-stats --dashboard-host=0.0.0.0 --dashboard-port=8265
163+
fi
164+
165+
# Build runtime env
166+
RUNTIME_ENV_JSON="{
167+
\"env_vars\": {
168+
\"PYTHONPATH\": \"/root/Megatron-LM/\",
169+
\"CUDA_DEVICE_MAX_CONNECTIONS\": \"1\",
170+
\"NCCL_NVLS_ENABLE\": \"${HAS_NVLINK}\"
171+
}
172+
}"
173+
174+
ray job submit --address="http://127.0.0.1:8265" \
175+
--runtime-env-json="${RUNTIME_ENV_JSON}" \
176+
-- python3 train_async.py \
177+
--actor-num-nodes 1 \
178+
--actor-num-gpus-per-node ${NUM_GPUS} \
179+
--multimodal-keys "${MULTIMODAL_KEYS}" \
180+
${MODEL_ARGS[@]} \
181+
${CKPT_ARGS[@]} \
182+
${SFT_ARGS[@]} \
183+
${EVAL_ARGS[@]} \
184+
${OPTIMIZER_ARGS[@]} \
185+
${WANDB_ARGS[@]} \
186+
${BACKEND_ARGS[@]}

slime/rollout/sft_rollout.py

Lines changed: 22 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -1,15 +1,15 @@
11
import logging
22

3-
from transformers import AutoTokenizer
4-
53
from slime.utils.mask_utils import MultiTurnLossMaskGenerator
4+
from slime.utils.processing_utils import load_processor, load_tokenizer, prepare_model_inputs
65

76
__all__ = ["generate_rollout"]
87

98
logger = logging.getLogger(__name__)
109

1110

1211
TOKENIZER = None
12+
PROCESSOR = None
1313
MASK_GENERATOR = None
1414
SAMPLE_PRINTED = False
1515

@@ -29,9 +29,12 @@ def generate_rollout(args, rollout_id, data_buffer, evaluation=False):
2929
assert not evaluation
3030
assert args.rollout_global_dataset
3131

32-
global TOKENIZER, MASK_GENERATOR, SAMPLE_PRINTED
32+
global TOKENIZER, PROCESSOR, MASK_GENERATOR, SAMPLE_PRINTED
3333
if TOKENIZER is None:
34-
TOKENIZER = AutoTokenizer.from_pretrained(args.hf_checkpoint, trust_remote_code=True)
34+
TOKENIZER = load_tokenizer(args.hf_checkpoint, trust_remote_code=True)
35+
36+
if PROCESSOR is None:
37+
PROCESSOR = load_processor(args.hf_checkpoint, trust_remote_code=True)
3538

3639
if MASK_GENERATOR is None:
3740
MASK_GENERATOR = MultiTurnLossMaskGenerator(TOKENIZER, tokenizer_type=args.loss_mask_type)
@@ -42,7 +45,21 @@ def generate_rollout(args, rollout_id, data_buffer, evaluation=False):
4245
(sample,) = sample
4346
messages = sample.prompt
4447
tools = sample.metadata.get("tools", None)
45-
token_ids, loss_mask = MASK_GENERATOR.get_loss_mask(messages, tools=tools)
48+
49+
input_ids, extra_info = prepare_model_inputs(
50+
messages, TOKENIZER, PROCESSOR, sample.metadata,
51+
args.apply_chat_template, args.apply_chat_template_kwargs
52+
)
53+
54+
has_multimodal = bool(extra_info.get("images") or extra_info.get("videos"))
55+
if has_multimodal:
56+
sample.multimodal_inputs = extra_info["multimodal_inputs"]
57+
token_ids, loss_mask = MASK_GENERATOR.get_loss_mask_with_multimodal_alignment(
58+
messages, input_ids, tools=tools
59+
)
60+
else:
61+
token_ids, loss_mask = MASK_GENERATOR.get_loss_mask(messages, tools=tools)
62+
4663
response_length = MASK_GENERATOR.get_response_lengths([loss_mask])[0]
4764

4865
sample.tokens = token_ids

slime/utils/mask_utils.py

Lines changed: 31 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -125,7 +125,7 @@ def gen_multi_turn_loss_mask_distill_qwen(
125125
loss_mask = [0] * len(token_ids)
126126
return token_ids, loss_mask
127127

128-
def get_loss_mask(self, messages: list[dict], tools: list[dict] = None) -> list[int]:
128+
def get_loss_mask(self, messages: list[dict], tools: list[dict] = None) -> tuple[list[int], list[int]]:
129129
if self.tokenizer_type == "qwen":
130130
if "<|Assistant|>" in self.tokenizer.get_added_vocab():
131131
return self.gen_multi_turn_loss_mask_distill_qwen(messages, tools)
@@ -138,6 +138,36 @@ def get_loss_mask(self, messages: list[dict], tools: list[dict] = None) -> list[
138138
else:
139139
raise ValueError(f"Unsupported tokenizer type: {self.tokenizer_type}")
140140

141+
def get_loss_mask_with_multimodal_alignment(
142+
self, messages: list[dict], input_ids: list[int], tools: list[dict] = None
143+
) -> tuple[list[int], list[int]]:
144+
text = []
145+
for msg in messages:
146+
if isinstance(msg.get("content"), list):
147+
text_parts = []
148+
for item in msg["content"]:
149+
if isinstance(item, dict) and item.get("type") == "text":
150+
text_parts.append(item.get("text", ""))
151+
elif isinstance(item, str):
152+
text_parts.append(item)
153+
text.append({
154+
"role": msg["role"],
155+
"content": " ".join(text_parts)
156+
})
157+
else:
158+
text.append(msg)
159+
160+
_, loss_mask_text = self.get_loss_mask(text, tools=tools)
161+
162+
diff = len(input_ids) - len(loss_mask_text)
163+
assert diff >= 0, (
164+
f"input_ids (length={len(input_ids)}) is shorter than text loss_mask (length={len(loss_mask_text)}) "
165+
f"Please check if processor and tokenizer tokenization are consistent."
166+
)
167+
loss_mask = [0] * diff + loss_mask_text
168+
169+
return input_ids, loss_mask
170+
141171
def get_text_from_loss_mask(self, token_ids: list[int], loss_masks: list[int]) -> list[str]:
142172
selected_texts = []
143173
current_tokens = []

0 commit comments

Comments
 (0)