Skip to content

Commit 748dfcc

Browse files
committed
massively improved tool calling
1 parent c4df151 commit 748dfcc

File tree

1 file changed

+94
-35
lines changed

1 file changed

+94
-35
lines changed

koboldcpp.py

Lines changed: 94 additions & 35 deletions
Original file line numberDiff line numberDiff line change
@@ -1982,6 +1982,8 @@ def extract_json_from_string(input_string):
19821982
parsed_json = None
19831983
try: # First check if model exported perfect json
19841984
parsed_json = json.loads(input_string)
1985+
if not isinstance(parsed_json, list):
1986+
parsed_json = [parsed_json]
19851987
return parsed_json
19861988
except Exception:
19871989
pass
@@ -1997,6 +1999,8 @@ def extract_json_from_string(input_string):
19971999
for potential_json in potential_jsons:
19982000
try:
19992001
parsed_json = json.loads(potential_json)
2002+
if not isinstance(parsed_json, list):
2003+
parsed_json = [parsed_json]
20002004
return parsed_json
20012005
except Exception:
20022006
continue
@@ -2039,6 +2043,35 @@ def parse_last_logprobs(lastlogprobs):
20392043
logprobsdict['content'].append(lp_content_item)
20402044
return logprobsdict
20412045

2046+
def extract_tool_info_from_tool_array(chosen_tool, tools_array):
2047+
found_function = ""
2048+
found_tooljson = None
2049+
try:
2050+
if isinstance(chosen_tool, str):
2051+
found_function = chosen_tool
2052+
elif isinstance(chosen_tool, dict): #if we can match the tool name, we must use that tool, remove all other tools
2053+
found_function = chosen_tool.get('function').get('name')
2054+
#if we find the function in tools, remove all other tools except the one matching the function name
2055+
for tool in tools_array:
2056+
if found_function and tool.get('type') == "function" and tool.get('function').get('name').lower() == found_function.lower():
2057+
found_tooljson = tool
2058+
break
2059+
except Exception:
2060+
# In case of any issues, just revert back to no specified function
2061+
print("Tools parsing not valid - discarded")
2062+
pass
2063+
return found_tooljson
2064+
2065+
def extract_all_names_from_tool_array(tools_array):
2066+
toolnames = []
2067+
for tool in tools_array:
2068+
try:
2069+
if tool.get('type') == "function" and tool.get('function').get('name'):
2070+
toolnames.append(tool.get('function').get('name'))
2071+
except Exception:
2072+
pass
2073+
return toolnames
2074+
20422075
def transform_genparams(genparams, api_format):
20432076
global chatcompl_adapter, maxctx
20442077

@@ -2120,32 +2153,6 @@ def transform_genparams(genparams, api_format):
21202153
ws ::= | " " | "\n" [ \t]{0,20}
21212154
"""
21222155

2123-
# tools handling
2124-
tools_array = genparams.get('tools', [])
2125-
chosen_tool = genparams.get('tool_choice', "auto")
2126-
tool_json_formatting_instruction = "\nUse this style of JSON object formatting to give your answer if you think the user is asking you to perform an action: " + json.dumps([{"id": "insert an id for the response", "type": "function", "function": {"name": "insert the name of the function you want to call", "arguments": {"first property key": "first property value", "second property key": "second property value"}}}], indent=0)
2127-
if tools_array and len(tools_array) > 0 and chosen_tool is not None:
2128-
try:
2129-
specified_function = ""
2130-
if isinstance(chosen_tool, str):
2131-
specified_function = chosen_tool
2132-
elif isinstance(chosen_tool, dict): #if we can match the tool name, we must use that tool, remove all other tools
2133-
specified_function = chosen_tool.get('function').get('name')
2134-
located_tooljson = None
2135-
#if we find the function in tools, remove all other tools except the one matching the function name
2136-
for tool in tools_array:
2137-
if specified_function and tool.get('type') == "function" and tool.get('function').get('name') == specified_function:
2138-
located_tooljson = tool
2139-
break
2140-
if located_tooljson:
2141-
tools_array = []
2142-
tools_array.append(located_tooljson)
2143-
tool_json_formatting_instruction = f"\nThe user is asking you to use the style of this JSON object formatting to complete the parameters for the specific function named {specified_function} in the following format: " + json.dumps([{"id": "insert an id for the response", "type": "function", "function": {"name": f"{specified_function}", "arguments": {"first property key": "first property value", "second property key": "second property value"}}}], indent=0)
2144-
except Exception:
2145-
# In case of any issues, just revert back to no specified function
2146-
print("Tools parsing not valid - discarded")
2147-
pass
2148-
21492156
# handle structured outputs
21502157
respformat = genparams.get('response_format', None)
21512158
if respformat:
@@ -2191,9 +2198,11 @@ def transform_genparams(genparams, api_format):
21912198
messages_string += "\n(Attached Image)\n"
21922199
# If last message, add any tools calls after message content and before message end token if any
21932200
if message['role'] == "user" and message_index == len(messages_array):
2194-
# Check if user is passing a openai tools array, if so add to end of prompt before assistant prompt unless tool_choice has been set to None
2201+
# tools handling: Check if user is passing a openai tools array, if so add to end of prompt before assistant prompt unless tool_choice has been set to None
2202+
tools_array = genparams.get('tools', [])
2203+
chosen_tool = genparams.get('tool_choice', "auto")
2204+
# first handle auto mode, determine whether a tool is needed
21952205
if tools_array and len(tools_array) > 0 and chosen_tool is not None and chosen_tool!="none":
2196-
#if auto mode, determine whether a tool is needed
21972206
tools_string = json.dumps(tools_array, indent=0)
21982207
should_use_tools = True
21992208
user_end = assistant_message_start
@@ -2218,15 +2227,64 @@ def transform_genparams(genparams, api_format):
22182227
print(f"\nRelevant tool is listed: {temp_poll_result['text']} ({should_use_tools})")
22192228

22202229
if should_use_tools:
2221-
messages_string += tools_string
2222-
messages_string += tool_json_formatting_instruction
2230+
#first, try and extract a specific tool if selected
2231+
used_tool_json = extract_tool_info_from_tool_array(chosen_tool, tools_array)
2232+
if used_tool_json: #already found the tool we want, remove all others
2233+
pass
2234+
elif len(tools_array)==1:
2235+
used_tool_json = tools_array[0]
2236+
else: # we have to find the tool we want the old fashioned way
2237+
toolnames = extract_all_names_from_tool_array(tools_array)
2238+
if len(toolnames) == 1:
2239+
used_tool_json = extract_tool_info_from_tool_array(toolnames[0], tools_array)
2240+
else:
2241+
pollgrammar = ""
2242+
for name in toolnames:
2243+
pollgrammar += ("" if pollgrammar=="" else " | ")
2244+
pollgrammar += "\"" + name + "\""
2245+
pollgrammar = r'root ::= ' + pollgrammar
2246+
decide_tool_prompt = "Which of the listed tools should be used? Pick exactly one. (Reply directly with the selected tool's name):"
2247+
temp_poll = {
2248+
"prompt": f"{messages_string}\n\nTool List:\n{tools_string}\n\n{decide_tool_prompt}{user_end}",
2249+
"max_length":8,
2250+
"temperature":0.1,
2251+
"top_k":1,
2252+
"rep_pen":1,
2253+
"ban_eos_token":False,
2254+
"grammar":pollgrammar
2255+
}
2256+
temp_poll_result = generate(genparams=temp_poll)
2257+
if temp_poll_result:
2258+
raw = temp_poll_result['text'].lower()
2259+
for name in toolnames:
2260+
if name.lower() in raw:
2261+
used_tool_json = extract_tool_info_from_tool_array(name, tools_array)
2262+
if not args.quiet:
2263+
print(f"\nAttempting to use tool: {name}")
2264+
break
2265+
2266+
if used_tool_json:
2267+
toolparamjson = None
2268+
toolname = None
2269+
# Set temperature low automatically if function calling
2270+
genparams["temperature"] = 0.1
2271+
genparams["using_openai_tools"] = True
2272+
# Set grammar to llamacpp example grammar to force json response (see https://github.com/ggerganov/llama.cpp/blob/master/grammars/json_arr.gbnf)
2273+
genparams["grammar"] = jsongrammar
2274+
try:
2275+
toolname = used_tool_json.get('function').get('name')
2276+
toolparamjson = used_tool_json.get('function').get('parameters')
2277+
bettergrammarjson = {"type":"array","items":{"type":"object","properties":{"id":{"type":"string","enum":["call_001"]},"type":{"type":"string","enum":["function"]},"function":{"type":"object","properties":{"name":{"type":"string"},"arguments":{}},"required":["name","arguments"],"additionalProperties":False}},"required":["id","type","function"],"additionalProperties":False}}
2278+
bettergrammarjson["items"]["properties"]["function"]["properties"]["arguments"] = toolparamjson
2279+
decoded = convert_json_to_gbnf(bettergrammarjson)
2280+
if decoded:
2281+
genparams["grammar"] = decoded
2282+
except Exception:
2283+
pass
2284+
tool_json_formatting_instruction = f"\nPlease use the provided schema to fill the parameters to create a function call for {toolname}, in the following format: " + json.dumps([{"id": "call_001", "type": "function", "function": {"name": f"{toolname}", "arguments": {"first property key": "first property value", "second property key": "second property value"}}}], indent=0)
2285+
messages_string += f"\n\nJSON Schema:\n{used_tool_json}\n\n{tool_json_formatting_instruction}{user_end}"
22232286

2224-
# Set temperature low automatically if function calling
2225-
genparams["temperature"] = 0.1
2226-
genparams["using_openai_tools"] = True
22272287

2228-
# Set grammar to llamacpp example grammar to force json response (see https://github.com/ggerganov/llama.cpp/blob/master/grammars/json_arr.gbnf)
2229-
genparams["grammar"] = jsongrammar
22302288
if message['role'] == "system":
22312289
messages_string += system_message_end
22322290
elif message['role'] == "user":
@@ -2480,6 +2538,7 @@ def run_blocking(): # api format 1=basic,2=kai,3=oai,4=oai-chat
24802538
if tool_calls and len(tool_calls)>0:
24812539
for tc in tool_calls:
24822540
tcarg = tc.get("function",{}).get("arguments",None)
2541+
tc["id"] = f"call_{random.randint(10000, 99999)}"
24832542
if tcarg and not isinstance(tcarg, str):
24842543
tc["function"]["arguments"] = json.dumps(tcarg)
24852544
recvtxt = None

0 commit comments

Comments
 (0)