Skip to content

Commit 671a4dc

Browse files
committed
add x1 parser
1 parent 234ef92 commit 671a4dc

File tree

1 file changed

+94
-41
lines changed

1 file changed

+94
-41
lines changed

fastdeploy/reasoning/ernie_x1_reasoning_parsers.py

Lines changed: 94 additions & 41 deletions
Original file line numberDiff line numberDiff line change
@@ -34,19 +34,62 @@ class ErnieX1ReasoningParser(ReasoningParser):
3434

3535
def __init__(self, tokenizer):
3636
super().__init__(tokenizer)
37-
self.think_end_token = "</think>"
38-
self.response_start_token = "<response>"
39-
self.response_end_token = "</response>"
40-
self.tool_call_start_token = "<tool_call>"
41-
self.tool_call_end_token = "</tool_call>"
37+
38+
# 定义所有需要检查的token
39+
token_definitions = {
40+
"think_start_token": "<think>",
41+
"think_end_token": "</think>",
42+
"response_start_token": "<response>",
43+
"response_end_token": "</response>",
44+
"tool_call_start_token": "<tool_call>",
45+
"tool_call_end_token": "</tool_call>",
46+
}
4247

4348
if not self.model_tokenizer:
4449
raise ValueError("The model tokenizer must be passed to the ReasoningParser constructor.")
4550

46-
self.think_end_token_id = self.vocab.get("</think>")
47-
if self.think_end_token_id is None:
48-
raise RuntimeError("Could not find think end token id in tokenizer vocabulary")
49-
self.tool_call_start_token_id = self.vocab.get("<tool_call>")
51+
missing_tokens = []
52+
for name, token_value in token_definitions.items():
53+
setattr(self, name, token_value)
54+
token_id = self.vocab.get(token_value)
55+
setattr(self, f"{name}_id", token_id)
56+
if token_id is None:
57+
missing_tokens.append(f"{name.replace('_', ' ')} token")
58+
59+
if missing_tokens:
60+
raise RuntimeError(
61+
f"Could not find the following token ids in tokenizer vocabulary: {', '.join(missing_tokens)}"
62+
)
63+
64+
self.token_status_mapping = {
65+
self.think_start_token_id: "think_start",
66+
self.think_end_token_id: "think_end",
67+
self.response_start_token_id: "response_start",
68+
self.response_end_token_id: "response_end",
69+
self.tool_call_start_token_id: "tool_call_start",
70+
self.tool_call_end_token_id: "tool_call_end",
71+
}
72+
73+
def find_last_special_token(self, prompt_token_ids: list[int]) -> int:
74+
for i in range(len(prompt_token_ids) - 1, -1, -1):
75+
if prompt_token_ids[i] in [
76+
self.think_end_token_id,
77+
self.think_start_token_id,
78+
self.response_start_token_id,
79+
self.response_end_token_id,
80+
self.tool_call_start_token_id,
81+
self.tool_call_end_token_id,
82+
]:
83+
return prompt_token_ids[i]
84+
return -1
85+
86+
def get_model_status(self, prompt_token_ids: list[int]):
87+
special_token_id = self.find_last_special_token(prompt_token_ids)
88+
89+
if special_token_id == -1:
90+
return "response_start"
91+
92+
return self.token_status_mapping.get(special_token_id, "response_start")
5093

5194
def is_reasoning_end(self, input_ids: list[int]) -> bool:
5295
return self.tool_call_start_token_id in input_ids
@@ -117,45 +160,55 @@ def extract_reasoning_content_streaming(
117160
# 默认情况不返回内容
118161
return None
119162

120-
def extract_reasoning_content(self, model_output: str, request: ChatCompletionRequest) -> Tuple[str, str]:
163+
def strip_last_newline(self, content: str, end_pos: int) -> str:
164+
return content[: end_pos - 1] if end_pos > 0 and content[end_pos - 1] == "\n" else content[:end_pos]
165+
166+
def extract_reasoning_content(
167+
self, model_output: str, request: ChatCompletionRequest, model_status: str
168+
) -> Tuple[str, str]:
121169
"""
122-
Batch version of the enhanced parser.
123-
Modified to preserve newlines in both reasoning and response content,
170+
Optimized batch version of the enhanced parser.
171+
Preserves newlines in both reasoning and response content,
124172
only removing the single newline before closing tags.
125173
"""
126174
reasoning_content = ""
127175
response_content = ""
128176

129-
think_end_pos = model_output.find(self.think_end_token)
130-
if think_end_pos != -1:
131-
# Extract thinking content - only remove the last newline before </think>
132-
reasoning_content = model_output[:think_end_pos]
133-
if think_end_pos > 0 and reasoning_content[-1] == "\n":
134-
reasoning_content = reasoning_content[:-1]
177+
# Define helper function to strip the last newline before a closing tag
178+
if model_status == "think_start":
179+
think_end_pos = model_output.find(self.think_end_token)
180+
if think_end_pos != -1:
181+
# Extract reasoning content
182+
reasoning_content = self.strip_last_newline(model_output, think_end_pos)
183+
remaining = model_output[think_end_pos + len(self.think_end_token) :].lstrip("\n")
184+
185+
# Determine if remaining content is a response or tool call
186+
if remaining.startswith(self.response_start_token):
187+
response_start_pos = len(self.response_start_token)
188+
response_content = self._extract_response_content(remaining[response_start_pos:])
189+
elif remaining.startswith(self.tool_call_start_token):
190+
pass # No response content
191+
else:
192+
# No think_end_token found, treat entire output as reasoning content
193+
reasoning_content = model_output
135194

136-
remaining = model_output[think_end_pos + len(self.think_end_token) :]
195+
elif model_status == "think_end":
196+
remaining = model_output.lstrip("\n")
197+
if remaining.startswith(self.response_start_token):
198+
response_start_pos = len(self.response_start_token)
199+
response_content = self._extract_response_content(remaining[response_start_pos:])
137200

138-
# Skip newlines after </think>
139-
remaining = remaining.lstrip("\n")
201+
elif model_status == "response_start":
202+
response_content = model_output.replace(self.response_end_token, "")
140203

141-
# Check for response or tool_call
142-
if remaining.startswith(self.response_start_token):
143-
response_pos = len(self.response_start_token)
144-
remaining = remaining[response_pos:].lstrip("\n")
145-
response_end_pos = remaining.find(self.response_end_token)
146-
if response_end_pos != -1:
147-
# Only strip the last newline before </response>, not all
148-
if response_end_pos > 0 and remaining[response_end_pos - 1] == "\n":
149-
response_content = remaining[: response_end_pos - 1]
150-
else:
151-
response_content = remaining[:response_end_pos]
152-
else:
153-
# If no </response> found, return the rest as response content
154-
response_content = remaining
155-
elif remaining.startswith(self.tool_call_start_token):
156-
pass # No response content
157-
else:
158-
# No thinking content found, return the whole input as reasoning
159-
reasoning_content = model_output
160-
response_content = ""
161204
return reasoning_content, response_content
205+
206+
def _extract_response_content(self, remaining: str) -> str:
207+
"""
208+
Extracts response content, ensuring that the last newline before
209+
the </response> tag is removed.
210+
"""
211+
response_end_pos = remaining.find(self.response_end_token)
212+
if response_end_pos != -1:
213+
return self.strip_last_newline(remaining, response_end_pos)
214+
return remaining

0 commit comments

Comments
 (0)