Skip to content

Commit 58da761

Browse files
committed
merge main
2 parents fb3e346 + 37e5071 commit 58da761

File tree

10 files changed

+973
-155
lines changed

10 files changed

+973
-155
lines changed

lightllm/models/qwen3_moe/layer_infer/transformer_layer_infer.py

Lines changed: 0 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -57,7 +57,6 @@ def _get_qkv(
5757
) -> torch.Tensor:
5858
input = input.view(-1, self.embed_dim_)
5959
q = layer_weight.q_proj.mm(input)
60-
print(q.shape, infer_state.batch_size)
6160
cache_kv = layer_weight.kv_proj.mm(
6261
input, out=cache_kv.view(-1, (self.tp_k_head_num_ + self.tp_v_head_num_) * self.head_dim_)
6362
).view(-1, (self.tp_k_head_num_ + self.tp_v_head_num_), self.head_dim_)

lightllm/server/api_cli.py

Lines changed: 7 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -104,6 +104,13 @@ def make_argument_parser() -> argparse.ArgumentParser:
104104
parser.add_argument(
105105
"--eos_id", nargs="+", type=int, default=None, help="eos stop token id, if None, will load from config.json"
106106
)
107+
parser.add_argument(
108+
"--tool_call_parser",
109+
type=str,
110+
choices=["qwen25", "llama3", "mistral"],
111+
default=None,
112+
help="tool call parser type",
113+
)
107114
parser.add_argument(
108115
"--running_max_req_size", type=int, default=1000, help="the max size for forward requests in the same time"
109116
)

lightllm/server/api_http.py

Lines changed: 8 additions & 139 deletions
Original file line numberDiff line numberDiff line change
@@ -25,7 +25,6 @@
2525
import os
2626
from io import BytesIO
2727
import pickle
28-
from .build_prompt import build_prompt, init_tokenizer
2928

3029
asyncio.set_event_loop_policy(uvloop.EventLoopPolicy())
3130
import ujson as json
@@ -44,22 +43,17 @@
4443
from .httpserver_for_pd_master.manager import HttpServerManagerForPDMaster
4544
from .api_lightllm import lightllm_get_score, lightllm_pd_generate_stream
4645
from lightllm.utils.envs_utils import get_env_start_args, get_lightllm_websocket_max_message_size
46+
from lightllm.utils.log_utils import init_logger
47+
from lightllm.server.metrics.manager import MetricClient
48+
from lightllm.utils.envs_utils import get_unique_server_name
49+
from dataclasses import dataclass
4750

51+
from .api_openai import chat_completions_impl
4852
from .api_models import (
4953
ChatCompletionRequest,
50-
UsageInfo,
51-
ChatMessage,
52-
ChatCompletionResponseChoice,
5354
ChatCompletionResponse,
54-
DeltaMessage,
55-
ChatCompletionStreamResponse,
56-
ChatCompletionStreamResponseChoice,
5755
)
58-
59-
from lightllm.utils.log_utils import init_logger
60-
from lightllm.server.metrics.manager import MetricClient
61-
from lightllm.utils.envs_utils import get_unique_server_name
62-
from dataclasses import dataclass
56+
from .build_prompt import build_prompt, init_tokenizer
6357

6458
logger = init_logger(__name__)
6559

@@ -224,133 +218,8 @@ async def compat_generate(request: Request) -> Response:
224218

225219
@app.post("/v1/chat/completions", response_model=ChatCompletionResponse)
226220
async def chat_completions(request: ChatCompletionRequest, raw_request: Request) -> Response:
227-
228-
if request.logit_bias is not None:
229-
return create_error_response(
230-
HTTPStatus.BAD_REQUEST,
231-
"The logit_bias parameter is not currently supported",
232-
)
233-
234-
if request.function_call != "none":
235-
return create_error_response(HTTPStatus.BAD_REQUEST, "The function call feature is not supported")
236-
237-
created_time = int(time.time())
238-
239-
multimodal_params_dict = {"images": []}
240-
for message in request.messages:
241-
if isinstance(message.content, list):
242-
texts = []
243-
for content in message.content:
244-
if content.type == "text" and content.text:
245-
texts.append(content.text)
246-
elif content.type == "image_url" and content.image_url is not None:
247-
img = content.image_url.url
248-
if img.startswith("http://") or img.startswith("https://"):
249-
multimodal_params_dict["images"].append({"type": "url", "data": img})
250-
elif img.startswith("data:image"):
251-
# "data:image/jpeg;base64,{base64_image}"
252-
data_str = img.split(";", 1)[1]
253-
if data_str.startswith("base64,"):
254-
data = data_str[7:]
255-
multimodal_params_dict["images"].append({"type": "base64", "data": data})
256-
else:
257-
raise ValueError("Unrecognized image input.")
258-
else:
259-
raise ValueError(
260-
"Unrecognized image input. Supports local path, http url, base64, and PIL.Image."
261-
)
262-
263-
message.content = "\n".join(texts)
264-
265-
prompt = await build_prompt(request)
266-
sampling_params_dict = {
267-
"do_sample": request.do_sample,
268-
"presence_penalty": request.presence_penalty,
269-
"frequency_penalty": request.frequency_penalty,
270-
"temperature": request.temperature,
271-
"top_p": request.top_p,
272-
"top_k": request.top_k,
273-
"ignore_eos": request.ignore_eos,
274-
"max_new_tokens": request.max_tokens,
275-
"stop_sequences": request.stop,
276-
"n": request.n,
277-
"best_of": request.n,
278-
"add_special_tokens": False,
279-
}
280-
sampling_params = SamplingParams()
281-
sampling_params.init(tokenizer=g_objs.httpserver_manager.tokenizer, **sampling_params_dict)
282-
283-
sampling_params.verify()
284-
multimodal_params = MultimodalParams(**multimodal_params_dict)
285-
286-
results_generator = g_objs.httpserver_manager.generate(
287-
prompt, sampling_params, multimodal_params, request=raw_request
288-
)
289-
290-
# Non-streaming case
291-
if not request.stream:
292-
final_output_dict = collections.defaultdict(list)
293-
count_output_tokens_dict = collections.defaultdict(lambda: 0)
294-
finish_reason_dict = {}
295-
prompt_tokens_dict = {}
296-
completion_tokens = 0
297-
async for sub_req_id, request_output, metadata, finish_status in results_generator:
298-
from .req_id_generator import convert_sub_id_to_group_id
299-
300-
group_request_id = convert_sub_id_to_group_id(sub_req_id)
301-
count_output_tokens_dict[sub_req_id] += 1
302-
final_output_dict[sub_req_id].append(request_output)
303-
if finish_status.is_finished():
304-
finish_reason_dict[sub_req_id] = finish_status.get_finish_reason()
305-
prompt_tokens_dict[sub_req_id] = metadata["prompt_tokens"]
306-
choices = []
307-
sub_ids = list(final_output_dict.keys())[: request.n]
308-
for i in range(request.n):
309-
sub_req_id = sub_ids[i]
310-
prompt_tokens = prompt_tokens_dict[sub_req_id]
311-
completion_tokens = count_output_tokens_dict[sub_req_id]
312-
usage = UsageInfo(
313-
prompt_tokens=prompt_tokens,
314-
completion_tokens=completion_tokens,
315-
total_tokens=prompt_tokens + completion_tokens,
316-
)
317-
chat_message = ChatMessage(role="assistant", content="".join(final_output_dict[sub_req_id]))
318-
choice = ChatCompletionResponseChoice(
319-
index=i, message=chat_message, finish_reason=finish_reason_dict[sub_req_id]
320-
)
321-
choices.append(choice)
322-
resp = ChatCompletionResponse(
323-
id=group_request_id, created=created_time, model=request.model, choices=choices, usage=usage
324-
)
325-
return resp
326-
327-
if sampling_params.n != 1:
328-
raise Exception("stream api only support n = 1")
329-
330-
# Streaming case
331-
async def stream_results() -> AsyncGenerator[bytes, None]:
332-
finish_reason = None
333-
from .req_id_generator import convert_sub_id_to_group_id
334-
335-
async for sub_req_id, request_output, metadata, finish_status in results_generator:
336-
group_request_id = convert_sub_id_to_group_id(sub_req_id)
337-
338-
delta_message = DeltaMessage(role="assistant", content=request_output)
339-
if finish_status.is_finished():
340-
finish_reason = finish_status.get_finish_reason()
341-
stream_choice = ChatCompletionStreamResponseChoice(
342-
index=0, delta=delta_message, finish_reason=finish_reason
343-
)
344-
stream_resp = ChatCompletionStreamResponse(
345-
id=group_request_id,
346-
created=created_time,
347-
model=request.model,
348-
choices=[stream_choice],
349-
)
350-
yield ("data: " + json.dumps(stream_resp.dict(), ensure_ascii=False) + "\n\n").encode("utf-8")
351-
352-
background_tasks = BackgroundTasks()
353-
return StreamingResponse(stream_results(), media_type="text/event-stream", background=background_tasks)
221+
resp = await chat_completions_impl(request, raw_request)
222+
return resp
354223

355224

356225
@app.get("/tokens")

lightllm/server/api_models.py

Lines changed: 51 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -20,6 +20,34 @@ class Message(BaseModel):
2020
content: Union[str, List[MessageContent]]
2121

2222

23+
class Function(BaseModel):
24+
"""Function descriptions."""
25+
26+
description: Optional[str] = Field(default=None, examples=[None])
27+
name: Optional[str] = None
28+
parameters: Optional[object] = None
29+
30+
31+
class Tool(BaseModel):
32+
"""Function wrapper."""
33+
34+
type: str = Field(default="function", examples=["function"])
35+
function: Function
36+
37+
38+
class ToolChoiceFuncName(BaseModel):
39+
"""The name of tool choice function."""
40+
41+
name: Optional[str] = None
42+
43+
44+
class ToolChoice(BaseModel):
45+
"""The tool choice definition."""
46+
47+
function: ToolChoiceFuncName
48+
type: Literal["function"] = Field(default="function", examples=["function"])
49+
50+
2351
class ChatCompletionRequest(BaseModel):
2452
model: str
2553
messages: List[Message]
@@ -35,6 +63,12 @@ class ChatCompletionRequest(BaseModel):
3563
logit_bias: Optional[Dict[str, float]] = None
3664
user: Optional[str] = None
3765

66+
# OpenAI Adaptive parameters for tool call
67+
tools: Optional[List[Tool]] = Field(default=None, examples=[None])
68+
tool_choice: Union[ToolChoice, Literal["auto", "required", "none"]] = Field(
69+
default="auto", examples=["none"]
70+
) # noqa
71+
3872
# Additional parameters supported by LightLLM
3973
do_sample: Optional[bool] = False
4074
top_k: Optional[int] = -1
@@ -44,6 +78,21 @@ class ChatCompletionRequest(BaseModel):
4478
character_settings: Optional[List[Dict[str, str]]] = None
4579

4680

81+
class FunctionResponse(BaseModel):
82+
"""Function response."""
83+
84+
name: Optional[str] = None
85+
arguments: Optional[str] = None
86+
87+
88+
class ToolCall(BaseModel):
89+
"""Tool call response."""
90+
91+
id: str
92+
type: Literal["function"] = "function"
93+
function: FunctionResponse
94+
95+
4796
class UsageInfo(BaseModel):
4897
prompt_tokens: int = 0
4998
completion_tokens: Optional[int] = 0
@@ -53,6 +102,7 @@ class UsageInfo(BaseModel):
53102
class ChatMessage(BaseModel):
54103
role: str
55104
content: str
105+
tool_calls: Optional[List[ToolCall]] = Field(default=None, examples=[None])
56106

57107

58108
class ChatCompletionResponseChoice(BaseModel):
@@ -77,6 +127,7 @@ def ensure_id_is_str(cls, v):
77127
class DeltaMessage(BaseModel):
78128
role: Optional[str] = None
79129
content: Optional[str] = None
130+
tool_calls: Optional[List[ToolCall]] = Field(default=None, examples=[None])
80131

81132

82133
class ChatCompletionStreamResponseChoice(BaseModel):

0 commit comments

Comments
 (0)