Skip to content

Commit 6fb8d6b

Browse files
committed
checkpoint
1 parent b7d0864 commit 6fb8d6b

File tree

1 file changed

+63
-25
lines changed

1 file changed

+63
-25
lines changed

llama_cpp/llama_chat_template.py

Lines changed: 63 additions & 25 deletions
Original file line numberDiff line numberDiff line change
@@ -422,15 +422,29 @@ def _handle_streaming_tool_calls(
422422
text = chunk["choices"][0]["text"]
423423
accumulated_text += text
424424
stop_reason = chunk["choices"][0]["finish_reason"]
425-
if stop_reason == "stop:<tool_call>":
426-
accumulated_text += "<tool_call>"
425+
426+
# Debug: Print accumulated text when we get a stop reason
427+
if stop_reason and llama.verbose:
428+
print(f"[DEBUG] Stop reason: {stop_reason}, Accumulated text: '{accumulated_text}'", file=sys.stderr)
429+
430+
# Check if we hit a tool call stop or if tool call is in the accumulated text
431+
# Also handle case where complete tool call is already generated
432+
if (stop_reason == "stop:<tool_call>" or
433+
(stop_reason and "<tool_call>" in accumulated_text and
434+
(stop_reason == "stop" or "</tool_call>" not in accumulated_text))):
435+
436+
if stop_reason == "stop:<tool_call>":
437+
accumulated_text += "<tool_call>"
438+
427439
# Found tool call, switch to grammar mode
428-
print("[DEBUG TOOLS] Found tool call, switching to grammar mode", file=sys.stderr)
440+
print(f"[DEBUG TOOLS] Found tool call, switching to grammar mode. Stop reason: {stop_reason}", file=sys.stderr)
441+
print(f"[DEBUG TOOLS] Accumulated text: '{accumulated_text}'", file=sys.stderr)
429442

430443
# First generate the tool name with grammar
431-
function_names = " | ".join([f'''"functions.{t["function"]["name"]}"''' for t in tools]) if tools else ""
444+
# Since we already have <tool_call>, we just need the function name and colon
445+
function_names = " | ".join([f'''"functions.{t["function"]["name"]}:"''' for t in tools]) if tools else ""
432446
tool_call_gbnf = (
433-
'root ::= functions\n' # We already have <tool_call>, just need the function name
447+
'root ::= "\\n" functions\n' # We already have <tool_call>, add newline then function name
434448
f"functions ::= {function_names}\n"
435449
)
436450

@@ -442,14 +456,17 @@ def _handle_streaming_tool_calls(
442456
)
443457

444458
# Generate the tool name (non-streaming for simpler handling)
459+
# Create a new prompt that includes the accumulated text
460+
combined_prompt = prompt + llama.tokenize((accumulated_text + "\n").encode("utf-8"), add_bos=False, special=True)
445461
name_completion = llama.create_completion(
446-
prompt=prompt + llama.tokenize((accumulated_text + "\n").encode("utf-8"), add_bos=False, special=True),
462+
prompt=combined_prompt,
447463
grammar=name_grammar,
448464
stream=False,
449-
stop=[":"], # Stop at the colon after function name
465+
stop=[], # Grammar will handle the format including colon
450466
**{k: v for k, v in base_completion_kwargs.items() if k != "stream" and k != "grammar"}
451467
)
452468
name_text = name_completion["choices"][0]["text"]
469+
tool_name = name_text.split(".")[-1].rstrip(":")
453470
# Convert to chat completion chunk and yield
454471
yield {
455472
"id": "chat" + name_completion["id"],
@@ -463,59 +480,79 @@ def _handle_streaming_tool_calls(
463480
"content": None,
464481
"tool_calls": [{
465482
"index": 0,
466-
"id": "call_0_" + name_text.split(".")[-1] + "_" + name_completion["id"],
483+
"id": "call_0_" + tool_name + "_" + name_completion["id"],
467484
"type": "function",
468485
"function": {
469-
"name": name_text.split(".")[-1],
486+
"name": tool_name,
470487
"arguments": ""
471488
}
472489
}]
473490
},
474491
"finish_reason": None
475492
}]
476493
}
477-
accumulated_text += name_text + ":" # Add the colon back since we stopped before it
478-
479-
# Get the selected tool from the accumulated text
480-
tool_name = accumulated_text.split("\n")[-1].split("functions.")[-1].split(":")[0]
494+
accumulated_text += name_text # name_text already includes the colon from grammar
495+
496+
# Get the selected tool from the name_text (remove newline, functions. prefix, and colon)
481497
tool = next((t for t in tools if t["function"]["name"] == tool_name), None)
482-
498+
499+
print(f"[DEBUG] Generated name_text: '{name_text}'", file=sys.stderr)
500+
print(f"[DEBUG] Extracted tool_name: '{tool_name}'", file=sys.stderr)
501+
print(f"[DEBUG] Found tool: {tool is not None}", file=sys.stderr)
502+
483503
if tool:
484504
# Get tool parameters grammar
485505
tool_grammar = _grammar_for_tool_parameters(tool, verbose=llama.verbose)
486-
506+
507+
print(f"[DEBUG] Starting parameter generation for tool: {tool['function']['name']}", file=sys.stderr)
508+
print(f"[DEBUG] Accumulated text: '{accumulated_text}'", file=sys.stderr)
509+
510+
# Create prompt for parameter generation (include function name and colon, then newline for JSON)
511+
param_prompt = prompt + llama.tokenize((accumulated_text + "\n").encode("utf-8"), add_bos=False, special=True)
512+
513+
param_text = ""
487514
# Stream the tool parameters
488515
for param_chunk in llama.create_completion(
489-
prompt=prompt + llama.tokenize((accumulated_text + "\n").encode("utf-8"), add_bos=False, special=True),
516+
prompt=param_prompt,
490517
grammar=tool_grammar,
491518
stream=True,
492-
stop=["</tool_call>"],
519+
stop=["}"],
493520
**{k: v for k, v in base_completion_kwargs.items() if k != "stream" and k != "grammar"}
494521
):
522+
param_text = param_chunk["choices"][0]["text"]
523+
if param_text:
524+
print(f"[DEBUG] Parameter chunk: '{param_text}'", file=sys.stderr)
525+
495526
# Convert to chat completion chunk and yield
496527
yield {
497-
"id": "chat" + param_chunk["id"],
528+
"id": "chat" + name_completion["id"],
498529
"object": "chat.completion.chunk",
499-
"created": param_chunk["created"],
500-
"model": param_chunk["model"],
530+
"created": name_completion["created"],
531+
"model": name_completion["model"],
501532
"choices": [{
502533
"index": 0,
503534
"delta": {
535+
"role": "assistant",
536+
"content": None,
504537
"tool_calls": [{
505538
"index": 0,
539+
"id": "call_0_" + tool_name + "_" + name_completion["id"],
540+
"type": "function",
506541
"function": {
507-
"arguments": param_chunk["choices"][0]["text"]
542+
"name": tool_name,
543+
"arguments": param_text
508544
}
509545
}]
510546
},
511547
"finish_reason": None
512548
}]
513549
}
514-
accumulated_text += param_chunk["choices"][0]["text"]
550+
accumulated_text += param_text
515551

516-
# After tool call, continue normal streaming
552+
# After tool call, continue normal streaming from where we left off
553+
continue_prompt = prompt + llama.tokenize(accumulated_text.encode("utf-8"), add_bos=False, special=True)
517554
for chunk in llama.create_completion(
518-
prompt=prompt + accumulated_text,
555+
prompt=continue_prompt,
519556
stream=True,
520557
**{k: v for k, v in base_completion_kwargs.items() if k != "stream"}
521558
):
@@ -525,8 +562,9 @@ def _handle_streaming_tool_calls(
525562
if llama.verbose:
526563
print(f"[DEBUG] Failed to stream tool call: {e}", file=sys.stderr)
527564
# Fall back to regular streaming without grammar
565+
fallback_prompt = prompt + llama.tokenize(accumulated_text.encode("utf-8"), add_bos=False, special=True)
528566
for chunk in llama.create_completion(
529-
prompt=prompt + llama.tokenize(accumulated_text.encode("utf-8"), add_bos=False, special=True),
567+
prompt=fallback_prompt,
530568
stream=True,
531569
**{k: v for k, v in base_completion_kwargs.items() if k != "stream"}
532570
):

0 commit comments

Comments
 (0)