Skip to content

Commit 6318acd

Browse files
sangchengmengshihaobai
authored andcommitted
add openai-api-image
1 parent 1d67b1e commit 6318acd

File tree

3 files changed

+61
-4
lines changed

3 files changed

+61
-4
lines changed

lightllm/server/api_http.py

Lines changed: 39 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -19,15 +19,20 @@
1919
import asyncio
2020
import collections
2121
import time
22+
import json
2223
import uvloop
24+
import requests
25+
import base64
2326
import os
27+
from io import BytesIO
2428
import pickle
2529
from .build_prompt import build_prompt, init_tokenizer
2630

2731
asyncio.set_event_loop_policy(uvloop.EventLoopPolicy())
2832
import ujson as json
2933
from http import HTTPStatus
3034
import uuid
35+
from PIL import Image
3136
import multiprocessing as mp
3237
from typing import AsyncGenerator, Union
3338
from typing import Callable
@@ -40,6 +45,7 @@
4045
from .httpserver_for_pd_master.manager import HttpServerManagerForPDMaster
4146
from .api_lightllm import lightllm_get_score, lightllm_pd_generate_stream
4247
from lightllm.utils.envs_utils import get_env_start_args
48+
from lightllm.server.embed_cache.utils import image2base64
4349

4450
from .api_models import (
4551
ChatCompletionRequest,
@@ -230,6 +236,38 @@ async def chat_completions(request: ChatCompletionRequest, raw_request: Request)
230236
return create_error_response(HTTPStatus.BAD_REQUEST, "The function call feature is not supported")
231237

232238
created_time = int(time.time())
239+
240+
multimodal_params_dict = {"images": []}
241+
for message in request.messages:
242+
if isinstance(message.content, list):
243+
texts = []
244+
for content in message.content:
245+
if content.type == 'text' and content.text:
246+
texts.append(content.text)
247+
elif content.type == 'image_url' and content.image_url is not None:
248+
img = content.image_url.url
249+
if img.startswith("http://") or img.startswith("https://"):
250+
response = requests.get(img, stream=True, timeout=2)
251+
data = image2base64(response.raw)
252+
elif img.startswith("file://"):
253+
data = image2base64(img[7:])
254+
elif img.startswith("data:image"):
255+
# "data:image/jpeg;base64,{base64_image}"
256+
data_str = img.split(";", 1)[1]
257+
if data_str.startswith("base64,"):
258+
data = data_str[7:]
259+
else :
260+
raise ValueError("Unrecognized image input.")
261+
else:
262+
raise ValueError("Unrecognized image input. Supports local path, http url, base64, and PIL.Image.")
263+
264+
multimodal_params_dict["images"].append({
265+
"type": "base64",
266+
"data": data
267+
})
268+
269+
message.content = "\n".join(texts)
270+
233271
prompt = await build_prompt(request)
234272
sampling_params_dict = {
235273
"do_sample": request.do_sample,
@@ -249,7 +287,7 @@ async def chat_completions(request: ChatCompletionRequest, raw_request: Request)
249287
sampling_params.init(tokenizer=g_objs.httpserver_manager.tokenizer, **sampling_params_dict)
250288

251289
sampling_params.verify()
252-
multimodal_params = MultimodalParams(images=[])
290+
multimodal_params = MultimodalParams(**multimodal_params_dict)
253291

254292
results_generator = g_objs.httpserver_manager.generate(
255293
prompt, sampling_params, multimodal_params, request=raw_request

lightllm/server/api_models.py

Lines changed: 13 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -5,10 +5,21 @@
55
import uuid
66

77

8+
class ImageURL(BaseModel):
9+
url: str
10+
11+
class MessageContent(BaseModel):
12+
type: str
13+
text: Optional[str] = None
14+
image_url: Optional[ImageURL] = None
15+
16+
class Message(BaseModel):
17+
role: str
18+
content: Union[str, List[MessageContent]]
19+
820
class ChatCompletionRequest(BaseModel):
9-
# The openai api native parameters
1021
model: str
11-
messages: List[Dict[str, str]]
22+
messages: List[Message]
1223
function_call: Optional[str] = "none"
1324
temperature: Optional[float] = 1
1425
top_p: Optional[float] = 1.0

lightllm/server/embed_cache/utils.py

Lines changed: 9 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,8 +1,9 @@
1+
import base64
12
import torch
23
import numpy as np
34
from io import BytesIO
45
import multiprocessing.shared_memory as shm
5-
6+
from PIL import Image
67

78
def tensor2bytes(t):
89
# t = t.cpu().numpy().tobytes()
@@ -12,6 +13,13 @@ def tensor2bytes(t):
1213
buf.seek(0)
1314
return buf.read()
1415

16+
def image2base64(img_str: str):
17+
image_obj = Image.open(img_str)
18+
if image_obj.format is None:
19+
raise ValueError("No image format found.")
20+
buffer = BytesIO()
21+
image_obj.save(buffer, format=image_obj.format)
22+
return base64.b64encode(buffer.getvalue()).decode('utf-8')
1523

1624
def bytes2tensor(b):
1725
# return torch.from_numpy(np.frombuffer(b, dtype=np.float16)).cuda()

0 commit comments

Comments
 (0)