Skip to content

Commit 12a3587

Browse files
authored
[Supplements and upgrades]Improvement of X1 parsers (#4172)
* reasoning_parser * reasoning_parser * reasoning_parser * reasoning_parser * reasoning_parser * reasoning_parser * reasoning_parser
1 parent dd2e844 commit 12a3587

File tree

4 files changed

+360
-119
lines changed

4 files changed

+360
-119
lines changed

fastdeploy/entrypoints/openai/tool_parsers/ernie_x1_tool_parser.py

Lines changed: 30 additions & 30 deletions
Original file line numberDiff line numberDiff line change
@@ -1,3 +1,4 @@
1+
"""
12
# Copyright (c) 2025 PaddlePaddle Authors. All Rights Reserved.
23
#
34
# Licensed under the Apache License, Version 2.0 (the "License"
@@ -11,6 +12,7 @@
1112
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
1213
# See the License for the specific language governing permissions and
1314
# limitations under the License.
15+
"""
1416

1517
import json
1618
import re
@@ -97,37 +99,37 @@ def extract_tool_calls(self, model_output: str, request: ChatCompletionRequest)
9799
remaining_text = model_output
98100

99101
while True:
100-
# 查找下一个tool_call块
102+
# Find the next <tool_call>
101103
tool_call_pos = remaining_text.find("<tool_call>")
102104
if tool_call_pos == -1:
103105
break
104106

105-
# 提取tool_call开始位置后的内容
107+
# Extract content after <tool_call>
106108
tool_content_start = tool_call_pos + len("<tool_call>")
107109
tool_content_end = remaining_text.find("</tool_call>", tool_content_start)
108110

109111
tool_json = ""
110112
if tool_content_end == -1:
111-
# 处理未闭合的tool_call块(截断情况)
113+
# Processing unclosed tool_call block (truncated case)
112114
tool_json = remaining_text[tool_content_start:].strip()
113-
remaining_text = "" # 没有更多内容需要处理
115+
remaining_text = "" # No more content to process
114116
else:
115-
# 处理完整的tool_call块
117+
# Processing closed </tool_call> block
116118
tool_json = remaining_text[tool_content_start:tool_content_end].strip()
117119
remaining_text = remaining_text[tool_content_end + len("</tool_call>") :]
118120

119121
if not tool_json:
120122
continue
121123

122-
# 处理JSON内容
124+
# Process tool_json
123125
tool_json = tool_json.strip()
124126
if not tool_json.startswith("{"):
125127
tool_json = "{" + tool_json
126128
if not tool_json.endswith("}"):
127129
tool_json = tool_json + "}"
128130

129131
try:
130-
# 首先尝试标准JSON解析
132+
# Parsing strategy: First try standard json.loads
131133
try:
132134
tool_data = json.loads(tool_json)
133135

@@ -136,26 +138,26 @@ def extract_tool_calls(self, model_output: str, request: ChatCompletionRequest)
136138
{
137139
"name": tool_data["name"],
138140
"arguments": tool_data["arguments"],
139-
"_is_complete": True, # 明确标记为完整解析
141+
"_is_complete": True, # Mark as complete
140142
}
141143
)
142144
continue
143145
except json.JSONDecodeError:
144146
pass
145147

146-
# 标准解析失败时尝试partial_json_parser
148+
# Try partial_json_parser when standard parsing fails
147149
from partial_json_parser.core.options import Allow
148150

149151
try:
150152
tool_data = {}
151153
flags = Allow.ALL & ~Allow.STR
152154

153-
# 解析name字段
155+
# Parse the name field
154156
name_match = re.search(r'"name"\s*:\s*"([^"]*)"', tool_json)
155157
if name_match:
156158
tool_data["name"] = name_match.group(1)
157159

158-
# 解析arguments字段
160+
# Parse the arguments field
159161
args_match = re.search(r'"arguments"\s*:\s*(\{.*)', tool_json)
160162
if args_match:
161163
try:
@@ -168,7 +170,7 @@ def extract_tool_calls(self, model_output: str, request: ChatCompletionRequest)
168170
{
169171
"name": tool_data.get("name", ""),
170172
"arguments": tool_data.get("arguments", {}),
171-
"_is_partial": True, # 标记为部分解析
173+
"_is_partial": True, # Mark as partial
172174
}
173175
)
174176
except Exception as e:
@@ -183,18 +185,18 @@ def extract_tool_calls(self, model_output: str, request: ChatCompletionRequest)
183185
return ExtractedToolCallInformation(tools_called=False, content=model_output)
184186

185187
tool_calls = []
186-
all_complete = True # 初始设为True,只要有一个不完整就变为False
188+
all_complete = True # Initialize as all complete
187189

188190
for tool_call in function_call_arr:
189-
# 记录工具调用解析状态
191+
# Set flags
190192
is_complete = tool_call.get("_is_complete", False)
191193
is_partial = tool_call.get("_is_partial", False)
192194

193-
# 只要有一个不完整就认为整体不完整
195+
# If any tool call is incomplete or partial, mark all_complete as False
194196
if not is_complete or is_partial:
195197
all_complete = False
196198

197-
# 处理参数序列化
199+
# Process arguments
198200
tool_args = tool_call.get("arguments", {})
199201
if not isinstance(tool_args, dict):
200202
tool_args = {}
@@ -215,7 +217,7 @@ def extract_tool_calls(self, model_output: str, request: ChatCompletionRequest)
215217
)
216218
)
217219

218-
# 只有当所有工具调用都明确标记为complete时才返回tools_called=True
220+
# Only return tools_called=True if all tool calls are complete
219221
return ExtractedToolCallInformation(
220222
tools_called=all_complete, tool_calls=tool_calls if tool_calls else None, content=""
221223
)
@@ -237,16 +239,16 @@ def extract_tool_calls_streaming(
237239

238240
if self.tool_call_start_token_id not in current_token_ids:
239241
return DeltaMessage(content=delta_text)
240-
# 忽略空chunk
242+
# Skip empty chunks
241243
if len(delta_text.strip()) == 0:
242244
return None
243245

244246
try:
245247
delta = None
246-
# 使用buffer累积delta_text内容
248+
# Use buffer to accumulate delta_text content
247249
self.buffer += delta_text
248250

249-
# 处理增量中的新tool_call开始
251+
# Process the buffer content
250252
if "<tool_call>" in delta_text:
251253
self.current_tool_id = (
252254
max(self.current_tool_id, 0) if self.current_tool_id == -1 else self.current_tool_id + 1
@@ -256,7 +258,7 @@ def extract_tool_calls_streaming(
256258
self.streamed_args_for_tool.append("")
257259
data_processor_logger.debug(f"New tool call started with ID: {self.current_tool_id}")
258260

259-
# 1. 尝试解析name字段
261+
# 1. Try to parse the name field
260262
if not self.current_tool_name_sent and '"name"' in self.buffer:
261263
name_match = re.search(r'"name"\s*:\s*"([^"]*)"', self.buffer)
262264
if name_match:
@@ -272,32 +274,31 @@ def extract_tool_calls_streaming(
272274
)
273275
]
274276
)
275-
# 删除已处理的name部分
277+
# Delete the processed name part from the buffer
276278
self.buffer = self.buffer[name_match.end() :]
277279
self.current_tool_name_sent = True
278280
return delta
279-
# 2. 尝试解析arguments字段
281+
# 2. Processing arguments field
280282
if '"arguments"' in self.buffer:
281283
args_match = re.search(r'"arguments"\s*:\s*(\{.*)', self.buffer)
282284
if args_match:
283285
args_content = args_match.group(1)
284286
try:
285-
# 检查是否到达arguments结尾(括号完全匹配)
287+
# Check if arguments field is complete by bracket matching
286288
if "}}" in args_content:
287-
# 逐个字符检查括号匹配状态
288289
matched_pos = -1
289290
for i, ch in enumerate(delta_text):
290291
if ch == "{":
291292
self.bracket_counts["total_l"] += 1
292293
elif ch == "}":
293294
self.bracket_counts["total_r"] += 1
294295

295-
if self.bracket_counts["total_l"] == self.bracket_counts["total_r"]: # 括号完全匹配
296+
if self.bracket_counts["total_l"] == self.bracket_counts["total_r"]:
296297
matched_pos = i
297298
break
298299

299300
if matched_pos >= 0:
300-
# 找到匹配点,清理buffer并返回
301+
# Clean up bracket counts for next tool call
301302
truncate_text = delta_text[: matched_pos + 1]
302303
delta = DeltaMessage(
303304
tool_calls=[
@@ -312,10 +313,10 @@ def extract_tool_calls_streaming(
312313
self.buffer = self.buffer[args_match.end() :]
313314
return delta
314315
else:
315-
# 没有完全匹配,继续累积
316+
# No complete match yet
316317
return None
317318
else:
318-
# 增量返回当前可解析的部分
319+
# Return partial arguments
319320
for ch in delta_text:
320321
if ch == "{":
321322
self.bracket_counts["total_l"] += 1
@@ -337,7 +338,6 @@ def extract_tool_calls_streaming(
337338
end_pos = self.buffer.find("</tool_call>")
338339
self.buffer = self.buffer[end_pos + len("</tool_call>") :]
339340

340-
# 完成当前工具调用处理
341341
self.streamed_args_for_tool.append("")
342342

343343
return delta
Lines changed: 31 additions & 89 deletions
Original file line numberDiff line numberDiff line change
@@ -1,19 +1,3 @@
1-
"""
2-
# Copyright (c) 2025 PaddlePaddle Authors. All Rights Reserved.
3-
#
4-
# Licensed under the Apache License, Version 2.0 (the "License"
5-
# you may not use this file except in compliance with the License.
6-
# You may obtain a copy of the License at
7-
#
8-
# http://www.apache.org/licenses/LICENSE-2.0
9-
#
10-
# Unless required by applicable law or agreed to in writing, software
11-
# distributed under the License is distributed on an "AS IS" BASIS,
12-
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13-
# See the License for the specific language governing permissions and
14-
# limitations under the License.
15-
"""
16-
171
from collections.abc import Sequence
182
from typing import Tuple, Union
193

@@ -26,10 +10,10 @@ class ErnieX1ReasoningParser(ReasoningParser):
2610
"""
2711
Reasoning parser for ernie_x1 model with stricter boundary checking.
2812
29-
This implementation follows the user's proposed approach:
30-
1. For thinking content: waits for \n then checks for </think> tag
31-
2. For response content: checks for <response> tag first, then waits for \n
32-
3. Handles newlines in content more precisely
13+
Unified rules:
14+
- Do not strip newline before </think>
15+
- Do not strip newline after <response>
16+
- Do not strip newline before </response>
3317
"""
3418

3519
def __init__(self, tokenizer):
@@ -48,9 +32,6 @@ def __init__(self, tokenizer):
4832
raise RuntimeError("Could not find think end token id in tokenizer vocabulary")
4933
self.tool_call_start_token_id = self.vocab.get("<tool_call>")
5034

51-
def is_reasoning_end(self, input_ids: list[int]) -> bool:
52-
return self.tool_call_start_token_id in input_ids
53-
5435
def extract_reasoning_content_streaming(
5536
self,
5637
previous_text: str,
@@ -60,102 +41,63 @@ def extract_reasoning_content_streaming(
6041
current_token_ids: Sequence[int],
6142
delta_token_ids: Sequence[int],
6243
) -> Union[DeltaMessage, None]:
63-
"""
64-
根据用户需求实现的流式解析方法:
65-
1. 初始内容都视为思考内容,返回delta_text,""
66-
2. 当遇到\n时检查后续是否是</think>
67-
3. 如果直接遇到</think>也结束思考
68-
4. 思考结束后检查是<response>还是<tool_call>
69-
5. 对于<response>内容,处理各种边界条件
70-
"""
44+
# Ignore the single </think> token
7145
if len(delta_token_ids) == 1 and delta_token_ids[0] == self.think_end_token_id:
7246
return None
73-
# 思考阶段处理
47+
48+
# --- Thinking stage handling ---
7449
if not previous_text.endswith(self.think_end_token) and self.think_end_token not in previous_text:
75-
# 如果遇到\n,暂时不返回,等待下一个delta_text
76-
if delta_text == "\n":
77-
return None
78-
# 如果前一个是\n且当前是</think>,结束思考
79-
elif previous_text.endswith("\n") and delta_text.startswith(self.think_end_token):
80-
return None
81-
# 如果直接遇到</think>也结束思考
82-
elif delta_text.startswith(self.think_end_token):
50+
# If delta is </think>, stop thinking, do not return
51+
if delta_text.startswith(self.think_end_token):
8352
return None
84-
# 否则继续返回思考内容
53+
# Otherwise, return thinking content (keep \n as-is)
8554
return DeltaMessage(reasoning_content=delta_text)
8655

87-
# 思考结束后检查是tool_call还是response
56+
# --- After thinking ends, check tool_call or response ---
8857
remaining_text = previous_text + delta_text
8958
after_think = remaining_text[remaining_text.find(self.think_end_token) + len(self.think_end_token) :]
90-
after_think = after_think.lstrip("\n") # 跳过think后的换行
59+
after_think = after_think.lstrip("\n")
9160

92-
# 处理tool_call情况
61+
# Handle tool_call case: skip it
9362
if after_think.startswith(self.tool_call_start_token):
9463
return None
9564

96-
# 处理response情况
97-
if after_think.startswith(self.response_start_token):
98-
# 遇到<response>标签时不立即返回
99-
if delta_text == self.response_start_token:
65+
# Handle response case
66+
if after_think.startswith(self.response_start_token) and self.response_end_token not in after_think:
67+
# Do not return when <response> tag itself appears
68+
if delta_text == self.response_start_token or delta_text == self.response_end_token:
10069
return None
101-
# 遇到<response>后的换行符也不立即返回
102-
elif delta_text == "\n" and previous_text.endswith(self.response_start_token):
103-
return None
104-
# 处理回复内容中的换行符
105-
if delta_text == "\n":
106-
return None
107-
# 如果前一个是\n且当前是</response>,结束回复
108-
elif previous_text.endswith("\n") and delta_text == self.response_end_token:
109-
return None
110-
# 如果直接遇到</response>也结束回复
111-
elif delta_text == self.response_end_token:
112-
return None
113-
# 其他情况返回实际内容
114-
else:
115-
return DeltaMessage(content=delta_text)
70+
return DeltaMessage(content=delta_text)
11671

117-
# 默认情况不返回内容
72+
# Default case: return nothing
11873
return None
11974

12075
def extract_reasoning_content(self, model_output: str, request: ChatCompletionRequest) -> Tuple[str, str]:
121-
"""
122-
Batch version of the enhanced parser.
123-
Modified to preserve newlines in both reasoning and response content,
124-
only removing the single newline before closing tags.
125-
"""
12676
reasoning_content = ""
12777
response_content = ""
12878

12979
think_end_pos = model_output.find(self.think_end_token)
13080
if think_end_pos != -1:
131-
# Extract thinking content - only remove the last newline before </think>
13281
reasoning_content = model_output[:think_end_pos]
133-
if think_end_pos > 0 and reasoning_content[-1] == "\n":
134-
reasoning_content = reasoning_content[:-1]
13582

13683
remaining = model_output[think_end_pos + len(self.think_end_token) :]
13784

138-
# Skip newlines after </think>
139-
remaining = remaining.lstrip("\n")
85+
# find <response> or <tool>
86+
response_pos = remaining.find(self.response_start_token)
87+
tool_pos = remaining.find(self.tool_call_start_token)
14088

141-
# Check for response or tool_call
142-
if remaining.startswith(self.response_start_token):
143-
response_pos = len(self.response_start_token)
144-
remaining = remaining[response_pos:].lstrip("\n")
145-
response_end_pos = remaining.find(self.response_end_token)
89+
# <response> first
90+
if response_pos != -1 and (tool_pos == -1 or response_pos < tool_pos):
91+
# The content after the response_start position
92+
remaining_response = remaining[response_pos + len(self.response_start_token) :]
93+
response_end_pos = remaining_response.find(self.response_end_token)
14694
if response_end_pos != -1:
147-
# Only strip the last newline before </response>, not all
148-
if response_end_pos > 0 and remaining[response_end_pos - 1] == "\n":
149-
response_content = remaining[: response_end_pos - 1]
150-
else:
151-
response_content = remaining[:response_end_pos]
95+
response_content = remaining_response[:response_end_pos]
15296
else:
153-
# If no </response> found, return the rest as response content
154-
response_content = remaining
155-
elif remaining.startswith(self.tool_call_start_token):
156-
pass # No response content
97+
response_content = remaining_response
98+
# The content after the response_start position is tool_call
15799
else:
158-
# No thinking content found, return the whole input as reasoning
159100
reasoning_content = model_output
160101
response_content = ""
102+
161103
return reasoning_content, response_content

0 commit comments

Comments
 (0)