Skip to content

Commit cab746c

Browse files
committed
optimize code structure
1 parent 7040607 commit cab746c

File tree

5 files changed

+82
-60
lines changed

5 files changed

+82
-60
lines changed

lightllm/server/api_cli.py

Lines changed: 7 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -104,7 +104,13 @@ def make_argument_parser() -> argparse.ArgumentParser:
104104
parser.add_argument(
105105
"--eos_id", nargs="+", type=int, default=None, help="eos stop token id, if None, will load from config.json"
106106
)
107-
parser.add_argument("--tool_call_parser", type=str, default=None, help="tool call parser type")
107+
parser.add_argument(
108+
"--tool_call_parser",
109+
type=str,
110+
choices=["qwen25", "llama3", "mistral"],
111+
default=None,
112+
help="tool call parser type",
113+
)
108114
parser.add_argument(
109115
"--running_max_req_size", type=int, default=1000, help="the max size for forward requests in the same time"
110116
)

lightllm/server/api_http.py

Lines changed: 60 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -25,7 +25,6 @@
2525
import os
2626
from io import BytesIO
2727
import pickle
28-
from .build_prompt import build_prompt, init_tokenizer
2928

3029
asyncio.set_event_loop_policy(uvloop.EventLoopPolicy())
3130
import ujson as json
@@ -44,22 +43,71 @@
4443
from .httpserver_for_pd_master.manager import HttpServerManagerForPDMaster
4544
from .api_lightllm import lightllm_get_score, lightllm_pd_generate_stream
4645
from lightllm.utils.envs_utils import get_env_start_args, get_lightllm_websocket_max_message_size
47-
4846
from lightllm.utils.log_utils import init_logger
4947
from lightllm.server.metrics.manager import MetricClient
5048
from lightllm.utils.envs_utils import get_unique_server_name
5149
from dataclasses import dataclass
5250

53-
from .api_openai import app as openai_api
54-
from .api_openai import g_objs
51+
from .api_openai import chat_completions_impl
52+
from .api_models import (
53+
ChatCompletionRequest,
54+
ChatCompletionResponse,
55+
)
56+
from .build_prompt import build_prompt, init_tokenizer
5557

5658
logger = init_logger(__name__)
5759

60+
61+
@dataclass
62+
class G_Objs:
63+
app: FastAPI = None
64+
metric_client: MetricClient = None
65+
args: object = None
66+
g_generate_func: Callable = None
67+
g_generate_stream_func: Callable = None
68+
httpserver_manager: Union[HttpServerManager, HttpServerManagerForPDMaster] = None
69+
shared_token_load: TokenLoad = None
70+
71+
def set_args(self, args):
72+
self.args = args
73+
from .api_lightllm import lightllm_generate, lightllm_generate_stream
74+
from .api_tgi import tgi_generate_impl, tgi_generate_stream_impl
75+
76+
if args.use_tgi_api:
77+
self.g_generate_func = tgi_generate_impl
78+
self.g_generate_stream_func = tgi_generate_stream_impl
79+
else:
80+
self.g_generate_func = lightllm_generate
81+
self.g_generate_stream_func = lightllm_generate_stream
82+
83+
if args.run_mode == "pd_master":
84+
self.metric_client = MetricClient(args.metric_port)
85+
self.httpserver_manager = HttpServerManagerForPDMaster(
86+
args,
87+
metric_port=args.metric_port,
88+
)
89+
else:
90+
init_tokenizer(args) # for openai api
91+
SamplingParams.load_generation_cfg(args.model_dir)
92+
self.metric_client = MetricClient(args.metric_port)
93+
self.httpserver_manager = HttpServerManager(
94+
args,
95+
router_port=args.router_port,
96+
cache_port=args.cache_port,
97+
detokenization_pub_port=args.detokenization_pub_port,
98+
visual_port=args.visual_port,
99+
enable_multimodal=args.enable_multimodal,
100+
metric_port=args.metric_port,
101+
)
102+
dp_size_in_node = max(1, args.dp // args.nnodes) # 兼容多机纯tp的运行模式,这时候 1 // 2 == 0, 需要兼容
103+
self.shared_token_load = TokenLoad(f"{get_unique_server_name()}_shared_token_load", dp_size_in_node)
104+
105+
106+
g_objs = G_Objs()
107+
58108
app = FastAPI()
59109
g_objs.app = app
60110

61-
app.mount("/v1", openai_api)
62-
63111

64112
def create_error_response(status_code: HTTPStatus, message: str) -> JSONResponse:
65113
g_objs.metric_client.counter_inc("lightllm_request_failure")
@@ -274,3 +322,9 @@ async def startup_event():
274322
loop.create_task(g_objs.httpserver_manager.handle_loop())
275323
logger.info(f"server start up ok, loop use is {asyncio.get_event_loop()}")
276324
return
325+
326+
327+
@app.post("/v1/chat/completions", response_model=ChatCompletionResponse)
328+
async def chat_completions(request: ChatCompletionRequest, raw_request: Request) -> Response:
329+
resp = await chat_completions_impl(request, raw_request)
330+
return resp

lightllm/server/api_openai.py

Lines changed: 4 additions & 52 deletions
Original file line numberDiff line numberDiff line change
@@ -50,63 +50,15 @@
5050
logger = init_logger(__name__)
5151

5252

53-
@dataclass
54-
class G_Objs:
55-
app: FastAPI = None
56-
metric_client: MetricClient = None
57-
args: object = None
58-
g_generate_func: Callable = None
59-
g_generate_stream_func: Callable = None
60-
httpserver_manager: Union[HttpServerManager, HttpServerManagerForPDMaster] = None
61-
shared_token_load: TokenLoad = None
62-
63-
def set_args(self, args):
64-
self.args = args
65-
from .api_lightllm import lightllm_generate, lightllm_generate_stream
66-
from .api_tgi import tgi_generate_impl, tgi_generate_stream_impl
67-
68-
if args.use_tgi_api:
69-
self.g_generate_func = tgi_generate_impl
70-
self.g_generate_stream_func = tgi_generate_stream_impl
71-
else:
72-
self.g_generate_func = lightllm_generate
73-
self.g_generate_stream_func = lightllm_generate_stream
74-
75-
if args.run_mode == "pd_master":
76-
self.metric_client = MetricClient(args.metric_port)
77-
self.httpserver_manager = HttpServerManagerForPDMaster(
78-
args,
79-
metric_port=args.metric_port,
80-
)
81-
else:
82-
init_tokenizer(args) # for openai api
83-
SamplingParams.load_generation_cfg(args.model_dir)
84-
self.metric_client = MetricClient(args.metric_port)
85-
self.httpserver_manager = HttpServerManager(
86-
args,
87-
router_port=args.router_port,
88-
cache_port=args.cache_port,
89-
detokenization_pub_port=args.detokenization_pub_port,
90-
visual_port=args.visual_port,
91-
enable_multimodal=args.enable_multimodal,
92-
metric_port=args.metric_port,
93-
)
94-
dp_size_in_node = max(1, args.dp // args.nnodes) # 兼容多机纯tp的运行模式,这时候 1 // 2 == 0, 需要兼容
95-
self.shared_token_load = TokenLoad(f"{get_unique_server_name()}_shared_token_load", dp_size_in_node)
96-
97-
98-
g_objs = G_Objs()
99-
100-
app = FastAPI()
101-
102-
10353
def create_error_response(status_code: HTTPStatus, message: str) -> JSONResponse:
54+
from .api_http import g_objs
55+
10456
g_objs.metric_client.counter_inc("lightllm_request_failure")
10557
return JSONResponse({"message": message}, status_code=status_code.value)
10658

10759

108-
@app.post("/chat/completions", response_model=ChatCompletionResponse)
109-
async def chat_completions(request: ChatCompletionRequest, raw_request: Request) -> Response:
60+
async def chat_completions_impl(request: ChatCompletionRequest, raw_request: Request) -> Response:
61+
from .api_http import g_objs
11062

11163
if request.logit_bias is not None:
11264
return create_error_response(

lightllm/server/function_call_parser.py

Lines changed: 11 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,14 @@
1-
# Adaptive from SGlang Repo [https://github.com/sgl-project/sglang/blob/main/python/sglang/srt/function_call_parser.py]
1+
# Adaptive from SGlang [https://github.com/sgl-project/sglang/blob/main/python/sglang/srt/function_call_parser.py]
2+
# Copyright 2023-2024 SGLang Team
3+
# Licensed under the Apache License, Version 2.0 (the "License");
4+
# you may not use this file except in compliance with the License.
5+
# You may obtain a copy of the License at
6+
# http://www.apache.org/licenses/LICENSE-2.0
7+
# Unless required by applicable law or agreed to in writing, software
8+
# distributed under the License is distributed on an "AS IS" BASIS,
9+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
10+
# See the License for the specific language governing permissions and
11+
# limitations under the License.
212

313
import json
414
import logging

0 commit comments

Comments
 (0)