Skip to content

Commit 4bf56c7

Browse files
authored
[Multimodal][torch.compile] Add compilation config field for turning off ViT/MM compile (vllm-project#28242)
Signed-off-by: Lucas Kabela <lucaskabela@meta.com>
1 parent 59b453e commit 4bf56c7

File tree

4 files changed

+60
-3
lines changed

4 files changed

+60
-3
lines changed

tests/compile/test_multimodal_compile.py

Lines changed: 33 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -3,10 +3,17 @@
33
import pytest
44

55
from vllm.compilation.counter import compilation_counter
6+
from vllm.config import VllmConfig
67
from vllm.config.compilation import CompilationMode
78
from vllm.platforms import current_platform
89

910

11+
def test_compile():
12+
vllm_config = VllmConfig()
13+
# Default configuration compiles mm encoder
14+
assert vllm_config.compilation_config.compile_mm_encoder
15+
16+
1017
# forked needed to workaround https://github.com/vllm-project/vllm/issues/21073
1118
@pytest.mark.forked
1219
@pytest.mark.skipif(not current_platform.is_cuda(), reason="Skip if not cuda")
@@ -31,8 +38,33 @@ def test_qwen2_5_vl_compilation(vllm_runner, monkeypatch):
3138
vllm_runner(
3239
"Qwen/Qwen2.5-VL-3B-Instruct",
3340
max_model_len=2048,
34-
gpu_memory_utilization=0.7,
41+
gpu_memory_utilization=0.8,
3542
compilation_config={"mode": CompilationMode.VLLM_COMPILE},
3643
) as _,
3744
):
3845
pass
46+
47+
48+
# forked needed to workaround https://github.com/vllm-project/vllm/issues/21073
49+
@pytest.mark.forked
50+
@pytest.mark.skipif(not current_platform.is_cuda(), reason="Skip if not cuda")
51+
def test_qwen2_5_vl_no_vit_compilation(vllm_runner, monkeypatch):
52+
"""Test that Qwen2.5-VL vision submodules are not compiled when the
53+
config is passed off
54+
"""
55+
# Disable multiprocessing so that the counter is in the same process
56+
monkeypatch.setenv("VLLM_ENABLE_V1_MULTIPROCESSING", "0")
57+
58+
with (
59+
compilation_counter.expect(num_models_seen=1),
60+
vllm_runner(
61+
"Qwen/Qwen2.5-VL-3B-Instruct",
62+
max_model_len=2048,
63+
gpu_memory_utilization=0.8,
64+
compilation_config={
65+
"mode": CompilationMode.VLLM_COMPILE,
66+
"compile_mm_encoder": False,
67+
},
68+
) as _,
69+
):
70+
pass

vllm/config/compilation.py

Lines changed: 8 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -150,6 +150,7 @@ class CompilationConfig:
150150
- [`backend`][vllm.config.CompilationConfig.backend]
151151
- [`custom_ops`][vllm.config.CompilationConfig.custom_ops]
152152
- [`splitting_ops`][vllm.config.CompilationConfig.splitting_ops]
153+
- [`compile_mm_encoder`][vllm.config.CompilationConfig.compile_mm_encoder]
153154
- CudaGraph capture:
154155
- [`use_cudagraph`][vllm.config.CompilationConfig.use_cudagraph]
155156
- [`cudagraph_mode`][vllm.config.CompilationConfig.cudagraph_mode]
@@ -250,6 +251,13 @@ class CompilationConfig:
250251
disabled when running with Inductor: mode>=VLLM_COMPILE and use_inductor=True.
251252
Inductor generates (fused) Triton kernels for disabled custom ops."""
252253
splitting_ops: list[str] | None = None
254+
255+
"""
256+
Provide control over whether to compile the multimodal encoder
257+
such as Qwen2_5_vl
258+
"""
259+
compile_mm_encoder: bool = True
260+
253261
"""A list of ops to exclude from cudagraphs, used in piecewise compilation.
254262
255263
The behavior depends on use_inductor_graph_partition:

vllm/model_executor/models/qwen2_5_vl.py

Lines changed: 8 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -67,6 +67,9 @@
6767
from vllm.model_executor.layers.quantization import QuantizationConfig
6868
from vllm.model_executor.model_loader.weight_utils import default_weight_loader
6969
from vllm.model_executor.models.module_mapping import MultiModelKeys
70+
from vllm.model_executor.models.transformers.utils import (
71+
should_torch_compile_mm_vit,
72+
)
7073
from vllm.multimodal import MULTIMODAL_REGISTRY
7174
from vllm.multimodal.evs import (
7275
compute_mrope_for_media,
@@ -464,6 +467,7 @@ def forward(
464467
"seqlens": 0,
465468
},
466469
mark_unbacked_dims={"seqlens": 0},
470+
enable_if=should_torch_compile_mm_vit,
467471
)
468472
class Qwen2_5_VisionBlock(nn.Module):
469473
def __init__(
@@ -529,7 +533,8 @@ def forward(
529533
@support_torch_compile(
530534
dynamic_arg_dims={
531535
"x": 0,
532-
}
536+
},
537+
enable_if=should_torch_compile_mm_vit,
533538
)
534539
class Qwen2_5_VisionPatchEmbed(nn.Module):
535540
def __init__(
@@ -560,7 +565,8 @@ def forward(self, x: torch.Tensor) -> torch.Tensor:
560565
@support_torch_compile(
561566
dynamic_arg_dims={
562567
"x": 0,
563-
}
568+
},
569+
enable_if=should_torch_compile_mm_vit,
564570
)
565571
class Qwen2_5_VisionPatchMerger(nn.Module):
566572
def __init__(

vllm/model_executor/models/transformers/utils.py

Lines changed: 11 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -205,3 +205,14 @@ def can_enable_torch_compile(vllm_config: "VllmConfig") -> bool:
205205
# Dynamic rope scaling is not compatible with torch.compile
206206
rope_scaling: dict = getattr(text_config, "rope_scaling", None) or {}
207207
return rope_scaling.get("rope_type") != "dynamic"
208+
209+
210+
def should_torch_compile_mm_vit(vllm_config: "VllmConfig") -> bool:
211+
"""
212+
Callable to be passed to `@support_torch_compile`'s `enable_if` argument.
213+
214+
Defaults to `True` but is disabled in the following situations:
215+
216+
- The model uses dynamic rope scaling.
217+
"""
218+
return vllm_config.compilation_config.compile_mm_encoder

0 commit comments

Comments
 (0)