Skip to content
This repository was archived by the owner on Sep 4, 2025. It is now read-only.

Commit 1230263

Browse files
authored
[Bugfix] Fix InternVL2 vision embeddings process with pipeline parallel (vllm-project#8299)
1 parent e497b8a commit 1230263

File tree

2 files changed

+10
-3
lines changed

2 files changed

+10
-3
lines changed

tests/distributed/test_pipeline_parallel.py

Lines changed: 8 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -32,7 +32,9 @@
3232
(1, 4, 1, 0, 0, "meta-llama/Meta-Llama-3-8B", "ray"),
3333
(2, 2, 1, 0, 0, "meta-llama/Meta-Llama-3-8B", "ray"),
3434
(2, 2, 0, 1, 0, "meta-llama/Meta-Llama-3-8B", "ray"),
35-
(2, 2, 1, 1, 1, "internlm/internlm2_5-7b-chat", "ray"),
35+
(1, 2, 1, 1, 1, "OpenGVLab/InternVL2-1B", "ray"),
36+
(1, 2, 1, 1, 1, "OpenGVLab/InternVL2-2B", "ray"),
37+
(1, 2, 1, 0, 1, "OpenGVLab/InternVL2-4B", "ray"),
3638
],
3739
)
3840
@fork_new_process_for_each_test
@@ -46,6 +48,8 @@ def test_compare_tp(TP_SIZE, PP_SIZE, EAGER_MODE, CHUNKED_PREFILL,
4648
# use half precision for speed and memory savings in CI environment
4749
"--dtype",
4850
"float16",
51+
"--max-model-len",
52+
"8192",
4953
"--pipeline-parallel-size",
5054
str(PP_SIZE),
5155
"--tensor-parallel-size",
@@ -62,7 +66,9 @@ def test_compare_tp(TP_SIZE, PP_SIZE, EAGER_MODE, CHUNKED_PREFILL,
6266
tp_args = [
6367
# use half precision for speed and memory savings in CI environment
6468
"--dtype",
65-
"bfloat16",
69+
"float16",
70+
"--max-model-len",
71+
"8192",
6672
"--tensor-parallel-size",
6773
str(max(TP_SIZE, 2)), # We only use 2 GPUs in the CI.
6874
"--distributed-executor-backend",

vllm/model_executor/models/internvl.py

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -17,6 +17,7 @@
1717

1818
from vllm.attention import AttentionMetadata
1919
from vllm.config import CacheConfig, MultiModalConfig
20+
from vllm.distributed import get_pp_group
2021
from vllm.inputs import INPUT_REGISTRY, InputContext, LLMInputs
2122
from vllm.model_executor.layers.quantization import QuantizationConfig
2223
from vllm.model_executor.layers.sampler import SamplerOutput
@@ -480,7 +481,7 @@ def forward(
480481
**kwargs: object,
481482
) -> SamplerOutput:
482483
image_input = self._parse_and_validate_image_input(**kwargs)
483-
if image_input is not None:
484+
if image_input is not None and get_pp_group().is_first_rank:
484485
inputs_embeds = self.language_model.model.get_input_embeddings(
485486
input_ids)
486487
vision_embeddings = self._process_image_input(image_input)

0 commit comments

Comments
 (0)