14
14
15
15
import json
16
16
import re
17
+ import uuid
17
18
from collections .abc import Sequence
18
19
from typing import Union
19
20
20
- from fastdeploy .entrypoints .chat_utils import random_tool_call_id
21
+ import partial_json_parser
22
+
23
+
24
+ def random_tool_call_id () -> str :
25
+ """Generate a random tool call ID"""
26
+ return f"chatcmpl-tool-{ str (uuid .uuid4 ().hex )} "
27
+
28
+
21
29
from fastdeploy .entrypoints .openai .protocol import (
22
30
ChatCompletionRequest ,
23
31
DeltaFunctionCall ,
@@ -53,8 +61,6 @@ def __init__(self, tokenizer):
53
61
self .tool_call_start_token : str = "<tool_call>"
54
62
self .tool_call_end_token : str = "</tool_call>"
55
63
56
- self .tool_call_regex = re .compile (r"<tool_call>(.*?)</tool_call>|<tool_call>(.*)" , re .DOTALL )
57
-
58
64
self .tool_call_start_token_id = self .vocab .get (self .tool_call_start_token )
59
65
self .tool_call_end_token_id = self .vocab .get (self .tool_call_end_token )
60
66
if self .tool_call_start_token_id is None or self .tool_call_end_token_id is None :
@@ -67,9 +73,7 @@ def __init__(self, tokenizer):
67
73
"The model tokenizer must be passed to the ToolCallParser constructor during construction."
68
74
)
69
75
70
- def extract_tool_calls (
71
- self , model_output : str , request : ChatCompletionRequest , model_status : str
72
- ) -> ExtractedToolCallInformation :
76
+ def extract_tool_calls (self , model_output : str , request : ChatCompletionRequest ) -> ExtractedToolCallInformation :
73
77
"""
74
78
Extract the tool calls from a complete model response.
75
79
Supports XML-style formats with newlines:
@@ -81,31 +85,144 @@ def extract_tool_calls(
81
85
3. Only name and arguments field without content: {"name": "get_weather", "argume
82
86
"""
83
87
84
- extract_content = model_output
85
- if model_status == "tool_call_start" :
86
- extract_content = "<tool_call>" + model_output
87
88
try :
88
- if self .tool_call_start_token not in extract_content :
89
- return ExtractedToolCallInformation (tools_called = False , tool_calls = [], content = model_output )
90
- function_call_tuples = self .tool_call_regex .findall (extract_content )
91
-
92
- raw_function_calls = [json .loads (match [0 ] if match [0 ] else match [1 ]) for match in function_call_tuples ]
93
-
94
- tool_calls = [
95
- ToolCall (
96
- type = "function" ,
97
- function = FunctionCall (
98
- name = function_call ["name" ],
99
- # function call args are JSON but as a string
100
- arguments = json .dumps (function_call ["arguments" ], ensure_ascii = False ),
101
- ),
89
+ tool_calls = []
90
+
91
+ # Check for invalid <response> tags before tool calls
92
+ if re .search (r"<response>[\s\S]*?</response>\s*(?=<tool_call>)" , model_output ):
93
+ data_processor_logger .error ("Invalid format: <response> tags found before <tool_call>" )
94
+ return ExtractedToolCallInformation (tools_called = False , content = model_output )
95
+
96
+ function_call_arr = []
97
+ remaining_text = model_output
98
+
99
+ while True :
100
+ # 查找下一个tool_call块
101
+ tool_call_pos = remaining_text .find ("<tool_call>" )
102
+ if tool_call_pos == - 1 :
103
+ break
104
+
105
+ # 提取tool_call开始位置后的内容
106
+ tool_content_start = tool_call_pos + len ("<tool_call>" )
107
+ tool_content_end = remaining_text .find ("</tool_call>" , tool_content_start )
108
+
109
+ tool_json = ""
110
+ if tool_content_end == - 1 :
111
+ # 处理未闭合的tool_call块(截断情况)
112
+ tool_json = remaining_text [tool_content_start :].strip ()
113
+ remaining_text = "" # 没有更多内容需要处理
114
+ else :
115
+ # 处理完整的tool_call块
116
+ tool_json = remaining_text [tool_content_start :tool_content_end ].strip ()
117
+ remaining_text = remaining_text [tool_content_end + len ("</tool_call>" ) :]
118
+
119
+ if not tool_json :
120
+ continue
121
+
122
+ # 处理JSON内容
123
+ tool_json = tool_json .strip ()
124
+ if not tool_json .startswith ("{" ):
125
+ tool_json = "{" + tool_json
126
+ if not tool_json .endswith ("}" ):
127
+ tool_json = tool_json + "}"
128
+
129
+ try :
130
+ # 首先尝试标准JSON解析
131
+ try :
132
+ tool_data = json .loads (tool_json )
133
+
134
+ if isinstance (tool_data , dict ) and "name" in tool_data and "arguments" in tool_data :
135
+ function_call_arr .append (
136
+ {
137
+ "name" : tool_data ["name" ],
138
+ "arguments" : tool_data ["arguments" ],
139
+ "_is_complete" : True , # 明确标记为完整解析
140
+ }
141
+ )
142
+ continue
143
+ except json .JSONDecodeError :
144
+ pass
145
+
146
+ # 标准解析失败时尝试partial_json_parser
147
+ from partial_json_parser .core .options import Allow
148
+
149
+ try :
150
+ tool_data = {}
151
+ flags = Allow .ALL & ~ Allow .STR
152
+
153
+ # 解析name字段
154
+ name_match = re .search (r'"name"\s*:\s*"([^"]*)"' , tool_json )
155
+ if name_match :
156
+ tool_data ["name" ] = name_match .group (1 )
157
+
158
+ # 解析arguments字段
159
+ args_match = re .search (r'"arguments"\s*:\s*(\{.*)' , tool_json )
160
+ if args_match :
161
+ try :
162
+ tool_data ["arguments" ] = partial_json_parser .loads (args_match .group (1 ), flags = flags )
163
+ except :
164
+ tool_data ["arguments" ] = None
165
+
166
+ if isinstance (tool_data , dict ):
167
+ function_call_arr .append (
168
+ {
169
+ "name" : tool_data .get ("name" , "" ),
170
+ "arguments" : tool_data .get ("arguments" , {}),
171
+ "_is_partial" : True , # 标记为部分解析
172
+ }
173
+ )
174
+ except Exception as e :
175
+ data_processor_logger .debug (f"Failed to parse tool call: { str (e )} " )
176
+ continue
177
+ except Exception as e :
178
+ data_processor_logger .debug (f"Failed to parse tool call: { str (e )} " )
179
+ continue
180
+
181
+ if not function_call_arr :
182
+ data_processor_logger .error ("No valid tool calls found" )
183
+ return ExtractedToolCallInformation (tools_called = False , content = model_output )
184
+
185
+ tool_calls = []
186
+ all_complete = True # 初始设为True,只要有一个不完整就变为False
187
+
188
+ for tool_call in function_call_arr :
189
+ # 记录工具调用解析状态
190
+ is_complete = tool_call .get ("_is_complete" , False )
191
+ is_partial = tool_call .get ("_is_partial" , False )
192
+
193
+ # 只要有一个不完整就认为整体不完整
194
+ if not is_complete or is_partial :
195
+ all_complete = False
196
+
197
+ # 处理参数序列化
198
+ tool_args = tool_call .get ("arguments" , {})
199
+ if not isinstance (tool_args , dict ):
200
+ tool_args = {}
201
+
202
+ try :
203
+ args_str = json .dumps (tool_args , ensure_ascii = False ) if tool_args else "{}"
204
+ except :
205
+ args_str = "{}"
206
+
207
+ tool_calls .append (
208
+ ToolCall (
209
+ type = "function" ,
210
+ id = random_tool_call_id (),
211
+ function = FunctionCall (
212
+ name = tool_call .get ("name" , "" ),
213
+ arguments = args_str ,
214
+ ),
215
+ )
102
216
)
103
- for function_call in raw_function_calls
104
- ]
105
- return ExtractedToolCallInformation (tools_called = True , tool_calls = tool_calls , content = "" )
106
- except Exception :
107
- data_processor_logger .error ("Error in extracting tool call from response." )
108
- return ExtractedToolCallInformation (tools_called = False , tool_calls = [], content = model_output )
217
+
218
+ # 只有当所有工具调用都明确标记为complete时才返回tools_called=True
219
+ return ExtractedToolCallInformation (
220
+ tools_called = all_complete , tool_calls = tool_calls if tool_calls else None , content = ""
221
+ )
222
+
223
+ except Exception as e :
224
+ data_processor_logger .error (f"Error in extracting tool call from response: { str (e )} " )
225
+ return ExtractedToolCallInformation (tools_called = False , tool_calls = None , content = model_output )
109
226
110
227
def extract_tool_calls_streaming (
111
228
self ,
@@ -116,7 +233,6 @@ def extract_tool_calls_streaming(
116
233
current_token_ids : Sequence [int ],
117
234
delta_token_ids : Sequence [int ],
118
235
request : dict ,
119
- model_status : str ,
120
236
) -> Union [DeltaMessage , None ]:
121
237
122
238
if self .tool_call_start_token_id not in current_token_ids :
0 commit comments