Skip to content

Commit b50bfca

Browse files
yfwyuki-97
andauthored
feat: VLM support via megatron backend (#1115)
Signed-off-by: Yi-Fu Wu <[email protected]> Co-authored-by: Yuki Huang <[email protected]>
1 parent 8003918 commit b50bfca

21 files changed

+738
-43
lines changed

3rdparty/Megatron-Bridge-workspace/setup.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -33,7 +33,7 @@
3333
"packaging",
3434
"tensorboard>=2.19.0",
3535
"torch",
36-
"transformers>=4.51.3",
36+
"transformers>=4.55.0",
3737
"typing-extensions",
3838
"rich",
3939
"wandb>=0.19.10",

examples/configs/recipes/vlm/vlm_grpo-qwen2.5-vl-3b-instruct-clevr-1n2g-dtensor2tp1.v1.yaml

Lines changed: 0 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -3,12 +3,3 @@ checkpointing:
33
checkpoint_dir: results/clevr_grpo
44
policy:
55
max_total_sequence_length: 3072
6-
env:
7-
refcoco:
8-
reward_functions:
9-
- name: format
10-
weight: 0.1
11-
- name: bbox_giou
12-
weight: 0.9
13-
kwargs:
14-
giou_penalty_thres: 1.0
Lines changed: 25 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,25 @@
1+
defaults: ../../vlm_grpo_3B.yaml
2+
checkpointing:
3+
checkpoint_dir: results/clevr_grpo
4+
policy:
5+
max_total_sequence_length: 3072
6+
dtensor_cfg:
7+
enabled: false
8+
dynamic_batching:
9+
enabled: false
10+
make_sequence_length_divisible_by: ${policy.megatron_cfg.tensor_model_parallel_size}
11+
optimizer: null
12+
megatron_cfg:
13+
enabled: true
14+
empty_unused_memory_level: 1
15+
optimizer:
16+
lr: 5.0e-07
17+
min_lr: 5.0e-08
18+
scheduler:
19+
lr_warmup_iters: 50
20+
lr_warmup_init: 5.0e-08
21+
distributed_data_parallel_config:
22+
overlap_grad_reduce: false
23+
logger:
24+
wandb:
25+
name: vlm-grpo-3b-megatron

examples/configs/recipes/vlm/vlm_grpo-smolvlm2-2.2b-instruct-clevr-1n2g-dtensor2tp1.v1.yaml renamed to examples/configs/recipes/vlm/vlm_grpo-smolvlm2-2.2b-instruct-clevr-1n2g-dtensor2tp1.v1.yaml.disabled

File renamed without changes.

examples/configs/vlm_grpo_3B.yaml

Lines changed: 68 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -58,6 +58,70 @@ policy:
5858
context_parallel_size: 1
5959
custom_parallel_plan: null
6060

61+
megatron_cfg:
62+
enabled: false
63+
empty_unused_memory_level: 0
64+
activation_checkpointing: false
65+
converter_type: "Qwen2ForCausalLM"
66+
tensor_model_parallel_size: 1
67+
expert_tensor_parallel_size: 1
68+
expert_model_parallel_size: 1
69+
pipeline_model_parallel_size: 1
70+
num_layers_in_first_pipeline_stage: null
71+
num_layers_in_last_pipeline_stage: null
72+
context_parallel_size: 1
73+
pipeline_dtype: ${policy.precision}
74+
sequence_parallel: false
75+
freeze_moe_router: true
76+
moe_router_dtype: "fp64"
77+
moe_router_load_balancing_type: "none" # "seq_aux_loss" causes logprob error divergence for grpo
78+
moe_router_bias_update_rate: 0.0 # by default, disable bias updates for grpo
79+
moe_permute_fusion: false
80+
#gives ~20% training perf speedup with sequence packing
81+
apply_rope_fusion: True
82+
defer_fp32_logits: null
83+
84+
optimizer:
85+
optimizer: "adam"
86+
lr: 5.0e-6
87+
min_lr: 5.0e-7
88+
weight_decay: 0.01
89+
bf16: true
90+
fp16: false
91+
params_dtype: "float32"
92+
93+
#adam
94+
adam_beta1: 0.9
95+
adam_beta2: 0.999
96+
adam_eps: 1e-8
97+
98+
#sgd
99+
sgd_momentum: 0.9
100+
101+
#distributed optimizer
102+
use_distributed_optimizer: true
103+
use_precision_aware_optimizer: true
104+
105+
clip_grad: ${policy.max_grad_norm}
106+
107+
scheduler:
108+
start_weight_decay: ${policy.megatron_cfg.optimizer.weight_decay}
109+
end_weight_decay: ${policy.megatron_cfg.optimizer.weight_decay}
110+
weight_decay_incr_style: "constant"
111+
lr_decay_style: "constant"
112+
lr_decay_iters: 1000
113+
lr_warmup_iters: 13
114+
lr_warmup_init: 5.0e-7
115+
116+
distributed_data_parallel_config:
117+
grad_reduce_in_fp32: false
118+
overlap_grad_reduce: true
119+
overlap_param_gather: true
120+
average_in_collective: true
121+
use_custom_fsdp: false
122+
data_parallel_sharding_strategy: "optim_grads_params"
123+
124+
61125
# dynamic_batching improves performance by ensuring logprob and training microbatches
62126
# have a sufficent number of tokens to maximize GPU utilization. Specifically, variable length
63127
# responses are sorted by sequence length and bucketed into microbatches with a total
@@ -76,6 +140,10 @@ policy:
76140

77141
sequence_packing:
78142
enabled: False
143+
train_mb_tokens: ${mul:${policy.max_total_sequence_length}, ${policy.train_micro_batch_size}}
144+
logprob_mb_tokens: ${mul:${policy.max_total_sequence_length}, ${policy.logprob_batch_size}}
145+
algorithm: "modified_first_fit_decreasing"
146+
sequence_length_round: 64
79147

80148
optimizer:
81149
name: "torch.optim.AdamW"
Lines changed: 200 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,200 @@
1+
grpo:
2+
num_prompts_per_step: 8
3+
num_generations_per_prompt: 16
4+
max_rollout_turns: 1
5+
max_num_epochs: 1
6+
max_num_steps: 1000000
7+
normalize_rewards: true
8+
use_leave_one_out_baseline: true
9+
val_period: 10
10+
val_at_start: false
11+
overlong_filtering: false
12+
max_val_samples: 256
13+
val_batch_size: 256
14+
seed: 42
15+
async_grpo:
16+
enabled: false
17+
max_trajectory_age_steps: 1
18+
loss_fn:
19+
reference_policy_kl_penalty: 0.01
20+
ratio_clip_min: 0.2
21+
ratio_clip_max: 0.2
22+
ratio_clip_c: null
23+
use_on_policy_kl_approximation: false
24+
use_importance_sampling_correction: false
25+
token_level_loss: true
26+
checkpointing:
27+
enabled: true
28+
checkpoint_dir: results/clevr_grpo_${policy.model_name}
29+
metric_name: val_reward
30+
higher_is_better: true
31+
keep_top_k: 3
32+
save_period: 10
33+
checkpoint_must_save_by: null
34+
policy:
35+
model_name: Qwen/Qwen2.5-VL-3B-Instruct
36+
tokenizer:
37+
name: ${policy.model_name}
38+
train_global_batch_size: 128
39+
train_micro_batch_size: 1
40+
generation_batch_size: 32
41+
logprob_batch_size: 4
42+
max_total_sequence_length: 2048
43+
precision: bfloat16
44+
dtensor_cfg:
45+
_v2: true
46+
enabled: false
47+
cpu_offload: false
48+
sequence_parallel: false
49+
activation_checkpointing: false
50+
tensor_parallel_size: 1
51+
context_parallel_size: 1
52+
custom_parallel_plan: null
53+
dynamic_batching:
54+
enabled: false
55+
train_mb_tokens: ${mul:${policy.max_total_sequence_length}, ${policy.train_micro_batch_size}}
56+
logprob_mb_tokens: ${mul:${policy.max_total_sequence_length}, ${policy.logprob_batch_size}}
57+
sequence_length_round: 64
58+
make_sequence_length_divisible_by: ${policy.megatron_cfg.tensor_model_parallel_size}
59+
max_grad_norm: 1.0
60+
sequence_packing:
61+
enabled: false
62+
train_mb_tokens: ${mul:${policy.max_total_sequence_length}, ${policy.train_micro_batch_size}}
63+
logprob_mb_tokens: ${mul:${policy.max_total_sequence_length}, ${policy.logprob_batch_size}}
64+
algorithm: modified_first_fit_decreasing
65+
sequence_length_round: 64
66+
optimizer: null
67+
scheduler:
68+
- name: torch.optim.lr_scheduler.LinearLR
69+
kwargs:
70+
start_factor: 0.1
71+
end_factor: 1.0
72+
total_iters: 50
73+
- name: torch.optim.lr_scheduler.ConstantLR
74+
kwargs:
75+
factor: 1.0
76+
total_iters: 10000000000
77+
- milestones:
78+
- 50
79+
generation:
80+
backend: vllm
81+
max_new_tokens: 1024
82+
temperature: 1.0
83+
top_p: 1.0
84+
top_k: null
85+
stop_token_ids: null
86+
stop_strings: null
87+
vllm_cfg:
88+
async_engine: false
89+
precision: ${policy.precision}
90+
tensor_parallel_size: 1
91+
pipeline_parallel_size: 1
92+
expert_parallel_size: 1
93+
gpu_memory_utilization: 0.6
94+
max_model_len: ${policy.max_total_sequence_length}
95+
enforce_eager: false
96+
enable_expert_parallel: false
97+
colocated:
98+
enabled: true
99+
resources:
100+
gpus_per_node: null
101+
num_nodes: null
102+
megatron_cfg:
103+
enabled: true
104+
empty_unused_memory_level: 0
105+
activation_checkpointing: false
106+
converter_type: Qwen2ForCausalLM
107+
tensor_model_parallel_size: 1
108+
expert_tensor_parallel_size: 1
109+
expert_model_parallel_size: 1
110+
pipeline_model_parallel_size: 1
111+
num_layers_in_first_pipeline_stage: null
112+
num_layers_in_last_pipeline_stage: null
113+
context_parallel_size: 1
114+
pipeline_dtype: ${policy.precision}
115+
sequence_parallel: false
116+
freeze_moe_router: true
117+
moe_router_dtype: fp64
118+
moe_router_load_balancing_type: none
119+
moe_router_bias_update_rate: 0.0
120+
moe_permute_fusion: false
121+
apply_rope_fusion: true
122+
optimizer:
123+
optimizer: adam
124+
lr: 2.0e-07
125+
min_lr: 2.0e-07
126+
weight_decay: 0.01
127+
bf16: true
128+
fp16: false
129+
params_dtype: float32
130+
adam_beta1: 0.9
131+
adam_beta2: 0.999
132+
adam_eps: 1.0e-08
133+
sgd_momentum: 0.9
134+
use_distributed_optimizer: true
135+
use_precision_aware_optimizer: true
136+
clip_grad: ${policy.max_grad_norm}
137+
scheduler:
138+
start_weight_decay: ${policy.megatron_cfg.optimizer.weight_decay}
139+
end_weight_decay: ${policy.megatron_cfg.optimizer.weight_decay}
140+
weight_decay_incr_style: constant
141+
lr_decay_style: constant
142+
lr_decay_iters: 1000
143+
lr_warmup_iters: 50
144+
lr_warmup_init: 2.0e-08
145+
distributed_data_parallel_config:
146+
grad_reduce_in_fp32: false
147+
overlap_grad_reduce: false
148+
overlap_param_gather: true
149+
average_in_collective: true
150+
use_custom_fsdp: false
151+
data_parallel_sharding_strategy: optim_grads_params
152+
data:
153+
max_input_seq_length: ${policy.max_total_sequence_length}
154+
prompt_file: examples/prompts/clevr_cogent_cot.txt
155+
system_prompt_file: null
156+
dataset_name: clevr-cogent
157+
split: trainA
158+
shuffle: true
159+
env:
160+
clevr-cogent:
161+
num_workers: 8
162+
reward_functions:
163+
- name: format
164+
weight: 0.2
165+
- name: exact_alnum
166+
weight: 0.8
167+
geometry3k:
168+
num_workers: 8
169+
reward_functions:
170+
- name: format
171+
weight: 0.1
172+
- name: math_expr
173+
weight: 0.9
174+
refcoco:
175+
num_workers: 8
176+
reward_functions:
177+
- name: format
178+
weight: 0.1
179+
- name: bbox_giou
180+
weight: 0.9
181+
kwargs:
182+
giou_penalty_thres: 0.5
183+
logger:
184+
log_dir: logs
185+
num_val_samples_to_print: 0
186+
wandb_enabled: false
187+
tensorboard_enabled: true
188+
swanlab_enabled: false
189+
mlflow_enabled: false
190+
monitor_gpus: false
191+
wandb:
192+
project: grpo-dev
193+
name: vlm-grpo-3b-megatron
194+
tensorboard: {}
195+
gpu_monitoring:
196+
collection_interval: 10
197+
flush_interval: 10
198+
cluster:
199+
gpus_per_node: 2
200+
num_nodes: 1

examples/run_vlm_grpo.py

Lines changed: 18 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -194,16 +194,29 @@ def hf_data_processor(
194194

195195
length = sum(len(m["token_ids"]) for m in message_log)
196196
loss_multiplier = 1.0
197-
if length > max_seq_length:
197+
if length >= max_seq_length:
198+
# Treat truncated messages as text only
199+
vllm_kwargs = {
200+
"vllm_content": None,
201+
"vllm_images": [],
202+
}
203+
198204
# make smaller and mask out
199205
for chat_message in message_log:
200206
chat_message["token_ids"] = chat_message["token_ids"][
201207
: min(4, max_seq_length // len(message_log))
202208
]
209+
for key, value in chat_message.items():
210+
if isinstance(value, PackedTensor):
211+
chat_message[key] = PackedTensor.empty_like(value)
203212
loss_multiplier = 0.0
204-
raise NotImplementedError(
205-
"Sequence length is too long, please use a shorter sequence length"
206-
)
213+
else:
214+
# get the prompt content! (use this for vllm-backend that needs formatted dialog and list of images) for the entire conversation
215+
# add images for vllm serving
216+
vllm_kwargs = {
217+
"vllm_content": string_formatted_dialog,
218+
"vllm_images": images,
219+
}
207220

208221
output: DatumSpec = {
209222
"message_log": message_log,
@@ -212,10 +225,7 @@ def hf_data_processor(
212225
"loss_multiplier": loss_multiplier,
213226
"idx": idx,
214227
"task_name": task_data_spec.task_name,
215-
# get the prompt content! (use this for vllm-backend that needs formatted dialog and list of images) for the entire conversation
216-
# add images for vllm serving
217-
"vllm_content": string_formatted_dialog,
218-
"vllm_images": images,
228+
**vllm_kwargs,
219229
}
220230
return output
221231

0 commit comments

Comments
 (0)