14
14
15
15
import json
16
16
import re
17
- import uuid
18
17
from collections .abc import Sequence
19
18
from typing import Union
20
19
21
- import partial_json_parser
22
-
23
-
24
- def random_tool_call_id () -> str :
25
- """Generate a random tool call ID"""
26
- return f"chatcmpl-tool-{ str (uuid .uuid4 ().hex )} "
27
-
28
-
20
+ from fastdeploy .entrypoints .chat_utils import random_tool_call_id
29
21
from fastdeploy .entrypoints .openai .protocol import (
30
22
ChatCompletionRequest ,
31
23
DeltaFunctionCall ,
@@ -61,6 +53,8 @@ def __init__(self, tokenizer):
61
53
self .tool_call_start_token : str = "<tool_call>"
62
54
self .tool_call_end_token : str = "</tool_call>"
63
55
56
+ self .tool_call_regex = re .compile (r"<tool_call>(.*?)</tool_call>|<tool_call>(.*)" , re .DOTALL )
57
+
64
58
self .tool_call_start_token_id = self .vocab .get (self .tool_call_start_token )
65
59
self .tool_call_end_token_id = self .vocab .get (self .tool_call_end_token )
66
60
if self .tool_call_start_token_id is None or self .tool_call_end_token_id is None :
@@ -73,7 +67,9 @@ def __init__(self, tokenizer):
73
67
"The model tokenizer must be passed to the ToolCallParser constructor during construction."
74
68
)
75
69
76
- def extract_tool_calls (self , model_output : str , request : ChatCompletionRequest ) -> ExtractedToolCallInformation :
70
+ def extract_tool_calls (
71
+ self , model_output : str , request : ChatCompletionRequest , model_status : str
72
+ ) -> ExtractedToolCallInformation :
77
73
"""
78
74
Extract the tool calls from a complete model response.
79
75
Supports XML-style formats with newlines:
@@ -85,144 +81,31 @@ def extract_tool_calls(self, model_output: str, request: ChatCompletionRequest)
85
81
3. Only name and arguments field without content: {"name": "get_weather", "argume
86
82
"""
87
83
84
+ extract_content = model_output
85
+ if model_status == "tool_call_start" :
86
+ extract_content = "<tool_call>" + model_output
88
87
try :
89
- tool_calls = []
90
-
91
- # Check for invalid <response> tags before tool calls
92
- if re .search (r"<response>[\s\S]*?</response>\s*(?=<tool_call>)" , model_output ):
93
- data_processor_logger .error ("Invalid format: <response> tags found before <tool_call>" )
94
- return ExtractedToolCallInformation (tools_called = False , content = model_output )
95
-
96
- function_call_arr = []
97
- remaining_text = model_output
98
-
99
- while True :
100
- # 查找下一个tool_call块
101
- tool_call_pos = remaining_text .find ("<tool_call>" )
102
- if tool_call_pos == - 1 :
103
- break
104
-
105
- # 提取tool_call开始位置后的内容
106
- tool_content_start = tool_call_pos + len ("<tool_call>" )
107
- tool_content_end = remaining_text .find ("</tool_call>" , tool_content_start )
108
-
109
- tool_json = ""
110
- if tool_content_end == - 1 :
111
- # 处理未闭合的tool_call块(截断情况)
112
- tool_json = remaining_text [tool_content_start :].strip ()
113
- remaining_text = "" # 没有更多内容需要处理
114
- else :
115
- # 处理完整的tool_call块
116
- tool_json = remaining_text [tool_content_start :tool_content_end ].strip ()
117
- remaining_text = remaining_text [tool_content_end + len ("</tool_call>" ) :]
118
-
119
- if not tool_json :
120
- continue
121
-
122
- # 处理JSON内容
123
- tool_json = tool_json .strip ()
124
- if not tool_json .startswith ("{" ):
125
- tool_json = "{" + tool_json
126
- if not tool_json .endswith ("}" ):
127
- tool_json = tool_json + "}"
128
-
129
- try :
130
- # 首先尝试标准JSON解析
131
- try :
132
- tool_data = json .loads (tool_json )
133
-
134
- if isinstance (tool_data , dict ) and "name" in tool_data and "arguments" in tool_data :
135
- function_call_arr .append (
136
- {
137
- "name" : tool_data ["name" ],
138
- "arguments" : tool_data ["arguments" ],
139
- "_is_complete" : True , # 明确标记为完整解析
140
- }
141
- )
142
- continue
143
- except json .JSONDecodeError :
144
- pass
145
-
146
- # 标准解析失败时尝试partial_json_parser
147
- from partial_json_parser .core .options import Allow
148
-
149
- try :
150
- tool_data = {}
151
- flags = Allow .ALL & ~ Allow .STR
152
-
153
- # 解析name字段
154
- name_match = re .search (r'"name"\s*:\s*"([^"]*)"' , tool_json )
155
- if name_match :
156
- tool_data ["name" ] = name_match .group (1 )
157
-
158
- # 解析arguments字段
159
- args_match = re .search (r'"arguments"\s*:\s*(\{.*)' , tool_json )
160
- if args_match :
161
- try :
162
- tool_data ["arguments" ] = partial_json_parser .loads (args_match .group (1 ), flags = flags )
163
- except :
164
- tool_data ["arguments" ] = None
165
-
166
- if isinstance (tool_data , dict ):
167
- function_call_arr .append (
168
- {
169
- "name" : tool_data .get ("name" , "" ),
170
- "arguments" : tool_data .get ("arguments" , {}),
171
- "_is_partial" : True , # 标记为部分解析
172
- }
173
- )
174
- except Exception as e :
175
- data_processor_logger .debug (f"Failed to parse tool call: { str (e )} " )
176
- continue
177
- except Exception as e :
178
- data_processor_logger .debug (f"Failed to parse tool call: { str (e )} " )
179
- continue
180
-
181
- if not function_call_arr :
182
- data_processor_logger .error ("No valid tool calls found" )
183
- return ExtractedToolCallInformation (tools_called = False , content = model_output )
184
-
185
- tool_calls = []
186
- all_complete = True # 初始设为True,只要有一个不完整就变为False
187
-
188
- for tool_call in function_call_arr :
189
- # 记录工具调用解析状态
190
- is_complete = tool_call .get ("_is_complete" , False )
191
- is_partial = tool_call .get ("_is_partial" , False )
192
-
193
- # 只要有一个不完整就认为整体不完整
194
- if not is_complete or is_partial :
195
- all_complete = False
196
-
197
- # 处理参数序列化
198
- tool_args = tool_call .get ("arguments" , {})
199
- if not isinstance (tool_args , dict ):
200
- tool_args = {}
201
-
202
- try :
203
- args_str = json .dumps (tool_args , ensure_ascii = False ) if tool_args else "{}"
204
- except :
205
- args_str = "{}"
206
-
207
- tool_calls .append (
208
- ToolCall (
209
- type = "function" ,
210
- id = random_tool_call_id (),
211
- function = FunctionCall (
212
- name = tool_call .get ("name" , "" ),
213
- arguments = args_str ,
214
- ),
215
- )
88
+ if self .tool_call_start_token not in extract_content :
89
+ return ExtractedToolCallInformation (tools_called = False , tool_calls = [], content = model_output )
90
+ function_call_tuples = self .tool_call_regex .findall (extract_content )
91
+
92
+ raw_function_calls = [json .loads (match [0 ] if match [0 ] else match [1 ]) for match in function_call_tuples ]
93
+
94
+ tool_calls = [
95
+ ToolCall (
96
+ type = "function" ,
97
+ function = FunctionCall (
98
+ name = function_call ["name" ],
99
+ # function call args are JSON but as a string
100
+ arguments = json .dumps (function_call ["arguments" ], ensure_ascii = False ),
101
+ ),
216
102
)
217
-
218
- # 只有当所有工具调用都明确标记为complete时才返回tools_called=True
219
- return ExtractedToolCallInformation (
220
- tools_called = all_complete , tool_calls = tool_calls if tool_calls else None , content = ""
221
- )
222
-
223
- except Exception as e :
224
- data_processor_logger .error (f"Error in extracting tool call from response: { str (e )} " )
225
- return ExtractedToolCallInformation (tools_called = False , tool_calls = None , content = model_output )
103
+ for function_call in raw_function_calls
104
+ ]
105
+ return ExtractedToolCallInformation (tools_called = True , tool_calls = tool_calls , content = "" )
106
+ except Exception :
107
+ data_processor_logger .error ("Error in extracting tool call from response." )
108
+ return ExtractedToolCallInformation (tools_called = False , tool_calls = [], content = model_output )
226
109
227
110
def extract_tool_calls_streaming (
228
111
self ,
@@ -233,6 +116,7 @@ def extract_tool_calls_streaming(
233
116
current_token_ids : Sequence [int ],
234
117
delta_token_ids : Sequence [int ],
235
118
request : dict ,
119
+ model_status : str ,
236
120
) -> Union [DeltaMessage , None ]:
237
121
238
122
if self .tool_call_start_token_id not in current_token_ids :
0 commit comments