Skip to content

Commit 59931b0

Browse files
authored
function call in reasoning content (#1158)
1 parent e10e433 commit 59931b0

File tree

2 files changed

+44
-43
lines changed

2 files changed

+44
-43
lines changed

lightllm/server/api_models.py

Lines changed: 41 additions & 41 deletions
Original file line numberDiff line numberDiff line change
@@ -68,6 +68,46 @@ class ResponseFormat(BaseModel):
6868
json_schema: Optional[JsonSchemaResponseFormat] = None
6969

7070

71+
class FunctionResponse(BaseModel):
72+
"""Function response."""
73+
74+
name: Optional[str] = None
75+
arguments: Optional[str] = None
76+
77+
78+
class ToolCall(BaseModel):
79+
"""Tool call response."""
80+
81+
id: Optional[str] = None
82+
index: Optional[int] = None
83+
type: Literal["function"] = "function"
84+
function: FunctionResponse
85+
86+
87+
class ChatCompletionMessageGenericParam(BaseModel):
88+
role: Literal["system", "assistant", "tool", "function"]
89+
content: Union[str, List[MessageContent], None] = Field(default=None)
90+
tool_call_id: Optional[str] = None
91+
name: Optional[str] = None
92+
reasoning_content: Optional[str] = None
93+
tool_calls: Optional[List[ToolCall]] = Field(default=None, examples=[None])
94+
95+
@field_validator("role", mode="before")
96+
@classmethod
97+
def _normalize_role(cls, v):
98+
if isinstance(v, str):
99+
v_lower = v.lower()
100+
if v_lower not in {"system", "assistant", "tool", "function"}:
101+
raise ValueError(
102+
"'role' must be one of 'system', 'assistant', 'tool', or 'function' (case-insensitive)."
103+
)
104+
return v_lower
105+
raise ValueError("'role' must be a string")
106+
107+
108+
ChatCompletionMessageParam = Union[ChatCompletionMessageGenericParam, Message]
109+
110+
71111
class CompletionRequest(BaseModel):
72112
model: str
73113
# prompt: string or tokens
@@ -137,7 +177,7 @@ def apply_loaded_defaults(cls, data: Any):
137177

138178
class ChatCompletionRequest(BaseModel):
139179
model: str
140-
messages: List[Message]
180+
messages: List[ChatCompletionMessageParam]
141181
function_call: Optional[str] = "none"
142182
temperature: Optional[float] = 1
143183
top_p: Optional[float] = 1.0
@@ -212,46 +252,6 @@ def apply_loaded_defaults(cls, data: Any):
212252
return data
213253

214254

215-
class FunctionResponse(BaseModel):
216-
"""Function response."""
217-
218-
name: Optional[str] = None
219-
arguments: Optional[str] = None
220-
221-
222-
class ToolCall(BaseModel):
223-
"""Tool call response."""
224-
225-
id: Optional[str] = None
226-
index: Optional[int] = None
227-
type: Literal["function"] = "function"
228-
function: FunctionResponse
229-
230-
231-
class ChatCompletionMessageGenericParam(BaseModel):
232-
role: Literal["system", "assistant", "tool", "function"]
233-
content: Union[str, List[MessageContent], None] = Field(default=None)
234-
tool_call_id: Optional[str] = None
235-
name: Optional[str] = None
236-
reasoning_content: Optional[str] = None
237-
tool_calls: Optional[List[ToolCall]] = Field(default=None, examples=[None])
238-
239-
@field_validator("role", mode="before")
240-
@classmethod
241-
def _normalize_role(cls, v):
242-
if isinstance(v, str):
243-
v_lower = v.lower()
244-
if v_lower not in {"system", "assistant", "tool", "function"}:
245-
raise ValueError(
246-
"'role' must be one of 'system', 'assistant', 'tool', or 'function' (case-insensitive)."
247-
)
248-
return v_lower
249-
raise ValueError("'role' must be a string")
250-
251-
252-
ChatCompletionMessageParam = Union[ChatCompletionMessageGenericParam, Message]
253-
254-
255255
class UsageInfo(BaseModel):
256256
prompt_tokens: int = 0
257257
completion_tokens: Optional[int] = 0

lightllm/server/api_openai.py

Lines changed: 3 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -260,6 +260,7 @@ async def chat_completions_impl(request: ChatCompletionRequest, raw_request: Req
260260

261261
finish_reason = finish_reason_dict[sub_req_id]
262262
text = "".join(final_output_dict[sub_req_id])
263+
full_text = text
263264

264265
# Handle reasoning content
265266
reasoning_text = None
@@ -284,14 +285,14 @@ async def chat_completions_impl(request: ChatCompletionRequest, raw_request: Req
284285
tool_calls = None
285286
tool_choice = request.tool_choice
286287
tools = request.tools
287-
if tool_choice != "none" and any([i in text for i in TOOLS_TAG_LIST]):
288+
if tool_choice != "none" and any([i in full_text for i in TOOLS_TAG_LIST]):
288289
if finish_reason == "stop":
289290
finish_reason = "tool_calls"
290291
try:
291292
# 为 tool_call_parser 提供默认值
292293
tool_parser = getattr(g_objs.args, "tool_call_parser", None) or "llama3"
293294
parser = FunctionCallParser(tools, tool_parser)
294-
full_normal_text, call_info_list = parser.parse_non_stream(text)
295+
full_normal_text, call_info_list = parser.parse_non_stream(full_text)
295296
tool_calls = []
296297
history_tool_calls_cnt = _get_history_tool_calls_cnt(request)
297298
for call_info in call_info_list:

0 commit comments

Comments
 (0)