@@ -298,34 +298,6 @@ const json tools = {special_function_tool, python_tool};
298298// json::array({special_function_call}));
299299// }
300300
301- static void test_format_detection () {
302- common_chat_params no_tools_params;
303- no_tools_params.messages = {{{" role" , " user" }, {" content" , " Hey" }}};
304-
305- common_chat_params tools_params = no_tools_params;
306- tools_params.tools = json::array ();
307-
308- auto describe = [](const std::string & template_file, const common_chat_params & params) {
309- const common_chat_template tmpl (read_file (template_file), " <s>" , " </s>" );
310- auto data = common_chat_init (tmpl, params);
311- return data.format ;
312- };
313-
314- assert_equals (std::string (" functionary v3.1 llama 3.1 tool calls" ), describe (" tests/chat/templates/meetkai-functionary-medium-v3.1.jinja" , tools_params));
315- assert_equals (std::string (" functionary v3.2 tool calls" ), describe (" tests/chat/templates/meetkai-functionary-medium-v3.2.jinja" , tools_params));
316- assert_equals (std::string (" firefunction v2 tool calls" ), describe (" tests/chat/templates/fireworks-ai-llama-3-firefunction-v2.jinja" , tools_params));
317- assert_equals (std::string (" llama 3.1 tool calls" ), describe (" tests/chat/templates/meta-llama-Meta-Llama-3.1-8B-Instruct.jinja" , tools_params));
318- assert_equals (std::string (" llama 3.2 tool calls" ), describe (" tests/chat/templates/meta-llama-Llama-3.2-3B-Instruct.jinja" , tools_params));
319- assert_equals (std::string (" hermes 2 pro tool calls" ), describe (" tests/chat/templates/Qwen-Qwen2.5-7B-Instruct.jinja" , tools_params));
320- assert_equals (std::string (" hermes 2 pro tool calls" ), describe (" tests/chat/templates/NousResearch-Hermes-2-Pro-Llama-3-8B-tool_use.jinja" , tools_params));
321- assert_equals (std::string (" hermes 2 pro tool calls" ), describe (" tests/chat/templates/NousResearch-Hermes-3-Llama-3.1-8B-tool_use.jinja" , tools_params));
322- assert_equals (std::string (" mistral nemo tool calls" ), describe (" tests/chat/templates/mistralai-Mistral-Nemo-Instruct-2407.jinja" , tools_params));
323- assert_equals (std::string (" deepseek r1 tool calls" ), describe (" tests/chat/templates/deepseek-ai-DeepSeek-R1-Distill-Llama-8B.jinja" , tools_params));
324- assert_equals (std::string (" generic tool calls" ), describe (" tests/chat/templates/google-gemma-7b-it.jinja" , tools_params));
325- assert_equals (std::string (" content-only" ), describe (" tests/chat/templates/google-gemma-7b-it.jinja" , no_tools_params));
326- // assert_equals(std::string("command_r_plus tool calls"), describe("tests/chat/templates/CohereForAI-c4ai-command-r-plus-tool_use.jinja_, tools_params));
327- }
328-
329301static std::string get_message_prompt_delta (const common_chat_template & tmpl, const std::vector<std::string> & end_tokens, const json & user_message, const json & delta_message, const json & tools) {
330302 fprintf (stderr, " Template source: %s\n " , tmpl.source ().c_str ());
331303 fprintf (stderr, " Delta message: %s\n " , delta_message.dump (2 ).c_str ());
@@ -363,20 +335,23 @@ static std::string get_message_prompt_delta(const common_chat_template & tmpl, c
363335 return delta;
364336}
365337
366- static void test_template (const common_chat_template & tmpl, const std::vector<std::string> & end_tokens, const json & tool_calling_message , const json & tools, bool skip_grammar_test = false ) {
338+ static void test_template (const common_chat_template & tmpl, const std::vector<std::string> & end_tokens, const json & test_message , const json & tools = {} , bool skip_grammar_test = false ) {
367339 // auto tool_call_style = common_tool_call_style_detect(tmpl);
368340 common_chat_msg expected_msg {
369341 " assistant" ,
370342 " " ,
371343 {},
372344 };
373- for (const auto & tc : tool_calling_message.at (" tool_calls" )) {
374- const auto & arguments = tc.at (" function" ).at (" arguments" );
375- expected_msg.tool_calls .push_back ({
376- tc.at (" function" ).at (" name" ).get <std::string>(),
377- arguments.is_string () ? arguments.get <std::string>() : arguments.dump (),
378- tc.contains (" id" ) ? tc.at (" id" ).get <std::string>() : " " ,
379- });
345+ auto has_tool_calls = test_message.contains (" tool_calls" );
346+ if (has_tool_calls) {
347+ for (const auto & tc : test_message.at (" tool_calls" )) {
348+ const auto & arguments = tc.at (" function" ).at (" arguments" );
349+ expected_msg.tool_calls .push_back ({
350+ tc.at (" function" ).at (" name" ).get <std::string>(),
351+ arguments.is_string () ? arguments.get <std::string>() : arguments.dump (),
352+ tc.contains (" id" ) ? tc.at (" id" ).get <std::string>() : " " ,
353+ });
354+ }
380355 }
381356
382357 // Format the message: apply the template to 1 user message w/ add_generation_prompt=true, then w/ the extra message w/ add_generation_prompt=false,
@@ -386,36 +361,45 @@ static void test_template(const common_chat_template & tmpl, const std::vector<s
386361 {" content" , " Hello, world!" }
387362 };
388363
389- common_chat_params params;
390- params.parallel_tool_calls = true ;
391- params.messages = json {user_message, tool_calling_message};
392- params.tools = tools;
393- auto chat_data = common_chat_init (tmpl, params);
394- fprintf (stderr, " PROMPT: %s\n " , chat_data.prompt .get <std::string>().c_str ());
395- auto grammar = build_grammar (chat_data.grammar );
396- if (!grammar) {
397- throw std::runtime_error (" Failed to build grammar" );
398- }
399-
400- if (!skip_grammar_test) {
401- auto full_delta = get_message_prompt_delta (tmpl, end_tokens, user_message, tool_calling_message, tools);
402- std::cout << " Full delta:\n ```\n " << full_delta << " \n ```" << std::endl;
403-
404- const auto msg = chat_data.parser ->parse_final (full_delta);
405- assert_msg_equals (expected_msg, msg);
406-
407- auto content_less_delta = get_message_prompt_delta (tmpl, end_tokens, user_message, {
408- {" role" , " assistant" },
409- {" content" , {}},
410- {" tool_calls" , tool_calling_message.at (" tool_calls" )}
411- }, tools);
412- if (!match_string (content_less_delta, grammar.get ())) {
413- throw std::runtime_error (" Failed to match content-less delta against grammar:\n\n Content-less delta: " + content_less_delta + " \n\n Grammar: " + chat_data.grammar );
364+ for (const auto & tool_choice : json ({" auto" , " required" })) {
365+ common_chat_params params;
366+ params.tool_choice = tool_choice;
367+ params.parallel_tool_calls = true ;
368+ params.messages = json {user_message, test_message};
369+ params.tools = tools;
370+ auto chat_data = common_chat_init (tmpl, params);
371+ // fprintf(stderr, "PROMPT: %s\n", chat_data.prompt.get<std::string>().c_str());
372+ if (has_tool_calls) {
373+ auto grammar = build_grammar (chat_data.grammar );
374+ if (!grammar) {
375+ throw std::runtime_error (" Failed to build grammar" );
376+ }
377+
378+ if (!skip_grammar_test) {
379+ auto full_delta = get_message_prompt_delta (tmpl, end_tokens, user_message, test_message, tools);
380+ std::cout << " Full delta:\n ```\n " << full_delta << " \n ```" << std::endl;
381+
382+ const auto msg = chat_data.parser ->parse_final (full_delta);
383+ assert_msg_equals (expected_msg, msg);
384+
385+ auto content_less_delta = get_message_prompt_delta (tmpl, end_tokens, user_message, {
386+ {" role" , " assistant" },
387+ {" content" , {}},
388+ {" tool_calls" , test_message.at (" tool_calls" )}
389+ }, tools);
390+ if (!match_string (content_less_delta, grammar.get ())) {
391+ throw std::runtime_error (" Failed to match content-less delta against grammar:\n\n Content-less delta: " + content_less_delta + " \n\n Grammar: " + chat_data.grammar );
392+ }
393+ }
414394 }
415395 }
416396}
417397
418398static void test_grammars () {
399+ auto text_message = json {
400+ {" role" , " assistant" },
401+ {" content" , " Hello, world!" },
402+ };
419403 auto tool_call_message = json {
420404 {" role" , " assistant" },
421405 {" content" , {}},
@@ -444,68 +428,128 @@ static void test_grammars() {
444428 }}}
445429 };
446430
431+
432+ common_chat_params no_tools_params;
433+ no_tools_params.messages = {{{" role" , " user" }, {" content" , " Hey" }}};
434+
435+ common_chat_params tools_params = no_tools_params;
436+ tools_params.tools = json::array ();
437+
438+ auto describe = [](const common_chat_template & tmpl, const common_chat_params & params) {
439+ auto data = common_chat_init (tmpl, params);
440+ return data.format ;
441+ };
442+
443+ {
444+ const common_chat_template tmpl (read_file (" tests/chat/templates/google-gemma-2-2b-it.jinja" ), " <s>" , " </s>" );
445+ std::vector<std::string> end_tokens { " <end_of_turn>" };
446+
447+ assert_equals (std::string (" generic tool calls" ), describe (tmpl, tools_params));
448+ assert_equals (std::string (" content-only" ), describe (tmpl, no_tools_params));
449+ test_template (tmpl, end_tokens, text_message, tools);
450+ test_template (tmpl, end_tokens, tool_call_message_with_id, tools);
451+ }
452+ {
453+ const common_chat_template tmpl (read_file (" tests/chat/templates/microsoft-Phi-3.5-mini-instruct.jinja" ), " <s>" , " </s>" );
454+ std::vector<std::string> end_tokens { " <|end|>" };
455+
456+ assert_equals (std::string (" generic tool calls" ), describe (tmpl, tools_params));
457+ test_template (tmpl, end_tokens, text_message, tools);
458+ test_template (tmpl, end_tokens, tool_call_message_with_id, tools);
459+ }
447460 {
448461 const common_chat_template tmpl (read_file (" tests/chat/templates/mistralai-Mistral-Nemo-Instruct-2407.jinja" ), " <s>" , " </s>" );
449- test_template (tmpl, { " </s>" }, tool_call_message_with_id, tools, /* skip_grammar_test= */ true );
462+ std::vector<std::string> end_tokens { " </s>" };
463+
464+ assert_equals (std::string (" mistral nemo tool calls" ), describe (tmpl, tools_params));
465+ test_template (tmpl, end_tokens, text_message, tools);
466+ test_template (tmpl, end_tokens, tool_call_message_with_id, tools, /* skip_grammar_test= */ true );
450467 }
451468 {
452469 const common_chat_template tmpl (read_file (" tests/chat/templates/Qwen-Qwen2.5-7B-Instruct.jinja" ), " <s>" , " </s>" );
453- // assert_equals(tmpl.requires_object_arguments_, true);
454- test_template (tmpl, { " <|im_end|>" }, tool_call_message, tools);
455- test_template (tmpl, { " <|im_end|>" }, python_tool_call_message, tools);
470+ std::vector<std::string> end_tokens { " <|im_end|>" };
471+
472+ assert_equals (std::string (" hermes 2 pro tool calls" ), describe (tmpl, tools_params));
473+ test_template (tmpl, end_tokens, text_message, tools);
474+ test_template (tmpl, end_tokens, tool_call_message, tools);
475+ test_template (tmpl, end_tokens, python_tool_call_message, tools);
456476 }
457477 {
458478 const common_chat_template tmpl (read_file (" tests/chat/templates/NousResearch-Hermes-2-Pro-Llama-3-8B-tool_use.jinja" ), " <s>" , " </s>" );
459- test_template (tmpl, { " <|im_end|>" }, tool_call_message, tools);
479+ std::vector<std::string> end_tokens { " <|im_end|>" };
480+
481+ assert_equals (std::string (" hermes 2 pro tool calls" ), describe (tmpl, tools_params));
482+ test_template (tmpl, end_tokens, text_message, tools);
483+ test_template (tmpl, end_tokens, tool_call_message, tools);
460484 }
461485 {
462486 const common_chat_template tmpl (read_file (" tests/chat/templates/NousResearch-Hermes-3-Llama-3.1-8B-tool_use.jinja" ), " <s>" , " </s>" );
463- test_template (tmpl, { " <|im_end|>" }, tool_call_message, tools) ;
464- }
465- {
466- const common_chat_template tmpl ( read_file ( " tests/chat/templates/meta-llama-Meta-Llama-3.1-8B-Instruct.jinja " ), " <s> " , " </s> " );
467- test_template (tmpl, { " <|eom_id|> " , " <|eot_id|> " } , tool_call_message, tools);
487+ std::vector<std::string> end_tokens { " <|im_end|>" };
488+
489+ assert_equals ( std::string ( " hermes 2 pro tool calls " ), describe (tmpl, tools_params));
490+ test_template (tmpl, end_tokens, text_message, tools );
491+ test_template (tmpl, end_tokens , tool_call_message, tools);
468492 }
469493 {
470494 const common_chat_template tmpl (read_file (" tests/chat/templates/meta-llama-Meta-Llama-3.1-8B-Instruct.jinja" ), " <s>" , " </s>" );
471- test_template (tmpl, { " <|eom_id|>" , " <|eot_id|>" }, python_tool_call_message, tools);
495+ std::vector<std::string> end_tokens { " <|eom_id|>" , " <|eot_id|>" };
496+
497+ assert_equals (std::string (" llama 3.1 tool calls" ), describe (tmpl, tools_params));
498+ test_template (tmpl, end_tokens, text_message, tools);
499+ test_template (tmpl, end_tokens, tool_call_message, tools);
500+ test_template (tmpl, end_tokens, python_tool_call_message, tools);
472501 }
473502 {
474503 const common_chat_template tmpl (read_file (" tests/chat/templates/meta-llama-Llama-3.2-3B-Instruct.jinja" ), " <s>" , " </s>" );
475- test_template (tmpl, { " <|eom_id|>" , " <|eot_id|>" }, tool_call_message, tools);
504+ std::vector<std::string> end_tokens { " <|eom_id|>" , " <|eot_id|>" };
505+
506+ assert_equals (std::string (" llama 3.2 tool calls" ), describe (tmpl, tools_params));
507+ test_template (tmpl, end_tokens, text_message, tools);
508+ test_template (tmpl, end_tokens, tool_call_message, tools);
476509 }
477510 {
478511 const common_chat_template tmpl (read_file (" tests/chat/templates/meta-llama-Llama-3.3-70B-Instruct.jinja" ), " <s>" , " </s>" );
479- test_template (tmpl, { " <|eom_id|>" , " <|eot_id|>" }, tool_call_message, tools);
512+ std::vector<std::string> end_tokens { " <|eom_id|>" , " <|eot_id|>" };
513+
514+ assert_equals (std::string (" llama 3.1 tool calls" ), describe (tmpl, tools_params));
515+ test_template (tmpl, end_tokens, text_message, tools);
516+ test_template (tmpl, end_tokens, tool_call_message, tools);
480517 }
481518 {
482519 const common_chat_template tmpl (read_file (" tests/chat/templates/meetkai-functionary-medium-v3.1.jinja" ), " <s>" , " </s>" );
483- test_template (tmpl, { " <|eom_id|>" , " <|eot_id|>" }, tool_call_message, tools);
520+ std::vector<std::string> end_tokens { " <|eom_id|>" , " <|eot_id|>" };
521+
522+ assert_equals (std::string (" functionary v3.1 llama 3.1 tool calls" ), describe (tmpl, tools_params));
523+ test_template (tmpl, end_tokens, text_message, tools);
524+ test_template (tmpl, end_tokens, tool_call_message, tools);
484525 }
485526 {
486527 const common_chat_template tmpl (read_file (" tests/chat/templates/meetkai-functionary-medium-v3.2.jinja" ), " <s>" , " </s>" );
487- test_template (tmpl, { " <|eom_id|>" , " <|eot_id|>" }, tool_call_message, tools);
528+ std::vector<std::string> end_tokens { " <|eom_id|>" , " <|eot_id|>" };
529+
530+ assert_equals (std::string (" functionary v3.2 tool calls" ), describe (tmpl, tools_params));
531+ test_template (tmpl, end_tokens, text_message, tools);
532+ test_template (tmpl, end_tokens, tool_call_message, tools);
488533 }
489534 {
490535 const common_chat_template tmpl (read_file (" tests/chat/templates/fireworks-ai-llama-3-firefunction-v2.jinja" ), " <s>" , " </s>" );
491- test_template (tmpl, { " <|eot_id|>" }, tool_call_message, tools);
492- }
493- {
494- const common_chat_template tmpl (read_file (" tests/chat/templates/google-gemma-2-2b-it.jinja" ), " <s>" , " </s>" );
495- test_template (tmpl, { " <end_of_turn>" }, tool_call_message_with_id, tools);
496- }
497- {
498- const common_chat_template tmpl (read_file (" tests/chat/templates/microsoft-Phi-3.5-mini-instruct.jinja" ), " <s>" , " </s>" );
499- test_template (tmpl, { " <|end|>" }, tool_call_message_with_id, tools);
536+ std::vector<std::string> end_tokens { " <|eot_id|>" };
537+
538+ assert_equals (std::string (" firefunction v2 tool calls" ), describe (tmpl, tools_params));
539+ test_template (tmpl, end_tokens, text_message, tools);
540+ test_template (tmpl, end_tokens, tool_call_message, tools);
500541 }
501542 {
502543 const common_chat_template tmpl (read_file (" tests/chat/templates/deepseek-ai-DeepSeek-R1-Distill-Llama-8B.jinja" ), " <s>" , " </s>" );
503- test_template (tmpl, { " <|end▁of▁sentence|>" }, tool_call_message, tools);
544+ std::vector<std::string> end_tokens { " <|end▁of▁sentence|>" };
545+
546+ assert_equals (std::string (" deepseek r1 tool calls" ), describe (tmpl, tools_params));
547+ test_template (tmpl, end_tokens, text_message, tools);
548+ test_template (tmpl, end_tokens, tool_call_message, tools);
504549 }
505550}
506551
507552int main () {
508- test_format_detection ();
509553 // test_parsing();
510554 test_grammars ();
511555
0 commit comments