Skip to content

Commit 2d607f1

Browse files
author
ochafik
committed
Update test-chat-handler.cpp
1 parent b565ab2 commit 2d607f1

File tree

1 file changed

+131
-87
lines changed

1 file changed

+131
-87
lines changed

tests/test-chat-handler.cpp

Lines changed: 131 additions & 87 deletions
Original file line numberDiff line numberDiff line change
@@ -298,34 +298,6 @@ const json tools = {special_function_tool, python_tool};
298298
// json::array({special_function_call}));
299299
// }
300300

301-
static void test_format_detection() {
302-
common_chat_params no_tools_params;
303-
no_tools_params.messages = {{{"role", "user"}, {"content", "Hey"}}};
304-
305-
common_chat_params tools_params = no_tools_params;
306-
tools_params.tools = json::array();
307-
308-
auto describe = [](const std::string & template_file, const common_chat_params & params) {
309-
const common_chat_template tmpl(read_file(template_file), "<s>", "</s>");
310-
auto data = common_chat_init(tmpl, params);
311-
return data.format;
312-
};
313-
314-
assert_equals(std::string("functionary v3.1 llama 3.1 tool calls"), describe("tests/chat/templates/meetkai-functionary-medium-v3.1.jinja", tools_params));
315-
assert_equals(std::string("functionary v3.2 tool calls"), describe("tests/chat/templates/meetkai-functionary-medium-v3.2.jinja", tools_params));
316-
assert_equals(std::string("firefunction v2 tool calls"), describe("tests/chat/templates/fireworks-ai-llama-3-firefunction-v2.jinja", tools_params));
317-
assert_equals(std::string("llama 3.1 tool calls"), describe("tests/chat/templates/meta-llama-Meta-Llama-3.1-8B-Instruct.jinja", tools_params));
318-
assert_equals(std::string("llama 3.2 tool calls"), describe("tests/chat/templates/meta-llama-Llama-3.2-3B-Instruct.jinja", tools_params));
319-
assert_equals(std::string("hermes 2 pro tool calls"), describe("tests/chat/templates/Qwen-Qwen2.5-7B-Instruct.jinja", tools_params));
320-
assert_equals(std::string("hermes 2 pro tool calls"), describe("tests/chat/templates/NousResearch-Hermes-2-Pro-Llama-3-8B-tool_use.jinja", tools_params));
321-
assert_equals(std::string("hermes 2 pro tool calls"), describe("tests/chat/templates/NousResearch-Hermes-3-Llama-3.1-8B-tool_use.jinja", tools_params));
322-
assert_equals(std::string("mistral nemo tool calls"), describe("tests/chat/templates/mistralai-Mistral-Nemo-Instruct-2407.jinja", tools_params));
323-
assert_equals(std::string("deepseek r1 tool calls"), describe("tests/chat/templates/deepseek-ai-DeepSeek-R1-Distill-Llama-8B.jinja", tools_params));
324-
assert_equals(std::string("generic tool calls"), describe("tests/chat/templates/google-gemma-7b-it.jinja", tools_params));
325-
assert_equals(std::string("content-only"), describe("tests/chat/templates/google-gemma-7b-it.jinja", no_tools_params));
326-
// assert_equals(std::string("command_r_plus tool calls"), describe("tests/chat/templates/CohereForAI-c4ai-command-r-plus-tool_use.jinja_, tools_params));
327-
}
328-
329301
static std::string get_message_prompt_delta(const common_chat_template & tmpl, const std::vector<std::string> & end_tokens, const json & user_message, const json & delta_message, const json & tools) {
330302
fprintf(stderr, "Template source: %s\n", tmpl.source().c_str());
331303
fprintf(stderr, "Delta message: %s\n", delta_message.dump(2).c_str());
@@ -363,20 +335,23 @@ static std::string get_message_prompt_delta(const common_chat_template & tmpl, c
363335
return delta;
364336
}
365337

366-
static void test_template(const common_chat_template & tmpl, const std::vector<std::string> & end_tokens, const json & tool_calling_message, const json & tools, bool skip_grammar_test = false) {
338+
static void test_template(const common_chat_template & tmpl, const std::vector<std::string> & end_tokens, const json & test_message, const json & tools = {}, bool skip_grammar_test = false) {
367339
// auto tool_call_style = common_tool_call_style_detect(tmpl);
368340
common_chat_msg expected_msg {
369341
"assistant",
370342
"",
371343
{},
372344
};
373-
for (const auto & tc : tool_calling_message.at("tool_calls")) {
374-
const auto & arguments = tc.at("function").at("arguments");
375-
expected_msg.tool_calls.push_back({
376-
tc.at("function").at("name").get<std::string>(),
377-
arguments.is_string() ? arguments.get<std::string>() : arguments.dump(),
378-
tc.contains("id") ? tc.at("id").get<std::string>() : "",
379-
});
345+
auto has_tool_calls = test_message.contains("tool_calls");
346+
if (has_tool_calls) {
347+
for (const auto & tc : test_message.at("tool_calls")) {
348+
const auto & arguments = tc.at("function").at("arguments");
349+
expected_msg.tool_calls.push_back({
350+
tc.at("function").at("name").get<std::string>(),
351+
arguments.is_string() ? arguments.get<std::string>() : arguments.dump(),
352+
tc.contains("id") ? tc.at("id").get<std::string>() : "",
353+
});
354+
}
380355
}
381356

382357
// Format the message: apply the template to 1 user message w/ add_generation_prompt=true, then w/ the extra message w/ add_generation_prompt=false,
@@ -386,36 +361,45 @@ static void test_template(const common_chat_template & tmpl, const std::vector<s
386361
{"content", "Hello, world!"}
387362
};
388363

389-
common_chat_params params;
390-
params.parallel_tool_calls = true;
391-
params.messages = json {user_message, tool_calling_message};
392-
params.tools = tools;
393-
auto chat_data = common_chat_init(tmpl, params);
394-
fprintf(stderr, "PROMPT: %s\n", chat_data.prompt.get<std::string>().c_str());
395-
auto grammar = build_grammar(chat_data.grammar);
396-
if (!grammar) {
397-
throw std::runtime_error("Failed to build grammar");
398-
}
399-
400-
if (!skip_grammar_test) {
401-
auto full_delta = get_message_prompt_delta(tmpl, end_tokens, user_message, tool_calling_message, tools);
402-
std::cout << "Full delta:\n```\n" << full_delta << "\n```" << std::endl;
403-
404-
const auto msg = chat_data.parser->parse_final(full_delta);
405-
assert_msg_equals(expected_msg, msg);
406-
407-
auto content_less_delta = get_message_prompt_delta(tmpl, end_tokens, user_message, {
408-
{"role", "assistant"},
409-
{"content", {}},
410-
{"tool_calls", tool_calling_message.at("tool_calls")}
411-
}, tools);
412-
if (!match_string(content_less_delta, grammar.get())) {
413-
throw std::runtime_error("Failed to match content-less delta against grammar:\n\nContent-less delta: " + content_less_delta + "\n\nGrammar: " + chat_data.grammar);
364+
for (const auto & tool_choice : json({"auto", "required"})) {
365+
common_chat_params params;
366+
params.tool_choice = tool_choice;
367+
params.parallel_tool_calls = true;
368+
params.messages = json {user_message, test_message};
369+
params.tools = tools;
370+
auto chat_data = common_chat_init(tmpl, params);
371+
// fprintf(stderr, "PROMPT: %s\n", chat_data.prompt.get<std::string>().c_str());
372+
if (has_tool_calls) {
373+
auto grammar = build_grammar(chat_data.grammar);
374+
if (!grammar) {
375+
throw std::runtime_error("Failed to build grammar");
376+
}
377+
378+
if (!skip_grammar_test) {
379+
auto full_delta = get_message_prompt_delta(tmpl, end_tokens, user_message, test_message, tools);
380+
std::cout << "Full delta:\n```\n" << full_delta << "\n```" << std::endl;
381+
382+
const auto msg = chat_data.parser->parse_final(full_delta);
383+
assert_msg_equals(expected_msg, msg);
384+
385+
auto content_less_delta = get_message_prompt_delta(tmpl, end_tokens, user_message, {
386+
{"role", "assistant"},
387+
{"content", {}},
388+
{"tool_calls", test_message.at("tool_calls")}
389+
}, tools);
390+
if (!match_string(content_less_delta, grammar.get())) {
391+
throw std::runtime_error("Failed to match content-less delta against grammar:\n\nContent-less delta: " + content_less_delta + "\n\nGrammar: " + chat_data.grammar);
392+
}
393+
}
414394
}
415395
}
416396
}
417397

418398
static void test_grammars() {
399+
auto text_message = json {
400+
{"role", "assistant"},
401+
{"content", "Hello, world!"},
402+
};
419403
auto tool_call_message = json {
420404
{"role", "assistant"},
421405
{"content", {}},
@@ -444,68 +428,128 @@ static void test_grammars() {
444428
}}}
445429
};
446430

431+
432+
common_chat_params no_tools_params;
433+
no_tools_params.messages = {{{"role", "user"}, {"content", "Hey"}}};
434+
435+
common_chat_params tools_params = no_tools_params;
436+
tools_params.tools = json::array();
437+
438+
auto describe = [](const common_chat_template & tmpl, const common_chat_params & params) {
439+
auto data = common_chat_init(tmpl, params);
440+
return data.format;
441+
};
442+
443+
{
444+
const common_chat_template tmpl(read_file("tests/chat/templates/google-gemma-2-2b-it.jinja"), "<s>", "</s>");
445+
std::vector<std::string> end_tokens { "<end_of_turn>" };
446+
447+
assert_equals(std::string("generic tool calls"), describe(tmpl, tools_params));
448+
assert_equals(std::string("content-only"), describe(tmpl, no_tools_params));
449+
test_template(tmpl, end_tokens, text_message, tools);
450+
test_template(tmpl, end_tokens, tool_call_message_with_id, tools);
451+
}
452+
{
453+
const common_chat_template tmpl(read_file("tests/chat/templates/microsoft-Phi-3.5-mini-instruct.jinja"), "<s>", "</s>");
454+
std::vector<std::string> end_tokens { "<|end|>" };
455+
456+
assert_equals(std::string("generic tool calls"), describe(tmpl, tools_params));
457+
test_template(tmpl, end_tokens, text_message, tools);
458+
test_template(tmpl, end_tokens, tool_call_message_with_id, tools);
459+
}
447460
{
448461
const common_chat_template tmpl(read_file("tests/chat/templates/mistralai-Mistral-Nemo-Instruct-2407.jinja"), "<s>", "</s>");
449-
test_template(tmpl, { "</s>" }, tool_call_message_with_id, tools, /* skip_grammar_test= */ true);
462+
std::vector<std::string> end_tokens { "</s>" };
463+
464+
assert_equals(std::string("mistral nemo tool calls"), describe(tmpl, tools_params));
465+
test_template(tmpl, end_tokens, text_message, tools);
466+
test_template(tmpl, end_tokens, tool_call_message_with_id, tools, /* skip_grammar_test= */ true);
450467
}
451468
{
452469
const common_chat_template tmpl(read_file("tests/chat/templates/Qwen-Qwen2.5-7B-Instruct.jinja"), "<s>", "</s>");
453-
// assert_equals(tmpl.requires_object_arguments_, true);
454-
test_template(tmpl, { "<|im_end|>" }, tool_call_message, tools);
455-
test_template(tmpl, { "<|im_end|>" }, python_tool_call_message, tools);
470+
std::vector<std::string> end_tokens { "<|im_end|>" };
471+
472+
assert_equals(std::string("hermes 2 pro tool calls"), describe(tmpl, tools_params));
473+
test_template(tmpl, end_tokens, text_message, tools);
474+
test_template(tmpl, end_tokens, tool_call_message, tools);
475+
test_template(tmpl, end_tokens, python_tool_call_message, tools);
456476
}
457477
{
458478
const common_chat_template tmpl(read_file("tests/chat/templates/NousResearch-Hermes-2-Pro-Llama-3-8B-tool_use.jinja"), "<s>", "</s>");
459-
test_template(tmpl, { "<|im_end|>" }, tool_call_message, tools);
479+
std::vector<std::string> end_tokens { "<|im_end|>" };
480+
481+
assert_equals(std::string("hermes 2 pro tool calls"), describe(tmpl, tools_params));
482+
test_template(tmpl, end_tokens, text_message, tools);
483+
test_template(tmpl, end_tokens, tool_call_message, tools);
460484
}
461485
{
462486
const common_chat_template tmpl(read_file("tests/chat/templates/NousResearch-Hermes-3-Llama-3.1-8B-tool_use.jinja"), "<s>", "</s>");
463-
test_template(tmpl, { "<|im_end|>" }, tool_call_message, tools);
464-
}
465-
{
466-
const common_chat_template tmpl(read_file("tests/chat/templates/meta-llama-Meta-Llama-3.1-8B-Instruct.jinja"), "<s>", "</s>");
467-
test_template(tmpl, { "<|eom_id|>", "<|eot_id|>" }, tool_call_message, tools);
487+
std::vector<std::string> end_tokens { "<|im_end|>" };
488+
489+
assert_equals(std::string("hermes 2 pro tool calls"), describe(tmpl, tools_params));
490+
test_template(tmpl, end_tokens, text_message, tools);
491+
test_template(tmpl, end_tokens, tool_call_message, tools);
468492
}
469493
{
470494
const common_chat_template tmpl(read_file("tests/chat/templates/meta-llama-Meta-Llama-3.1-8B-Instruct.jinja"), "<s>", "</s>");
471-
test_template(tmpl, { "<|eom_id|>", "<|eot_id|>" }, python_tool_call_message, tools);
495+
std::vector<std::string> end_tokens { "<|eom_id|>", "<|eot_id|>" };
496+
497+
assert_equals(std::string("llama 3.1 tool calls"), describe(tmpl, tools_params));
498+
test_template(tmpl, end_tokens, text_message, tools);
499+
test_template(tmpl, end_tokens, tool_call_message, tools);
500+
test_template(tmpl, end_tokens, python_tool_call_message, tools);
472501
}
473502
{
474503
const common_chat_template tmpl(read_file("tests/chat/templates/meta-llama-Llama-3.2-3B-Instruct.jinja"), "<s>", "</s>");
475-
test_template(tmpl, { "<|eom_id|>", "<|eot_id|>" }, tool_call_message, tools);
504+
std::vector<std::string> end_tokens { "<|eom_id|>", "<|eot_id|>" };
505+
506+
assert_equals(std::string("llama 3.2 tool calls"), describe(tmpl, tools_params));
507+
test_template(tmpl, end_tokens, text_message, tools);
508+
test_template(tmpl, end_tokens, tool_call_message, tools);
476509
}
477510
{
478511
const common_chat_template tmpl(read_file("tests/chat/templates/meta-llama-Llama-3.3-70B-Instruct.jinja"), "<s>", "</s>");
479-
test_template(tmpl, { "<|eom_id|>", "<|eot_id|>" }, tool_call_message, tools);
512+
std::vector<std::string> end_tokens { "<|eom_id|>", "<|eot_id|>" };
513+
514+
assert_equals(std::string("llama 3.1 tool calls"), describe(tmpl, tools_params));
515+
test_template(tmpl, end_tokens, text_message, tools);
516+
test_template(tmpl, end_tokens, tool_call_message, tools);
480517
}
481518
{
482519
const common_chat_template tmpl(read_file("tests/chat/templates/meetkai-functionary-medium-v3.1.jinja"), "<s>", "</s>");
483-
test_template(tmpl, { "<|eom_id|>", "<|eot_id|>" }, tool_call_message, tools);
520+
std::vector<std::string> end_tokens { "<|eom_id|>", "<|eot_id|>" };
521+
522+
assert_equals(std::string("functionary v3.1 llama 3.1 tool calls"), describe(tmpl, tools_params));
523+
test_template(tmpl, end_tokens, text_message, tools);
524+
test_template(tmpl, end_tokens, tool_call_message, tools);
484525
}
485526
{
486527
const common_chat_template tmpl(read_file("tests/chat/templates/meetkai-functionary-medium-v3.2.jinja"), "<s>", "</s>");
487-
test_template(tmpl, { "<|eom_id|>", "<|eot_id|>" }, tool_call_message, tools);
528+
std::vector<std::string> end_tokens { "<|eom_id|>", "<|eot_id|>" };
529+
530+
assert_equals(std::string("functionary v3.2 tool calls"), describe(tmpl, tools_params));
531+
test_template(tmpl, end_tokens, text_message, tools);
532+
test_template(tmpl, end_tokens, tool_call_message, tools);
488533
}
489534
{
490535
const common_chat_template tmpl(read_file("tests/chat/templates/fireworks-ai-llama-3-firefunction-v2.jinja"), "<s>", "</s>");
491-
test_template(tmpl, { "<|eot_id|>" }, tool_call_message, tools);
492-
}
493-
{
494-
const common_chat_template tmpl(read_file("tests/chat/templates/google-gemma-2-2b-it.jinja"), "<s>", "</s>");
495-
test_template(tmpl, { "<end_of_turn>" }, tool_call_message_with_id, tools);
496-
}
497-
{
498-
const common_chat_template tmpl(read_file("tests/chat/templates/microsoft-Phi-3.5-mini-instruct.jinja"), "<s>", "</s>");
499-
test_template(tmpl, { "<|end|>" }, tool_call_message_with_id, tools);
536+
std::vector<std::string> end_tokens { "<|eot_id|>" };
537+
538+
assert_equals(std::string("firefunction v2 tool calls"), describe(tmpl, tools_params));
539+
test_template(tmpl, end_tokens, text_message, tools);
540+
test_template(tmpl, end_tokens, tool_call_message, tools);
500541
}
501542
{
502543
const common_chat_template tmpl(read_file("tests/chat/templates/deepseek-ai-DeepSeek-R1-Distill-Llama-8B.jinja"), "<s>", "</s>");
503-
test_template(tmpl, { "<|end▁of▁sentence|>" }, tool_call_message, tools);
544+
std::vector<std::string> end_tokens { "<|end▁of▁sentence|>" };
545+
546+
assert_equals(std::string("deepseek r1 tool calls"), describe(tmpl, tools_params));
547+
test_template(tmpl, end_tokens, text_message, tools);
548+
test_template(tmpl, end_tokens, tool_call_message, tools);
504549
}
505550
}
506551

507552
int main() {
508-
test_format_detection();
509553
// test_parsing();
510554
test_grammars();
511555

0 commit comments

Comments
 (0)