Skip to content

Commit 3db4da5

Browse files
authored
chat : support Granite model reasoning and tool call (#14864)
1 parent 476aa3f commit 3db4da5

File tree

6 files changed

+252
-1
lines changed

6 files changed

+252
-1
lines changed

common/chat-parser.cpp

Lines changed: 9 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -55,7 +55,15 @@ bool common_chat_msg_parser::add_tool_call(const std::string & name, const std::
5555
bool common_chat_msg_parser::add_tool_call(const json & tool_call) {
5656
std::string name = tool_call.contains("name") ? tool_call.at("name") : "";
5757
std::string id = tool_call.contains("id") ? tool_call.at("id") : "";
58-
std::string arguments = tool_call.contains("arguments") ? tool_call.at("arguments") : "";
58+
std::string arguments = "";
59+
if (tool_call.contains("arguments")) {
60+
if (tool_call.at("arguments").is_object()) {
61+
arguments = tool_call.at("arguments").dump();
62+
} else {
63+
arguments = tool_call.at("arguments");
64+
}
65+
}
66+
5967
return add_tool_call(name, id, arguments);
6068
}
6169

common/chat.cpp

Lines changed: 129 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -606,6 +606,7 @@ const char * common_chat_format_name(common_chat_format format) {
606606
case COMMON_CHAT_FORMAT_FUNCTIONARY_V3_1_LLAMA_3_1: return "Functionary v3.1 Llama 3.1";
607607
case COMMON_CHAT_FORMAT_HERMES_2_PRO: return "Hermes 2 Pro";
608608
case COMMON_CHAT_FORMAT_COMMAND_R7B: return "Command R7B";
609+
case COMMON_CHAT_FORMAT_GRANITE: return "Granite";
609610
case COMMON_CHAT_FORMAT_GPT_OSS: return "GPT-OSS";
610611
default:
611612
throw std::runtime_error("Unknown chat format");
@@ -618,6 +619,7 @@ const char * common_reasoning_format_name(common_reasoning_format format) {
618619
case COMMON_REASONING_FORMAT_AUTO: return "auto";
619620
case COMMON_REASONING_FORMAT_DEEPSEEK: return "deepseek";
620621
case COMMON_REASONING_FORMAT_DEEPSEEK_LEGACY: return "deepseek-legacy";
622+
case COMMON_REASONING_FORMAT_GRANITE: return "granite";
621623
default:
622624
throw std::runtime_error("Unknown reasoning format");
623625
}
@@ -1734,6 +1736,124 @@ static void common_chat_parse_hermes_2_pro(common_chat_msg_parser & builder) {
17341736
builder.add_content(builder.consume_rest());
17351737
}
17361738

1739+
static common_chat_params common_chat_params_init_granite(const common_chat_template & tmpl, const struct templates_params & inputs) {
1740+
common_chat_params data;
1741+
1742+
// Pass thinking context for Granite template
1743+
json additional_context = {
1744+
{"thinking", inputs.enable_thinking},
1745+
};
1746+
1747+
data.prompt = apply(tmpl, inputs, /* messages_override= */ std::nullopt, /* tools_override= */ std::nullopt, additional_context);
1748+
data.format = COMMON_CHAT_FORMAT_GRANITE;
1749+
1750+
if (string_ends_with(data.prompt, "<think>\n") || string_ends_with(data.prompt, "<think>")) {
1751+
if (!inputs.enable_thinking) {
1752+
data.prompt += "</think>";
1753+
} else {
1754+
data.thinking_forced_open = true;
1755+
}
1756+
}
1757+
1758+
if (!inputs.tools.is_null()) {
1759+
// Granite uses <|tool_call|> followed by JSON list
1760+
data.grammar_lazy = inputs.tool_choice != COMMON_CHAT_TOOL_CHOICE_REQUIRED;
1761+
data.grammar = build_grammar([&](const common_grammar_builder & builder) {
1762+
std::vector<std::string> tool_rules;
1763+
foreach_function(inputs.tools, [&](const json & tool) {
1764+
const auto & function = tool.at("function");
1765+
std::string name = function.at("name");
1766+
auto parameters = function.at("parameters");
1767+
builder.resolve_refs(parameters);
1768+
tool_rules.push_back(builder.add_rule(name + "-call", builder.add_schema(name +
1769+
"-args", {
1770+
{"type", "object"},
1771+
{"properties", {
1772+
{"name", {{"const", name}}},
1773+
{"arguments", parameters},
1774+
}},
1775+
{"required", json::array({"name", "arguments"})},
1776+
})));
1777+
});
1778+
1779+
auto tool_call = builder.add_rule("tool_call", string_join(tool_rules, " | "));
1780+
auto tool_list = builder.add_rule("tool_list", "\"[\" space " + tool_call + " (\",\" space " + tool_call + ")* space \"]\"");
1781+
1782+
if (data.thinking_forced_open) {
1783+
builder.add_rule("root", "\"</think>\" space \"<response>\" space [^<]* \"</response>\" space \"<|tool_call|>\" space " + tool_list);
1784+
} else {
1785+
builder.add_rule("root", "\"<|tool_call|>\" space " + tool_list);
1786+
}
1787+
1788+
data.grammar_triggers.push_back({
1789+
COMMON_GRAMMAR_TRIGGER_TYPE_WORD,
1790+
"<|tool_call|>"
1791+
});
1792+
1793+
data.preserved_tokens = {
1794+
"<think>",
1795+
"</think>",
1796+
"<response>",
1797+
"</response>",
1798+
"<|tool_call|>",
1799+
};
1800+
});
1801+
} else {
1802+
// Handle thinking tags for non-tool responses
1803+
if (data.thinking_forced_open && inputs.enable_thinking) {
1804+
data.grammar_lazy = false;
1805+
data.grammar = build_grammar([&](const common_grammar_builder & builder) {
1806+
builder.add_rule("root", "\"</think>\" space \"<response>\" space .* \"</response>\" space");
1807+
});
1808+
data.preserved_tokens = {
1809+
"<think>",
1810+
"</think>",
1811+
"<response>",
1812+
"</response>",
1813+
};
1814+
}
1815+
}
1816+
1817+
return data;
1818+
}
1819+
1820+
static void common_chat_parse_granite(common_chat_msg_parser & builder) {
1821+
// Parse thinking tags
1822+
builder.try_parse_reasoning("<think>", "</think>");
1823+
1824+
// Parse response tags using regex
1825+
static const common_regex response_regex("<response>([\\s\\S]*?)</response>");
1826+
if (auto res = builder.try_find_regex(response_regex)) {
1827+
// Extract the content between the tags (capture group 1)
1828+
auto content = builder.str(res->groups[1]);
1829+
builder.add_content(content);
1830+
builder.move_to(res->groups[0].end);
1831+
}
1832+
1833+
if (!builder.syntax().parse_tool_calls) {
1834+
builder.add_content(builder.consume_rest());
1835+
return;
1836+
}
1837+
1838+
// Look for tool calls
1839+
static const common_regex tool_call_regex(regex_escape("<|tool_call|>"));
1840+
if (auto res = builder.try_find_regex(tool_call_regex)) {
1841+
builder.move_to(res->groups[0].end);
1842+
1843+
// Expect JSON array of tool calls
1844+
auto tool_calls_data = builder.consume_json();
1845+
if (tool_calls_data.json.is_array()) {
1846+
if (!builder.add_tool_calls(tool_calls_data.json)) {
1847+
builder.add_content("<|tool_call|>" + tool_calls_data.json.dump());
1848+
}
1849+
} else {
1850+
builder.add_content("<|tool_call|>" + tool_calls_data.json.dump());
1851+
}
1852+
} else {
1853+
builder.add_content(builder.consume_rest());
1854+
}
1855+
}
1856+
17371857
static common_chat_params common_chat_params_init_without_tools(const common_chat_template & tmpl, const struct templates_params & inputs) {
17381858
common_chat_params data;
17391859
data.prompt = apply(tmpl, inputs);
@@ -1805,6 +1925,11 @@ static common_chat_params common_chat_templates_apply_jinja(
18051925
return common_chat_params_init_command_r7b(tmpl, params);
18061926
}
18071927

1928+
// Granite (IBM) - detects thinking / tools support
1929+
if (src.find("elif thinking") != std::string::npos && src.find("<|tool_call|>") != std::string::npos) {
1930+
return common_chat_params_init_granite(tmpl, params);
1931+
}
1932+
18081933
// Hermes 2/3 Pro, Qwen 2.5 Instruct (w/ tools)
18091934
if (src.find("<tool_call>") != std::string::npos && params.json_schema.is_null()) {
18101935
return common_chat_params_init_hermes_2_pro(tmpl, params);
@@ -1865,6 +1990,7 @@ static common_chat_params common_chat_templates_apply_legacy(
18651990
int alloc_size = 0;
18661991
std::vector<llama_chat_message> chat;
18671992
std::vector<std::string> contents;
1993+
18681994
for (const auto & msg : inputs.messages) {
18691995
auto content = msg.content;
18701996
for (const auto & part : msg.content_parts) {
@@ -1966,6 +2092,9 @@ static void common_chat_parse(common_chat_msg_parser & builder) {
19662092
case COMMON_CHAT_FORMAT_COMMAND_R7B:
19672093
common_chat_parse_command_r7b(builder);
19682094
break;
2095+
case COMMON_CHAT_FORMAT_GRANITE:
2096+
common_chat_parse_granite(builder);
2097+
break;
19692098
case COMMON_CHAT_FORMAT_GPT_OSS:
19702099
common_chat_parse_gpt_oss(builder);
19712100
break;

common/chat.h

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -109,6 +109,7 @@ enum common_chat_format {
109109
COMMON_CHAT_FORMAT_FUNCTIONARY_V3_1_LLAMA_3_1,
110110
COMMON_CHAT_FORMAT_HERMES_2_PRO,
111111
COMMON_CHAT_FORMAT_COMMAND_R7B,
112+
COMMON_CHAT_FORMAT_GRANITE,
112113
COMMON_CHAT_FORMAT_GPT_OSS,
113114

114115
COMMON_CHAT_FORMAT_COUNT, // Not a format, just the # formats

common/common.h

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -239,6 +239,7 @@ enum common_reasoning_format {
239239
COMMON_REASONING_FORMAT_AUTO,
240240
COMMON_REASONING_FORMAT_DEEPSEEK_LEGACY, // Extract thinking tag contents and return as `message.reasoning_content`, or leave inline in <think> tags in stream mode
241241
COMMON_REASONING_FORMAT_DEEPSEEK, // Extract thinking tag contents and return as `message.reasoning_content`, including in streaming deltas.
242+
COMMON_REASONING_FORMAT_GRANITE, // Extract thinking tag contents and return as `message.reasoning_content`, including in streaming deltas.
242243
};
243244

244245
struct common_params {
Lines changed: 59 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,59 @@
1+
{# Alias tools -> available_tools #}
2+
{%- if tools and not available_tools -%}
3+
{%- set available_tools = tools -%}
4+
{%- endif -%}
5+
{%- if messages[0]['role'] == 'system' %}
6+
{%- set system_message = messages[0]['content'] %}
7+
{%- set loop_messages = messages[1:] %}
8+
{%- else %}
9+
{%- set system_message = "Knowledge Cutoff Date: April 2024. Today's Date: " + strftime_now('%B %d, %Y') + ". You are Granite, developed by IBM." %}
10+
{%- if available_tools and documents %}
11+
{%- set system_message = system_message + " You are a helpful assistant with access to the following tools. When a tool is required to answer the user's query, respond only with <|tool_call|> followed by a JSON list of tools used. If a tool does not exist in the provided list of tools, notify the user that you do not have the ability to fulfill the request. Write the response to the user's input by strictly aligning with the facts in the provided documents. If the information needed to answer the question is not available in the documents, inform the user that the question cannot be answered based on the available data." %}
12+
{%- elif available_tools %}
13+
{%- set system_message = system_message + " You are a helpful assistant with access to the following tools. When a tool is required to answer the user's query, respond only with <|tool_call|> followed by a JSON list of tools used. If a tool does not exist in the provided list of tools, notify the user that you do not have the ability to fulfill the request." %}
14+
{%- elif documents %}
15+
{%- set system_message = system_message + " Write the response to the user's input by strictly aligning with the facts in the provided documents. If the information needed to answer the question is not available in the documents, inform the user that the question cannot be answered based on the available data." %}
16+
{%- elif thinking %}
17+
{%- set system_message = system_message + " You are a helpful AI assistant.
18+
Respond to every user query in a comprehensive and detailed way. You can write down your thoughts and reasoning process before responding. In the thought process, engage in a comprehensive cycle of analysis, summarization, exploration, reassessment, reflection, backtracing, and iteration to develop well-considered thinking process. In the response section, based on various attempts, explorations, and reflections from the thoughts section, systematically present the final solution that you deem correct. The response should summarize the thought process. Write your thoughts between <think></think> and write your response between <response></response> for each user query." %}
19+
{%- else %}
20+
{%- set system_message = system_message + " You are a helpful AI assistant." %}
21+
{%- endif %}
22+
{%- if 'citations' in controls and documents %}
23+
{%- set system_message = system_message + '
24+
Use the symbols <|start_of_cite|> and <|end_of_cite|> to indicate when a fact comes from a document in the search result, e.g <|start_of_cite|> {document_id: 1}my fact <|end_of_cite|> for a fact from document 1. Afterwards, list all the citations with their corresponding documents in an ordered list.' %}
25+
{%- endif %}
26+
{%- if 'hallucinations' in controls and documents %}
27+
{%- set system_message = system_message + '
28+
Finally, after the response is written, include a numbered list of sentences from the response with a corresponding risk value that are hallucinated and not based in the documents.' %}
29+
{%- endif %}
30+
{%- set loop_messages = messages %}
31+
{%- endif %}
32+
{{- '<|start_of_role|>system<|end_of_role|>' + system_message + '<|end_of_text|>
33+
' }}
34+
{%- if available_tools %}
35+
{{- '<|start_of_role|>available_tools<|end_of_role|>' }}
36+
{{- available_tools | tojson(indent=4) }}
37+
{{- '<|end_of_text|>
38+
' }}
39+
{%- endif %}
40+
{%- if documents %}
41+
{%- for document in documents %}
42+
{{- '<|start_of_role|>document {"document_id": "' + document['doc_id'] | string + '"}<|end_of_role|>
43+
' }}
44+
{{- document['text'] }}
45+
{{- '<|end_of_text|>
46+
' }}
47+
{%- endfor %}
48+
{%- endif %}
49+
{%- for message in loop_messages %}
50+
{{- '<|start_of_role|>' + message['role'] + '<|end_of_role|>' + message['content'] + '<|end_of_text|>
51+
' }}
52+
{%- if loop.last and add_generation_prompt %}
53+
{{- '<|start_of_role|>assistant' }}
54+
{%- if controls %}
55+
{{- ' ' + controls | tojson()}}
56+
{%- endif %}
57+
{{- '<|end_of_role|>' }}
58+
{%- endif %}
59+
{%- endfor %}

tests/test-chat.cpp

Lines changed: 53 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1386,6 +1386,59 @@ static void test_template_output_parsers() {
13861386
"{\"arg1\": 1}\n"
13871387
"```<|tool▁call▁end|><|tool▁calls▁end|>");
13881388
}
1389+
{
1390+
auto tmpls = read_templates("models/templates/ibm-granite-granite-3.3-2B-Instruct.jinja");
1391+
std::vector<std::string> end_tokens{ "<|end_of_text|>" };
1392+
1393+
assert_equals(COMMON_CHAT_FORMAT_GRANITE, common_chat_templates_apply(tmpls.get(), inputs_no_tools).format);
1394+
1395+
assert_equals(COMMON_CHAT_FORMAT_GRANITE, common_chat_templates_apply(tmpls.get(), inputs_tools).format);
1396+
1397+
// Test parsing regular content
1398+
assert_msg_equals(message_assist,
1399+
common_chat_parse(
1400+
"Hello, world!\nWhat's up?",
1401+
/* is_partial= */ false,
1402+
{COMMON_CHAT_FORMAT_GRANITE}));
1403+
1404+
// Test parsing content with thinking
1405+
assert_msg_equals(message_assist_thoughts,
1406+
common_chat_parse(
1407+
"<think>I'm\nthinking</think>Hello, world!\nWhat's up?",
1408+
/* is_partial= */ false,
1409+
{
1410+
/* .format = */ COMMON_CHAT_FORMAT_GRANITE,
1411+
/* .reasoning_format = */ COMMON_REASONING_FORMAT_GRANITE,
1412+
}));
1413+
1414+
// Test parsing tool calls
1415+
assert_msg_equals(message_assist_call,
1416+
common_chat_parse(
1417+
"<|tool_call|>[{\"name\": \"special_function\", \"arguments\": {\"arg1\": 1}}]",
1418+
/* is_partial= */ false,
1419+
{COMMON_CHAT_FORMAT_GRANITE}));
1420+
1421+
// Test template generation for regular content
1422+
test_templates(tmpls.get(), end_tokens, message_assist, tools,
1423+
"Hello, world!\nWhat's up?",
1424+
/* expect_grammar_triggered= */ false);
1425+
1426+
// Test template generation for tool calls
1427+
test_templates(tmpls.get(), end_tokens, message_assist_call_id, tools,
1428+
"{\n"
1429+
" \"tool_calls\": [\n"
1430+
" {\n"
1431+
" \"name\": \"special_function\",\n"
1432+
" \"arguments\": {\n"
1433+
" \"arg1\": 1\n"
1434+
" },\n"
1435+
" \"id\": \"123456789\"\n"
1436+
" }\n"
1437+
" ]\n"
1438+
"}",
1439+
/* expect_grammar_triggered= */ false
1440+
);
1441+
}
13891442
}
13901443

13911444
static void test_msg_diffs_compute() {

0 commit comments

Comments
 (0)