Skip to content
This repository was archived by the owner on Sep 4, 2025. It is now read-only.

Commit a79e522

Browse files
authored
[Model] Support pp for qwen2-vl (vllm-project#8696)
1 parent 3e83c12 commit a79e522

File tree

4 files changed

+46
-14
lines changed

4 files changed

+46
-14
lines changed

tests/distributed/test_pipeline_parallel.py

Lines changed: 8 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -8,6 +8,8 @@
88
import os
99

1010
import pytest
11+
from packaging import version
12+
from transformers import __version__ as transformers_version
1113

1214
from vllm.logger import init_logger
1315

@@ -37,6 +39,7 @@
3739
(1, 2, 1, 1, 1, "OpenGVLab/InternVL2-1B", "mp"),
3840
(1, 2, 1, 1, 1, "OpenGVLab/InternVL2-2B", "mp"),
3941
(1, 2, 1, 0, 1, "OpenGVLab/InternVL2-4B", "mp"),
42+
(1, 2, 0, 1, 0, "Qwen/Qwen2-VL-2B-Instruct", "mp")
4043
],
4144
)
4245
@fork_new_process_for_each_test
@@ -46,6 +49,11 @@ def test_compare_tp(TP_SIZE, PP_SIZE, EAGER_MODE, CHUNKED_PREFILL,
4649
pytest.skip("Skipping multi-node pipeline parallel test for "
4750
"multiprocessing distributed backend")
4851

52+
# Skip tests that require transformers>=4.45.0
53+
if "Qwen2-VL" in MODEL_NAME and version.parse(
54+
transformers_version) < version.parse("4.45.0.dev0"):
55+
pytest.skip("This test requires transformers>=4.45.0")
56+
4957
pp_args = [
5058
# use half precision for speed and memory savings in CI environment
5159
"--dtype",

vllm/config.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -51,6 +51,7 @@
5151
"Qwen2ForCausalLM",
5252
"Qwen2MoeForCausalLM",
5353
"QWenLMHeadModel",
54+
"Qwen2VLForConditionalGeneration",
5455
]
5556

5657

vllm/model_executor/models/qwen2.py

Lines changed: 15 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -49,7 +49,7 @@
4949
from vllm.sequence import IntermediateTensors
5050

5151
from .interfaces import SupportsLoRA
52-
from .utils import is_pp_missing_parameter, make_layers
52+
from .utils import PPMissingLayer, is_pp_missing_parameter, make_layers
5353

5454

5555
class Qwen2MLP(nn.Module):
@@ -235,11 +235,16 @@ def __init__(
235235
self.padding_idx = config.pad_token_id
236236
self.vocab_size = config.vocab_size
237237

238-
self.embed_tokens = VocabParallelEmbedding(
239-
config.vocab_size,
240-
config.hidden_size,
241-
quant_config=quant_config,
242-
)
238+
if get_pp_group().is_first_rank or (config.tie_word_embeddings
239+
and get_pp_group().is_last_rank):
240+
self.embed_tokens = VocabParallelEmbedding(
241+
config.vocab_size,
242+
config.hidden_size,
243+
quant_config=quant_config,
244+
)
245+
else:
246+
self.embed_tokens = PPMissingLayer()
247+
243248
self.start_layer, self.end_layer, self.layers = make_layers(
244249
config.num_hidden_layers,
245250
lambda prefix: Qwen2DecoderLayer(config=config,
@@ -248,7 +253,10 @@ def __init__(
248253
prefix=f"{prefix}.layers",
249254
)
250255

251-
self.norm = RMSNorm(config.hidden_size, eps=config.rms_norm_eps)
256+
if get_pp_group().is_last_rank:
257+
self.norm = RMSNorm(config.hidden_size, eps=config.rms_norm_eps)
258+
else:
259+
self.norm = PPMissingLayer()
252260

253261
def get_input_embeddings(self, input_ids: torch.Tensor) -> torch.Tensor:
254262
return self.embed_tokens(input_ids)

vllm/model_executor/models/qwen2_vl.py

Lines changed: 22 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -45,7 +45,7 @@
4545
from vllm.attention.selector import (_Backend, backend_name_to_enum,
4646
get_global_forced_attn_backend)
4747
from vllm.config import CacheConfig, MultiModalConfig
48-
from vllm.distributed import parallel_state
48+
from vllm.distributed import get_pp_group, parallel_state
4949
from vllm.distributed import utils as dist_utils
5050
from vllm.inputs import INPUT_REGISTRY, InputContext, LLMInputs
5151
from vllm.logger import init_logger
@@ -68,6 +68,9 @@
6868
from vllm.sequence import IntermediateTensors, SequenceData
6969
from vllm.transformers_utils.processor import get_processor
7070

71+
from .utils import (PPMissingLayer, is_pp_missing_parameter,
72+
make_empty_intermediate_tensors_factory)
73+
7174
logger = init_logger(__name__)
7275

7376
# === Vision Inputs === #
@@ -856,15 +859,21 @@ def __init__(self,
856859

857860
self.model = Qwen2Model(config, cache_config, quant_config)
858861

859-
if config.tie_word_embeddings:
860-
self.lm_head = self.model.embed_tokens
862+
if get_pp_group().is_last_rank:
863+
if config.tie_word_embeddings:
864+
self.lm_head = self.model.embed_tokens
865+
else:
866+
self.lm_head = ParallelLMHead(config.vocab_size,
867+
config.hidden_size,
868+
quant_config=quant_config)
861869
else:
862-
self.lm_head = ParallelLMHead(config.vocab_size,
863-
config.hidden_size,
864-
quant_config=quant_config)
870+
self.lm_head = PPMissingLayer()
865871

866872
self.logits_processor = LogitsProcessor(config.vocab_size)
867873
self.sampler = Sampler()
874+
self.make_empty_intermediate_tensors = (
875+
make_empty_intermediate_tensors_factory(
876+
["hidden_states", "residual"], config.hidden_size))
868877

869878
def _validate_and_reshape_mm_tensor(self,
870879
mm_input: Union[torch.Tensor,
@@ -979,7 +988,8 @@ def forward(
979988
image_input = self._parse_and_validate_image_input(**kwargs)
980989
video_input = self._parse_and_validate_video_input(**kwargs)
981990

982-
if image_input is None and video_input is None:
991+
if (image_input is None
992+
and video_input is None) or not get_pp_group().is_first_rank:
983993
inputs_embeds = None
984994
else:
985995
if getattr(self.config, "rope_scaling", {}).get("type",
@@ -1015,6 +1025,7 @@ def forward(
10151025
positions=positions,
10161026
kv_caches=kv_caches,
10171027
attn_metadata=attn_metadata,
1028+
intermediate_tensors=intermediate_tensors,
10181029
inputs_embeds=inputs_embeds,
10191030
)
10201031
return hidden_states
@@ -1055,6 +1066,8 @@ def load_weights(self, weights: Iterable[Tuple[str, torch.Tensor]]):
10551066
# Skip loading extra bias for GPTQ models.
10561067
if name.endswith(".bias") and name not in params_dict:
10571068
continue
1069+
if is_pp_missing_parameter(name, self):
1070+
continue
10581071
param = params_dict[name]
10591072
weight_loader = param.weight_loader
10601073
weight_loader(param, loaded_weight, shard_id)
@@ -1081,6 +1094,8 @@ def load_weights(self, weights: Iterable[Tuple[str, torch.Tensor]]):
10811094
# Skip loading extra bias for GPTQ models.
10821095
if name.endswith(".bias") and name not in params_dict:
10831096
continue
1097+
if is_pp_missing_parameter(name, self):
1098+
continue
10841099
param = params_dict[name]
10851100
except KeyError:
10861101
print(params_dict.keys())

0 commit comments

Comments
 (0)