Skip to content

Commit b7d0864

Browse files
committed
checkpoint
1 parent 6f1af73 commit b7d0864

File tree

1 file changed

+88
-22
lines changed

1 file changed

+88
-22
lines changed

llama_cpp/llama_chat_template.py

Lines changed: 88 additions & 22 deletions
Original file line numberDiff line numberDiff line change
@@ -423,48 +423,114 @@ def _handle_streaming_tool_calls(
423423
accumulated_text += text
424424
stop_reason = chunk["choices"][0]["finish_reason"]
425425
if stop_reason == "stop:<tool_call>":
426+
accumulated_text += "<tool_call>"
426427
# Found tool call, switch to grammar mode
427428
print("[DEBUG TOOLS] Found tool call, switching to grammar mode", file=sys.stderr)
428-
429-
# Use the helper function to detect tool and get tool name
429+
430+
# First generate the tool name with grammar
431+
function_names = " | ".join([f'''"functions.{t["function"]["name"]}"''' for t in tools]) if tools else ""
432+
tool_call_gbnf = (
433+
'root ::= functions\n' # We already have <tool_call>, just need the function name
434+
f"functions ::= {function_names}\n"
435+
)
436+
430437
try:
431-
tool, tool_name_completion_text = _detect_tool_call_and_get_tool(
432-
accumulated_text, tools, prompt, llama, llama.verbose
438+
# Generate the tool call with grammar
439+
name_grammar = llama_grammar.LlamaGrammar.from_string(
440+
tool_call_gbnf,
441+
verbose=llama.verbose
433442
)
434-
443+
444+
# Generate the tool name (non-streaming for simpler handling)
445+
name_completion = llama.create_completion(
446+
prompt=prompt + llama.tokenize((accumulated_text + "\n").encode("utf-8"), add_bos=False, special=True),
447+
grammar=name_grammar,
448+
stream=False,
449+
stop=[":"], # Stop at the colon after function name
450+
**{k: v for k, v in base_completion_kwargs.items() if k != "stream" and k != "grammar"}
451+
)
452+
name_text = name_completion["choices"][0]["text"]
453+
# Convert to chat completion chunk and yield
454+
yield {
455+
"id": "chat" + name_completion["id"],
456+
"object": "chat.completion.chunk",
457+
"created": name_completion["created"],
458+
"model": name_completion["model"],
459+
"choices": [{
460+
"index": 0,
461+
"delta": {
462+
"role": "assistant",
463+
"content": None,
464+
"tool_calls": [{
465+
"index": 0,
466+
"id": "call_0_" + name_text.split(".")[-1] + "_" + name_completion["id"],
467+
"type": "function",
468+
"function": {
469+
"name": name_text.split(".")[-1],
470+
"arguments": ""
471+
}
472+
}]
473+
},
474+
"finish_reason": None
475+
}]
476+
}
477+
accumulated_text += name_text + ":" # Add the colon back since we stopped before it
478+
479+
# Get the selected tool from the accumulated text
480+
tool_name = accumulated_text.split("\n")[-1].split("functions.")[-1].split(":")[0]
481+
tool = next((t for t in tools if t["function"]["name"] == tool_name), None)
482+
435483
if tool:
436484
# Get tool parameters grammar
437485
tool_grammar = _grammar_for_tool_parameters(tool, verbose=llama.verbose)
438-
439-
# Continue generation with tool grammar
440-
new_prompt = prompt + accumulated_text + tool_name_completion_text + "\n"
441-
for tool_chunk in llama.create_completion(
442-
prompt=new_prompt,
486+
487+
# Stream the tool parameters
488+
for param_chunk in llama.create_completion(
489+
prompt=prompt + llama.tokenize((accumulated_text + "\n").encode("utf-8"), add_bos=False, special=True),
443490
grammar=tool_grammar,
444491
stream=True,
445492
stop=["</tool_call>"],
446-
**base_completion_kwargs
493+
**{k: v for k, v in base_completion_kwargs.items() if k != "stream" and k != "grammar"}
447494
):
448-
yield from _convert_text_completion_chunks_to_chat(iter([tool_chunk]))
449-
495+
# Convert to chat completion chunk and yield
496+
yield {
497+
"id": "chat" + param_chunk["id"],
498+
"object": "chat.completion.chunk",
499+
"created": param_chunk["created"],
500+
"model": param_chunk["model"],
501+
"choices": [{
502+
"index": 0,
503+
"delta": {
504+
"tool_calls": [{
505+
"index": 0,
506+
"function": {
507+
"arguments": param_chunk["choices"][0]["text"]
508+
}
509+
}]
510+
},
511+
"finish_reason": None
512+
}]
513+
}
514+
accumulated_text += param_chunk["choices"][0]["text"]
515+
450516
# After tool call, continue normal streaming
451-
for remaining_chunk in llama.create_completion(
452-
prompt=new_prompt,
517+
for chunk in llama.create_completion(
518+
prompt=prompt + accumulated_text,
453519
stream=True,
454-
**base_completion_kwargs
520+
**{k: v for k, v in base_completion_kwargs.items() if k != "stream"}
455521
):
456-
yield from _convert_text_completion_chunks_to_chat(iter([remaining_chunk]))
457-
522+
yield from _convert_text_completion_chunks_to_chat(iter([chunk]))
523+
458524
except Exception as e:
459525
if llama.verbose:
460526
print(f"[DEBUG] Failed to stream tool call: {e}", file=sys.stderr)
461527
# Fall back to regular streaming without grammar
462-
for fallback_chunk in llama.create_completion(
463-
prompt=prompt + accumulated_text,
528+
for chunk in llama.create_completion(
529+
prompt=prompt + llama.tokenize(accumulated_text.encode("utf-8"), add_bos=False, special=True),
464530
stream=True,
465-
**base_completion_kwargs
531+
**{k: v for k, v in base_completion_kwargs.items() if k != "stream"}
466532
):
467-
yield from _convert_text_completion_chunks_to_chat(iter([fallback_chunk]))
533+
yield from _convert_text_completion_chunks_to_chat(iter([chunk]))
468534
else:
469535
# Keep streaming normally until we find a tool call
470536
yield from _convert_text_completion_chunks_to_chat(iter([chunk]))

0 commit comments

Comments
 (0)