Skip to content

Commit 49c2b41

Browse files
author
Judd
committed
add tool calling script for Qwen2.5
1 parent 64f0fae commit 49c2b41

File tree

1 file changed

+105
-0
lines changed

1 file changed

+105
-0
lines changed

scripts/tool_qwen2.5.py

Lines changed: 105 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,105 @@
1+
import json
2+
from typing import Literal
3+
import sys
4+
from datetime import datetime
5+
6+
from binding import PATH_BINDS
7+
8+
import tool_definition
9+
from tool_definition import dispatch_tool
10+
11+
from tool_mistral import get_tools
12+
13+
FN_CALL_TEMPLATE = """system
14+
You are Qwen, created by Alibaba Cloud. You are a helpful assistant.
15+
16+
Current Date: {date_string}
17+
18+
# Tools
19+
20+
You may call one or more functions to assist with the user query.
21+
22+
You are provided with function signatures within <tools></tools> XML tags:
23+
<tools>
24+
{tools_json}
25+
</tools>
26+
27+
For each function call, return a json object with function name and arguments within <tool_call></tool_call> XML tags:
28+
<tool_call>
29+
{{"name": <function-name>, "arguments": <args-json-object>}}
30+
</tool_call>"""
31+
32+
def build_system_prompt(functions: list[dict]):
33+
tool_desc_template = FN_CALL_TEMPLATE
34+
tools_json = '\n\n'.join([json.dumps(f, ensure_ascii=False) for f in functions])
35+
tool_system = tool_desc_template.format(date_string=datetime.now().strftime('%Y-%m-%d'), tools_json=tools_json)
36+
return tool_system
37+
38+
import chatllm, sys, re
39+
from chatllm import ChatLLM, LLMChatChunk
40+
41+
def call_function(c: dict) -> str:
42+
try:
43+
observations = dispatch_tool(c['name'], c['arguments'], c['id'] if 'id' in c else None)
44+
return observations.text
45+
except Exception as e:
46+
print(f"error occurs: {e}")
47+
return "failed to call the function"
48+
49+
TOOL_CALL_START = "<tool_call>"
50+
TOOL_CALL_CLOSE = "</tool_call>"
51+
52+
TOOL_RESULT_START = "<tool_response>"
53+
TOOL_RESULT_CLOSE = "</tool_response>"
54+
55+
class ToolChatLLM(ChatLLM):
56+
chunk_acc = ''
57+
tool_calls = []
58+
59+
def callback_print(self, s: str) -> None:
60+
if self.chunk_acc is None:
61+
self.chunk_acc = ''
62+
63+
if self.chunk_acc == '':
64+
if TOOL_CALL_START.startswith(s):
65+
self.chunk_acc = s
66+
else:
67+
super().callback_print(s)
68+
69+
return
70+
71+
self.chunk_acc = self.chunk_acc + s
72+
73+
if len(self.chunk_acc) <= len(TOOL_CALL_START): return
74+
75+
if not self.chunk_acc.startswith(TOOL_CALL_START):
76+
super().callback_print(self.chunk_acc)
77+
self.chunk_acc = ''
78+
79+
close = self.chunk_acc.find(TOOL_CALL_CLOSE)
80+
if close > 0:
81+
self.tool_calls.append(self.chunk_acc[len(TOOL_CALL_START):close])
82+
s = self.chunk_acc[close + len(TOOL_CALL_CLOSE):]
83+
if len(s) > 0: super().callback_print(s)
84+
self.chunk_acc = ''
85+
86+
def callback_end(self) -> None:
87+
for t in self.tool_calls:
88+
self.call_tool(t)
89+
90+
self.chunk_acc = ''
91+
super().callback_end()
92+
self.tool_calls = []
93+
94+
def call_tool(self, s: str) -> None:
95+
s = s.strip()
96+
tc = tool_definition.json_decode_ignore_extra(s)
97+
if not isinstance(tc, dict): return
98+
if not 'name' in tc: return
99+
100+
print(f"[Use Tool]: {tc['name']}")
101+
rsp = call_function(tc)
102+
self.tool_input(TOOL_RESULT_START + rsp + TOOL_RESULT_CLOSE)
103+
104+
if __name__ == '__main__':
105+
chatllm.demo_simple(sys.argv[1:] + ['-s', build_system_prompt(get_tools())], ToolChatLLM, lib_path=PATH_BINDS)

0 commit comments

Comments
 (0)