Skip to content

Commit 95291a9

Browse files
authored
rosie fixes: add format normalization for tools and tool call streaming fixes (LostRuins#1842)
1 parent 5125c0b commit 95291a9

File tree

1 file changed

+114
-10
lines changed

1 file changed

+114
-10
lines changed

koboldcpp.py

Lines changed: 114 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -2362,7 +2362,7 @@ def is_ipv6_supported():
23622362
except Exception:
23632363
return False
23642364

2365-
def format_jinja(messages,tools):
2365+
def format_jinja(messages, tools):
23662366
try:
23672367
def strftime_now(format='%Y-%m-%d %H:%M:%S'):
23682368
return datetime.now().strftime(format)
@@ -2374,7 +2374,11 @@ def tojson(x, ensure_ascii=False, indent=None, separators=None, sort_keys=False)
23742374
jinja_env.globals['strftime_now'] = strftime_now
23752375
jinja_env.filters["tojson"] = tojson
23762376
jinja_compiled_template = jinja_env.from_string(cached_chat_template)
2377-
text = jinja_compiled_template.render(messages=messages, tools=tools, add_generation_prompt=True, bos_token="", eos_token="")
2377+
text = None
2378+
if tools and len(tools)>0:
2379+
text = jinja_compiled_template.render(messages=messages, tools=tools, add_generation_prompt=True, bos_token="", eos_token="")
2380+
else:
2381+
text = jinja_compiled_template.render(messages=messages, add_generation_prompt=True, bos_token="", eos_token="")
23782382
return text if text else None
23792383
except Exception as e:
23802384
print(f"Jinja formatting failed: {e}")
@@ -2392,6 +2396,31 @@ def remove_outer_tags(inputstr):
23922396
return stripped # If no match, return original string
23932397
except Exception:
23942398
return stripped
2399+
2400+
def normalize_tool_call(obj): # Normalize various tool call formats to OpenAI format
2401+
if "type" in obj and "function" in obj: # Already in OpenAI format
2402+
return obj
2403+
if "name" in obj and ("arguments" in obj or "parameters" in obj):
2404+
args = obj.get("arguments", obj.get("parameters", {}))
2405+
return {
2406+
"type": "function",
2407+
"function": {
2408+
"name": obj["name"],
2409+
"arguments": args
2410+
}
2411+
}
2412+
if "function" in obj and isinstance(obj["function"], dict):
2413+
func = obj["function"]
2414+
if "name" in func:
2415+
return {
2416+
"type": "function",
2417+
"function": {
2418+
"name": func["name"],
2419+
"arguments": func.get("arguments", func.get("parameters", {}))
2420+
}
2421+
}
2422+
2423+
return obj
23952424

23962425
# Used to parse json for openai tool calls
23972426
def extract_json_from_string(input_string):
@@ -3059,6 +3088,7 @@ def run_blocking(): # api format 1=basic,2=kai,3=oai,4=oai-chat
30593088
if using_openai_tools:
30603089
tool_calls = extract_json_from_string(recvtxt)
30613090
if tool_calls and len(tool_calls)>0:
3091+
tool_calls = [normalize_tool_call(obj) for obj in tool_calls]
30623092
for tc in tool_calls:
30633093
tcarg = tc.get("function",{}).get("arguments",None)
30643094
tc["id"] = f"call_{random.randint(10000, 99999)}"
@@ -3094,8 +3124,8 @@ def run_blocking(): # api format 1=basic,2=kai,3=oai,4=oai-chat
30943124
print(f"Generate: Error while generating: {e}")
30953125

30963126
async def send_oai_sse_event(self, data):
3097-
if data=="[DONE]":
3098-
self.wfile.write(f'data: {data}'.encode())
3127+
if data and data.strip()=="[DONE]":
3128+
self.wfile.write(f'data: {data.strip()}\n\n'.encode())
30993129
else:
31003130
self.wfile.write(f'data: {data}\n\n'.encode())
31013131
self.wfile.flush()
@@ -4346,18 +4376,92 @@ def do_POST(self):
43464376
self.send_header("cache-control", "no-cache")
43474377
self.send_header("connection", "keep-alive")
43484378
self.end_headers(content_type='text/event-stream')
4379+
4380+
content_text = None
43494381
toolsdata_res = []
43504382
try:
43514383
toolsdata_res = gendat['choices'][0]['message']['tool_calls']
43524384
if toolsdata_res and len(toolsdata_res)>0:
4353-
toolsdata_res[0]["index"] = 0 # need to add an index for OWUI
4385+
toolsdata_res[0]["index"] = 0 # need to add an index for OWUI
43544386
except Exception:
43554387
toolsdata_res = []
4356-
toolsdata_p1 = json.dumps({"id":"koboldcpp","object":"chat.completion.chunk","created":int(time.time()),"model":friendlymodelname,"choices":[{"index":0,"finish_reason":None,"delta":{'role':'assistant','content':None, "tool_calls":toolsdata_res}}]})
4357-
toolsdata_p2 = json.dumps({"id":"koboldcpp","object":"chat.completion.chunk","created":int(time.time()),"model":friendlymodelname,"choices":[{"index":0,"finish_reason":"tool_calls","delta":{}}]})
4358-
self.wfile.write(f'data: {toolsdata_p1}\n\n'.encode())
4359-
self.wfile.write(f'data: {toolsdata_p2}\n\n'.encode())
4360-
self.wfile.write('data: [DONE]'.encode())
4388+
try:
4389+
content_text = gendat['choices'][0]['message'].get('content', None)
4390+
except Exception:
4391+
content_text = None
4392+
4393+
# Send role chunk first
4394+
chunk_role = json.dumps({
4395+
"id": "koboldcpp",
4396+
"object": "chat.completion.chunk",
4397+
"created": int(time.time()),
4398+
"model": friendlymodelname,
4399+
"choices": [{"index": 0, "finish_reason": None, "delta": {"role": "assistant"}}]
4400+
})
4401+
self.wfile.write(f"data: {chunk_role}\n\n".encode())
4402+
self.wfile.flush()
4403+
4404+
# Send content if present
4405+
if content_text:
4406+
chunk_content = json.dumps({
4407+
"id": "koboldcpp",
4408+
"object": "chat.completion.chunk",
4409+
"created": int(time.time()),
4410+
"model": friendlymodelname,
4411+
"choices": [{"index": 0, "finish_reason": None, "delta": {"content": content_text}}]
4412+
})
4413+
self.wfile.write(f"data: {chunk_content}\n\n".encode())
4414+
self.wfile.flush()
4415+
4416+
# Send tool calls incrementally in OpenAI format
4417+
if toolsdata_res and len(toolsdata_res) > 0:
4418+
for idx, tool_call in enumerate(toolsdata_res):
4419+
tc_meta = {
4420+
"index": idx,
4421+
"id": tool_call.get("id", f"call_{idx}"),
4422+
"type": "function",
4423+
"function": {
4424+
"name": tool_call.get("function", {}).get("name", ""),
4425+
"arguments": ""
4426+
}
4427+
}
4428+
chunk_meta = json.dumps({
4429+
"id": "koboldcpp",
4430+
"object": "chat.completion.chunk",
4431+
"created": int(time.time()),
4432+
"model": friendlymodelname,
4433+
"choices": [{"index": 0, "finish_reason": None, "delta": {"tool_calls": [tc_meta]}}]
4434+
})
4435+
self.wfile.write(f"data: {chunk_meta}\n\n".encode())
4436+
self.wfile.flush()
4437+
4438+
args_str = tool_call.get("function", {}).get("arguments", "{}")
4439+
if isinstance(args_str, dict):
4440+
args_str = json.dumps(args_str)
4441+
tc_args = {
4442+
"index": idx,
4443+
"function": {"arguments": args_str}
4444+
}
4445+
chunk_args = json.dumps({
4446+
"id": "koboldcpp",
4447+
"object": "chat.completion.chunk",
4448+
"created": int(time.time()),
4449+
"model": friendlymodelname,
4450+
"choices": [{"index": 0, "finish_reason": None, "delta": {"tool_calls": [tc_args]}}]
4451+
})
4452+
self.wfile.write(f"data: {chunk_args}\n\n".encode())
4453+
self.wfile.flush()
4454+
4455+
# Final chunk
4456+
chunk_final = json.dumps({
4457+
"id": "koboldcpp",
4458+
"object": "chat.completion.chunk",
4459+
"created": int(time.time()),
4460+
"model": friendlymodelname,
4461+
"choices": [{"index": 0, "finish_reason": "tool_calls", "delta": {}}]
4462+
})
4463+
self.wfile.write(f"data: {chunk_final}\n\n".encode())
4464+
self.wfile.write("data: [DONE]\n\n".encode())
43614465
self.wfile.flush()
43624466
self.close_connection = True
43634467
except Exception as ex:

0 commit comments

Comments
 (0)