88from unittest .mock import Mock
99
1010import pytest
11+ import torch
1112
12- from vllm import LLM
13+ from vllm import LLM , envs
1314from vllm .platforms import current_platform
1415from vllm .v1 .engine .llm_engine import LLMEngine as LLMEngineV1
1516
16- from ..conftest import VllmRunner
17+ from ..conftest import HfRunner , VllmRunner
1718from ..models .utils import check_outputs_equal
1819from ..utils import multi_gpu_test
1920
@@ -43,11 +44,26 @@ def test_vllm_gc_ed():
4344 assert weak_llm () is None
4445
4546
47+ def _fix_prompt_embed_outputs (
48+ vllm_outputs : list [tuple [list [int ], str ]], hf_model : HfRunner ,
49+ example_prompts : list [str ]) -> list [tuple [list [int ], str ]]:
50+ fixed_vllm_outputs = []
51+ for vllm_output , hf_input , prompt in zip (
52+ vllm_outputs , hf_model .get_inputs (example_prompts ),
53+ example_prompts ):
54+ hf_input_ids = hf_input ["input_ids" ].tolist ()[0 ]
55+ fixed_vllm_outputs .append (
56+ (hf_input_ids + vllm_output [0 ][len (hf_input_ids ):],
57+ prompt + vllm_output [1 ]))
58+ return fixed_vllm_outputs
59+
60+
4661@pytest .mark .parametrize ("model" , MODELS )
4762@pytest .mark .parametrize ("backend" , ["FLASH_ATTN" ])
4863@pytest .mark .parametrize ("dtype" , ["half" ])
4964@pytest .mark .parametrize ("max_tokens" , [5 ])
5065@pytest .mark .parametrize ("enforce_eager" , [False ])
66+ @pytest .mark .parametrize ("enable_prompt_embeds" , [True , False ])
5167def test_models (
5268 monkeypatch : pytest .MonkeyPatch ,
5369 hf_runner ,
@@ -56,8 +72,13 @@ def test_models(
5672 dtype : str ,
5773 max_tokens : int ,
5874 enforce_eager : bool ,
75+ enable_prompt_embeds : bool ,
5976) -> None :
6077
78+ if enable_prompt_embeds and envs .is_set (
79+ "VLLM_USE_V1" ) and envs .VLLM_USE_V1 :
80+ pytest .skip ("enable_prompt_embeds is not supported in v1." )
81+
6182 if backend == "FLASHINFER" and current_platform .is_rocm ():
6283 pytest .skip ("Flashinfer does not support ROCm/HIP." )
6384
@@ -78,14 +99,25 @@ def test_models(
7899
79100 with hf_runner (model , dtype = dtype ) as hf_model :
80101 hf_outputs = hf_model .generate_greedy (example_prompts , max_tokens )
102+ if enable_prompt_embeds :
103+ with torch .no_grad ():
104+ prompt_embeds = hf_model .get_prompt_embeddings (
105+ example_prompts )
81106
82107 with VllmRunner (model ,
83108 max_model_len = 8192 ,
84109 dtype = dtype ,
85110 enforce_eager = enforce_eager ,
111+ enable_prompt_embeds = enable_prompt_embeds ,
86112 gpu_memory_utilization = 0.7 ) as vllm_model :
87- vllm_outputs = vllm_model .generate_greedy (example_prompts ,
88- max_tokens )
113+ if enable_prompt_embeds :
114+ vllm_outputs = vllm_model .generate_greedy (
115+ prompt_embeds , max_tokens )
116+ vllm_outputs = _fix_prompt_embed_outputs (
117+ vllm_outputs , hf_model , example_prompts )
118+ else :
119+ vllm_outputs = vllm_model .generate_greedy (
120+ example_prompts , max_tokens )
89121
90122 check_outputs_equal (
91123 outputs_0_lst = hf_outputs ,
@@ -108,6 +140,7 @@ def test_models(
108140 ("distilbert/distilgpt2" , "mp" , "FLASHINFER" , "A100" ),
109141 ("meta-llama/Meta-Llama-3-8B" , "ray" , "FLASHINFER" , "A100" ),
110142 ])
143+ @pytest .mark .parametrize ("enable_prompt_embeds" , [True , False ])
111144def test_models_distributed (
112145 monkeypatch : pytest .MonkeyPatch ,
113146 hf_runner ,
@@ -117,14 +150,22 @@ def test_models_distributed(
117150 distributed_executor_backend : str ,
118151 attention_backend : str ,
119152 test_suite : str ,
153+ enable_prompt_embeds : bool ,
120154) -> None :
121155
156+ if enable_prompt_embeds and envs .is_set (
157+ "VLLM_USE_V1" ) and envs .VLLM_USE_V1 :
158+ pytest .skip ("enable_prompt_embeds is not supported in v1." )
159+
122160 if test_suite != TARGET_TEST_SUITE :
123161 pytest .skip (f"Skip test for { test_suite } " )
124162
125163 with monkeypatch .context () as monkeypatch_context :
126164 if model == "meta-llama/Llama-3.2-1B-Instruct" and distributed_executor_backend == "ray" and attention_backend == "" and test_suite == "L4" : # noqa
127- # test Ray Compiled Graph
165+ if enable_prompt_embeds :
166+ pytest .skip (
167+ "enable_prompt_embeds does not work with ray compiled dag."
168+ )
128169 monkeypatch_context .setenv ("VLLM_USE_RAY_SPMD_WORKER" , "1" )
129170 monkeypatch_context .setenv ("VLLM_USE_RAY_COMPILED_DAG" , "1" )
130171
@@ -147,12 +188,26 @@ def test_models_distributed(
147188 dtype = dtype ,
148189 tensor_parallel_size = 2 ,
149190 distributed_executor_backend = distributed_executor_backend ,
191+ enable_prompt_embeds = enable_prompt_embeds ,
192+ gpu_memory_utilization = 0.7 ,
150193 ) as vllm_model :
151- vllm_outputs = vllm_model .generate_greedy (example_prompts ,
152- max_tokens )
153-
154- with hf_runner (model , dtype = dtype ) as hf_model :
155- hf_outputs = hf_model .generate_greedy (example_prompts , max_tokens )
194+ if enable_prompt_embeds :
195+ with hf_runner (model , dtype = dtype ) as hf_model :
196+ with torch .no_grad ():
197+ prompt_embeds = hf_model .get_prompt_embeddings (
198+ example_prompts )
199+ vllm_outputs = vllm_model .generate_greedy (
200+ prompt_embeds , max_tokens )
201+ vllm_outputs = _fix_prompt_embed_outputs (
202+ vllm_outputs , hf_model , example_prompts )
203+ hf_outputs = hf_model .generate_greedy (
204+ example_prompts , max_tokens )
205+ else :
206+ vllm_outputs = vllm_model .generate_greedy (
207+ example_prompts , max_tokens )
208+ with hf_runner (model , dtype = dtype ) as hf_model :
209+ hf_outputs = hf_model .generate_greedy (
210+ example_prompts , max_tokens )
156211
157212 check_outputs_equal (
158213 outputs_0_lst = hf_outputs ,
0 commit comments