@@ -34,19 +34,62 @@ class ErnieX1ReasoningParser(ReasoningParser):
34
34
35
35
def __init__ (self , tokenizer ):
36
36
super ().__init__ (tokenizer )
37
- self .think_end_token = "</think>"
38
- self .response_start_token = "<response>"
39
- self .response_end_token = "</response>"
40
- self .tool_call_start_token = "<tool_call>"
41
- self .tool_call_end_token = "</tool_call>"
37
+
38
+ # 定义所有需要检查的token
39
+ token_definitions = {
40
+ "think_start_token" : "<think>" ,
41
+ "think_end_token" : "</think>" ,
42
+ "response_start_token" : "<response>" ,
43
+ "response_end_token" : "</response>" ,
44
+ "tool_call_start_token" : "<tool_call>" ,
45
+ "tool_call_end_token" : "</tool_call>" ,
46
+ }
42
47
43
48
if not self .model_tokenizer :
44
49
raise ValueError ("The model tokenizer must be passed to the ReasoningParser constructor." )
45
50
46
- self .think_end_token_id = self .vocab .get ("</think>" )
47
- if self .think_end_token_id is None :
48
- raise RuntimeError ("Could not find think end token id in tokenizer vocabulary" )
49
- self .tool_call_start_token_id = self .vocab .get ("<tool_call>" )
51
+ missing_tokens = []
52
+ for name , token_value in token_definitions .items ():
53
+ setattr (self , name , token_value )
54
+ token_id = self .vocab .get (token_value )
55
+ setattr (self , f"{ name } _id" , token_id )
56
+ if token_id is None :
57
+ missing_tokens .append (f"{ name .replace ('_' , ' ' )} token" )
58
+
59
+ if missing_tokens :
60
+ raise RuntimeError (
61
+ f"Could not find the following token ids in tokenizer vocabulary: { ', ' .join (missing_tokens )} "
62
+ )
63
+
64
+ self .token_status_mapping = {
65
+ self .think_start_token_id : "think_start" ,
66
+ self .think_end_token_id : "think_end" ,
67
+ self .response_start_token_id : "response_start" ,
68
+ self .response_end_token_id : "response_end" ,
69
+ self .tool_call_start_token_id : "tool_call_start" ,
70
+ self .tool_call_end_token_id : "tool_call_end" ,
71
+ }
72
+
73
+ def find_last_special_token (self , prompt_token_ids : list [int ]) -> int :
74
+ for i in range (len (prompt_token_ids ) - 1 , - 1 , - 1 ):
75
+ if prompt_token_ids [i ] in [
76
+ self .think_end_token_id ,
77
+ self .think_start_token_id ,
78
+ self .response_start_token_id ,
79
+ self .response_end_token_id ,
80
+ self .tool_call_start_token_id ,
81
+ self .tool_call_end_token_id ,
82
+ ]:
83
+ return prompt_token_ids [i ]
84
+ return - 1
85
+
86
+ def get_model_status (self , prompt_token_ids : list [int ]):
87
+ special_token_id = self .find_last_special_token (prompt_token_ids )
88
+
89
+ if special_token_id == - 1 :
90
+ return "response_start"
91
+
92
+ return self .token_status_mapping .get (special_token_id , "response_start" )
50
93
51
94
def is_reasoning_end (self , input_ids : list [int ]) -> bool :
52
95
return self .tool_call_start_token_id in input_ids
@@ -117,45 +160,55 @@ def extract_reasoning_content_streaming(
117
160
# 默认情况不返回内容
118
161
return None
119
162
120
- def extract_reasoning_content (self , model_output : str , request : ChatCompletionRequest ) -> Tuple [str , str ]:
163
+ def strip_last_newline (self , content : str , end_pos : int ) -> str :
164
+ return content [: end_pos - 1 ] if end_pos > 0 and content [end_pos - 1 ] == "\n " else content [:end_pos ]
165
+
166
+ def extract_reasoning_content (
167
+ self , model_output : str , request : ChatCompletionRequest , model_status : str
168
+ ) -> Tuple [str , str ]:
121
169
"""
122
- Batch version of the enhanced parser.
123
- Modified to preserve newlines in both reasoning and response content,
170
+ Optimized batch version of the enhanced parser.
171
+ Preserves newlines in both reasoning and response content,
124
172
only removing the single newline before closing tags.
125
173
"""
126
174
reasoning_content = ""
127
175
response_content = ""
128
176
129
- think_end_pos = model_output .find (self .think_end_token )
130
- if think_end_pos != - 1 :
131
- # Extract thinking content - only remove the last newline before </think>
132
- reasoning_content = model_output [:think_end_pos ]
133
- if think_end_pos > 0 and reasoning_content [- 1 ] == "\n " :
134
- reasoning_content = reasoning_content [:- 1 ]
177
+ # Define helper function to strip the last newline before a closing tag
178
+ if model_status == "think_start" :
179
+ think_end_pos = model_output .find (self .think_end_token )
180
+ if think_end_pos != - 1 :
181
+ # Extract reasoning content
182
+ reasoning_content = self .strip_last_newline (model_output , think_end_pos )
183
+ remaining = model_output [think_end_pos + len (self .think_end_token ) :].lstrip ("\n " )
184
+
185
+ # Determine if remaining content is a response or tool call
186
+ if remaining .startswith (self .response_start_token ):
187
+ response_start_pos = len (self .response_start_token )
188
+ response_content = self ._extract_response_content (remaining [response_start_pos :])
189
+ elif remaining .startswith (self .tool_call_start_token ):
190
+ pass # No response content
191
+ else :
192
+ # No think_end_token found, treat entire output as reasoning content
193
+ reasoning_content = model_output
135
194
136
- remaining = model_output [think_end_pos + len (self .think_end_token ) :]
195
+ elif model_status == "think_end" :
196
+ remaining = model_output .lstrip ("\n " )
197
+ if remaining .startswith (self .response_start_token ):
198
+ response_start_pos = len (self .response_start_token )
199
+ response_content = self ._extract_response_content (remaining [response_start_pos :])
137
200
138
- # Skip newlines after </think>
139
- remaining = remaining . lstrip ( " \n " )
201
+ elif model_status == "response_start" :
202
+ response_content = model_output . replace ( self . response_end_token , " " )
140
203
141
- # Check for response or tool_call
142
- if remaining .startswith (self .response_start_token ):
143
- response_pos = len (self .response_start_token )
144
- remaining = remaining [response_pos :].lstrip ("\n " )
145
- response_end_pos = remaining .find (self .response_end_token )
146
- if response_end_pos != - 1 :
147
- # Only strip the last newline before </response>, not all
148
- if response_end_pos > 0 and remaining [response_end_pos - 1 ] == "\n " :
149
- response_content = remaining [: response_end_pos - 1 ]
150
- else :
151
- response_content = remaining [:response_end_pos ]
152
- else :
153
- # If no </response> found, return the rest as response content
154
- response_content = remaining
155
- elif remaining .startswith (self .tool_call_start_token ):
156
- pass # No response content
157
- else :
158
- # No thinking content found, return the whole input as reasoning
159
- reasoning_content = model_output
160
- response_content = ""
161
204
return reasoning_content , response_content
205
+
206
+ def _extract_response_content (self , remaining : str ) -> str :
207
+ """
208
+ Extracts response content, ensuring that the last newline before
209
+ the </response> tag is removed.
210
+ """
211
+ response_end_pos = remaining .find (self .response_end_token )
212
+ if response_end_pos != - 1 :
213
+ return self .strip_last_newline (remaining , response_end_pos )
214
+ return remaining
0 commit comments