Skip to content

Commit e98dd36

Browse files
committed
checkpoint
1 parent 6fb8d6b commit e98dd36

File tree

1 file changed

+47
-92
lines changed

1 file changed

+47
-92
lines changed

llama_cpp/llama_chat_template.py

Lines changed: 47 additions & 92 deletions
Original file line numberDiff line numberDiff line change
@@ -263,9 +263,6 @@ def _detect_tool_call_and_get_tool(
263263
if not response_text.rstrip().endswith("<tool_call>"):
264264
return None, None
265265

266-
if verbose:
267-
print("[DEBUG] Found tool call tag, switching to tool grammar", file=sys.stderr)
268-
269266
try:
270267
name_grammar = _grammar_for_tool_name(tools, verbose=verbose)
271268

@@ -283,14 +280,9 @@ def _detect_tool_call_and_get_tool(
283280
tool_name = tool_name_completion["choices"][0]["text"].split("\n")[-1][len("functions."):-1]
284281
tool = next((t for t in tools if t["function"]["name"] == tool_name), None)
285282

286-
if tool is None and verbose:
287-
print(f"[DEBUG] Tool {tool_name} not found", file=sys.stderr)
288-
289283
return tool, tool_name_completion["choices"][0]["text"]
290284

291285
except Exception as e:
292-
if verbose:
293-
print(f"[DEBUG] Failed to create tool call grammar: {e}", file=sys.stderr)
294286
return None, None
295287

296288

@@ -306,19 +298,13 @@ def _parse_auto_tool_choice_response(
306298
"""
307299
response_text = completion_result["choices"][0]["text"]
308300

309-
if verbose:
310-
print(f"[DEBUG] Auto tool choice triggered. Response text: {response_text[:200]}...", file=sys.stderr)
311-
print("[DEBUG] Looking for tool_call tags...", file=sys.stderr)
312-
313301
# Parse the response similar to how template handles <think></think> and <tool_call></tool_call>
314302
message_content = response_text
315303
tool_call_json = None
316304

317305
# Look for <tool_call> tags in the response
318306
tool_call_start_idx = response_text.find("<tool_call>")
319307
if tool_call_start_idx >= 0 and "</tool_call>" in response_text:
320-
if verbose:
321-
print("[DEBUG] Found tool_call tags, attempting to parse", file=sys.stderr)
322308
try:
323309
# Extract content between <tool_call> tags (like template does)
324310
tool_call_start = tool_call_start_idx + len("<tool_call>")
@@ -335,9 +321,6 @@ def _parse_auto_tool_choice_response(
335321
if tool_call_start >= 0 and tool_call_end > tool_call_start:
336322
# Switch to grammar strict mode for tool call content
337323
tool_call_content = response_text[tool_call_start:tool_call_end].strip()
338-
if verbose:
339-
print(f"[DEBUG] Extracted tool_call content: {tool_call_content}", file=sys.stderr)
340-
print("[DEBUG] Switching to grammar strict mode for tool call", file=sys.stderr)
341324

342325
parsed_json = json.loads(tool_call_content)
343326

@@ -346,30 +329,17 @@ def _parse_auto_tool_choice_response(
346329
"name" in parsed_json and
347330
"input" in parsed_json):
348331

349-
if verbose:
350-
print(f"[DEBUG] Valid tool call found: {parsed_json.get('name')}", file=sys.stderr)
351332
tool_call_json = parsed_json
352333
# Extract message content before the <tool_call> tag
353334
message_content = response_text[:response_text.find("<tool_call>")].strip()
354335

355336
except json.JSONDecodeError as e:
356337
# Not valid JSON, treat as pure message
357-
if verbose:
358-
print(f"[DEBUG] Tool call JSON parsing failed: {e}. JSON text: {tool_call_content[:200]}...", file=sys.stderr)
359338
pass
360339
else:
361-
if verbose:
362-
has_start = "<tool_call>" in response_text
363-
has_end = "</tool_call>" in response_text
364-
print(f"[DEBUG] Tool call tags not found. Has start tag: {has_start}, Has end tag: {has_end}", file=sys.stderr)
365-
366-
if verbose:
367-
print(f"[DEBUG] Final tool_call_json: {tool_call_json is not None}", file=sys.stderr)
368-
340+
pass
369341
# If we found a valid tool call, build the response
370342
if tool_call_json:
371-
if verbose:
372-
print(f"[DEBUG] Building tool call response for {tool_call_json['name']}", file=sys.stderr)
373343
tool_name = tool_call_json["name"]
374344
tool_input = tool_call_json["input"]
375345
tool_id = "call_0_" + tool_name + "_" + completion_result["id"]
@@ -405,40 +375,42 @@ def _parse_auto_tool_choice_response(
405375

406376

407377
def _handle_streaming_tool_calls(
408-
completion_or_chunks: Iterator[llama_types.CreateCompletionStreamResponse],
409378
tools: List[llama_types.ChatCompletionTool],
410379
prompt: str,
411380
llama: llama.Llama,
412381
base_completion_kwargs: Dict[str, Any],
382+
stopping_criteria: Optional[llama.StoppingCriteriaList] = None,
383+
grammar: Optional[llama.LlamaGrammar] = None,
384+
tool_call_index: int = 0,
413385
) -> Iterator[llama_types.ChatCompletionChunk]:
414386
"""Handle streaming completions with tool call detection and grammar switching.
415387
388+
Args:
389+
tool_call_index: Index for this tool call (for multiple tool calls)
390+
416391
Yields:
417392
Chat completion chunks as they become available
418393
"""
419-
accumulated_text = ""
394+
# Generate text until we hit a tool call or end
395+
completion_chunks = llama.create_completion(
396+
prompt=prompt,
397+
stream=True,
398+
stop=["<tool_call>"], # Stop at tool call if we find one
399+
stopping_criteria=stopping_criteria,
400+
grammar=grammar,
401+
**{k: v for k, v in base_completion_kwargs.items() if k != "stream"}
402+
)
420403

421-
for chunk in cast(Iterator[llama_types.CreateCompletionStreamResponse], completion_or_chunks):
404+
accumulated_text = ""
405+
for chunk in completion_chunks:
422406
text = chunk["choices"][0]["text"]
423407
accumulated_text += text
424408
stop_reason = chunk["choices"][0]["finish_reason"]
425409

426-
# Debug: Print accumulated text when we get a stop reason
427-
if stop_reason and llama.verbose:
428-
print(f"[DEBUG] Stop reason: {stop_reason}, Accumulated text: '{accumulated_text}'", file=sys.stderr)
429-
430-
# Check if we hit a tool call stop or if tool call is in the accumulated text
431-
# Also handle case where complete tool call is already generated
432-
if (stop_reason == "stop:<tool_call>" or
433-
(stop_reason and "<tool_call>" in accumulated_text and
434-
(stop_reason == "stop" or "</tool_call>" not in accumulated_text))):
410+
# Check if we hit a tool call
411+
if (stop_reason == "stop:<tool_call>"):
435412

436-
if stop_reason == "stop:<tool_call>":
437-
accumulated_text += "<tool_call>"
438-
439-
# Found tool call, switch to grammar mode
440-
print(f"[DEBUG TOOLS] Found tool call, switching to grammar mode. Stop reason: {stop_reason}", file=sys.stderr)
441-
print(f"[DEBUG TOOLS] Accumulated text: '{accumulated_text}'", file=sys.stderr)
413+
accumulated_text += "<tool_call>"
442414

443415
# First generate the tool name with grammar
444416
# Since we already have <tool_call>, we just need the function name and colon
@@ -479,8 +451,8 @@ def _handle_streaming_tool_calls(
479451
"role": "assistant",
480452
"content": None,
481453
"tool_calls": [{
482-
"index": 0,
483-
"id": "call_0_" + tool_name + "_" + name_completion["id"],
454+
"index": tool_call_index,
455+
"id": f"call_{tool_call_index}_{tool_name}_{name_completion['id']}",
484456
"type": "function",
485457
"function": {
486458
"name": tool_name,
@@ -495,18 +467,11 @@ def _handle_streaming_tool_calls(
495467

496468
# Get the selected tool from the name_text (remove newline, functions. prefix, and colon)
497469
tool = next((t for t in tools if t["function"]["name"] == tool_name), None)
498-
499-
print(f"[DEBUG] Generated name_text: '{name_text}'", file=sys.stderr)
500-
print(f"[DEBUG] Extracted tool_name: '{tool_name}'", file=sys.stderr)
501-
print(f"[DEBUG] Found tool: {tool is not None}", file=sys.stderr)
502-
470+
503471
if tool:
504472
# Get tool parameters grammar
505473
tool_grammar = _grammar_for_tool_parameters(tool, verbose=llama.verbose)
506-
507-
print(f"[DEBUG] Starting parameter generation for tool: {tool['function']['name']}", file=sys.stderr)
508-
print(f"[DEBUG] Accumulated text: '{accumulated_text}'", file=sys.stderr)
509-
474+
510475
# Create prompt for parameter generation (include function name and colon, then newline for JSON)
511476
param_prompt = prompt + llama.tokenize((accumulated_text + "\n").encode("utf-8"), add_bos=False, special=True)
512477

@@ -516,13 +481,10 @@ def _handle_streaming_tool_calls(
516481
prompt=param_prompt,
517482
grammar=tool_grammar,
518483
stream=True,
519-
stop=["}"],
484+
# stop=["}"],
520485
**{k: v for k, v in base_completion_kwargs.items() if k != "stream" and k != "grammar"}
521486
):
522487
param_text = param_chunk["choices"][0]["text"]
523-
if param_text:
524-
print(f"[DEBUG] Parameter chunk: '{param_text}'", file=sys.stderr)
525-
526488
# Convert to chat completion chunk and yield
527489
yield {
528490
"id": "chat" + name_completion["id"],
@@ -535,8 +497,8 @@ def _handle_streaming_tool_calls(
535497
"role": "assistant",
536498
"content": None,
537499
"tool_calls": [{
538-
"index": 0,
539-
"id": "call_0_" + tool_name + "_" + name_completion["id"],
500+
"index": tool_call_index,
501+
"id": f"call_{tool_call_index}_{tool_name}_{name_completion['id']}",
540502
"type": "function",
541503
"function": {
542504
"name": tool_name,
@@ -548,19 +510,19 @@ def _handle_streaming_tool_calls(
548510
}]
549511
}
550512
accumulated_text += param_text
551-
552-
# After tool call, continue normal streaming from where we left off
553-
continue_prompt = prompt + llama.tokenize(accumulated_text.encode("utf-8"), add_bos=False, special=True)
554-
for chunk in llama.create_completion(
555-
prompt=continue_prompt,
556-
stream=True,
557-
**{k: v for k, v in base_completion_kwargs.items() if k != "stream"}
558-
):
559-
yield from _convert_text_completion_chunks_to_chat(iter([chunk]))
560-
513+
514+
# After completing the tool call parameters, continue with more completions
515+
# Recursively handle the next completion by starting a new generation
516+
yield from _handle_streaming_tool_calls(
517+
tools,
518+
prompt + llama.tokenize((accumulated_text + "\n</tool_call>\n").encode("utf-8"), add_bos=False, special=True),
519+
llama,
520+
base_completion_kwargs,
521+
stopping_criteria=stopping_criteria,
522+
grammar=grammar,
523+
tool_call_index=tool_call_index + 1 # Increment index for potential next tool call
524+
)
561525
except Exception as e:
562-
if llama.verbose:
563-
print(f"[DEBUG] Failed to stream tool call: {e}", file=sys.stderr)
564526
# Fall back to regular streaming without grammar
565527
fallback_prompt = prompt + llama.tokenize(accumulated_text.encode("utf-8"), add_bos=False, special=True)
566528
for chunk in llama.create_completion(
@@ -726,19 +688,14 @@ def _handle_streaming_completion(
726688

727689
# Handle auto tool choice for streaming
728690
if tool_choice == "auto" and tools is not None and len(tools) > 0:
729-
# First generate normally until we hit a tool call tag
730-
completion_chunks = llama.create_completion(
731-
prompt=prompt,
732-
stream=True,
733-
stop="<tool_call>",
734-
stopping_criteria=stopping_criteria,
735-
grammar=grammar,
736-
**base_completion_kwargs,
737-
)
738-
739-
# Use helper function to handle streaming with tool call detection
691+
# Start the recursive completion process that handles both text and tool calls
740692
yield from _handle_streaming_tool_calls(
741-
completion_chunks, tools, prompt, llama, base_completion_kwargs
693+
tools,
694+
prompt,
695+
llama,
696+
base_completion_kwargs,
697+
stopping_criteria=stopping_criteria,
698+
grammar=grammar
742699
)
743700
return
744701

@@ -1494,8 +1451,6 @@ def _grammar_for_tool_parameters(
14941451
verbose=verbose
14951452
)
14961453
except Exception as e:
1497-
if verbose:
1498-
print(f"[DEBUG] Failed to parse function parameters as JSON schema: {e}", file=sys.stderr)
14991454
return llama_grammar.LlamaGrammar.from_string(
15001455
llama_grammar.JSON_GBNF,
15011456
verbose=verbose

0 commit comments

Comments
 (0)