Skip to content

Commit 300f446

Browse files
committed
fix parser
1 parent 41f1418 commit 300f446

File tree

1 file changed

+146
-30
lines changed

1 file changed

+146
-30
lines changed

fastdeploy/entrypoints/openai/tool_parsers/ernie_x1_tool_parser.py

Lines changed: 146 additions & 30 deletions
Original file line numberDiff line numberDiff line change
@@ -14,10 +14,18 @@
1414

1515
import json
1616
import re
17+
import uuid
1718
from collections.abc import Sequence
1819
from typing import Union
1920

20-
from fastdeploy.entrypoints.chat_utils import random_tool_call_id
21+
import partial_json_parser
22+
23+
24+
def random_tool_call_id() -> str:
25+
"""Generate a random tool call ID"""
26+
return f"chatcmpl-tool-{str(uuid.uuid4().hex)}"
27+
28+
2129
from fastdeploy.entrypoints.openai.protocol import (
2230
ChatCompletionRequest,
2331
DeltaFunctionCall,
@@ -53,8 +61,6 @@ def __init__(self, tokenizer):
5361
self.tool_call_start_token: str = "<tool_call>"
5462
self.tool_call_end_token: str = "</tool_call>"
5563

56-
self.tool_call_regex = re.compile(r"<tool_call>(.*?)</tool_call>|<tool_call>(.*)", re.DOTALL)
57-
5864
self.tool_call_start_token_id = self.vocab.get(self.tool_call_start_token)
5965
self.tool_call_end_token_id = self.vocab.get(self.tool_call_end_token)
6066
if self.tool_call_start_token_id is None or self.tool_call_end_token_id is None:
@@ -67,9 +73,7 @@ def __init__(self, tokenizer):
6773
"The model tokenizer must be passed to the ToolCallParser constructor during construction."
6874
)
6975

70-
def extract_tool_calls(
71-
self, model_output: str, request: ChatCompletionRequest, model_status: str
72-
) -> ExtractedToolCallInformation:
76+
def extract_tool_calls(self, model_output: str, request: ChatCompletionRequest) -> ExtractedToolCallInformation:
7377
"""
7478
Extract the tool calls from a complete model response.
7579
Supports XML-style formats with newlines:
@@ -81,31 +85,144 @@ def extract_tool_calls(
8185
3. Only name and arguments field without content: {"name": "get_weather", "argume
8286
"""
8387

84-
extract_content = model_output
85-
if model_status == "tool_call_start":
86-
extract_content = "<tool_call>" + model_output
8788
try:
88-
if self.tool_call_start_token not in extract_content:
89-
return ExtractedToolCallInformation(tools_called=False, tool_calls=[], content=model_output)
90-
function_call_tuples = self.tool_call_regex.findall(extract_content)
91-
92-
raw_function_calls = [json.loads(match[0] if match[0] else match[1]) for match in function_call_tuples]
93-
94-
tool_calls = [
95-
ToolCall(
96-
type="function",
97-
function=FunctionCall(
98-
name=function_call["name"],
99-
# function call args are JSON but as a string
100-
arguments=json.dumps(function_call["arguments"], ensure_ascii=False),
101-
),
89+
tool_calls = []
90+
91+
# Check for invalid <response> tags before tool calls
92+
if re.search(r"<response>[\s\S]*?</response>\s*(?=<tool_call>)", model_output):
93+
data_processor_logger.error("Invalid format: <response> tags found before <tool_call>")
94+
return ExtractedToolCallInformation(tools_called=False, content=model_output)
95+
96+
function_call_arr = []
97+
remaining_text = model_output
98+
99+
while True:
100+
# 查找下一个tool_call块
101+
tool_call_pos = remaining_text.find("<tool_call>")
102+
if tool_call_pos == -1:
103+
break
104+
105+
# 提取tool_call开始位置后的内容
106+
tool_content_start = tool_call_pos + len("<tool_call>")
107+
tool_content_end = remaining_text.find("</tool_call>", tool_content_start)
108+
109+
tool_json = ""
110+
if tool_content_end == -1:
111+
# 处理未闭合的tool_call块(截断情况)
112+
tool_json = remaining_text[tool_content_start:].strip()
113+
remaining_text = "" # 没有更多内容需要处理
114+
else:
115+
# 处理完整的tool_call块
116+
tool_json = remaining_text[tool_content_start:tool_content_end].strip()
117+
remaining_text = remaining_text[tool_content_end + len("</tool_call>") :]
118+
119+
if not tool_json:
120+
continue
121+
122+
# 处理JSON内容
123+
tool_json = tool_json.strip()
124+
if not tool_json.startswith("{"):
125+
tool_json = "{" + tool_json
126+
if not tool_json.endswith("}"):
127+
tool_json = tool_json + "}"
128+
129+
try:
130+
# 首先尝试标准JSON解析
131+
try:
132+
tool_data = json.loads(tool_json)
133+
134+
if isinstance(tool_data, dict) and "name" in tool_data and "arguments" in tool_data:
135+
function_call_arr.append(
136+
{
137+
"name": tool_data["name"],
138+
"arguments": tool_data["arguments"],
139+
"_is_complete": True, # 明确标记为完整解析
140+
}
141+
)
142+
continue
143+
except json.JSONDecodeError:
144+
pass
145+
146+
# 标准解析失败时尝试partial_json_parser
147+
from partial_json_parser.core.options import Allow
148+
149+
try:
150+
tool_data = {}
151+
flags = Allow.ALL & ~Allow.STR
152+
153+
# 解析name字段
154+
name_match = re.search(r'"name"\s*:\s*"([^"]*)"', tool_json)
155+
if name_match:
156+
tool_data["name"] = name_match.group(1)
157+
158+
# 解析arguments字段
159+
args_match = re.search(r'"arguments"\s*:\s*(\{.*)', tool_json)
160+
if args_match:
161+
try:
162+
tool_data["arguments"] = partial_json_parser.loads(args_match.group(1), flags=flags)
163+
except:
164+
tool_data["arguments"] = None
165+
166+
if isinstance(tool_data, dict):
167+
function_call_arr.append(
168+
{
169+
"name": tool_data.get("name", ""),
170+
"arguments": tool_data.get("arguments", {}),
171+
"_is_partial": True, # 标记为部分解析
172+
}
173+
)
174+
except Exception as e:
175+
data_processor_logger.debug(f"Failed to parse tool call: {str(e)}")
176+
continue
177+
except Exception as e:
178+
data_processor_logger.debug(f"Failed to parse tool call: {str(e)}")
179+
continue
180+
181+
if not function_call_arr:
182+
data_processor_logger.error("No valid tool calls found")
183+
return ExtractedToolCallInformation(tools_called=False, content=model_output)
184+
185+
tool_calls = []
186+
all_complete = True # 初始设为True,只要有一个不完整就变为False
187+
188+
for tool_call in function_call_arr:
189+
# 记录工具调用解析状态
190+
is_complete = tool_call.get("_is_complete", False)
191+
is_partial = tool_call.get("_is_partial", False)
192+
193+
# 只要有一个不完整就认为整体不完整
194+
if not is_complete or is_partial:
195+
all_complete = False
196+
197+
# 处理参数序列化
198+
tool_args = tool_call.get("arguments", {})
199+
if not isinstance(tool_args, dict):
200+
tool_args = {}
201+
202+
try:
203+
args_str = json.dumps(tool_args, ensure_ascii=False) if tool_args else "{}"
204+
except:
205+
args_str = "{}"
206+
207+
tool_calls.append(
208+
ToolCall(
209+
type="function",
210+
id=random_tool_call_id(),
211+
function=FunctionCall(
212+
name=tool_call.get("name", ""),
213+
arguments=args_str,
214+
),
215+
)
102216
)
103-
for function_call in raw_function_calls
104-
]
105-
return ExtractedToolCallInformation(tools_called=True, tool_calls=tool_calls, content="")
106-
except Exception:
107-
data_processor_logger.error("Error in extracting tool call from response.")
108-
return ExtractedToolCallInformation(tools_called=False, tool_calls=[], content=model_output)
217+
218+
# 只有当所有工具调用都明确标记为complete时才返回tools_called=True
219+
return ExtractedToolCallInformation(
220+
tools_called=all_complete, tool_calls=tool_calls if tool_calls else None, content=""
221+
)
222+
223+
except Exception as e:
224+
data_processor_logger.error(f"Error in extracting tool call from response: {str(e)}")
225+
return ExtractedToolCallInformation(tools_called=False, tool_calls=None, content=model_output)
109226

110227
def extract_tool_calls_streaming(
111228
self,
@@ -116,7 +233,6 @@ def extract_tool_calls_streaming(
116233
current_token_ids: Sequence[int],
117234
delta_token_ids: Sequence[int],
118235
request: dict,
119-
model_status: str,
120236
) -> Union[DeltaMessage, None]:
121237

122238
if self.tool_call_start_token_id not in current_token_ids:

0 commit comments

Comments
 (0)