1+ /*
2+ Tests chat handling, including grammar generation and parsing for tool calling, for various templates.
3+
4+ Also acts as a CLI to generate a Markdown summary of the formats of Jinja templates,
5+ e.g. given Minja (http://github.com/google/minja) checked out in parent dir:
6+
7+ cmake -B build && cmake --build build --parallel && \
8+ ./build/bin/test-chat ../minja/build/tests/*.jinja 2>/dev/null
9+
10+ */
111#include " chat.hpp"
212#include " chat-template.hpp"
313#include " llama-grammar.h"
@@ -44,7 +54,7 @@ static void assert_equals(const T & expected, const T & actual) {
4454}
4555
4656static std::string read_file (const std::string &path) {
47- std::cout << " # Reading: " << path << std::endl << std::flush;
57+ std::cerr << " # Reading: " << path << std::endl << std::flush;
4858 std::ifstream fs (path, std::ios_base::binary);
4959 if (!fs.is_open ()) {
5060 fs = std::ifstream (" ../" + path, std::ios_base::binary);
@@ -168,13 +178,15 @@ struct delta_data {
168178 common_chat_parser parser;
169179};
170180
171- static delta_data init_delta (const common_chat_template & tmpl, const std::vector<std::string> & end_tokens, const json & user_message, const json & delta_message, const json & tools) {
181+ static delta_data init_delta (const common_chat_template & tmpl, const std::vector<std::string> & end_tokens, const json & user_message, const json & delta_message, const json & tools, const json & tool_choice ) {
172182 common_chat_inputs inputs;
173183 inputs.parallel_tool_calls = true ;
174184 inputs.messages = json::array ();
175185 inputs.messages .push_back (user_message);
176186 inputs.tools = tools;
187+ inputs.tool_choice = tool_choice;
177188 auto params_prefix = common_chat_params_init (tmpl, inputs);
189+
178190 inputs.messages .push_back (delta_message);
179191 inputs.add_generation_prompt = false ;
180192 auto params_full = common_chat_params_init (tmpl, inputs);
@@ -220,7 +232,7 @@ static void test_template(const common_chat_template & tmpl, const std::vector<s
220232 };
221233
222234 for (const auto & tool_choice : json ({" auto" , " required" })) {
223- auto data = init_delta (tmpl, end_tokens, user_message, test_message, tools);
235+ auto data = init_delta (tmpl, end_tokens, user_message, test_message, tools, tool_choice );
224236 if (!expected_delta.empty ()) {
225237 assert_equals (expected_delta, data.delta );
226238 }
@@ -248,6 +260,10 @@ static void test_template(const common_chat_template & tmpl, const std::vector<s
248260 }
249261}
250262
263+ static std::string describe (const common_chat_template & tmpl, const common_chat_inputs & params) {
264+ return common_chat_params_init (tmpl, params).format ;
265+ }
266+
251267static void test_template_output_parsers () {
252268 auto text_message = json {
253269 {" role" , " assistant" },
@@ -295,29 +311,25 @@ static void test_template_output_parsers() {
295311 };
296312
297313
298- common_chat_inputs no_tools_params ;
299- no_tools_params .messages = {{{" role" , " user" }, {" content" , " Hey" }}};
314+ common_chat_inputs inputs_no_tools ;
315+ inputs_no_tools .messages = {{{" role" , " user" }, {" content" , " Hey" }}};
300316
301- common_chat_inputs tools_params = no_tools_params;
302- tools_params.tools = json::array ();
303- tools_params.tools .push_back (special_function_tool);
304-
305- auto describe = [](const common_chat_template & tmpl, const common_chat_inputs & params) {
306- return common_chat_params_init (tmpl, params).format ;
307- };
317+ common_chat_inputs inputs_tools = inputs_no_tools;
318+ inputs_tools.tools = json::array ();
319+ inputs_tools.tools .push_back (special_function_tool);
308320
309321 {
310322 const common_chat_template tmpl (read_file (
311323 " models/templates/google-gemma-2-2b-it.jinja" ), " <s>" , " </s>" );
312324 std::vector<std::string> end_tokens { " <end_of_turn>" };
313325
314- assert_equals (std::string (" content-only" ), describe (tmpl, no_tools_params) );
315- assert_equals (std::string (" generic tool calls" ), describe (tmpl, tools_params) );
316- assert_equals (std::string (" generic tool calls" ), describe (common_chat_template (read_file (
317- " models/templates/microsoft-Phi-3.5-mini-instruct.jinja" ), " <s>" , " </s>" ), tools_params) );
326+ assert_equals (std::string (" content-only" ), common_chat_params_init (tmpl, inputs_no_tools). format );
327+ assert_equals (std::string (" generic tool calls" ), common_chat_params_init (tmpl, inputs_tools). format );
328+ assert_equals (std::string (" generic tool calls" ), common_chat_params_init (common_chat_template (read_file (
329+ " models/templates/microsoft-Phi-3.5-mini-instruct.jinja" ), " <s>" , " </s>" ), inputs_tools). format );
318330
319331 // Generic tool calls doesn't generate / parse content-only messages symmetrically.
320- assert_msg_equals (msg_from_json (text_message), common_chat_params_init (tmpl, tools_params ).parser (
332+ assert_msg_equals (msg_from_json (text_message), common_chat_params_init (tmpl, inputs_tools ).parser (
321333 " {\n "
322334 " \" response\" : \" Hello, world!\"\n "
323335 " }" ));
@@ -339,7 +351,7 @@ static void test_template_output_parsers() {
339351 " models/templates/mistralai-Mistral-Nemo-Instruct-2407.jinja" ), " <s>" , " </s>" );
340352 std::vector<std::string> end_tokens { " </s>" };
341353
342- assert_equals (std::string (" mistral nemo tool calls" ), describe (tmpl, tools_params) );
354+ assert_equals (std::string (" mistral nemo tool calls" ), common_chat_params_init (tmpl, inputs_tools). format );
343355
344356 test_template (tmpl, end_tokens, text_message, tools, " Hello, world!" , /* skip_grammar_test= */ true );
345357 test_template (tmpl, end_tokens, tool_call_message_with_id, tools,
@@ -351,11 +363,11 @@ static void test_template_output_parsers() {
351363 " models/templates/NousResearch-Hermes-2-Pro-Llama-3-8B-tool_use.jinja" ), " <s>" , " </s>" );
352364 std::vector<std::string> end_tokens { " <|im_end|>" };
353365
354- assert_equals (std::string (" hermes 2 pro tool calls" ), describe (tmpl, tools_params) );
355- assert_equals (std::string (" hermes 2 pro tool calls" ), describe (common_chat_template (read_file (
356- " models/templates/NousResearch-Hermes-3-Llama-3.1-8B-tool_use.jinja" ), " <s>" , " </s>" ), tools_params) );
357- assert_equals (std::string (" hermes 2 pro tool calls" ), describe (common_chat_template (read_file (
358- " models/templates/Qwen-Qwen2.5-7B-Instruct.jinja" ), " <s>" , " </s>" ), tools_params) );
366+ assert_equals (std::string (" hermes 2 pro tool calls" ), common_chat_params_init (tmpl, inputs_tools). format );
367+ assert_equals (std::string (" hermes 2 pro tool calls" ), common_chat_params_init (common_chat_template (read_file (
368+ " models/templates/NousResearch-Hermes-3-Llama-3.1-8B-tool_use.jinja" ), " <s>" , " </s>" ), inputs_tools). format );
369+ assert_equals (std::string (" hermes 2 pro tool calls" ), common_chat_params_init (common_chat_template (read_file (
370+ " models/templates/Qwen-Qwen2.5-7B-Instruct.jinja" ), " <s>" , " </s>" ), inputs_tools). format );
359371
360372 test_template (tmpl, end_tokens, text_message, tools, " Hello, world!" , /* skip_grammar_test= */ true );
361373 test_template (tmpl, end_tokens, tool_call_message, tools,
@@ -372,9 +384,9 @@ static void test_template_output_parsers() {
372384 " models/templates/meta-llama-Llama-3.1-8B-Instruct.jinja" ), " <s>" , " </s>" );
373385 std::vector<std::string> end_tokens { " <|eom_id|>" , " <|eot_id|>" };
374386
375- assert_equals (std::string (" llama 3.x tool calls (w/ builtin tools)" ), describe (tmpl, tools_params) );
376- assert_equals (std::string (" llama 3.x tool calls (w/ builtin tools)" ), describe (common_chat_template (read_file (
377- " models/templates/meta-llama-Llama-3.3-70B-Instruct.jinja" ), " <s>" , " </s>" ), tools_params) );
387+ assert_equals (std::string (" llama 3.x tool calls (w/ builtin tools)" ), common_chat_params_init (tmpl, inputs_tools). format );
388+ assert_equals (std::string (" llama 3.x tool calls (w/ builtin tools)" ), common_chat_params_init (common_chat_template (read_file (
389+ " models/templates/meta-llama-Llama-3.3-70B-Instruct.jinja" ), " <s>" , " </s>" ), inputs_tools). format );
378390
379391 // test_template(tmpl, end_tokens, text_message, tools, R"(?)", /* skip_grammar_test= */ true);
380392 test_template (tmpl, end_tokens, code_interpreter_tool_call_message, llama_3_1_tools,
@@ -389,7 +401,7 @@ static void test_template_output_parsers() {
389401 " models/templates/meta-llama-Llama-3.2-3B-Instruct.jinja" ), " <s>" , " </s>" );
390402 std::vector<std::string> end_tokens { " <|eom_id|>" , " <|eot_id|>" };
391403
392- assert_equals (std::string (" llama 3.x tool calls" ), describe (tmpl, tools_params) );
404+ assert_equals (std::string (" llama 3.x tool calls" ), common_chat_params_init (tmpl, inputs_tools). format );
393405
394406 test_template (tmpl, end_tokens, text_message, tools,
395407 " Hello, world!" , /* skip_grammar_test= */ true );
@@ -401,7 +413,7 @@ static void test_template_output_parsers() {
401413 " models/templates/meetkai-functionary-medium-v3.1.jinja" ), " <s>" , " </s>" );
402414 std::vector<std::string> end_tokens { " <|eom_id|>" , " <|eot_id|>" };
403415
404- assert_equals (std::string (" functionary v3.1 llama 3.1 tool calls" ), describe (tmpl, tools_params) );
416+ assert_equals (std::string (" functionary v3.1 llama 3.1 tool calls" ), common_chat_params_init (tmpl, inputs_tools). format );
405417
406418 test_template (tmpl, end_tokens, text_message, tools,
407419 " Hello, world!" , /* skip_grammar_test= */ true );
@@ -413,8 +425,8 @@ static void test_template_output_parsers() {
413425 " models/templates/meetkai-functionary-medium-v3.2.jinja" ), " <s>" , " </s>" );
414426 std::vector<std::string> end_tokens { " <|eom_id|>" , " <|eot_id|>" };
415427
416- assert_equals (std::string (" functionary v3.2 content-only" ), describe (tmpl, no_tools_params) );
417- assert_equals (std::string (" functionary v3.2 tool calls" ), describe (tmpl, tools_params) );
428+ assert_equals (std::string (" functionary v3.2 content-only" ), common_chat_params_init (tmpl, inputs_no_tools). format );
429+ assert_equals (std::string (" functionary v3.2 tool calls" ), common_chat_params_init (tmpl, inputs_tools). format );
418430
419431 test_template (tmpl, end_tokens, text_message, tools,
420432 " all\n "
@@ -428,7 +440,7 @@ static void test_template_output_parsers() {
428440 " models/templates/fireworks-ai-llama-3-firefunction-v2.jinja" ), " <s>" , " </s>" );
429441 std::vector<std::string> end_tokens { " <|eot_id|>" };
430442
431- assert_equals (std::string (" firefunction v2 tool calls" ), describe (tmpl, tools_params) );
443+ assert_equals (std::string (" firefunction v2 tool calls" ), common_chat_params_init (tmpl, inputs_tools). format );
432444
433445 test_template (tmpl, end_tokens, text_message, tools,
434446 " Hello, world!" , /* skip_grammar_test= */ true );
@@ -440,7 +452,7 @@ static void test_template_output_parsers() {
440452 " models/templates/deepseek-ai-DeepSeek-R1-Distill-Llama-8B.jinja" ), " <s>" , " </s>" );
441453 std::vector<std::string> end_tokens { " <|end▁of▁sentence|>" };
442454
443- assert_equals (std::string (" deepseek r1 tool calls" ), describe (tmpl, tools_params) );
455+ assert_equals (std::string (" deepseek r1 tool calls" ), common_chat_params_init (tmpl, inputs_tools). format );
444456
445457 test_template (tmpl, end_tokens, text_message, tools,
446458 " Hello, world!" , /* skip_grammar_test= */ true );
@@ -452,9 +464,33 @@ static void test_template_output_parsers() {
452464 }
453465}
454466
455- int main () {
456- test_template_output_parsers ();
457-
458- std::cout << " \n [tool-call] All tests passed!" << std::endl;
467+ int main (int argc, char **argv) {
468+ #ifndef _WIN32
469+ if (argc > 1 ) {
470+ common_chat_inputs inputs;
471+ inputs.messages = {{{" role" , " user" }, {" content" , " Hey" }}};
472+ inputs.tools = json::array ({special_function_tool});
473+
474+ std::cout << " | Template | Format |\n " ;
475+ std::cout << " |----------|--------|\n " ;
476+
477+ for (int i = 1 ; i < argc; i++) {
478+ std::string path = argv[i];
479+ if (path.rfind (" .jinja" ) != path.size () - 6 ) {
480+ std::cerr << " Skipping non-jinja file: " << path << std::endl;
481+ continue ;
482+ }
483+ common_chat_template tmpl (read_file (path), " " , " " );
484+ auto parts = string_split (path, " /" );
485+ auto name = parts[parts.size () - 1 ];
486+ std::cout << " | " << name << " | " << common_chat_params_init (tmpl, inputs).format << " |\n " ;
487+ }
488+ }
489+ else
490+ #endif
491+ {
492+ test_template_output_parsers ();
493+ std::cout << " \n [chat] All tests passed!" << std::endl;
494+ }
459495 return 0 ;
460496}
0 commit comments