Skip to content

Commit a87c05f

Browse files
committed
move function call determination to separate method
1 parent cade9f4 commit a87c05f

File tree

1 file changed

+121
-113
lines changed

1 file changed

+121
-113
lines changed

koboldcpp.py

Lines changed: 121 additions & 113 deletions
Original file line numberDiff line numberDiff line change
@@ -2242,12 +2242,111 @@ def extract_all_names_from_tool_array(tools_array):
22422242
pass
22432243
return toolnames
22442244

2245+
#returns the found JSON of the correct tool to use, or None if no tool is suitable
2246+
def determine_tool_json_to_use(genparams, curr_ctx, assistant_message_start, is_followup_tool):
2247+
# tools handling: Check if user is passing a openai tools array, if so add to end of prompt before assistant prompt unless tool_choice has been set to None
2248+
tools_array = genparams.get('tools', [])
2249+
chosen_tool = genparams.get('tool_choice', "auto")
2250+
# first handle auto mode, determine whether a tool is needed
2251+
used_tool_json = None
2252+
if tools_array and len(tools_array) > 0 and chosen_tool is not None and chosen_tool!="none":
2253+
tools_string = json.dumps(tools_array, indent=0)
2254+
should_use_tools = True
2255+
if chosen_tool=="auto":
2256+
# if you want a different template, you can set 'custom_tools_prompt' in the chat completions adapter as follows
2257+
custom_tools_prompt = "Can the user query be answered by a listed tool above? (One word response: yes or no):"
2258+
if is_followup_tool:
2259+
custom_tools_prompt = "Can the user query be further answered by another listed tool above? (If response is already complete, reply NO) (One word response: yes or no):"
2260+
# note: message string already contains the instruct start tag!
2261+
pollgrammar = r'root ::= "yes" | "no" | "Yes" | "No" | "YES" | "NO"'
2262+
temp_poll = {
2263+
"prompt": f"{curr_ctx}\n\nTool List:\n{tools_string}\n\n{custom_tools_prompt}{assistant_message_start}",
2264+
"max_length":5,
2265+
"temperature":0.1,
2266+
"top_k":1,
2267+
"rep_pen":1,
2268+
"ban_eos_token":False,
2269+
"grammar":pollgrammar
2270+
}
2271+
temp_poll_result = generate(genparams=temp_poll)
2272+
if temp_poll_result and "yes" not in temp_poll_result['text'].lower():
2273+
should_use_tools = False
2274+
if not args.quiet:
2275+
print(f"\nRelevant tool is listed: {temp_poll_result['text']} ({should_use_tools})")
2276+
2277+
if should_use_tools:
2278+
#first, try and extract a specific tool if selected
2279+
used_tool_json = extract_tool_info_from_tool_array(chosen_tool, tools_array)
2280+
if used_tool_json: #already found the tool we want, remove all others
2281+
pass
2282+
elif len(tools_array)==1:
2283+
used_tool_json = tools_array[0]
2284+
else: # we have to find the tool we want the old fashioned way
2285+
toolnames = extract_all_names_from_tool_array(tools_array)
2286+
if len(toolnames) == 1:
2287+
used_tool_json = extract_tool_info_from_tool_array(toolnames[0], tools_array)
2288+
else:
2289+
pollgrammar = ""
2290+
for name in toolnames:
2291+
pollgrammar += ("" if pollgrammar=="" else " | ")
2292+
pollgrammar += "\"" + name + "\""
2293+
pollgrammar = r'root ::= ' + pollgrammar
2294+
decide_tool_prompt = "Which of the listed tools should be used next? Pick exactly one. (Reply directly with the selected tool's name):"
2295+
temp_poll = {
2296+
"prompt": f"{curr_ctx}\n\nTool List:\n{tools_string}\n\n{decide_tool_prompt}{assistant_message_start}",
2297+
"max_length":16,
2298+
"temperature":0.1,
2299+
"top_k":1,
2300+
"rep_pen":1,
2301+
"ban_eos_token":False,
2302+
"grammar":pollgrammar
2303+
}
2304+
temp_poll_result = generate(genparams=temp_poll)
2305+
if temp_poll_result:
2306+
raw = temp_poll_result['text'].lower()
2307+
for name in toolnames:
2308+
if name.lower() in raw:
2309+
used_tool_json = extract_tool_info_from_tool_array(name, tools_array)
2310+
if not args.quiet:
2311+
print(f"\nAttempting to use tool: {name}")
2312+
break
2313+
2314+
return used_tool_json
2315+
2316+
22452317
def transform_genparams(genparams, api_format):
22462318
global chatcompl_adapter, maxctx
22472319

22482320
if api_format < 0: #not text gen, do nothing
22492321
return
22502322

2323+
jsongrammar = r"""
2324+
root ::= arr
2325+
value ::= object | array | string | number | ("true" | "false" | "null") ws
2326+
arr ::=
2327+
"[\n" ws (
2328+
value
2329+
(",\n" ws value)*
2330+
)? "]"
2331+
object ::=
2332+
"{" ws (
2333+
string ":" ws value
2334+
("," ws string ":" ws value)*
2335+
)? "}" ws
2336+
array ::=
2337+
"[" ws (
2338+
value
2339+
("," ws value)*
2340+
)? "]" ws
2341+
string ::=
2342+
"\"" (
2343+
[^"\\\x7F\x00-\x1F] |
2344+
"\\" (["\\bfnrt] | "u" [0-9a-fA-F]{4})
2345+
)* "\"" ws
2346+
number ::= ("-"? ([0-9] | [1-9] [0-9]{0,15})) ("." [0-9]+)? ([eE] [-+]? [1-9] [0-9]{0,15})? ws
2347+
ws ::= | " " | "\n" [ \t]{0,20}
2348+
"""
2349+
22512350
#api format 1=basic,2=kai,3=oai,4=oai-chat,5=interrogate,6=ollama,7=ollamachat
22522351
#alias all nonstandard alternative names for rep pen.
22532352
rp1 = float(genparams.get('repeat_penalty', 1.0))
@@ -2297,32 +2396,6 @@ def transform_genparams(genparams, api_format):
22972396
tools_message_end = adapter_obj.get("tools_end", "")
22982397
images_added = []
22992398
audio_added = []
2300-
jsongrammar = r"""
2301-
root ::= arr
2302-
value ::= object | array | string | number | ("true" | "false" | "null") ws
2303-
arr ::=
2304-
"[\n" ws (
2305-
value
2306-
(",\n" ws value)*
2307-
)? "]"
2308-
object ::=
2309-
"{" ws (
2310-
string ":" ws value
2311-
("," ws string ":" ws value)*
2312-
)? "}" ws
2313-
array ::=
2314-
"[" ws (
2315-
value
2316-
("," ws value)*
2317-
)? "]" ws
2318-
string ::=
2319-
"\"" (
2320-
[^"\\\x7F\x00-\x1F] |
2321-
"\\" (["\\bfnrt] | "u" [0-9a-fA-F]{4})
2322-
)* "\"" ws
2323-
number ::= ("-"? ([0-9] | [1-9] [0-9]{0,15})) ("." [0-9]+)? ([eE] [-+]? [1-9] [0-9]{0,15})? ws
2324-
ws ::= | " " | "\n" [ \t]{0,20}
2325-
"""
23262399

23272400
# handle structured outputs
23282401
respformat = genparams.get('response_format', None)
@@ -2398,93 +2471,28 @@ def transform_genparams(genparams, api_format):
23982471
messages_string += f"\n(Attached Audio {attachedaudid})\n"
23992472
# If last message, add any tools calls after message content and before message end token if any
24002473
if (message['role'] == "user" or message['role'] == "tool") and message_index == len(messages_array):
2401-
# tools handling: Check if user is passing a openai tools array, if so add to end of prompt before assistant prompt unless tool_choice has been set to None
2402-
tools_array = genparams.get('tools', [])
2403-
chosen_tool = genparams.get('tool_choice', "auto")
2404-
# first handle auto mode, determine whether a tool is needed
2405-
if tools_array and len(tools_array) > 0 and chosen_tool is not None and chosen_tool!="none":
2406-
tools_string = json.dumps(tools_array, indent=0)
2407-
should_use_tools = True
2408-
user_end = assistant_message_start
2409-
if chosen_tool=="auto":
2410-
# if you want a different template, you can set 'custom_tools_prompt' in the chat completions adapter as follows
2411-
custom_tools_prompt = adapter_obj.get("custom_tools_prompt", "Can the user query be answered by a listed tool above? (One word response: yes or no):")
2412-
if message['role'] == "tool":
2413-
custom_tools_prompt = adapter_obj.get("custom_tools_prompt", "Can the user query be further answered by another listed tool above? (If response is already complete, reply NO) (One word response: yes or no):")
2414-
# note: message string already contains the instruct start tag!
2415-
pollgrammar = r'root ::= "yes" | "no" | "Yes" | "No" | "YES" | "NO"'
2416-
temp_poll = {
2417-
"prompt": f"{messages_string}\n\nTool List:\n{tools_string}\n\n{custom_tools_prompt}{user_end}",
2418-
"max_length":5,
2419-
"temperature":0.1,
2420-
"top_k":1,
2421-
"rep_pen":1,
2422-
"ban_eos_token":False,
2423-
"grammar":pollgrammar
2424-
}
2425-
temp_poll_result = generate(genparams=temp_poll)
2426-
if temp_poll_result and "yes" not in temp_poll_result['text'].lower():
2427-
should_use_tools = False
2428-
if not args.quiet:
2429-
print(f"\nRelevant tool is listed: {temp_poll_result['text']} ({should_use_tools})")
2430-
2431-
if should_use_tools:
2432-
#first, try and extract a specific tool if selected
2433-
used_tool_json = extract_tool_info_from_tool_array(chosen_tool, tools_array)
2434-
if used_tool_json: #already found the tool we want, remove all others
2435-
pass
2436-
elif len(tools_array)==1:
2437-
used_tool_json = tools_array[0]
2438-
else: # we have to find the tool we want the old fashioned way
2439-
toolnames = extract_all_names_from_tool_array(tools_array)
2440-
if len(toolnames) == 1:
2441-
used_tool_json = extract_tool_info_from_tool_array(toolnames[0], tools_array)
2442-
else:
2443-
pollgrammar = ""
2444-
for name in toolnames:
2445-
pollgrammar += ("" if pollgrammar=="" else " | ")
2446-
pollgrammar += "\"" + name + "\""
2447-
pollgrammar = r'root ::= ' + pollgrammar
2448-
decide_tool_prompt = "Which of the listed tools should be used next? Pick exactly one. (Reply directly with the selected tool's name):"
2449-
temp_poll = {
2450-
"prompt": f"{messages_string}\n\nTool List:\n{tools_string}\n\n{decide_tool_prompt}{user_end}",
2451-
"max_length":16,
2452-
"temperature":0.1,
2453-
"top_k":1,
2454-
"rep_pen":1,
2455-
"ban_eos_token":False,
2456-
"grammar":pollgrammar
2457-
}
2458-
temp_poll_result = generate(genparams=temp_poll)
2459-
if temp_poll_result:
2460-
raw = temp_poll_result['text'].lower()
2461-
for name in toolnames:
2462-
if name.lower() in raw:
2463-
used_tool_json = extract_tool_info_from_tool_array(name, tools_array)
2464-
if not args.quiet:
2465-
print(f"\nAttempting to use tool: {name}")
2466-
break
2467-
2468-
if used_tool_json:
2469-
toolparamjson = None
2470-
toolname = None
2471-
# Set temperature lower automatically if function calling, cannot exceed 0.5
2472-
genparams["temperature"] = (1.0 if genparams.get("temperature", 0.5) > 1.0 else genparams.get("temperature", 0.5))
2473-
genparams["using_openai_tools"] = True
2474-
# Set grammar to llamacpp example grammar to force json response (see https://github.com/ggerganov/llama.cpp/blob/master/grammars/json_arr.gbnf)
2475-
genparams["grammar"] = jsongrammar
2476-
try:
2477-
toolname = used_tool_json.get('function').get('name')
2478-
toolparamjson = used_tool_json.get('function').get('parameters')
2479-
bettergrammarjson = {"type":"array","items":{"type":"object","properties":{"id":{"type":"string","enum":["call_001"]},"type":{"type":"string","enum":["function"]},"function":{"type":"object","properties":{"name":{"type":"string"},"arguments":{}},"required":["name","arguments"],"additionalProperties":False}},"required":["id","type","function"],"additionalProperties":False}}
2480-
bettergrammarjson["items"]["properties"]["function"]["properties"]["arguments"] = toolparamjson
2481-
decoded = convert_json_to_gbnf(bettergrammarjson)
2482-
if decoded:
2483-
genparams["grammar"] = decoded
2484-
except Exception:
2485-
pass
2486-
tool_json_formatting_instruction = f"\nPlease use the provided schema to fill the parameters to create a function call for {toolname}, in the following format: " + json.dumps([{"id": "call_001", "type": "function", "function": {"name": f"{toolname}", "arguments": {"first property key": "first property value", "second property key": "second property value"}}}], indent=0)
2487-
messages_string += f"\n\nJSON Schema:\n{used_tool_json}\n\n{tool_json_formatting_instruction}{user_end}"
2474+
used_tool_json = determine_tool_json_to_use(genparams, messages_string, assistant_message_start, (message['role'] == "tool"))
2475+
2476+
if used_tool_json:
2477+
toolparamjson = None
2478+
toolname = None
2479+
# Set temperature lower automatically if function calling, cannot exceed 0.5
2480+
genparams["temperature"] = (1.0 if genparams.get("temperature", 0.5) > 1.0 else genparams.get("temperature", 0.5))
2481+
genparams["using_openai_tools"] = True
2482+
# Set grammar to llamacpp example grammar to force json response (see https://github.com/ggerganov/llama.cpp/blob/master/grammars/json_arr.gbnf)
2483+
genparams["grammar"] = jsongrammar
2484+
try:
2485+
toolname = used_tool_json.get('function').get('name')
2486+
toolparamjson = used_tool_json.get('function').get('parameters')
2487+
bettergrammarjson = {"type":"array","items":{"type":"object","properties":{"id":{"type":"string","enum":["call_001"]},"type":{"type":"string","enum":["function"]},"function":{"type":"object","properties":{"name":{"type":"string"},"arguments":{}},"required":["name","arguments"],"additionalProperties":False}},"required":["id","type","function"],"additionalProperties":False}}
2488+
bettergrammarjson["items"]["properties"]["function"]["properties"]["arguments"] = toolparamjson
2489+
decoded = convert_json_to_gbnf(bettergrammarjson)
2490+
if decoded:
2491+
genparams["grammar"] = decoded
2492+
except Exception:
2493+
pass
2494+
tool_json_formatting_instruction = f"\nPlease use the provided schema to fill the parameters to create a function call for {toolname}, in the following format: " + json.dumps([{"id": "call_001", "type": "function", "function": {"name": f"{toolname}", "arguments": {"first property key": "first property value", "second property key": "second property value"}}}], indent=0)
2495+
messages_string += f"\n\nJSON Schema:\n{used_tool_json}\n\n{tool_json_formatting_instruction}{assistant_message_start}"
24882496

24892497

24902498
if message['role'] == "system":

0 commit comments

Comments
 (0)