Skip to content
This repository was archived by the owner on Sep 4, 2025. It is now read-only.

Commit 8ca5051

Browse files
[Misc] Use NamedTuple in Multi-image example (vllm-project#8705)
Signed-off-by: Alex-Brooks <[email protected]>
1 parent 06ed281 commit 8ca5051

File tree

1 file changed

+52
-22
lines changed

1 file changed

+52
-22
lines changed

examples/offline_inference_vision_language_multi_image.py

Lines changed: 52 additions & 22 deletions
Original file line numberDiff line numberDiff line change
@@ -4,8 +4,9 @@
44
by the model.
55
"""
66
from argparse import Namespace
7-
from typing import List
7+
from typing import List, NamedTuple, Optional
88

9+
from PIL.Image import Image
910
from transformers import AutoProcessor, AutoTokenizer
1011

1112
from vllm import LLM, SamplingParams
@@ -19,7 +20,15 @@
1920
]
2021

2122

22-
def load_qwenvl_chat(question: str, image_urls: List[str]):
23+
class ModelRequestData(NamedTuple):
24+
llm: LLM
25+
prompt: str
26+
stop_token_ids: Optional[List[str]]
27+
image_data: List[Image]
28+
chat_template: Optional[str]
29+
30+
31+
def load_qwenvl_chat(question: str, image_urls: List[str]) -> ModelRequestData:
2332
model_name = "Qwen/Qwen-VL-Chat"
2433
llm = LLM(
2534
model=model_name,
@@ -48,10 +57,16 @@ def load_qwenvl_chat(question: str, image_urls: List[str]):
4857

4958
stop_tokens = ["<|endoftext|>", "<|im_start|>", "<|im_end|>"]
5059
stop_token_ids = [tokenizer.convert_tokens_to_ids(i) for i in stop_tokens]
51-
return llm, prompt, stop_token_ids, None, chat_template
60+
return ModelRequestData(
61+
llm=llm,
62+
prompt=prompt,
63+
stop_token_ids=stop_token_ids,
64+
image_data=[fetch_image(url) for url in image_urls],
65+
chat_template=chat_template,
66+
)
5267

5368

54-
def load_phi3v(question: str, image_urls: List[str]):
69+
def load_phi3v(question: str, image_urls: List[str]) -> ModelRequestData:
5570
llm = LLM(
5671
model="microsoft/Phi-3.5-vision-instruct",
5772
trust_remote_code=True,
@@ -62,10 +77,17 @@ def load_phi3v(question: str, image_urls: List[str]):
6277
for i, _ in enumerate(image_urls, start=1))
6378
prompt = f"<|user|>\n{placeholders}\n{question}<|end|>\n<|assistant|>\n"
6479
stop_token_ids = None
65-
return llm, prompt, stop_token_ids, None, None
80+
81+
return ModelRequestData(
82+
llm=llm,
83+
prompt=prompt,
84+
stop_token_ids=stop_token_ids,
85+
image_data=[fetch_image(url) for url in image_urls],
86+
chat_template=None,
87+
)
6688

6789

68-
def load_internvl(question: str, image_urls: List[str]):
90+
def load_internvl(question: str, image_urls: List[str]) -> ModelRequestData:
6991
model_name = "OpenGVLab/InternVL2-2B"
7092

7193
llm = LLM(
@@ -93,10 +115,16 @@ def load_internvl(question: str, image_urls: List[str]):
93115
stop_tokens = ["<|endoftext|>", "<|im_start|>", "<|im_end|>", "<|end|>"]
94116
stop_token_ids = [tokenizer.convert_tokens_to_ids(i) for i in stop_tokens]
95117

96-
return llm, prompt, stop_token_ids, None, None
118+
return ModelRequestData(
119+
llm=llm,
120+
prompt=prompt,
121+
stop_token_ids=stop_token_ids,
122+
image_data=[fetch_image(url) for url in image_urls],
123+
chat_template=None,
124+
)
97125

98126

99-
def load_qwen2_vl(question, image_urls: List[str]):
127+
def load_qwen2_vl(question, image_urls: List[str]) -> ModelRequestData:
100128
try:
101129
from qwen_vl_utils import process_vision_info
102130
except ModuleNotFoundError:
@@ -143,7 +171,13 @@ def load_qwen2_vl(question, image_urls: List[str]):
143171
else:
144172
image_data, _ = process_vision_info(messages)
145173

146-
return llm, prompt, stop_token_ids, image_data, None
174+
return ModelRequestData(
175+
llm=llm,
176+
prompt=prompt,
177+
stop_token_ids=stop_token_ids,
178+
image_data=image_data,
179+
chat_template=None,
180+
)
147181

148182

149183
model_example_map = {
@@ -155,20 +189,17 @@ def load_qwen2_vl(question, image_urls: List[str]):
155189

156190

157191
def run_generate(model, question: str, image_urls: List[str]):
158-
llm, prompt, stop_token_ids, image_data, _ = model_example_map[model](
159-
question, image_urls)
160-
if image_data is None:
161-
image_data = [fetch_image(url) for url in image_urls]
192+
req_data = model_example_map[model](question, image_urls)
162193

163194
sampling_params = SamplingParams(temperature=0.0,
164195
max_tokens=128,
165-
stop_token_ids=stop_token_ids)
196+
stop_token_ids=req_data.stop_token_ids)
166197

167-
outputs = llm.generate(
198+
outputs = req_data.llm.generate(
168199
{
169-
"prompt": prompt,
200+
"prompt": req_data.prompt,
170201
"multi_modal_data": {
171-
"image": image_data
202+
"image": req_data.image_data
172203
},
173204
},
174205
sampling_params=sampling_params)
@@ -179,13 +210,12 @@ def run_generate(model, question: str, image_urls: List[str]):
179210

180211

181212
def run_chat(model: str, question: str, image_urls: List[str]):
182-
llm, _, stop_token_ids, _, chat_template = model_example_map[model](
183-
question, image_urls)
213+
req_data = model_example_map[model](question, image_urls)
184214

185215
sampling_params = SamplingParams(temperature=0.0,
186216
max_tokens=128,
187-
stop_token_ids=stop_token_ids)
188-
outputs = llm.chat(
217+
stop_token_ids=req_data.stop_token_ids)
218+
outputs = req_data.llm.chat(
189219
[{
190220
"role":
191221
"user",
@@ -203,7 +233,7 @@ def run_chat(model: str, question: str, image_urls: List[str]):
203233
],
204234
}],
205235
sampling_params=sampling_params,
206-
chat_template=chat_template,
236+
chat_template=req_data.chat_template,
207237
)
208238

209239
for o in outputs:

0 commit comments

Comments
 (0)