Skip to content

Commit 29424ef

Browse files
sangchengmengshihaobai
authored andcommitted
[add]openai_api_support_image
1 parent 6234bd3 commit 29424ef

File tree

4 files changed

+64
-4
lines changed

4 files changed

+64
-4
lines changed

lightllm/server/api_cli.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -278,7 +278,7 @@ def make_argument_parser() -> argparse.ArgumentParser:
278278
parser.add_argument(
279279
"--vit_quant_type",
280280
type=str,
281-
default=None,
281+
default="none",
282282
help="""Quantization method: ppl-w4a16-128 | flashllm-w6a16
283283
| ao-int4wo-[32,64,128,256] | ao-int8wo | ao-fp8w8a16 | ao-fp6w6a16
284284
| vllm-w8a8 | vllm-fp8w8a8""",

lightllm/server/api_http.py

Lines changed: 24 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -231,6 +231,29 @@ async def chat_completions(request: ChatCompletionRequest, raw_request: Request)
231231
return create_error_response(HTTPStatus.BAD_REQUEST, "The function call feature is not supported")
232232

233233
created_time = int(time.time())
234+
235+
multimodal_params_dict = {"images": []}
236+
for message in request.messages:
237+
if isinstance(message.content, list):
238+
texts = []
239+
for content in message.content:
240+
if content.type == 'text' and content.text:
241+
texts.append(content.text)
242+
elif content.type == 'image_url' and content.image_url is not None:
243+
for img in content.image_url.url:
244+
data_str = img.data
245+
prefix = "base64,"
246+
idx = data_str.find(prefix)
247+
if idx != -1:
248+
data_str = data_str[idx + len(prefix):]
249+
multimodal_params_dict["images"].append({
250+
"type": "base64",
251+
"data": data_str
252+
})
253+
254+
message.content = "\n".join(texts)
255+
# print(multimodal_params_dict)
256+
234257
prompt = await build_prompt(request)
235258
sampling_params_dict = {
236259
"do_sample": request.do_sample,
@@ -250,7 +273,7 @@ async def chat_completions(request: ChatCompletionRequest, raw_request: Request)
250273
sampling_params.init(tokenizer=g_objs.httpserver_manager.tokenizer, **sampling_params_dict)
251274

252275
sampling_params.verify()
253-
multimodal_params = MultimodalParams(images=[])
276+
multimodal_params = MultimodalParams(**multimodal_params_dict)
254277

255278
results_generator = g_objs.httpserver_manager.generate(
256279
prompt, sampling_params, multimodal_params, request=raw_request

lightllm/server/api_models.py

Lines changed: 38 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -4,11 +4,47 @@
44
from typing import Dict, List, Optional, Union, Literal
55
import uuid
66

7+
class ImageData(BaseModel):
8+
type: str
9+
data: str
10+
11+
class ImageURL(BaseModel):
12+
url: List[ImageData]
13+
14+
@field_validator("url", mode="before")
15+
def ensure_list(cls, v):
16+
if isinstance(v, list):
17+
new_list = []
18+
for item in v:
19+
if isinstance(item, str):
20+
new_list.append({"type": "base64", "data": item})
21+
elif isinstance(item, dict):
22+
if "type" not in item:
23+
item["type"] = "base64"
24+
new_list.append(item)
25+
else:
26+
new_list.append(item)
27+
return new_list
28+
elif isinstance(v, str):
29+
return [{"type": "base64", "data": v}]
30+
elif isinstance(v, dict):
31+
if "type" not in v:
32+
v["type"] = "base64"
33+
return [v]
34+
return v
35+
36+
class MessageContent(BaseModel):
37+
type: str
38+
text: Optional[str] = None
39+
image_url: Optional[ImageURL] = None
40+
41+
class Message(BaseModel):
42+
role: str
43+
content: Union[str, List[MessageContent]]
744

845
class ChatCompletionRequest(BaseModel):
9-
# The openai api native parameters
1046
model: str
11-
messages: List[Dict[str, str]]
47+
messages: List[Message]
1248
function_call: Optional[str] = "none"
1349
temperature: Optional[float] = 1
1450
top_p: Optional[float] = 1.0

lightllm/server/visualserver/model_infer/model_rpc.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -22,6 +22,7 @@
2222

2323
class VisualModelRpcServer(rpyc.Service):
2424
def exposed_init_model(self, kvargs):
25+
kvargs = obtain(kvargs)
2526
import torch
2627
import torch.distributed as dist
2728

0 commit comments

Comments
 (0)