|
45 | 45 | from .api_lightllm import lightllm_get_score, lightllm_pd_generate_stream |
46 | 46 | from lightllm.utils.envs_utils import get_env_start_args, get_lightllm_websocket_max_message_size |
47 | 47 |
|
48 | | -from .api_models import ( |
49 | | - ChatCompletionRequest, |
50 | | - UsageInfo, |
51 | | - ChatMessage, |
52 | | - ChatCompletionResponseChoice, |
53 | | - ChatCompletionResponse, |
54 | | - DeltaMessage, |
55 | | - ChatCompletionStreamResponse, |
56 | | - ChatCompletionStreamResponseChoice, |
57 | | -) |
58 | | - |
59 | 48 | from lightllm.utils.log_utils import init_logger |
60 | 49 | from lightllm.server.metrics.manager import MetricClient |
61 | 50 | from lightllm.utils.envs_utils import get_unique_server_name |
62 | 51 | from dataclasses import dataclass |
63 | 52 |
|
64 | | -logger = init_logger(__name__) |
65 | | - |
| 53 | +from .api_openai import app as openai_api |
| 54 | +from .api_openai import g_objs |
66 | 55 |
|
67 | | -@dataclass |
68 | | -class G_Objs: |
69 | | - app: FastAPI = None |
70 | | - metric_client: MetricClient = None |
71 | | - args: object = None |
72 | | - g_generate_func: Callable = None |
73 | | - g_generate_stream_func: Callable = None |
74 | | - httpserver_manager: Union[HttpServerManager, HttpServerManagerForPDMaster] = None |
75 | | - shared_token_load: TokenLoad = None |
76 | | - |
77 | | - def set_args(self, args): |
78 | | - self.args = args |
79 | | - from .api_lightllm import lightllm_generate, lightllm_generate_stream |
80 | | - from .api_tgi import tgi_generate_impl, tgi_generate_stream_impl |
81 | | - |
82 | | - if args.use_tgi_api: |
83 | | - self.g_generate_func = tgi_generate_impl |
84 | | - self.g_generate_stream_func = tgi_generate_stream_impl |
85 | | - else: |
86 | | - self.g_generate_func = lightllm_generate |
87 | | - self.g_generate_stream_func = lightllm_generate_stream |
88 | | - |
89 | | - if args.run_mode == "pd_master": |
90 | | - self.metric_client = MetricClient(args.metric_port) |
91 | | - self.httpserver_manager = HttpServerManagerForPDMaster( |
92 | | - args, |
93 | | - metric_port=args.metric_port, |
94 | | - ) |
95 | | - else: |
96 | | - init_tokenizer(args) # for openai api |
97 | | - SamplingParams.load_generation_cfg(args.model_dir) |
98 | | - self.metric_client = MetricClient(args.metric_port) |
99 | | - self.httpserver_manager = HttpServerManager( |
100 | | - args, |
101 | | - router_port=args.router_port, |
102 | | - cache_port=args.cache_port, |
103 | | - detokenization_pub_port=args.detokenization_pub_port, |
104 | | - visual_port=args.visual_port, |
105 | | - enable_multimodal=args.enable_multimodal, |
106 | | - metric_port=args.metric_port, |
107 | | - ) |
108 | | - dp_size_in_node = max(1, args.dp // args.nnodes) # 兼容多机纯tp的运行模式,这时候 1 // 2 == 0, 需要兼容 |
109 | | - self.shared_token_load = TokenLoad(f"{get_unique_server_name()}_shared_token_load", dp_size_in_node) |
110 | | - |
111 | | - |
112 | | -g_objs = G_Objs() |
| 56 | +logger = init_logger(__name__) |
113 | 57 |
|
114 | 58 | app = FastAPI() |
115 | 59 | g_objs.app = app |
116 | 60 |
|
| 61 | +app.mount("/v1", openai_api) |
| 62 | + |
117 | 63 |
|
118 | 64 | def create_error_response(status_code: HTTPStatus, message: str) -> JSONResponse: |
119 | 65 | g_objs.metric_client.counter_inc("lightllm_request_failure") |
@@ -222,137 +168,6 @@ async def compat_generate(request: Request) -> Response: |
222 | 168 | return await generate(request) |
223 | 169 |
|
224 | 170 |
|
225 | | -@app.post("/v1/chat/completions", response_model=ChatCompletionResponse) |
226 | | -async def chat_completions(request: ChatCompletionRequest, raw_request: Request) -> Response: |
227 | | - |
228 | | - if request.logit_bias is not None: |
229 | | - return create_error_response( |
230 | | - HTTPStatus.BAD_REQUEST, |
231 | | - "The logit_bias parameter is not currently supported", |
232 | | - ) |
233 | | - |
234 | | - if request.function_call != "none": |
235 | | - return create_error_response(HTTPStatus.BAD_REQUEST, "The function call feature is not supported") |
236 | | - |
237 | | - created_time = int(time.time()) |
238 | | - |
239 | | - multimodal_params_dict = {"images": []} |
240 | | - for message in request.messages: |
241 | | - if isinstance(message.content, list): |
242 | | - texts = [] |
243 | | - for content in message.content: |
244 | | - if content.type == "text" and content.text: |
245 | | - texts.append(content.text) |
246 | | - elif content.type == "image_url" and content.image_url is not None: |
247 | | - img = content.image_url.url |
248 | | - if img.startswith("http://") or img.startswith("https://"): |
249 | | - multimodal_params_dict["images"].append({"type": "url", "data": img}) |
250 | | - elif img.startswith("data:image"): |
251 | | - # "data:image/jpeg;base64,{base64_image}" |
252 | | - data_str = img.split(";", 1)[1] |
253 | | - if data_str.startswith("base64,"): |
254 | | - data = data_str[7:] |
255 | | - multimodal_params_dict["images"].append({"type": "base64", "data": data}) |
256 | | - else: |
257 | | - raise ValueError("Unrecognized image input.") |
258 | | - else: |
259 | | - raise ValueError( |
260 | | - "Unrecognized image input. Supports local path, http url, base64, and PIL.Image." |
261 | | - ) |
262 | | - |
263 | | - message.content = "\n".join(texts) |
264 | | - |
265 | | - prompt = await build_prompt(request) |
266 | | - sampling_params_dict = { |
267 | | - "do_sample": request.do_sample, |
268 | | - "presence_penalty": request.presence_penalty, |
269 | | - "frequency_penalty": request.frequency_penalty, |
270 | | - "temperature": request.temperature, |
271 | | - "top_p": request.top_p, |
272 | | - "top_k": request.top_k, |
273 | | - "ignore_eos": request.ignore_eos, |
274 | | - "max_new_tokens": request.max_tokens, |
275 | | - "stop_sequences": request.stop, |
276 | | - "n": request.n, |
277 | | - "best_of": request.n, |
278 | | - "add_special_tokens": False, |
279 | | - } |
280 | | - sampling_params = SamplingParams() |
281 | | - sampling_params.init(tokenizer=g_objs.httpserver_manager.tokenizer, **sampling_params_dict) |
282 | | - |
283 | | - sampling_params.verify() |
284 | | - multimodal_params = MultimodalParams(**multimodal_params_dict) |
285 | | - |
286 | | - results_generator = g_objs.httpserver_manager.generate( |
287 | | - prompt, sampling_params, multimodal_params, request=raw_request |
288 | | - ) |
289 | | - |
290 | | - # Non-streaming case |
291 | | - if not request.stream: |
292 | | - final_output_dict = collections.defaultdict(list) |
293 | | - count_output_tokens_dict = collections.defaultdict(lambda: 0) |
294 | | - finish_reason_dict = {} |
295 | | - prompt_tokens_dict = {} |
296 | | - completion_tokens = 0 |
297 | | - async for sub_req_id, request_output, metadata, finish_status in results_generator: |
298 | | - from .req_id_generator import convert_sub_id_to_group_id |
299 | | - |
300 | | - group_request_id = convert_sub_id_to_group_id(sub_req_id) |
301 | | - count_output_tokens_dict[sub_req_id] += 1 |
302 | | - final_output_dict[sub_req_id].append(request_output) |
303 | | - if finish_status.is_finished(): |
304 | | - finish_reason_dict[sub_req_id] = finish_status.get_finish_reason() |
305 | | - prompt_tokens_dict[sub_req_id] = metadata["prompt_tokens"] |
306 | | - choices = [] |
307 | | - sub_ids = list(final_output_dict.keys())[: request.n] |
308 | | - for i in range(request.n): |
309 | | - sub_req_id = sub_ids[i] |
310 | | - prompt_tokens = prompt_tokens_dict[sub_req_id] |
311 | | - completion_tokens = count_output_tokens_dict[sub_req_id] |
312 | | - usage = UsageInfo( |
313 | | - prompt_tokens=prompt_tokens, |
314 | | - completion_tokens=completion_tokens, |
315 | | - total_tokens=prompt_tokens + completion_tokens, |
316 | | - ) |
317 | | - chat_message = ChatMessage(role="assistant", content="".join(final_output_dict[sub_req_id])) |
318 | | - choice = ChatCompletionResponseChoice( |
319 | | - index=i, message=chat_message, finish_reason=finish_reason_dict[sub_req_id] |
320 | | - ) |
321 | | - choices.append(choice) |
322 | | - resp = ChatCompletionResponse( |
323 | | - id=group_request_id, created=created_time, model=request.model, choices=choices, usage=usage |
324 | | - ) |
325 | | - return resp |
326 | | - |
327 | | - if sampling_params.n != 1: |
328 | | - raise Exception("stream api only support n = 1") |
329 | | - |
330 | | - # Streaming case |
331 | | - async def stream_results() -> AsyncGenerator[bytes, None]: |
332 | | - finish_reason = None |
333 | | - from .req_id_generator import convert_sub_id_to_group_id |
334 | | - |
335 | | - async for sub_req_id, request_output, metadata, finish_status in results_generator: |
336 | | - group_request_id = convert_sub_id_to_group_id(sub_req_id) |
337 | | - |
338 | | - delta_message = DeltaMessage(role="assistant", content=request_output) |
339 | | - if finish_status.is_finished(): |
340 | | - finish_reason = finish_status.get_finish_reason() |
341 | | - stream_choice = ChatCompletionStreamResponseChoice( |
342 | | - index=0, delta=delta_message, finish_reason=finish_reason |
343 | | - ) |
344 | | - stream_resp = ChatCompletionStreamResponse( |
345 | | - id=group_request_id, |
346 | | - created=created_time, |
347 | | - model=request.model, |
348 | | - choices=[stream_choice], |
349 | | - ) |
350 | | - yield ("data: " + json.dumps(stream_resp.dict(), ensure_ascii=False) + "\n\n").encode("utf-8") |
351 | | - |
352 | | - background_tasks = BackgroundTasks() |
353 | | - return StreamingResponse(stream_results(), media_type="text/event-stream", background=background_tasks) |
354 | | - |
355 | | - |
356 | 171 | @app.get("/tokens") |
357 | 172 | @app.post("/tokens") |
358 | 173 | async def tokens(request: Request): |
|
0 commit comments