|
25 | 25 | import os |
26 | 26 | from io import BytesIO |
27 | 27 | import pickle |
28 | | -from .build_prompt import build_prompt, init_tokenizer |
29 | 28 |
|
30 | 29 | asyncio.set_event_loop_policy(uvloop.EventLoopPolicy()) |
31 | 30 | import ujson as json |
|
44 | 43 | from .httpserver_for_pd_master.manager import HttpServerManagerForPDMaster |
45 | 44 | from .api_lightllm import lightllm_get_score, lightllm_pd_generate_stream |
46 | 45 | from lightllm.utils.envs_utils import get_env_start_args, get_lightllm_websocket_max_message_size |
| 46 | +from lightllm.utils.log_utils import init_logger |
| 47 | +from lightllm.server.metrics.manager import MetricClient |
| 48 | +from lightllm.utils.envs_utils import get_unique_server_name |
| 49 | +from dataclasses import dataclass |
47 | 50 |
|
| 51 | +from .api_openai import chat_completions_impl |
48 | 52 | from .api_models import ( |
49 | 53 | ChatCompletionRequest, |
50 | | - UsageInfo, |
51 | | - ChatMessage, |
52 | | - ChatCompletionResponseChoice, |
53 | 54 | ChatCompletionResponse, |
54 | | - DeltaMessage, |
55 | | - ChatCompletionStreamResponse, |
56 | | - ChatCompletionStreamResponseChoice, |
57 | 55 | ) |
58 | | - |
59 | | -from lightllm.utils.log_utils import init_logger |
60 | | -from lightllm.server.metrics.manager import MetricClient |
61 | | -from lightllm.utils.envs_utils import get_unique_server_name |
62 | | -from dataclasses import dataclass |
| 56 | +from .build_prompt import build_prompt, init_tokenizer |
63 | 57 |
|
64 | 58 | logger = init_logger(__name__) |
65 | 59 |
|
@@ -224,133 +218,8 @@ async def compat_generate(request: Request) -> Response: |
224 | 218 |
|
225 | 219 | @app.post("/v1/chat/completions", response_model=ChatCompletionResponse) |
226 | 220 | async def chat_completions(request: ChatCompletionRequest, raw_request: Request) -> Response: |
227 | | - |
228 | | - if request.logit_bias is not None: |
229 | | - return create_error_response( |
230 | | - HTTPStatus.BAD_REQUEST, |
231 | | - "The logit_bias parameter is not currently supported", |
232 | | - ) |
233 | | - |
234 | | - if request.function_call != "none": |
235 | | - return create_error_response(HTTPStatus.BAD_REQUEST, "The function call feature is not supported") |
236 | | - |
237 | | - created_time = int(time.time()) |
238 | | - |
239 | | - multimodal_params_dict = {"images": []} |
240 | | - for message in request.messages: |
241 | | - if isinstance(message.content, list): |
242 | | - texts = [] |
243 | | - for content in message.content: |
244 | | - if content.type == "text" and content.text: |
245 | | - texts.append(content.text) |
246 | | - elif content.type == "image_url" and content.image_url is not None: |
247 | | - img = content.image_url.url |
248 | | - if img.startswith("http://") or img.startswith("https://"): |
249 | | - multimodal_params_dict["images"].append({"type": "url", "data": img}) |
250 | | - elif img.startswith("data:image"): |
251 | | - # "data:image/jpeg;base64,{base64_image}" |
252 | | - data_str = img.split(";", 1)[1] |
253 | | - if data_str.startswith("base64,"): |
254 | | - data = data_str[7:] |
255 | | - multimodal_params_dict["images"].append({"type": "base64", "data": data}) |
256 | | - else: |
257 | | - raise ValueError("Unrecognized image input.") |
258 | | - else: |
259 | | - raise ValueError( |
260 | | - "Unrecognized image input. Supports local path, http url, base64, and PIL.Image." |
261 | | - ) |
262 | | - |
263 | | - message.content = "\n".join(texts) |
264 | | - |
265 | | - prompt = await build_prompt(request) |
266 | | - sampling_params_dict = { |
267 | | - "do_sample": request.do_sample, |
268 | | - "presence_penalty": request.presence_penalty, |
269 | | - "frequency_penalty": request.frequency_penalty, |
270 | | - "temperature": request.temperature, |
271 | | - "top_p": request.top_p, |
272 | | - "top_k": request.top_k, |
273 | | - "ignore_eos": request.ignore_eos, |
274 | | - "max_new_tokens": request.max_tokens, |
275 | | - "stop_sequences": request.stop, |
276 | | - "n": request.n, |
277 | | - "best_of": request.n, |
278 | | - "add_special_tokens": False, |
279 | | - } |
280 | | - sampling_params = SamplingParams() |
281 | | - sampling_params.init(tokenizer=g_objs.httpserver_manager.tokenizer, **sampling_params_dict) |
282 | | - |
283 | | - sampling_params.verify() |
284 | | - multimodal_params = MultimodalParams(**multimodal_params_dict) |
285 | | - |
286 | | - results_generator = g_objs.httpserver_manager.generate( |
287 | | - prompt, sampling_params, multimodal_params, request=raw_request |
288 | | - ) |
289 | | - |
290 | | - # Non-streaming case |
291 | | - if not request.stream: |
292 | | - final_output_dict = collections.defaultdict(list) |
293 | | - count_output_tokens_dict = collections.defaultdict(lambda: 0) |
294 | | - finish_reason_dict = {} |
295 | | - prompt_tokens_dict = {} |
296 | | - completion_tokens = 0 |
297 | | - async for sub_req_id, request_output, metadata, finish_status in results_generator: |
298 | | - from .req_id_generator import convert_sub_id_to_group_id |
299 | | - |
300 | | - group_request_id = convert_sub_id_to_group_id(sub_req_id) |
301 | | - count_output_tokens_dict[sub_req_id] += 1 |
302 | | - final_output_dict[sub_req_id].append(request_output) |
303 | | - if finish_status.is_finished(): |
304 | | - finish_reason_dict[sub_req_id] = finish_status.get_finish_reason() |
305 | | - prompt_tokens_dict[sub_req_id] = metadata["prompt_tokens"] |
306 | | - choices = [] |
307 | | - sub_ids = list(final_output_dict.keys())[: request.n] |
308 | | - for i in range(request.n): |
309 | | - sub_req_id = sub_ids[i] |
310 | | - prompt_tokens = prompt_tokens_dict[sub_req_id] |
311 | | - completion_tokens = count_output_tokens_dict[sub_req_id] |
312 | | - usage = UsageInfo( |
313 | | - prompt_tokens=prompt_tokens, |
314 | | - completion_tokens=completion_tokens, |
315 | | - total_tokens=prompt_tokens + completion_tokens, |
316 | | - ) |
317 | | - chat_message = ChatMessage(role="assistant", content="".join(final_output_dict[sub_req_id])) |
318 | | - choice = ChatCompletionResponseChoice( |
319 | | - index=i, message=chat_message, finish_reason=finish_reason_dict[sub_req_id] |
320 | | - ) |
321 | | - choices.append(choice) |
322 | | - resp = ChatCompletionResponse( |
323 | | - id=group_request_id, created=created_time, model=request.model, choices=choices, usage=usage |
324 | | - ) |
325 | | - return resp |
326 | | - |
327 | | - if sampling_params.n != 1: |
328 | | - raise Exception("stream api only support n = 1") |
329 | | - |
330 | | - # Streaming case |
331 | | - async def stream_results() -> AsyncGenerator[bytes, None]: |
332 | | - finish_reason = None |
333 | | - from .req_id_generator import convert_sub_id_to_group_id |
334 | | - |
335 | | - async for sub_req_id, request_output, metadata, finish_status in results_generator: |
336 | | - group_request_id = convert_sub_id_to_group_id(sub_req_id) |
337 | | - |
338 | | - delta_message = DeltaMessage(role="assistant", content=request_output) |
339 | | - if finish_status.is_finished(): |
340 | | - finish_reason = finish_status.get_finish_reason() |
341 | | - stream_choice = ChatCompletionStreamResponseChoice( |
342 | | - index=0, delta=delta_message, finish_reason=finish_reason |
343 | | - ) |
344 | | - stream_resp = ChatCompletionStreamResponse( |
345 | | - id=group_request_id, |
346 | | - created=created_time, |
347 | | - model=request.model, |
348 | | - choices=[stream_choice], |
349 | | - ) |
350 | | - yield ("data: " + json.dumps(stream_resp.dict(), ensure_ascii=False) + "\n\n").encode("utf-8") |
351 | | - |
352 | | - background_tasks = BackgroundTasks() |
353 | | - return StreamingResponse(stream_results(), media_type="text/event-stream", background=background_tasks) |
| 221 | + resp = await chat_completions_impl(request, raw_request) |
| 222 | + return resp |
354 | 223 |
|
355 | 224 |
|
356 | 225 | @app.get("/tokens") |
|
0 commit comments