Skip to content

Commit 11764c6

Browse files
[perf] feat: Add MFU for Qwen3-VL dense (verl-project#4753)
### What does this PR do? Add the _estimate_qwen3_vit_flop and _estimate_qwen3_vl_flops function to calculate the FLOPs of Qwen3-VL dense models. Update the test cases to verify the calculation accuracy of Qwen3-VL models. ### Checklist Before Starting - [ ] Search for similar PRs. Paste at least one query link here: ... - [ ] Format the PR title as `[{modules}] {type}: {description}` (This will be checked by the CI) - `{modules}` include `fsdp`, `megatron`, `sglang`, `vllm`, `rollout`, `trainer`, `ci`, `training_utils`, `recipe`, `hardware`, `deployment`, `ray`, `worker`, `single_controller`, `misc`, `perf`, `model`, `algo`, `env`, `tool`, `ckpt`, `doc`, `data`, `cfg`, `reward` - If this PR involves multiple modules, separate them with `,` like `[megatron, fsdp, doc]` - `{type}` is in `feat`, `fix`, `refactor`, `chore`, `test` - If this PR breaks any API (CLI arguments, config, function signature, etc.), add `[BREAKING]` to the beginning of the title. - Example: `[BREAKING][fsdp, megatron] feat: dynamic batching` ### Test The following is the output result of running the test file. <img width="1271" height="152" alt="image" src="https://github.com/user-attachments/assets/2a3d426c-bd32-4369-9c07-c8a17c60e98b" /> > For changes that can not be tested by CI (e.g., algorithm implementation, new model support), validate by experiment(s) and show results like training curve plots, evaluation results, etc. ### API and Usage Example > Demonstrate how the API changes if any, and provide usage example(s) if possible. ```python # Add code snippet or script demonstrating how to use this ``` ### Design & Code Changes > Demonstrate the high-level design if this PR is complex, and list the specific changes. ### Checklist Before Submitting > [!IMPORTANT] > Please check all the following items before requesting a review, otherwise the reviewer might deprioritize this PR for review. - [ ] Read the [Contribute Guide](https://github.com/volcengine/verl/blob/main/CONTRIBUTING.md). - [ ] Apply [pre-commit checks](https://github.com/volcengine/verl/blob/main/CONTRIBUTING.md#code-linting-and-formatting): `pre-commit install && pre-commit run --all-files --show-diff-on-failure --color=always` - [ ] Add / Update [the documentation](https://github.com/volcengine/verl/tree/main/docs). - [ ] Add unit or end-to-end test(s) to [the CI workflow](https://github.com/volcengine/verl/tree/main/.github/workflows) to cover all the code. If not feasible, explain why: ... - [ ] Once your PR is ready for CI, send a message in [the `ci-request` channel](https://verl-project.slack.com/archives/C091TCESWB1) in [the `verl` Slack workspace](https://join.slack.com/t/verl-project/shared_invite/zt-3855yhg8g-CTkqXu~hKojPCmo7k_yXTQ). (If not accessible, please try [the Feishu group (飞书群)](https://applink.larkoffice.com/client/chat/chatter/add_by_link?link_token=772jd4f1-cd91-441e-a820-498c6614126a).)
1 parent 8f41b05 commit 11764c6

File tree

6 files changed

+211
-22
lines changed

6 files changed

+211
-22
lines changed

tests/utils/test_flops_counter.py

Lines changed: 91 additions & 16 deletions
Original file line numberDiff line numberDiff line change
@@ -16,14 +16,16 @@
1616

1717
import pytest
1818

19-
from verl.utils.flops_counter import _DEVICE_FLOPS, FlopsCounter, get_device_flops
19+
from verl.utils.flops_counter import FlopsCounter
2020

2121
VALID_CONFIG_TYPE = {"llama", "qwen2", "qwen3", "qwen3_moe", "deepseek_v3", "mistral", "gemma3_text", "apertus"}
2222

2323

2424
class Config:
2525
def __init__(self, config_dict):
2626
for key, value in config_dict.items():
27+
if isinstance(value, dict):
28+
value = Config(value)
2729
setattr(self, key, value)
2830

2931

@@ -300,28 +302,101 @@ def __init__(self, config_dict):
300302
# S*(2*V*H + L*(4*H**2 + k_mlp*H*I + k_qkn*H)) * (SUM[seqlen]) + 12*SUM[seqlen**2]*L*H
301303
"expected_flops_tuple": (199154680725504 / 1e12, 732294071451648 / 1e12),
302304
},
305+
"qwen3_vl": {
306+
"config": { # Qwen/Qwen3-VL-8B
307+
"model_type": "qwen3_vl",
308+
# -------- Text config --------
309+
"text_config": {
310+
"vocab_size": 151936,
311+
"hidden_size": 4096,
312+
"intermediate_size": 12288,
313+
"num_hidden_layers": 36,
314+
"num_attention_heads": 32,
315+
"num_key_value_heads": 8,
316+
"head_dim": 128,
317+
},
318+
# -------- Vision config (ViT) --------
319+
"vision_config": {
320+
"deepstack_visual_indexes": [8, 16, 24],
321+
"num_heads": 16,
322+
"depth": 27,
323+
"hidden_size": 1152,
324+
"intermediate_size": 4304,
325+
"out_hidden_size": 4096,
326+
"spatial_merge_size": 2,
327+
"temporal_patch_size": 2,
328+
"in_channels": 3,
329+
"patch_size": 16,
330+
},
331+
},
332+
"batch_seqlens_tuple": (
333+
[512, 1024, 2048],
334+
[4096, 4096, 4096],
335+
),
336+
"images_seqlens_tuple": ([512, 1024, 2048], [4096, 4096, 4096]),
337+
# -----Text-----
338+
# 6*(vocab*hidden*2
339+
# + layer*(hidden*(q+k+v+o) + hidden*inter*3)
340+
# )*token_sum
341+
# + 12*sum(seqlen^2)*layer*hidden
342+
#
343+
# -----ViT-----
344+
# patch_embed_N =hidden*temporal_patch_size*in_channels* patch_size^2
345+
# attn_linear_N =hidden*(4*hidden)
346+
# mlp_N =hidden*inter*2
347+
# merger_N =((o+hidden*spatial_merge_size^2) * (hidden*spatial_merge_size^2))
348+
# deepstack_merger_N =merger_N * 3
349+
# dense_N =patch_embed_N + (attn_linear_N + mlp_N) * 27 + deepstack_merger_N + merger_N
350+
#
351+
# 6*(151936*4096*2
352+
# + 36*(4096*(4096+1024+1024+4096) + 4096*12288*3)
353+
# )*(512+1024+2048)
354+
# + 12*(512*512+1024*1024+2048*2048)*36*4096
355+
# + 6 * dense_N * (512 + 1024 + 2048)
356+
# + 12 * (512**2 + 1024**2 + 2048**2) * 27 * 16 * 72
357+
#
358+
# 6*(151936*4096*2
359+
# + 36*(4096*(4096+1024+1024+4096) + 4096*12288*3)
360+
# )*(4096+4096+4096)
361+
# + 12*(4096*4096+4096*4096+4096*4096)*36*4096
362+
# + 6 * dense_N * (4096 + 4096 + 2048)
363+
# + 12 * (4096**2 + 4096**2 + 4096**2) * 27 * 16 * 72
364+
"expected_flops_tuple": (
365+
200250312622080 / 1e12,
366+
753976643420160 / 1e12,
367+
),
368+
},
303369
}
304370

305371

306372
@pytest.mark.parametrize(
307373
"config_type",
308-
["llama", "qwen2", "qwen3", "qwen3_moe", "deepseek_v3", "mistral", "gemma3_text", "apertus", "gpt_oss"],
374+
["llama", "qwen2", "qwen3", "qwen3_moe", "deepseek_v3", "mistral", "gemma3_text", "apertus", "gpt_oss", "qwen3_vl"],
309375
)
310376
def test_flops_counter(config_type: str):
311377
test_config = CONFIG[config_type]
312378
config = Config(test_config["config"])
313379
flops_counter = FlopsCounter(config)
314-
for batch_seqlens, expected_flops in zip(
315-
test_config["batch_seqlens_tuple"], test_config["expected_flops_tuple"], strict=True
316-
):
317-
# set delta time to 1 to get the flops
318-
counted_flops, _ = flops_counter.estimate_flops(batch_seqlens, 1)
319-
print(f"Expect flops for {test_config['config']} is {expected_flops}, but get {counted_flops}")
320-
assert math.isclose(counted_flops, expected_flops), (
321-
f"Expect flops for {test_config['config']} is {expected_flops}, but get {counted_flops}"
322-
)
323-
324-
325-
def test_device_flops():
326-
for key, val in _DEVICE_FLOPS.items():
327-
assert get_device_flops(unit="B", device_name=key) == val
380+
if "images_seqlens_tuple" in test_config:
381+
for batch_seqlens, images_seqlens, expected_flops in zip(
382+
test_config["batch_seqlens_tuple"],
383+
test_config["images_seqlens_tuple"],
384+
test_config["expected_flops_tuple"],
385+
strict=True,
386+
):
387+
# set delta time to 1 to get the flops
388+
counted_flops, _ = flops_counter.estimate_flops(batch_seqlens, 1, images_seqlens=images_seqlens)
389+
print(f"Expect flops for {test_config['config']} is {expected_flops}, but get {counted_flops}")
390+
assert math.isclose(counted_flops, expected_flops), (
391+
f"Expect flops for {test_config['config']} is {expected_flops}, but get {counted_flops}"
392+
)
393+
else:
394+
for batch_seqlens, expected_flops in zip(
395+
test_config["batch_seqlens_tuple"], test_config["expected_flops_tuple"], strict=True
396+
):
397+
# set delta time to 1 to get the flops
398+
counted_flops, _ = flops_counter.estimate_flops(batch_seqlens, 1)
399+
print(f"Expect flops for {test_config['config']} is {expected_flops}, but get {counted_flops}")
400+
assert math.isclose(counted_flops, expected_flops), (
401+
f"Expect flops for {test_config['config']} is {expected_flops}, but get {counted_flops}"
402+
)

verl/experimental/agent_loop/agent_loop.py

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -660,6 +660,10 @@ def _compute_multi_modal_inputs(self, output, input_ids) -> dict[str, torch.Tens
660660
# We must use dict(multi_modal_inputs) to convert BatchFeature values to a new dict
661661
# because np.array() only keeps the keys for BatchFeature.
662662
multi_modal_inputs = dict(multi_modal_inputs.convert_to_tensors("pt"))
663+
image_grid_thw = multi_modal_inputs.get("image_grid_thw")
664+
if image_grid_thw is not None:
665+
images_seqlens = torch.repeat_interleave(image_grid_thw[:, 1] * image_grid_thw[:, 2], image_grid_thw[:, 0])
666+
multi_modal_inputs["images_seqlens"] = images_seqlens
663667
return multi_modal_inputs
664668

665669
def _compute_position_ids(self, input_ids, attention_mask, multi_modal_inputs) -> torch.Tensor:

verl/trainer/ppo/ray_trainer.py

Lines changed: 7 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1449,7 +1449,13 @@ def fit(self):
14491449

14501450
# compute global_valid tokens
14511451
batch.meta_info["global_token_num"] = torch.sum(batch.batch["attention_mask"], dim=-1).tolist()
1452-
1452+
# get images_seqlens
1453+
images_seqlens_all = []
1454+
for multi_modal_input in batch.non_tensor_batch["multi_modal_inputs"]:
1455+
if "image_grid_thw" not in multi_modal_input.keys():
1456+
continue
1457+
images_seqlens_all.extend(multi_modal_input["images_seqlens"].tolist())
1458+
batch.meta_info["images_seqlens"] = images_seqlens_all
14531459
with marked_timer("reward", timing_raw, color="yellow"):
14541460
# compute reward model score
14551461
if self.use_rm and "rm_scores" not in batch.batch.keys():

verl/utils/flops_counter.py

Lines changed: 101 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -118,6 +118,99 @@ def _estimate_qwen2_flops(config, tokens_sum, batch_seqlens, delta_time):
118118
return flops_achieved
119119

120120

121+
def _estimate_qwen3_vl_flops(config, tokens_sum, batch_seqlens, delta_time, **kargs):
122+
# qwen3_vl uses text_config and vision_config to distinguish configs of different parts.
123+
hidden_size = config.text_config.hidden_size
124+
vocab_size = config.text_config.vocab_size
125+
num_hidden_layers = config.text_config.num_hidden_layers
126+
num_key_value_heads = config.text_config.num_key_value_heads
127+
num_attention_heads = config.text_config.num_attention_heads
128+
intermediate_size = config.text_config.intermediate_size
129+
130+
head_dim = hidden_size // num_attention_heads
131+
q_size = num_attention_heads * head_dim
132+
k_size = num_key_value_heads * head_dim
133+
v_size = num_key_value_heads * head_dim
134+
135+
# non-attn per layer parm
136+
mlp_N = hidden_size * intermediate_size * 3
137+
attn_linear_N = hidden_size * (q_size + k_size + v_size + num_attention_heads * head_dim)
138+
emd_and_lm_head_N = vocab_size * hidden_size * 2
139+
# non-attn all_layer parm
140+
dense_N = (mlp_N + attn_linear_N) * num_hidden_layers + emd_and_lm_head_N
141+
# non-attn all_layer & all_token fwd & bwd flops
142+
dense_N_flops = 6 * dense_N * tokens_sum
143+
144+
# qwen3_vl uses deepstack to merge visual embeds and text embeds, but it has no tensor operation.
145+
146+
# attn all_layer & all_token fwd & bwd flops
147+
seqlen_square_sum = 0
148+
for seqlen in batch_seqlens:
149+
seqlen_square_sum += seqlen * seqlen
150+
attn_qkv_flops = 12 * seqlen_square_sum * head_dim * num_attention_heads * num_hidden_layers
151+
152+
# vit flops
153+
images_seqlens = kargs.get("images_seqlens", None)
154+
if images_seqlens is not None:
155+
vit_flops = _estimate_qwen3_vit_flop(images_seqlens, config.vision_config)
156+
else:
157+
vit_flops = 0
158+
159+
# all_layer & all_token fwd & bwd flops
160+
flops_all_token = dense_N_flops + attn_qkv_flops + vit_flops
161+
flops_achieved = flops_all_token * (1.0 / delta_time) / 1e12
162+
return flops_achieved
163+
164+
165+
def _estimate_qwen3_vit_flop(images_seqlens, config):
166+
"""
167+
Estimate the FLOPS of the vision encoder for Qwen3-VL
168+
"""
169+
170+
if config is None:
171+
return 0
172+
tokens_sum = sum(images_seqlens)
173+
174+
num_heads = config.num_heads
175+
depth = config.depth
176+
177+
dim = config.hidden_size
178+
mlp_hidden_dim = config.intermediate_size
179+
out_hidden_size = config.out_hidden_size
180+
181+
spatial_merge_size = config.spatial_merge_size
182+
183+
head_dim = dim // num_heads
184+
185+
# every vision token's patch_embed comes from a conv of (C, T, H, W) -> (dim,)
186+
patch_embed_N = dim * config.in_channels * config.temporal_patch_size * config.patch_size * config.patch_size
187+
# Qwen3 VL vision mlp does not use GLU, thus 2.
188+
mlp_N = dim * mlp_hidden_dim * 2
189+
attn_linear_N = dim * (4 * dim) # qkv and output proj
190+
merger_N = (out_hidden_size + (dim * (spatial_merge_size**2))) * (dim * (spatial_merge_size**2))
191+
192+
# Qwen3 VL uses deep stack, one merger for every deepstack layer
193+
deepstack_merger_N = merger_N * len(config.deepstack_visual_indexes)
194+
# non-attn all_layer parm
195+
dense_N = patch_embed_N + (mlp_N + attn_linear_N) * depth + deepstack_merger_N + merger_N
196+
197+
# non-attn all_layer & all_token fwd & bwd flops
198+
dense_N_flops = 6 * dense_N * tokens_sum
199+
200+
# In Qwen3 VL, full attention is used in all vision layers.
201+
full_attn_layer_num = depth
202+
203+
# full attn layer & all_token fwd & bwd flops
204+
seqlen_square_sum = 0
205+
for seqlen in images_seqlens:
206+
seqlen_square_sum += seqlen * seqlen
207+
attn_qkv_flops = 12 * seqlen_square_sum * head_dim * num_heads * full_attn_layer_num
208+
209+
vit_flops = dense_N_flops + attn_qkv_flops
210+
211+
return vit_flops
212+
213+
121214
def _estimate_deepseek_v3_flops(config, tokens_sum, batch_seqlens, delta_time):
122215
hidden_size = config.hidden_size
123216
vocab_size = config.vocab_size
@@ -397,7 +490,7 @@ def _estimate_unknown_flops(config, tokens_sum, batch_seqlens, delta_time):
397490
"qwen2_5_vl": _estimate_qwen2_flops,
398491
"qwen3": _estimate_qwen2_flops,
399492
"qwen3_moe": _estimate_qwen2_moe_flops,
400-
"qwen3_vl": _estimate_qwen2_flops,
493+
"qwen3_vl": _estimate_qwen3_vl_flops,
401494
"qwen3_vl_moe": _estimate_qwen2_moe_flops,
402495
"deepseek_v3": _estimate_deepseek_v3_flops,
403496
"minicpmv": _estimate_qwen2_flops,
@@ -429,10 +522,10 @@ def __init__(self, config: PretrainedConfig):
429522
f"zero."
430523
)
431524

432-
self.config = getattr(config, "text_config", config)
525+
self.config = config
433526

434527
# TODO: actually we can make this a static method
435-
def estimate_flops(self, batch_seqlens, delta_time):
528+
def estimate_flops(self, batch_seqlens, delta_time, **kargs):
436529
"""
437530
Estimate the FLOPS based on the number of valid tokens in the current batch and the time taken.
438531
@@ -447,6 +540,10 @@ def estimate_flops(self, batch_seqlens, delta_time):
447540
"""
448541
tokens_sum = sum(batch_seqlens)
449542
func = ESTIMATE_FUNC.get(self.config.model_type, _estimate_unknown_flops)
450-
estimated_flops = func(self.config, tokens_sum, batch_seqlens, delta_time)
543+
images_seqlens = kargs.get("images_seqlens", None)
544+
if images_seqlens is not None and "vl" in func.__name__:
545+
estimated_flops = func(self.config, tokens_sum, batch_seqlens, delta_time, **kargs)
546+
else:
547+
estimated_flops = func(self.config, tokens_sum, batch_seqlens, delta_time)
451548
promised_flops = get_device_flops()
452549
return estimated_flops, promised_flops

verl/workers/fsdp_workers.py

Lines changed: 4 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -926,7 +926,10 @@ def update_actor(self, data: DataProto):
926926
metrics = self.actor.update_policy(data=data)
927927
delta_time = timer.last
928928
global_num_tokens = data.meta_info["global_token_num"]
929-
estimated_flops, promised_flops = self.flops_counter.estimate_flops(global_num_tokens, delta_time)
929+
images_seqlens = data.meta_info.get("images_seqlens", None)
930+
estimated_flops, promised_flops = self.flops_counter.estimate_flops(
931+
global_num_tokens, delta_time, images_seqlens=images_seqlens
932+
)
930933
metrics["perf/mfu/actor"] = (
931934
estimated_flops * self.config.actor.ppo_epochs / promised_flops / self.world_size
932935
)

verl/workers/megatron_workers.py

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -738,6 +738,10 @@ def update_actor(self, data: DataProto):
738738
metrics = self.actor.update_policy(dataloader=dataloader)
739739
delta_time = timer.last
740740
global_num_tokens = data.meta_info["global_token_num"]
741+
images_seqlens = data.meta_info.get("images_seqlens", None)
742+
estimated_flops, promised_flops = self.flops_counter.estimate_flops(
743+
global_num_tokens, delta_time, images_seqlens=images_seqlens
744+
)
741745
estimated_flops, promised_flops = self.flops_counter.estimate_flops(global_num_tokens, delta_time)
742746
metrics["perf/mfu/actor"] = estimated_flops * self.config.actor.ppo_epochs / promised_flops / self.world_size
743747
metrics["perf/max_memory_allocated_gb"] = get_torch_device().max_memory_allocated() / (1024**3)

0 commit comments

Comments
 (0)