Skip to content

Commit 3e8b84b

Browse files
committed
added support for structured output in chat completions
1 parent e8b3aea commit 3e8b84b

File tree

1 file changed

+57
-28
lines changed

1 file changed

+57
-28
lines changed

koboldcpp.py

Lines changed: 57 additions & 28 deletions
Original file line numberDiff line numberDiff line change
@@ -1267,6 +1267,14 @@ def generate(genparams, stream_flag=False):
12671267
ban_eos_token = genparams.get('ban_eos_token', False)
12681268
stream_sse = stream_flag
12691269
grammar = genparams.get('grammar', '')
1270+
#translate grammar if its json
1271+
try:
1272+
grammarjson = json.loads(grammar)
1273+
decoded = convert_json_to_gbnf(grammarjson)
1274+
if decoded:
1275+
grammar = decoded
1276+
except Exception:
1277+
pass
12701278
grammar_retain_state = genparams.get('grammar_retain_state', False)
12711279
genkey = genparams.get('genkey', '')
12721280
trimstop = genparams.get('trim_stop', True)
@@ -2051,6 +2059,32 @@ def transform_genparams(genparams, api_format):
20512059
tools_message_start = adapter_obj.get("tools_start", "")
20522060
tools_message_end = adapter_obj.get("tools_end", "")
20532061
images_added = []
2062+
jsongrammar = r"""
2063+
root ::= arr
2064+
value ::= object | array | string | number | ("true" | "false" | "null") ws
2065+
arr ::=
2066+
"[\n" ws (
2067+
value
2068+
(",\n" ws value)*
2069+
)? "]"
2070+
object ::=
2071+
"{" ws (
2072+
string ":" ws value
2073+
("," ws string ":" ws value)*
2074+
)? "}" ws
2075+
array ::=
2076+
"[" ws (
2077+
value
2078+
("," ws value)*
2079+
)? "]" ws
2080+
string ::=
2081+
"\"" (
2082+
[^"\\\x7F\x00-\x1F] |
2083+
"\\" (["\\bfnrt] | "u" [0-9a-fA-F]{4})
2084+
)* "\"" ws
2085+
number ::= ("-"? ([0-9] | [1-9] [0-9]{0,15})) ("." [0-9]+)? ([eE] [-+]? [1-9] [0-9]{0,15})? ws
2086+
ws ::= | " " | "\n" [ \t]{0,20}
2087+
"""
20542088

20552089
# tools handling
20562090
tools_array = genparams.get('tools', [])
@@ -2075,6 +2109,24 @@ def transform_genparams(genparams, api_format):
20752109
tool_json_formatting_instruction = f"\nThe user is asking you to use the style of this JSON object formatting to complete the parameters for the specific function named {specified_function} in the following format: " + json.dumps([{"id": "insert an id for the response", "type": "function", "function": {"name": f"{specified_function}", "arguments": {"first property key": "first property value", "second property key": "second property value"}}}], indent=0)
20762110
except Exception:
20772111
# In case of any issues, just revert back to no specified function
2112+
print("Tools parsing not valid - discarded")
2113+
pass
2114+
2115+
# handle structured outputs
2116+
respformat = genparams.get('response_format', None)
2117+
if respformat:
2118+
try:
2119+
rt = respformat.get('type')
2120+
if rt.lower() == "json_schema":
2121+
schema = respformat.get('json_schema').get('schema')
2122+
decoded = convert_json_to_gbnf(schema)
2123+
if decoded:
2124+
genparams["grammar"] = decoded
2125+
elif rt.lower() == "json_object":
2126+
genparams["grammar"] = jsongrammar
2127+
except Exception:
2128+
# In case of any issues, just do normal gen
2129+
print("Structured Output not valid - discarded")
20782130
pass
20792131

20802132
message_index = 0
@@ -2115,13 +2167,15 @@ def transform_genparams(genparams, api_format):
21152167
# if you want a different template, you can set 'custom_tools_prompt' in the chat completions adapter as follows
21162168
custom_tools_prompt = adapter_obj.get("custom_tools_prompt", "Can the user query be answered by a listed tool? (One word response: yes or no):")
21172169
# note: message string already contains the instruct start tag!
2170+
pollgrammar = r'root ::= "yes" | "no" | "Yes" | "No" | "YES" | "NO"'
21182171
temp_poll = {
21192172
"prompt": f"{messages_string}\n\nTool List:\n{tools_string}\n\n{custom_tools_prompt}{user_end}",
2120-
"max_length":6,
2173+
"max_length":4,
21212174
"temperature":0.1,
21222175
"top_k":1,
21232176
"rep_pen":1,
2124-
"ban_eos_token":False
2177+
"ban_eos_token":False,
2178+
"grammar":pollgrammar
21252179
}
21262180
temp_poll_result = generate(genparams=temp_poll)
21272181
if temp_poll_result and "yes" not in temp_poll_result['text'].lower():
@@ -2138,32 +2192,7 @@ def transform_genparams(genparams, api_format):
21382192
genparams["using_openai_tools"] = True
21392193

21402194
# Set grammar to llamacpp example grammar to force json response (see https://github.com/ggerganov/llama.cpp/blob/master/grammars/json_arr.gbnf)
2141-
genparams["grammar"] = r"""
2142-
root ::= arr
2143-
value ::= object | array | string | number | ("true" | "false" | "null") ws
2144-
arr ::=
2145-
"[\n" ws (
2146-
value
2147-
(",\n" ws value)*
2148-
)? "]"
2149-
object ::=
2150-
"{" ws (
2151-
string ":" ws value
2152-
("," ws string ":" ws value)*
2153-
)? "}" ws
2154-
array ::=
2155-
"[" ws (
2156-
value
2157-
("," ws value)*
2158-
)? "]" ws
2159-
string ::=
2160-
"\"" (
2161-
[^"\\\x7F\x00-\x1F] |
2162-
"\\" (["\\bfnrt] | "u" [0-9a-fA-F]{4})
2163-
)* "\"" ws
2164-
number ::= ("-"? ([0-9] | [1-9] [0-9]{0,15})) ("." [0-9]+)? ([eE] [-+]? [1-9] [0-9]{0,15})? ws
2165-
ws ::= | " " | "\n" [ \t]{0,20}
2166-
"""
2195+
genparams["grammar"] = jsongrammar
21672196
if message['role'] == "system":
21682197
messages_string += system_message_end
21692198
elif message['role'] == "user":

0 commit comments

Comments
 (0)