4
4
by the model.
5
5
"""
6
6
from argparse import Namespace
7
- from typing import List
7
+ from typing import List , NamedTuple , Optional
8
8
9
+ from PIL .Image import Image
9
10
from transformers import AutoProcessor , AutoTokenizer
10
11
11
12
from vllm import LLM , SamplingParams
19
20
]
20
21
21
22
22
- def load_qwenvl_chat (question : str , image_urls : List [str ]):
23
+ class ModelRequestData (NamedTuple ):
24
+ llm : LLM
25
+ prompt : str
26
+ stop_token_ids : Optional [List [str ]]
27
+ image_data : List [Image ]
28
+ chat_template : Optional [str ]
29
+
30
+
31
+ def load_qwenvl_chat (question : str , image_urls : List [str ]) -> ModelRequestData :
23
32
model_name = "Qwen/Qwen-VL-Chat"
24
33
llm = LLM (
25
34
model = model_name ,
@@ -48,10 +57,16 @@ def load_qwenvl_chat(question: str, image_urls: List[str]):
48
57
49
58
stop_tokens = ["<|endoftext|>" , "<|im_start|>" , "<|im_end|>" ]
50
59
stop_token_ids = [tokenizer .convert_tokens_to_ids (i ) for i in stop_tokens ]
51
- return llm , prompt , stop_token_ids , None , chat_template
60
+ return ModelRequestData (
61
+ llm = llm ,
62
+ prompt = prompt ,
63
+ stop_token_ids = stop_token_ids ,
64
+ image_data = [fetch_image (url ) for url in image_urls ],
65
+ chat_template = chat_template ,
66
+ )
52
67
53
68
54
- def load_phi3v (question : str , image_urls : List [str ]):
69
+ def load_phi3v (question : str , image_urls : List [str ]) -> ModelRequestData :
55
70
llm = LLM (
56
71
model = "microsoft/Phi-3.5-vision-instruct" ,
57
72
trust_remote_code = True ,
@@ -62,10 +77,17 @@ def load_phi3v(question: str, image_urls: List[str]):
62
77
for i , _ in enumerate (image_urls , start = 1 ))
63
78
prompt = f"<|user|>\n { placeholders } \n { question } <|end|>\n <|assistant|>\n "
64
79
stop_token_ids = None
65
- return llm , prompt , stop_token_ids , None , None
80
+
81
+ return ModelRequestData (
82
+ llm = llm ,
83
+ prompt = prompt ,
84
+ stop_token_ids = stop_token_ids ,
85
+ image_data = [fetch_image (url ) for url in image_urls ],
86
+ chat_template = None ,
87
+ )
66
88
67
89
68
- def load_internvl (question : str , image_urls : List [str ]):
90
+ def load_internvl (question : str , image_urls : List [str ]) -> ModelRequestData :
69
91
model_name = "OpenGVLab/InternVL2-2B"
70
92
71
93
llm = LLM (
@@ -93,10 +115,16 @@ def load_internvl(question: str, image_urls: List[str]):
93
115
stop_tokens = ["<|endoftext|>" , "<|im_start|>" , "<|im_end|>" , "<|end|>" ]
94
116
stop_token_ids = [tokenizer .convert_tokens_to_ids (i ) for i in stop_tokens ]
95
117
96
- return llm , prompt , stop_token_ids , None , None
118
+ return ModelRequestData (
119
+ llm = llm ,
120
+ prompt = prompt ,
121
+ stop_token_ids = stop_token_ids ,
122
+ image_data = [fetch_image (url ) for url in image_urls ],
123
+ chat_template = None ,
124
+ )
97
125
98
126
99
- def load_qwen2_vl (question , image_urls : List [str ]):
127
+ def load_qwen2_vl (question , image_urls : List [str ]) -> ModelRequestData :
100
128
try :
101
129
from qwen_vl_utils import process_vision_info
102
130
except ModuleNotFoundError :
@@ -143,7 +171,13 @@ def load_qwen2_vl(question, image_urls: List[str]):
143
171
else :
144
172
image_data , _ = process_vision_info (messages )
145
173
146
- return llm , prompt , stop_token_ids , image_data , None
174
+ return ModelRequestData (
175
+ llm = llm ,
176
+ prompt = prompt ,
177
+ stop_token_ids = stop_token_ids ,
178
+ image_data = image_data ,
179
+ chat_template = None ,
180
+ )
147
181
148
182
149
183
model_example_map = {
@@ -155,20 +189,17 @@ def load_qwen2_vl(question, image_urls: List[str]):
155
189
156
190
157
191
def run_generate (model , question : str , image_urls : List [str ]):
158
- llm , prompt , stop_token_ids , image_data , _ = model_example_map [model ](
159
- question , image_urls )
160
- if image_data is None :
161
- image_data = [fetch_image (url ) for url in image_urls ]
192
+ req_data = model_example_map [model ](question , image_urls )
162
193
163
194
sampling_params = SamplingParams (temperature = 0.0 ,
164
195
max_tokens = 128 ,
165
- stop_token_ids = stop_token_ids )
196
+ stop_token_ids = req_data . stop_token_ids )
166
197
167
- outputs = llm .generate (
198
+ outputs = req_data . llm .generate (
168
199
{
169
- "prompt" : prompt ,
200
+ "prompt" : req_data . prompt ,
170
201
"multi_modal_data" : {
171
- "image" : image_data
202
+ "image" : req_data . image_data
172
203
},
173
204
},
174
205
sampling_params = sampling_params )
@@ -179,13 +210,12 @@ def run_generate(model, question: str, image_urls: List[str]):
179
210
180
211
181
212
def run_chat (model : str , question : str , image_urls : List [str ]):
182
- llm , _ , stop_token_ids , _ , chat_template = model_example_map [model ](
183
- question , image_urls )
213
+ req_data = model_example_map [model ](question , image_urls )
184
214
185
215
sampling_params = SamplingParams (temperature = 0.0 ,
186
216
max_tokens = 128 ,
187
- stop_token_ids = stop_token_ids )
188
- outputs = llm .chat (
217
+ stop_token_ids = req_data . stop_token_ids )
218
+ outputs = req_data . llm .chat (
189
219
[{
190
220
"role" :
191
221
"user" ,
@@ -203,7 +233,7 @@ def run_chat(model: str, question: str, image_urls: List[str]):
203
233
],
204
234
}],
205
235
sampling_params = sampling_params ,
206
- chat_template = chat_template ,
236
+ chat_template = req_data . chat_template ,
207
237
)
208
238
209
239
for o in outputs :
0 commit comments