Skip to content
Draft
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
79 changes: 65 additions & 14 deletions common/chat.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -2324,6 +2324,16 @@ static common_chat_params common_chat_params_init_granite(const common_chat_temp
data.prompt = apply(tmpl, inputs, /* messages_override= */ std::nullopt, /* tools_override= */ std::nullopt, additional_context);
data.format = COMMON_CHAT_FORMAT_GRANITE;

const auto & src = tmpl.source();
const bool has_pipe_tool_call = src.find("<|tool_call|>") != std::string::npos;
const bool has_plain_tool_call = src.find("<tool_call>") != std::string::npos;
const bool has_plain_tool_call_close = src.find("</tool_call>") != std::string::npos;
const bool use_plain_tool_call = !has_pipe_tool_call && has_plain_tool_call;

const std::string tool_call_tag = use_plain_tool_call ? "<tool_call>" : "<|tool_call|>";
const std::string tool_call_close_tag =
use_plain_tool_call && has_plain_tool_call_close ? "</tool_call>" : "";

if (string_ends_with(data.prompt, "<think>\n") || string_ends_with(data.prompt, "<think>")) {
if (!inputs.enable_thinking) {
data.prompt += "</think>";
Expand All @@ -2333,9 +2343,23 @@ static common_chat_params common_chat_params_init_granite(const common_chat_temp
}

if (!inputs.tools.is_null()) {
// Granite uses <|tool_call|> followed by JSON list
// Granite uses a sentinel tag followed by a JSON list of tool calls
data.grammar_lazy = inputs.tool_choice != COMMON_CHAT_TOOL_CHOICE_REQUIRED;
data.grammar = build_grammar([&](const common_grammar_builder & builder) {
const auto quote_literal = [](const std::string & literal) {
std::string escaped;
escaped.reserve(literal.size() * 2 + 2);
escaped.push_back('"');
for (const char ch : literal) {
if (ch == '"' || ch == '\\') {
escaped.push_back('\\');
}
escaped.push_back(ch);
}
escaped.push_back('"');
return escaped;
};

std::vector<std::string> tool_rules;
foreach_function(inputs.tools, [&](const json & tool) {
const auto & function = tool.at("function");
Expand All @@ -2356,24 +2380,37 @@ static common_chat_params common_chat_params_init_granite(const common_chat_temp
auto tool_call = builder.add_rule("tool_call", string_join(tool_rules, " | "));
auto tool_list = builder.add_rule("tool_list", "\"[\" space " + tool_call + " (\",\" space " + tool_call + ")* space \"]\"");

const auto tool_call_literal = quote_literal(tool_call_tag);
const auto tool_call_close_literal =
tool_call_close_tag.empty() ? std::string{} : quote_literal(tool_call_close_tag);
const auto optional_close_segment = tool_call_close_literal.empty()
? std::string{}
: " (space " + tool_call_close_literal + ")?";

if (data.thinking_forced_open) {
builder.add_rule("root", "\"</think>\" space \"<response>\" space [^<]* \"</response>\" space \"<|tool_call|>\" space " + tool_list);
builder.add_rule(
"root",
"\"</think>\" space \"<response>\" space [^<]* \"</response>\" space " +
tool_call_literal + " space " + tool_list + optional_close_segment);
} else {
builder.add_rule("root", "\"<|tool_call|>\" space " + tool_list);
builder.add_rule("root", tool_call_literal + " space " + tool_list + optional_close_segment);
}

data.grammar_triggers.push_back({
COMMON_GRAMMAR_TRIGGER_TYPE_WORD,
"<|tool_call|>"
tool_call_tag
});

data.preserved_tokens = {
"<think>",
"</think>",
"<response>",
"</response>",
"<|tool_call|>",
tool_call_tag,
};
if (!tool_call_close_tag.empty()) {
data.preserved_tokens.push_back(tool_call_close_tag);
}
});
} else {
// Handle thinking tags for non-tool responses
Expand Down Expand Up @@ -2426,17 +2463,30 @@ static void common_chat_parse_granite(common_chat_msg_parser & builder) {
}

// Look for tool calls
static const common_regex tool_call_regex(regex_escape("<|tool_call|>"));
if (auto res = builder.try_find_regex(tool_call_regex)) {
builder.move_to(res->groups[0].end);
static const common_regex tool_call_regex_legacy(regex_escape("<|tool_call|>"));
static const common_regex tool_call_regex_plain(regex_escape("<tool_call>"));

// Expect JSON array of tool calls
if (auto tool_call = builder.try_consume_json_with_dumped_args({{{"arguments"}}})) {
if (!builder.add_tool_calls(tool_call->value) || tool_call->is_partial) {
throw common_chat_msg_partial_exception("incomplete tool call");
const auto try_parse_tool_calls = [&](const common_regex & regex, const std::string & close_tag) {
if (auto res = builder.try_find_regex(regex)) {
builder.move_to(res->groups[0].end);

// Expect JSON array of tool calls
if (auto tool_call = builder.try_consume_json_with_dumped_args({{{"arguments"}}})) {
if (!builder.add_tool_calls(tool_call->value) || tool_call->is_partial) {
throw common_chat_msg_partial_exception("incomplete tool call");
}
builder.consume_spaces();
if (!close_tag.empty()) {
builder.try_consume_literal(close_tag);
}
}
return true;
}
} else {
return false;
};

if (!try_parse_tool_calls(tool_call_regex_legacy, "") &&
!try_parse_tool_calls(tool_call_regex_plain, "</tool_call>")) {
builder.add_content(builder.consume_rest());
}
}
Expand Down Expand Up @@ -2719,7 +2769,8 @@ static common_chat_params common_chat_templates_apply_jinja(
}

// Granite (IBM) - detects thinking / tools support
if (src.find("elif thinking") != std::string::npos && src.find("<|tool_call|>") != std::string::npos) {
if ((src.find("elif thinking") != std::string::npos || src.find("tools_system_message_prefix") != std::string::npos) &&
(src.find("<|tool_call|>") != std::string::npos || src.find("<tool_call>") != std::string::npos)) {
return common_chat_params_init_granite(tmpl, params);
}

Expand Down
Loading