Skip to content

Commit a7f5829

Browse files
save work
1 parent 242590f commit a7f5829

File tree

9 files changed

+963
-213
lines changed

9 files changed

+963
-213
lines changed

lightllm/server/api_cli.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -104,6 +104,7 @@ def make_argument_parser() -> argparse.ArgumentParser:
104104
parser.add_argument(
105105
"--eos_id", nargs="+", type=int, default=None, help="eos stop token id, if None, will load from config.json"
106106
)
107+
parser.add_argument("--tool_call_parser", type=str, default=None, help="tool call parser type")
107108
parser.add_argument(
108109
"--running_max_req_size", type=int, default=1000, help="the max size for forward requests in the same time"
109110
)

lightllm/server/api_http.py

Lines changed: 5 additions & 190 deletions
Original file line numberDiff line numberDiff line change
@@ -45,75 +45,21 @@
4545
from .api_lightllm import lightllm_get_score, lightllm_pd_generate_stream
4646
from lightllm.utils.envs_utils import get_env_start_args, get_lightllm_websocket_max_message_size
4747

48-
from .api_models import (
49-
ChatCompletionRequest,
50-
UsageInfo,
51-
ChatMessage,
52-
ChatCompletionResponseChoice,
53-
ChatCompletionResponse,
54-
DeltaMessage,
55-
ChatCompletionStreamResponse,
56-
ChatCompletionStreamResponseChoice,
57-
)
58-
5948
from lightllm.utils.log_utils import init_logger
6049
from lightllm.server.metrics.manager import MetricClient
6150
from lightllm.utils.envs_utils import get_unique_server_name
6251
from dataclasses import dataclass
6352

64-
logger = init_logger(__name__)
65-
53+
from .api_openai import app as openai_api
54+
from .api_openai import g_objs
6655

67-
@dataclass
68-
class G_Objs:
69-
app: FastAPI = None
70-
metric_client: MetricClient = None
71-
args: object = None
72-
g_generate_func: Callable = None
73-
g_generate_stream_func: Callable = None
74-
httpserver_manager: Union[HttpServerManager, HttpServerManagerForPDMaster] = None
75-
shared_token_load: TokenLoad = None
76-
77-
def set_args(self, args):
78-
self.args = args
79-
from .api_lightllm import lightllm_generate, lightllm_generate_stream
80-
from .api_tgi import tgi_generate_impl, tgi_generate_stream_impl
81-
82-
if args.use_tgi_api:
83-
self.g_generate_func = tgi_generate_impl
84-
self.g_generate_stream_func = tgi_generate_stream_impl
85-
else:
86-
self.g_generate_func = lightllm_generate
87-
self.g_generate_stream_func = lightllm_generate_stream
88-
89-
if args.run_mode == "pd_master":
90-
self.metric_client = MetricClient(args.metric_port)
91-
self.httpserver_manager = HttpServerManagerForPDMaster(
92-
args,
93-
metric_port=args.metric_port,
94-
)
95-
else:
96-
init_tokenizer(args) # for openai api
97-
SamplingParams.load_generation_cfg(args.model_dir)
98-
self.metric_client = MetricClient(args.metric_port)
99-
self.httpserver_manager = HttpServerManager(
100-
args,
101-
router_port=args.router_port,
102-
cache_port=args.cache_port,
103-
detokenization_pub_port=args.detokenization_pub_port,
104-
visual_port=args.visual_port,
105-
enable_multimodal=args.enable_multimodal,
106-
metric_port=args.metric_port,
107-
)
108-
dp_size_in_node = max(1, args.dp // args.nnodes) # 兼容多机纯tp的运行模式,这时候 1 // 2 == 0, 需要兼容
109-
self.shared_token_load = TokenLoad(f"{get_unique_server_name()}_shared_token_load", dp_size_in_node)
110-
111-
112-
g_objs = G_Objs()
56+
logger = init_logger(__name__)
11357

11458
app = FastAPI()
11559
g_objs.app = app
11660

61+
app.mount("/v1", openai_api)
62+
11763

11864
def create_error_response(status_code: HTTPStatus, message: str) -> JSONResponse:
11965
g_objs.metric_client.counter_inc("lightllm_request_failure")
@@ -222,137 +168,6 @@ async def compat_generate(request: Request) -> Response:
222168
return await generate(request)
223169

224170

225-
@app.post("/v1/chat/completions", response_model=ChatCompletionResponse)
226-
async def chat_completions(request: ChatCompletionRequest, raw_request: Request) -> Response:
227-
228-
if request.logit_bias is not None:
229-
return create_error_response(
230-
HTTPStatus.BAD_REQUEST,
231-
"The logit_bias parameter is not currently supported",
232-
)
233-
234-
if request.function_call != "none":
235-
return create_error_response(HTTPStatus.BAD_REQUEST, "The function call feature is not supported")
236-
237-
created_time = int(time.time())
238-
239-
multimodal_params_dict = {"images": []}
240-
for message in request.messages:
241-
if isinstance(message.content, list):
242-
texts = []
243-
for content in message.content:
244-
if content.type == "text" and content.text:
245-
texts.append(content.text)
246-
elif content.type == "image_url" and content.image_url is not None:
247-
img = content.image_url.url
248-
if img.startswith("http://") or img.startswith("https://"):
249-
multimodal_params_dict["images"].append({"type": "url", "data": img})
250-
elif img.startswith("data:image"):
251-
# "data:image/jpeg;base64,{base64_image}"
252-
data_str = img.split(";", 1)[1]
253-
if data_str.startswith("base64,"):
254-
data = data_str[7:]
255-
multimodal_params_dict["images"].append({"type": "base64", "data": data})
256-
else:
257-
raise ValueError("Unrecognized image input.")
258-
else:
259-
raise ValueError(
260-
"Unrecognized image input. Supports local path, http url, base64, and PIL.Image."
261-
)
262-
263-
message.content = "\n".join(texts)
264-
265-
prompt = await build_prompt(request)
266-
sampling_params_dict = {
267-
"do_sample": request.do_sample,
268-
"presence_penalty": request.presence_penalty,
269-
"frequency_penalty": request.frequency_penalty,
270-
"temperature": request.temperature,
271-
"top_p": request.top_p,
272-
"top_k": request.top_k,
273-
"ignore_eos": request.ignore_eos,
274-
"max_new_tokens": request.max_tokens,
275-
"stop_sequences": request.stop,
276-
"n": request.n,
277-
"best_of": request.n,
278-
"add_special_tokens": False,
279-
}
280-
sampling_params = SamplingParams()
281-
sampling_params.init(tokenizer=g_objs.httpserver_manager.tokenizer, **sampling_params_dict)
282-
283-
sampling_params.verify()
284-
multimodal_params = MultimodalParams(**multimodal_params_dict)
285-
286-
results_generator = g_objs.httpserver_manager.generate(
287-
prompt, sampling_params, multimodal_params, request=raw_request
288-
)
289-
290-
# Non-streaming case
291-
if not request.stream:
292-
final_output_dict = collections.defaultdict(list)
293-
count_output_tokens_dict = collections.defaultdict(lambda: 0)
294-
finish_reason_dict = {}
295-
prompt_tokens_dict = {}
296-
completion_tokens = 0
297-
async for sub_req_id, request_output, metadata, finish_status in results_generator:
298-
from .req_id_generator import convert_sub_id_to_group_id
299-
300-
group_request_id = convert_sub_id_to_group_id(sub_req_id)
301-
count_output_tokens_dict[sub_req_id] += 1
302-
final_output_dict[sub_req_id].append(request_output)
303-
if finish_status.is_finished():
304-
finish_reason_dict[sub_req_id] = finish_status.get_finish_reason()
305-
prompt_tokens_dict[sub_req_id] = metadata["prompt_tokens"]
306-
choices = []
307-
sub_ids = list(final_output_dict.keys())[: request.n]
308-
for i in range(request.n):
309-
sub_req_id = sub_ids[i]
310-
prompt_tokens = prompt_tokens_dict[sub_req_id]
311-
completion_tokens = count_output_tokens_dict[sub_req_id]
312-
usage = UsageInfo(
313-
prompt_tokens=prompt_tokens,
314-
completion_tokens=completion_tokens,
315-
total_tokens=prompt_tokens + completion_tokens,
316-
)
317-
chat_message = ChatMessage(role="assistant", content="".join(final_output_dict[sub_req_id]))
318-
choice = ChatCompletionResponseChoice(
319-
index=i, message=chat_message, finish_reason=finish_reason_dict[sub_req_id]
320-
)
321-
choices.append(choice)
322-
resp = ChatCompletionResponse(
323-
id=group_request_id, created=created_time, model=request.model, choices=choices, usage=usage
324-
)
325-
return resp
326-
327-
if sampling_params.n != 1:
328-
raise Exception("stream api only support n = 1")
329-
330-
# Streaming case
331-
async def stream_results() -> AsyncGenerator[bytes, None]:
332-
finish_reason = None
333-
from .req_id_generator import convert_sub_id_to_group_id
334-
335-
async for sub_req_id, request_output, metadata, finish_status in results_generator:
336-
group_request_id = convert_sub_id_to_group_id(sub_req_id)
337-
338-
delta_message = DeltaMessage(role="assistant", content=request_output)
339-
if finish_status.is_finished():
340-
finish_reason = finish_status.get_finish_reason()
341-
stream_choice = ChatCompletionStreamResponseChoice(
342-
index=0, delta=delta_message, finish_reason=finish_reason
343-
)
344-
stream_resp = ChatCompletionStreamResponse(
345-
id=group_request_id,
346-
created=created_time,
347-
model=request.model,
348-
choices=[stream_choice],
349-
)
350-
yield ("data: " + json.dumps(stream_resp.dict(), ensure_ascii=False) + "\n\n").encode("utf-8")
351-
352-
background_tasks = BackgroundTasks()
353-
return StreamingResponse(stream_results(), media_type="text/event-stream", background=background_tasks)
354-
355-
356171
@app.get("/tokens")
357172
@app.post("/tokens")
358173
async def tokens(request: Request):

lightllm/server/api_models.py

Lines changed: 50 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -20,6 +20,34 @@ class Message(BaseModel):
2020
content: Union[str, List[MessageContent]]
2121

2222

23+
class Function(BaseModel):
24+
"""Function descriptions."""
25+
26+
description: Optional[str] = Field(default=None, examples=[None])
27+
name: Optional[str] = None
28+
parameters: Optional[object] = None
29+
30+
31+
class Tool(BaseModel):
32+
"""Function wrapper."""
33+
34+
type: str = Field(default="function", examples=["function"])
35+
function: Function
36+
37+
38+
class ToolChoiceFuncName(BaseModel):
39+
"""The name of tool choice function."""
40+
41+
name: Optional[str] = None
42+
43+
44+
class ToolChoice(BaseModel):
45+
"""The tool choice definition."""
46+
47+
function: ToolChoiceFuncName
48+
type: Literal["function"] = Field(default="function", examples=["function"])
49+
50+
2351
class ChatCompletionRequest(BaseModel):
2452
model: str
2553
messages: List[Message]
@@ -35,6 +63,12 @@ class ChatCompletionRequest(BaseModel):
3563
logit_bias: Optional[Dict[str, float]] = None
3664
user: Optional[str] = None
3765

66+
# OpenAI Adaptive parameters for tool call
67+
tools: Optional[List[Tool]] = Field(default=None, examples=[None])
68+
tool_choice: Union[ToolChoice, Literal["auto", "required", "none"]] = Field(
69+
default="auto", examples=["none"]
70+
) # noqa
71+
3872
# Additional parameters supported by LightLLM
3973
do_sample: Optional[bool] = False
4074
top_k: Optional[int] = -1
@@ -44,6 +78,21 @@ class ChatCompletionRequest(BaseModel):
4478
character_settings: Optional[List[Dict[str, str]]] = None
4579

4680

81+
class FunctionResponse(BaseModel):
82+
"""Function response."""
83+
84+
name: Optional[str] = None
85+
arguments: Optional[str] = None
86+
87+
88+
class ToolCall(BaseModel):
89+
"""Tool call response."""
90+
91+
id: str
92+
type: Literal["function"] = "function"
93+
function: FunctionResponse
94+
95+
4796
class UsageInfo(BaseModel):
4897
prompt_tokens: int = 0
4998
completion_tokens: Optional[int] = 0
@@ -53,6 +102,7 @@ class UsageInfo(BaseModel):
53102
class ChatMessage(BaseModel):
54103
role: str
55104
content: str
105+
tool_calls: Optional[List[ToolCall]] = Field(default=None, examples=[None])
56106

57107

58108
class ChatCompletionResponseChoice(BaseModel):

0 commit comments

Comments
 (0)