Skip to content

Commit 9609327

Browse files
[Core] [Bugfix]: tensor parallel with prompt embeds (vllm-project#18171)
Signed-off-by: Nan2018 <[email protected]> Co-authored-by: Andrew Sansom <[email protected]>
1 parent f07a673 commit 9609327

File tree

4 files changed

+136
-62
lines changed

4 files changed

+136
-62
lines changed

tests/basic_correctness/test_basic_correctness.py

Lines changed: 65 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -8,12 +8,13 @@
88
from unittest.mock import Mock
99

1010
import pytest
11+
import torch
1112

12-
from vllm import LLM
13+
from vllm import LLM, envs
1314
from vllm.platforms import current_platform
1415
from vllm.v1.engine.llm_engine import LLMEngine as LLMEngineV1
1516

16-
from ..conftest import VllmRunner
17+
from ..conftest import HfRunner, VllmRunner
1718
from ..models.utils import check_outputs_equal
1819
from ..utils import multi_gpu_test
1920

@@ -43,11 +44,26 @@ def test_vllm_gc_ed():
4344
assert weak_llm() is None
4445

4546

47+
def _fix_prompt_embed_outputs(
48+
vllm_outputs: list[tuple[list[int], str]], hf_model: HfRunner,
49+
example_prompts: list[str]) -> list[tuple[list[int], str]]:
50+
fixed_vllm_outputs = []
51+
for vllm_output, hf_input, prompt in zip(
52+
vllm_outputs, hf_model.get_inputs(example_prompts),
53+
example_prompts):
54+
hf_input_ids = hf_input["input_ids"].tolist()[0]
55+
fixed_vllm_outputs.append(
56+
(hf_input_ids + vllm_output[0][len(hf_input_ids):],
57+
prompt + vllm_output[1]))
58+
return fixed_vllm_outputs
59+
60+
4661
@pytest.mark.parametrize("model", MODELS)
4762
@pytest.mark.parametrize("backend", ["FLASH_ATTN"])
4863
@pytest.mark.parametrize("dtype", ["half"])
4964
@pytest.mark.parametrize("max_tokens", [5])
5065
@pytest.mark.parametrize("enforce_eager", [False])
66+
@pytest.mark.parametrize("enable_prompt_embeds", [True, False])
5167
def test_models(
5268
monkeypatch: pytest.MonkeyPatch,
5369
hf_runner,
@@ -56,8 +72,13 @@ def test_models(
5672
dtype: str,
5773
max_tokens: int,
5874
enforce_eager: bool,
75+
enable_prompt_embeds: bool,
5976
) -> None:
6077

78+
if enable_prompt_embeds and envs.is_set(
79+
"VLLM_USE_V1") and envs.VLLM_USE_V1:
80+
pytest.skip("enable_prompt_embeds is not supported in v1.")
81+
6182
if backend == "FLASHINFER" and current_platform.is_rocm():
6283
pytest.skip("Flashinfer does not support ROCm/HIP.")
6384

@@ -78,14 +99,25 @@ def test_models(
7899

79100
with hf_runner(model, dtype=dtype) as hf_model:
80101
hf_outputs = hf_model.generate_greedy(example_prompts, max_tokens)
102+
if enable_prompt_embeds:
103+
with torch.no_grad():
104+
prompt_embeds = hf_model.get_prompt_embeddings(
105+
example_prompts)
81106

82107
with VllmRunner(model,
83108
max_model_len=8192,
84109
dtype=dtype,
85110
enforce_eager=enforce_eager,
111+
enable_prompt_embeds=enable_prompt_embeds,
86112
gpu_memory_utilization=0.7) as vllm_model:
87-
vllm_outputs = vllm_model.generate_greedy(example_prompts,
88-
max_tokens)
113+
if enable_prompt_embeds:
114+
vllm_outputs = vllm_model.generate_greedy(
115+
prompt_embeds, max_tokens)
116+
vllm_outputs = _fix_prompt_embed_outputs(
117+
vllm_outputs, hf_model, example_prompts)
118+
else:
119+
vllm_outputs = vllm_model.generate_greedy(
120+
example_prompts, max_tokens)
89121

90122
check_outputs_equal(
91123
outputs_0_lst=hf_outputs,
@@ -108,6 +140,7 @@ def test_models(
108140
("distilbert/distilgpt2", "mp", "FLASHINFER", "A100"),
109141
("meta-llama/Meta-Llama-3-8B", "ray", "FLASHINFER", "A100"),
110142
])
143+
@pytest.mark.parametrize("enable_prompt_embeds", [True, False])
111144
def test_models_distributed(
112145
monkeypatch: pytest.MonkeyPatch,
113146
hf_runner,
@@ -117,14 +150,22 @@ def test_models_distributed(
117150
distributed_executor_backend: str,
118151
attention_backend: str,
119152
test_suite: str,
153+
enable_prompt_embeds: bool,
120154
) -> None:
121155

156+
if enable_prompt_embeds and envs.is_set(
157+
"VLLM_USE_V1") and envs.VLLM_USE_V1:
158+
pytest.skip("enable_prompt_embeds is not supported in v1.")
159+
122160
if test_suite != TARGET_TEST_SUITE:
123161
pytest.skip(f"Skip test for {test_suite}")
124162

125163
with monkeypatch.context() as monkeypatch_context:
126164
if model == "meta-llama/Llama-3.2-1B-Instruct" and distributed_executor_backend == "ray" and attention_backend == "" and test_suite == "L4": # noqa
127-
# test Ray Compiled Graph
165+
if enable_prompt_embeds:
166+
pytest.skip(
167+
"enable_prompt_embeds does not work with ray compiled dag."
168+
)
128169
monkeypatch_context.setenv("VLLM_USE_RAY_SPMD_WORKER", "1")
129170
monkeypatch_context.setenv("VLLM_USE_RAY_COMPILED_DAG", "1")
130171

@@ -147,12 +188,26 @@ def test_models_distributed(
147188
dtype=dtype,
148189
tensor_parallel_size=2,
149190
distributed_executor_backend=distributed_executor_backend,
191+
enable_prompt_embeds=enable_prompt_embeds,
192+
gpu_memory_utilization=0.7,
150193
) as vllm_model:
151-
vllm_outputs = vllm_model.generate_greedy(example_prompts,
152-
max_tokens)
153-
154-
with hf_runner(model, dtype=dtype) as hf_model:
155-
hf_outputs = hf_model.generate_greedy(example_prompts, max_tokens)
194+
if enable_prompt_embeds:
195+
with hf_runner(model, dtype=dtype) as hf_model:
196+
with torch.no_grad():
197+
prompt_embeds = hf_model.get_prompt_embeddings(
198+
example_prompts)
199+
vllm_outputs = vllm_model.generate_greedy(
200+
prompt_embeds, max_tokens)
201+
vllm_outputs = _fix_prompt_embed_outputs(
202+
vllm_outputs, hf_model, example_prompts)
203+
hf_outputs = hf_model.generate_greedy(
204+
example_prompts, max_tokens)
205+
else:
206+
vllm_outputs = vllm_model.generate_greedy(
207+
example_prompts, max_tokens)
208+
with hf_runner(model, dtype=dtype) as hf_model:
209+
hf_outputs = hf_model.generate_greedy(
210+
example_prompts, max_tokens)
156211

157212
check_outputs_equal(
158213
outputs_0_lst=hf_outputs,

tests/conftest.py

Lines changed: 9 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -430,6 +430,15 @@ def get_inputs(
430430

431431
return all_inputs
432432

433+
def get_prompt_embeddings(self, prompts: list[str]) -> list[torch.Tensor]:
434+
all_inputs = self.get_inputs(prompts)
435+
embeddings = []
436+
for inputs in all_inputs:
437+
input_ids = self.wrap_device(inputs)["input_ids"]
438+
embedding = self.model.get_input_embeddings()(input_ids).squeeze(0)
439+
embeddings.append(embedding)
440+
return embeddings
441+
433442
def classify(self, prompts: list[str]) -> list[str]:
434443
# output is final logits
435444
all_inputs = self.get_inputs(prompts)

vllm/sequence.py

Lines changed: 7 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -112,12 +112,12 @@ class RequestMetrics:
112112
will include model forward, block/sync across
113113
workers, cpu-gpu sync time and sampling time.
114114
spec_token_acceptance_counts: number of accepted speculative tokens at
115-
each position; the first token is from
115+
each position; the first token is from
116116
the target model and is always accepted;
117-
e.g., when it's [10, 8, 4, 2] for a req,
117+
e.g., when it's [10, 8, 4, 2] for a req,
118118
it means there were 10 forward passes in
119-
total, and there were 8, 4, 2 accepted
120-
tokens at 1st, 2nd, 3rd speculation step.
119+
total, and there were 8, 4, 2 accepted
120+
tokens at 1st, 2nd, 3rd speculation step.
121121
"""
122122
arrival_time: float
123123
last_token_time: float
@@ -714,9 +714,9 @@ class SequenceGroup:
714714
trace_headers: OpenTelemetry trace headers.
715715
prompt_adapter_request: Prompt Adapter request.
716716
priority: User-defined priority of the request.
717-
draft_size: The number of speculative tokens plus one from the target
717+
draft_size: The number of speculative tokens plus one from the target
718718
model; equal to max number of tokens a step can generate
719-
for single-draft speculative decoding but larger than
719+
for single-draft speculative decoding but larger than
720720
that for multi-draft SD (currently not supported).
721721
"""
722722

@@ -1123,7 +1123,7 @@ def __repr__(self) -> str:
11231123
self.output_embed.shape if self.output_embed is not None else None
11241124
return (f"SequenceOutput(parent_seq_id={self.parent_seq_id}, "
11251125
f"output_token={self.output_token}, "
1126-
f"output_embed.shape={output_embed_shape}"
1126+
f"output_embed.shape={output_embed_shape}, "
11271127
f"logprobs={self.logprobs})")
11281128

11291129
def __eq__(self, other: object) -> bool:

vllm/worker/model_runner.py

Lines changed: 55 additions & 45 deletions
Original file line numberDiff line numberDiff line change
@@ -23,7 +23,7 @@
2323
from vllm.attention.backends.utils import CommonAttentionState
2424
from vllm.config import CompilationLevel, VllmConfig
2525
from vllm.core.scheduler import SchedulerOutputs
26-
from vllm.distributed import get_pp_group
26+
from vllm.distributed import broadcast_tensor_dict, get_pp_group
2727
from vllm.distributed.kv_transfer import get_kv_transfer_group
2828
from vllm.distributed.parallel_state import (get_tensor_model_parallel_rank,
2929
graph_capture)
@@ -872,23 +872,23 @@ def build(self) -> ModelInputForGPU:
872872
"""
873873
# Combine and flatten intermediate data.
874874
input_tokens = list[int]()
875-
inputs_embeds_lst = list[torch.Tensor]()
875+
inputs_embeds_list = list[torch.Tensor]()
876876
token_types = list[int]()
877877
for inter_data in self.inter_data_list:
878878
for cur_input_tokens in inter_data.input_tokens:
879879
input_tokens.extend(cur_input_tokens)
880880
for cur_token_types in inter_data.token_types:
881881
token_types.extend(cur_token_types)
882882
if inter_data.inputs_embeds is not None:
883-
inputs_embeds_lst.append(
883+
inputs_embeds_list.append(
884884
inter_data.inputs_embeds.to(
885885
dtype=self.runner.model_config.dtype,
886886
device=self.runner.device))
887887
inputs_embeds: Optional[torch.Tensor]
888-
if len(inputs_embeds_lst) == 0:
888+
if len(inputs_embeds_list) == 0:
889889
inputs_embeds = None
890890
else:
891-
inputs_embeds = torch.cat(inputs_embeds_lst, dim=0).to(
891+
inputs_embeds = torch.cat(inputs_embeds_list, dim=0).to(
892892
dtype=self.runner.model_config.dtype,
893893
device=self.runner.device)
894894
assert len(inputs_embeds) == len(input_tokens)
@@ -1893,50 +1893,60 @@ def execute_model(
18931893
logits = self.model.compute_logits(hidden_or_intermediate_states,
18941894
model_input.sampling_metadata)
18951895

1896-
if not self.is_driver_worker:
1897-
return []
1896+
if self.is_driver_worker:
1897+
if model_input.async_callback is not None:
1898+
model_input.async_callback()
18981899

1899-
if model_input.async_callback is not None:
1900-
model_input.async_callback()
1900+
# Sample the next token.
1901+
assert isinstance(self.sampler, Sampler)
1902+
orig_include_gpu_probs = self.sampler.include_gpu_probs_tensor
1903+
if model_input.inputs_embeds is not None:
1904+
self.sampler.include_gpu_probs_tensor = True
19011905

1902-
# Sample the next token.
1903-
assert isinstance(self.sampler, Sampler)
1904-
orig_include_gpu_probs_tensor = self.sampler.include_gpu_probs_tensor
1905-
if model_input.inputs_embeds is not None:
1906-
self.sampler.include_gpu_probs_tensor = True
1907-
1908-
output: SamplerOutput = self.sampler(
1909-
logits=logits,
1910-
sampling_metadata=model_input.sampling_metadata,
1911-
)
1912-
if (self.observability_config is not None
1913-
and self.observability_config.collect_model_forward_time
1914-
and output is not None):
1915-
model_forward_end.synchronize()
1916-
model_forward_time = model_forward_start.elapsed_time(
1917-
model_forward_end)
1918-
orig_model_forward_time = 0.0
1919-
if intermediate_tensors is not None:
1920-
orig_model_forward_time = intermediate_tensors.tensors.get(
1921-
"model_forward_time", torch.tensor(0.0)).item()
1922-
# If there are multiple workers, we are still tracking the latency
1923-
# from the start time of the driver worker to the end time of the
1924-
# driver worker. The model forward time will then end up covering
1925-
# the communication time as well.
1926-
output.model_forward_time = (orig_model_forward_time +
1927-
model_forward_time)
1906+
output: SamplerOutput = self.sampler(
1907+
logits=logits,
1908+
sampling_metadata=model_input.sampling_metadata,
1909+
)
1910+
if (self.observability_config is not None
1911+
and self.observability_config.collect_model_forward_time
1912+
and output is not None):
1913+
model_forward_end.synchronize()
1914+
model_forward_time = model_forward_start.elapsed_time(
1915+
model_forward_end)
1916+
orig_model_forward_time = 0.0
1917+
if intermediate_tensors is not None:
1918+
orig_model_forward_time = intermediate_tensors.tensors.get(
1919+
"model_forward_time", torch.tensor(0.0)).item()
1920+
# If there are multiple workers, we are still tracking the
1921+
# latency from the start time of the driver worker to the end
1922+
# time of the driver worker. The model forward time will then
1923+
# end up covering the communication time as well.
1924+
output.model_forward_time = (orig_model_forward_time +
1925+
model_forward_time)
19281926

19291927
if model_input.inputs_embeds is not None:
1930-
self.sampler.include_gpu_probs_tensor = \
1931-
orig_include_gpu_probs_tensor
1932-
if output.sampled_token_ids is not None:
1933-
output.sampled_token_embeds = self.model.get_input_embeddings(
1934-
output.sampled_token_ids.squeeze(1))
1935-
1936-
for token_embed, sequence_group_output in zip(
1937-
output.sampled_token_embeds, output.outputs):
1938-
assert len(sequence_group_output.samples) == 1
1939-
sequence_group_output.samples[0].output_embed = token_embed
1928+
if self.is_driver_worker:
1929+
sampled = broadcast_tensor_dict(
1930+
{"token_ids": output.sampled_token_ids})
1931+
else:
1932+
sampled = broadcast_tensor_dict()
1933+
if sampled["token_ids"] is not None:
1934+
sampled_token_embeds = self.model.get_input_embeddings(
1935+
sampled["token_ids"].squeeze(1))
1936+
if self.is_driver_worker:
1937+
self.sampler.include_gpu_probs_tensor = \
1938+
orig_include_gpu_probs
1939+
1940+
output.sampled_token_embeds = sampled_token_embeds
1941+
1942+
for token_embed, sequence_group_output in zip(
1943+
output.sampled_token_embeds, output.outputs):
1944+
assert len(sequence_group_output.samples) == 1
1945+
sequence_group_output.samples[
1946+
0].output_embed = token_embed
1947+
1948+
if not self.is_driver_worker:
1949+
return []
19401950

19411951
if self.return_hidden_states:
19421952
# we only need to pass hidden states of most recent token

0 commit comments

Comments
 (0)