@@ -72,32 +72,17 @@ static std::string dump(const json & j) {
7272 return minja::Value (j).dump (-1 , /* to_json= */ true );
7373}
7474
75- static void assert_msg_equals (const common_chat_msg & result, const std::string & expected_content, const json & expected_tool_calls) {
76- assert_equals (expected_content, result.content );
77- auto tool_calls = json::array ();
78- for (const auto & tc : result.tool_calls ) {
79- auto arguments = tc.arguments ;
80- try {
81- arguments = dump (json::parse (arguments));
82- } catch (const std::exception & e) {
83- // ignore
84- }
85- auto tool_call = json {
86- {" type" , " function" },
87- {" function" , {
88- {" arguments" , arguments},
89- {" name" , tc.name },
90- }},
91- };
92- if (!tc.id .empty ()) {
93- tool_call[" id" ] = tc.id ;
94- }
95- tool_calls.push_back (tool_call);
75+ static void assert_msg_equals (const common_chat_msg & expected, const common_chat_msg & actual) {
76+ assert_equals (expected.role , actual.role );
77+ assert_equals (expected.content , actual.content );
78+ assert_equals (expected.tool_calls .size (), actual.tool_calls .size ());
79+ for (size_t i = 0 ; i < expected.tool_calls .size (); i++) {
80+ const auto & expected_tool_call = expected.tool_calls [i];
81+ const auto & actual_tool_call = actual.tool_calls [i];
82+ assert_equals (expected_tool_call.name , actual_tool_call.name );
83+ assert_equals (dump (json::parse (expected_tool_call.arguments )), dump (json::parse (actual_tool_call.arguments )));
84+ assert_equals (expected_tool_call.id , actual_tool_call.id );
9685 }
97- // Reparse / dump w/ non-ordered JSON variant.
98- auto expected = nlohmann::json::parse (expected_tool_calls.dump ()).dump ();
99- auto actual = nlohmann::json::parse (tool_calls.dump ()).dump ();
100- assert_equals (expected, actual);
10186}
10287
10388const auto special_function_tool = json::parse(R"( {
@@ -373,7 +358,19 @@ static std::string get_message_prompt_delta(const common_chat_template & tmpl, c
373358
374359static void test_template (const common_chat_template & tmpl, const std::vector<std::string> & end_tokens, const json & tool_calling_message, const json & tools, bool skip_grammar_test = false ) {
375360 // auto tool_call_style = common_tool_call_style_detect(tmpl);
376- auto & tool_calls = tool_calling_message.at (" tool_calls" );
361+ common_chat_msg expected_msg {
362+ " assistant" ,
363+ " " ,
364+ {},
365+ };
366+ for (const auto & tc : tool_calling_message.at (" tool_calls" )) {
367+ const auto & arguments = tc.at (" function" ).at (" arguments" );
368+ expected_msg.tool_calls .push_back ({
369+ tc.at (" function" ).at (" name" ).get <std::string>(),
370+ arguments.is_string () ? arguments.get <std::string>() : arguments.dump (),
371+ tc.contains (" id" ) ? tc.at (" id" ).get <std::string>() : " " ,
372+ });
373+ }
377374
378375 // Format the message: apply the template to 1 user message w/ add_generation_prompt=true, then w/ the extra message w/ add_generation_prompt=false,
379376 // get the diff and try and parse it w/ the grammar.
@@ -398,12 +395,12 @@ static void test_template(const common_chat_template & tmpl, const std::vector<s
398395 std::cout << " Full delta:\n ```\n " << full_delta << " \n ```" << std::endl;
399396
400397 const auto msg = chat_data.parser ->parse_final (full_delta);
401- assert_msg_equals (msg, " " , tool_calls );
398+ assert_msg_equals (expected_msg, msg );
402399
403400 auto content_less_delta = get_message_prompt_delta (tmpl, end_tokens, user_message, {
404401 {" role" , " assistant" },
405402 {" content" , {}},
406- {" tool_calls" , tool_calls}
403+ {" tool_calls" , tool_calling_message. at ( " tool_calls" ) }
407404 }, tools);
408405 if (!match_string (content_less_delta, grammar.get ())) {
409406 throw std::runtime_error (" Failed to match content-less delta against grammar:\n\n Content-less delta: " + content_less_delta + " \n\n Grammar: " + chat_data.grammar );
@@ -433,7 +430,9 @@ static void test_grammars() {
433430 {" type" , " function" },
434431 {" function" , {
435432 {" name" , " python" },
436- {" arguments" , " print('hey')" }
433+ {" arguments" , {
434+ {" code" , " print('hey')" },
435+ }},
437436 }},
438437 }}}
439438 };
@@ -442,12 +441,12 @@ static void test_grammars() {
442441 const common_chat_template tmpl (read_file (" tests/chat/templates/mistralai-Mistral-Nemo-Instruct-2407.jinja" ), " <s>" , " </s>" );
443442 test_template (tmpl, { " </s>" }, tool_call_message_with_id, tools, /* skip_grammar_test= */ true );
444443 }
445- // {
446- // const common_chat_template tmpl(read_file("tests/chat/templates/Qwen-Qwen2.5-7B-Instruct.jinja"), "<s>", "</s>");
447- // // assert_equals(tmpl.requires_object_arguments_, true);
448- // test_template(tmpl, { "<|im_end|>" }, tool_call_message, tools);
449- // test_template(tmpl, { "<|im_end|>" }, python_tool_call_message, tools);
450- // }
444+ {
445+ const common_chat_template tmpl (read_file (" tests/chat/templates/Qwen-Qwen2.5-7B-Instruct.jinja" ), " <s>" , " </s>" );
446+ // assert_equals(tmpl.requires_object_arguments_, true);
447+ test_template (tmpl, { " <|im_end|>" }, tool_call_message, tools);
448+ test_template (tmpl, { " <|im_end|>" }, python_tool_call_message, tools);
449+ }
451450 {
452451 const common_chat_template tmpl (read_file (" tests/chat/templates/NousResearch-Hermes-2-Pro-Llama-3-8B-tool_use.jinja" ), " <s>" , " </s>" );
453452 test_template (tmpl, { " <|im_end|>" }, tool_call_message, tools);
@@ -456,22 +455,22 @@ static void test_grammars() {
456455 const common_chat_template tmpl (read_file (" tests/chat/templates/NousResearch-Hermes-3-Llama-3.1-8B-tool_use.jinja" ), " <s>" , " </s>" );
457456 test_template (tmpl, { " <|im_end|>" }, tool_call_message, tools);
458457 }
459- // {
460- // const common_chat_template tmpl(read_file("tests/chat/templates/meta-llama-Meta-Llama-3.1-8B-Instruct.jinja"), "<s>", "</s>");
461- // test_template(tmpl, { "<|eom_id|>", "<|eot_id|>" }, tool_call_message, tools);
462- // }
463- // {
464- // const common_chat_template tmpl(read_file("tests/chat/templates/meta-llama-Meta-Llama-3.1-8B-Instruct.jinja"), "<s>", "</s>");
465- // test_template(tmpl, { "<|eom_id|>", "<|eot_id|>" }, python_tool_call_message, tools);
466- // }
458+ {
459+ const common_chat_template tmpl (read_file (" tests/chat/templates/meta-llama-Meta-Llama-3.1-8B-Instruct.jinja" ), " <s>" , " </s>" );
460+ test_template (tmpl, { " <|eom_id|>" , " <|eot_id|>" }, tool_call_message, tools);
461+ }
462+ {
463+ const common_chat_template tmpl (read_file (" tests/chat/templates/meta-llama-Meta-Llama-3.1-8B-Instruct.jinja" ), " <s>" , " </s>" );
464+ test_template (tmpl, { " <|eom_id|>" , " <|eot_id|>" }, python_tool_call_message, tools);
465+ }
467466 {
468467 const common_chat_template tmpl (read_file (" tests/chat/templates/meta-llama-Llama-3.2-3B-Instruct.jinja" ), " <s>" , " </s>" );
469468 test_template (tmpl, { " <|eom_id|>" , " <|eot_id|>" }, tool_call_message, tools);
470469 }
471- // {
472- // const common_chat_template tmpl(read_file("tests/chat/templates/meta-llama-Llama-3.3-70B-Instruct.jinja"), "<s>", "</s>");
473- // test_template(tmpl, { "<|eom_id|>", "<|eot_id|>" }, tool_call_message, tools);
474- // }
470+ {
471+ const common_chat_template tmpl (read_file (" tests/chat/templates/meta-llama-Llama-3.3-70B-Instruct.jinja" ), " <s>" , " </s>" );
472+ test_template (tmpl, { " <|eom_id|>" , " <|eot_id|>" }, tool_call_message, tools);
473+ }
475474 {
476475 const common_chat_template tmpl (read_file (" tests/chat/templates/meetkai-functionary-medium-v3.1.jinja" ), " <s>" , " </s>" );
477476 test_template (tmpl, { " <|eom_id|>" , " <|eot_id|>" }, tool_call_message, tools);
0 commit comments