diff --git a/common/chat.cpp b/common/chat.cpp index 8587140e1ff0a..f95779963827c 100644 --- a/common/chat.cpp +++ b/common/chat.cpp @@ -2324,6 +2324,16 @@ static common_chat_params common_chat_params_init_granite(const common_chat_temp data.prompt = apply(tmpl, inputs, /* messages_override= */ std::nullopt, /* tools_override= */ std::nullopt, additional_context); data.format = COMMON_CHAT_FORMAT_GRANITE; + const auto & src = tmpl.source(); + const bool has_pipe_tool_call = src.find("<|tool_call|>") != std::string::npos; + const bool has_plain_tool_call = src.find("") != std::string::npos; + const bool has_plain_tool_call_close = src.find("") != std::string::npos; + const bool use_plain_tool_call = !has_pipe_tool_call && has_plain_tool_call; + + const std::string tool_call_tag = use_plain_tool_call ? "" : "<|tool_call|>"; + const std::string tool_call_close_tag = + use_plain_tool_call && has_plain_tool_call_close ? "" : ""; + if (string_ends_with(data.prompt, "\n") || string_ends_with(data.prompt, "")) { if (!inputs.enable_thinking) { data.prompt += ""; @@ -2333,9 +2343,23 @@ static common_chat_params common_chat_params_init_granite(const common_chat_temp } if (!inputs.tools.is_null()) { - // Granite uses <|tool_call|> followed by JSON list + // Granite uses a sentinel tag followed by a JSON list of tool calls data.grammar_lazy = inputs.tool_choice != COMMON_CHAT_TOOL_CHOICE_REQUIRED; data.grammar = build_grammar([&](const common_grammar_builder & builder) { + const auto quote_literal = [](const std::string & literal) { + std::string escaped; + escaped.reserve(literal.size() * 2 + 2); + escaped.push_back('"'); + for (const char ch : literal) { + if (ch == '"' || ch == '\\') { + escaped.push_back('\\'); + } + escaped.push_back(ch); + } + escaped.push_back('"'); + return escaped; + }; + std::vector tool_rules; foreach_function(inputs.tools, [&](const json & tool) { const auto & function = tool.at("function"); @@ -2356,15 +2380,25 @@ static common_chat_params common_chat_params_init_granite(const common_chat_temp auto tool_call = builder.add_rule("tool_call", string_join(tool_rules, " | ")); auto tool_list = builder.add_rule("tool_list", "\"[\" space " + tool_call + " (\",\" space " + tool_call + ")* space \"]\""); + const auto tool_call_literal = quote_literal(tool_call_tag); + const auto tool_call_close_literal = + tool_call_close_tag.empty() ? std::string{} : quote_literal(tool_call_close_tag); + const auto optional_close_segment = tool_call_close_literal.empty() + ? std::string{} + : " (space " + tool_call_close_literal + ")?"; + if (data.thinking_forced_open) { - builder.add_rule("root", "\"\" space \"\" space [^<]* \"\" space \"<|tool_call|>\" space " + tool_list); + builder.add_rule( + "root", + "\"\" space \"\" space [^<]* \"\" space " + + tool_call_literal + " space " + tool_list + optional_close_segment); } else { - builder.add_rule("root", "\"<|tool_call|>\" space " + tool_list); + builder.add_rule("root", tool_call_literal + " space " + tool_list + optional_close_segment); } data.grammar_triggers.push_back({ COMMON_GRAMMAR_TRIGGER_TYPE_WORD, - "<|tool_call|>" + tool_call_tag }); data.preserved_tokens = { @@ -2372,8 +2406,11 @@ static common_chat_params common_chat_params_init_granite(const common_chat_temp "", "", "", - "<|tool_call|>", + tool_call_tag, }; + if (!tool_call_close_tag.empty()) { + data.preserved_tokens.push_back(tool_call_close_tag); + } }); } else { // Handle thinking tags for non-tool responses @@ -2426,17 +2463,30 @@ static void common_chat_parse_granite(common_chat_msg_parser & builder) { } // Look for tool calls - static const common_regex tool_call_regex(regex_escape("<|tool_call|>")); - if (auto res = builder.try_find_regex(tool_call_regex)) { - builder.move_to(res->groups[0].end); + static const common_regex tool_call_regex_legacy(regex_escape("<|tool_call|>")); + static const common_regex tool_call_regex_plain(regex_escape("")); - // Expect JSON array of tool calls - if (auto tool_call = builder.try_consume_json_with_dumped_args({{{"arguments"}}})) { - if (!builder.add_tool_calls(tool_call->value) || tool_call->is_partial) { - throw common_chat_msg_partial_exception("incomplete tool call"); + const auto try_parse_tool_calls = [&](const common_regex & regex, const std::string & close_tag) { + if (auto res = builder.try_find_regex(regex)) { + builder.move_to(res->groups[0].end); + + // Expect JSON array of tool calls + if (auto tool_call = builder.try_consume_json_with_dumped_args({{{"arguments"}}})) { + if (!builder.add_tool_calls(tool_call->value) || tool_call->is_partial) { + throw common_chat_msg_partial_exception("incomplete tool call"); + } + builder.consume_spaces(); + if (!close_tag.empty()) { + builder.try_consume_literal(close_tag); + } } + return true; } - } else { + return false; + }; + + if (!try_parse_tool_calls(tool_call_regex_legacy, "") && + !try_parse_tool_calls(tool_call_regex_plain, "")) { builder.add_content(builder.consume_rest()); } } @@ -2719,7 +2769,8 @@ static common_chat_params common_chat_templates_apply_jinja( } // Granite (IBM) - detects thinking / tools support - if (src.find("elif thinking") != std::string::npos && src.find("<|tool_call|>") != std::string::npos) { + if ((src.find("elif thinking") != std::string::npos || src.find("tools_system_message_prefix") != std::string::npos) && + (src.find("<|tool_call|>") != std::string::npos || src.find("") != std::string::npos)) { return common_chat_params_init_granite(tmpl, params); }