|
25 | 25 | import os |
26 | 26 | from io import BytesIO |
27 | 27 | import pickle |
28 | | -from .build_prompt import build_prompt, init_tokenizer |
29 | 28 |
|
30 | 29 | asyncio.set_event_loop_policy(uvloop.EventLoopPolicy()) |
31 | 30 | import ujson as json |
|
44 | 43 | from .httpserver_for_pd_master.manager import HttpServerManagerForPDMaster |
45 | 44 | from .api_lightllm import lightllm_get_score, lightllm_pd_generate_stream |
46 | 45 | from lightllm.utils.envs_utils import get_env_start_args, get_lightllm_websocket_max_message_size |
47 | | - |
48 | 46 | from lightllm.utils.log_utils import init_logger |
49 | 47 | from lightllm.server.metrics.manager import MetricClient |
50 | 48 | from lightllm.utils.envs_utils import get_unique_server_name |
51 | 49 | from dataclasses import dataclass |
52 | 50 |
|
53 | | -from .api_openai import app as openai_api |
54 | | -from .api_openai import g_objs |
| 51 | +from .api_openai import chat_completions_impl |
| 52 | +from .api_models import ( |
| 53 | + ChatCompletionRequest, |
| 54 | + ChatCompletionResponse, |
| 55 | +) |
| 56 | +from .build_prompt import build_prompt, init_tokenizer |
55 | 57 |
|
56 | 58 | logger = init_logger(__name__) |
57 | 59 |
|
| 60 | + |
| 61 | +@dataclass |
| 62 | +class G_Objs: |
| 63 | + app: FastAPI = None |
| 64 | + metric_client: MetricClient = None |
| 65 | + args: object = None |
| 66 | + g_generate_func: Callable = None |
| 67 | + g_generate_stream_func: Callable = None |
| 68 | + httpserver_manager: Union[HttpServerManager, HttpServerManagerForPDMaster] = None |
| 69 | + shared_token_load: TokenLoad = None |
| 70 | + |
| 71 | + def set_args(self, args): |
| 72 | + self.args = args |
| 73 | + from .api_lightllm import lightllm_generate, lightllm_generate_stream |
| 74 | + from .api_tgi import tgi_generate_impl, tgi_generate_stream_impl |
| 75 | + |
| 76 | + if args.use_tgi_api: |
| 77 | + self.g_generate_func = tgi_generate_impl |
| 78 | + self.g_generate_stream_func = tgi_generate_stream_impl |
| 79 | + else: |
| 80 | + self.g_generate_func = lightllm_generate |
| 81 | + self.g_generate_stream_func = lightllm_generate_stream |
| 82 | + |
| 83 | + if args.run_mode == "pd_master": |
| 84 | + self.metric_client = MetricClient(args.metric_port) |
| 85 | + self.httpserver_manager = HttpServerManagerForPDMaster( |
| 86 | + args, |
| 87 | + metric_port=args.metric_port, |
| 88 | + ) |
| 89 | + else: |
| 90 | + init_tokenizer(args) # for openai api |
| 91 | + SamplingParams.load_generation_cfg(args.model_dir) |
| 92 | + self.metric_client = MetricClient(args.metric_port) |
| 93 | + self.httpserver_manager = HttpServerManager( |
| 94 | + args, |
| 95 | + router_port=args.router_port, |
| 96 | + cache_port=args.cache_port, |
| 97 | + detokenization_pub_port=args.detokenization_pub_port, |
| 98 | + visual_port=args.visual_port, |
| 99 | + enable_multimodal=args.enable_multimodal, |
| 100 | + metric_port=args.metric_port, |
| 101 | + ) |
| 102 | + dp_size_in_node = max(1, args.dp // args.nnodes) # 兼容多机纯tp的运行模式,这时候 1 // 2 == 0, 需要兼容 |
| 103 | + self.shared_token_load = TokenLoad(f"{get_unique_server_name()}_shared_token_load", dp_size_in_node) |
| 104 | + |
| 105 | + |
| 106 | +g_objs = G_Objs() |
| 107 | + |
58 | 108 | app = FastAPI() |
59 | 109 | g_objs.app = app |
60 | 110 |
|
61 | | -app.mount("/v1", openai_api) |
62 | | - |
63 | 111 |
|
64 | 112 | def create_error_response(status_code: HTTPStatus, message: str) -> JSONResponse: |
65 | 113 | g_objs.metric_client.counter_inc("lightllm_request_failure") |
@@ -274,3 +322,9 @@ async def startup_event(): |
274 | 322 | loop.create_task(g_objs.httpserver_manager.handle_loop()) |
275 | 323 | logger.info(f"server start up ok, loop use is {asyncio.get_event_loop()}") |
276 | 324 | return |
| 325 | + |
| 326 | + |
| 327 | +@app.post("/v1/chat/completions", response_model=ChatCompletionResponse) |
| 328 | +async def chat_completions(request: ChatCompletionRequest, raw_request: Request) -> Response: |
| 329 | + resp = await chat_completions_impl(request, raw_request) |
| 330 | + return resp |
0 commit comments