Skip to content

Commit 1ed1980

Browse files
committed
Add test case for granite
1 parent 13a8ecc commit 1ed1980

File tree

4 files changed

+68
-3
lines changed

4 files changed

+68
-3
lines changed

common/chat-parser.cpp

Lines changed: 9 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -55,7 +55,15 @@ bool common_chat_msg_parser::add_tool_call(const std::string & name, const std::
5555
bool common_chat_msg_parser::add_tool_call(const json & tool_call) {
5656
std::string name = tool_call.contains("name") ? tool_call.at("name") : "";
5757
std::string id = tool_call.contains("id") ? tool_call.at("id") : "";
58-
std::string arguments = tool_call.contains("arguments") ? tool_call.at("arguments") : "";
58+
std::string arguments = "";
59+
if (tool_call.contains("arguments")) {
60+
if (tool_call.at("arguments").is_object()) {
61+
arguments = tool_call.at("arguments").dump();
62+
} else {
63+
arguments = tool_call.at("arguments");
64+
}
65+
}
66+
5967
return add_tool_call(name, id, arguments);
6068
}
6169

common/chat.cpp

Lines changed: 5 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -592,6 +592,8 @@ const char * common_chat_format_name(common_chat_format format) {
592592
case COMMON_CHAT_FORMAT_FUNCTIONARY_V3_1_LLAMA_3_1: return "Functionary v3.1 Llama 3.1";
593593
case COMMON_CHAT_FORMAT_HERMES_2_PRO: return "Hermes 2 Pro";
594594
case COMMON_CHAT_FORMAT_COMMAND_R7B: return "Command R7B";
595+
case COMMON_CHAT_FORMAT_GRANITE: return "Granite";
596+
595597
default:
596598
throw std::runtime_error("Unknown chat format");
597599
}
@@ -602,6 +604,7 @@ const char * common_reasoning_format_name(common_reasoning_format format) {
602604
case COMMON_REASONING_FORMAT_NONE: return "none";
603605
case COMMON_REASONING_FORMAT_DEEPSEEK: return "deepseek";
604606
case COMMON_REASONING_FORMAT_DEEPSEEK_LEGACY: return "deepseek-legacy";
607+
case COMMON_REASONING_FORMAT_GRANITE: return "granite";
605608
default:
606609
throw std::runtime_error("Unknown reasoning format");
607610
}
@@ -1709,6 +1712,7 @@ static common_chat_params common_chat_params_init_granite(const common_chat_temp
17091712
};
17101713

17111714
data.prompt = apply(tmpl, inputs, /* messages_override= */ std::nullopt, /* tools_override= */ std::nullopt, additional_context);
1715+
data.format = COMMON_CHAT_FORMAT_GRANITE;
17121716

17131717
if (string_ends_with(data.prompt, "<think>\n") || string_ends_with(data.prompt, "<think>")) {
17141718
if (!inputs.enable_thinking) {
@@ -1719,7 +1723,6 @@ static common_chat_params common_chat_params_init_granite(const common_chat_temp
17191723
}
17201724

17211725
if (!inputs.tools.is_null()) {
1722-
data.format = COMMON_CHAT_FORMAT_GRANITE;
17231726
// Granite uses <|tool_call|> followed by JSON list
17241727
data.grammar_lazy = inputs.tool_choice != COMMON_CHAT_TOOL_CHOICE_REQUIRED;
17251728
data.grammar = build_grammar([&](const common_grammar_builder & builder) {
@@ -1763,7 +1766,6 @@ static common_chat_params common_chat_params_init_granite(const common_chat_temp
17631766
};
17641767
});
17651768
} else {
1766-
data.format = COMMON_CHAT_FORMAT_CONTENT_ONLY;
17671769
// Handle thinking tags for non-tool responses
17681770
if (data.thinking_forced_open && inputs.enable_thinking) {
17691771
data.grammar_lazy = false;
@@ -1948,6 +1950,7 @@ static common_chat_params common_chat_templates_apply_legacy(
19481950
int alloc_size = 0;
19491951
std::vector<llama_chat_message> chat;
19501952
std::vector<std::string> contents;
1953+
19511954
for (const auto & msg : inputs.messages) {
19521955
auto content = msg.content;
19531956
for (const auto & part : msg.content_parts) {

common/common.h

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -231,6 +231,7 @@ enum common_reasoning_format {
231231
COMMON_REASONING_FORMAT_NONE,
232232
COMMON_REASONING_FORMAT_DEEPSEEK_LEGACY, // Extract thinking tag contents and return as `message.reasoning_content`, or leave inline in <think> tags in stream mode
233233
COMMON_REASONING_FORMAT_DEEPSEEK, // Extract thinking tag contents and return as `message.reasoning_content`, including in streaming deltas.
234+
COMMON_REASONING_FORMAT_GRANITE, // Extract thinking tag contents and return as `message.reasoning_content`, including in streaming deltas.
234235
};
235236

236237
struct common_params {

tests/test-chat.cpp

Lines changed: 53 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1343,6 +1343,59 @@ static void test_template_output_parsers() {
13431343
"{\"arg1\": 1}\n"
13441344
"```<|tool▁call▁end|><|tool▁calls▁end|>");
13451345
}
1346+
{
1347+
auto tmpls = read_templates("models/templates/ibm-granite-granite-2.2-2B-Instruct.jinja");
1348+
std::vector<std::string> end_tokens{ "<|end_of_text|>" };
1349+
1350+
assert_equals(COMMON_CHAT_FORMAT_GRANITE, common_chat_templates_apply(tmpls.get(), inputs_no_tools).format);
1351+
1352+
assert_equals(COMMON_CHAT_FORMAT_GRANITE, common_chat_templates_apply(tmpls.get(), inputs_tools).format);
1353+
1354+
// Test parsing regular content
1355+
assert_msg_equals(message_assist,
1356+
common_chat_parse(
1357+
"Hello, world!\nWhat's up?",
1358+
/* is_partial= */ false,
1359+
{COMMON_CHAT_FORMAT_GRANITE}));
1360+
1361+
// Test parsing content with thinking
1362+
assert_msg_equals(message_assist_thoughts,
1363+
common_chat_parse(
1364+
"<think>I'm\nthinking</think>Hello, world!\nWhat's up?",
1365+
/* is_partial= */ false,
1366+
{
1367+
/* .format = */ COMMON_CHAT_FORMAT_GRANITE,
1368+
/* .reasoning_format = */ COMMON_REASONING_FORMAT_GRANITE,
1369+
}));
1370+
1371+
// Test parsing tool calls
1372+
assert_msg_equals(message_assist_call,
1373+
common_chat_parse(
1374+
"<|tool_call|>[{\"name\": \"special_function\", \"arguments\": {\"arg1\": 1}}]",
1375+
/* is_partial= */ false,
1376+
{COMMON_CHAT_FORMAT_GRANITE}));
1377+
1378+
// Test template generation for regular content
1379+
test_templates(tmpls.get(), end_tokens, message_assist, tools,
1380+
"Hello, world!\nWhat's up?",
1381+
/* expect_grammar_triggered= */ false);
1382+
1383+
// Test template generation for tool calls
1384+
test_templates(tmpls.get(), end_tokens, message_assist_call_id, tools,
1385+
"{\n"
1386+
" \"tool_calls\": [\n"
1387+
" {\n"
1388+
" \"name\": \"special_function\",\n"
1389+
" \"arguments\": {\n"
1390+
" \"arg1\": 1\n"
1391+
" },\n"
1392+
" \"id\": \"123456789\"\n"
1393+
" }\n"
1394+
" ]\n"
1395+
"}",
1396+
/* expect_grammar_triggered= */ false
1397+
);
1398+
}
13461399
}
13471400

13481401
static void test_msg_diffs_compute() {

0 commit comments

Comments
 (0)