@@ -207,7 +207,6 @@ static void foreach_function(const json & tools, const std::function<void(const
207207}
208208
209209static common_chat_data common_chat_init_generic_tool_call (const common_chat_template & tmpl, const struct common_chat_params & params) {
210- fprintf (stderr, " [%s]\n " , __func__);
211210 common_chat_data data;
212211
213212 auto tool_call_schemas = json::array ();
@@ -318,7 +317,6 @@ static common_chat_data common_chat_init_generic_tool_call(const common_chat_tem
318317}
319318
320319static common_chat_data common_chat_init_mistral_nemo_tool_call (const common_chat_template & tmpl, const struct common_chat_params & params) {
321- fprintf (stderr, " [%s]\n " , __func__);
322320 common_chat_data data;
323321 data.grammar_lazy = params.tool_choice != " required" ;
324322 data.grammar = build_grammar ([&](const common_grammar_builder & builder) {
@@ -358,25 +356,71 @@ static common_chat_data common_chat_init_mistral_nemo_tool_call(const common_cha
358356 data.prompt = tmpl.apply (params.messages , params.tools .empty () ? json () : params.tools , params.add_generation_prompt );
359357 data.format = " mistral nemo tool calls" ;
360358 data.parser = std::make_unique<monolithic_chat_parser>([](const std::string & input) {
361- return parse_prefixed_json_tool_call_array (input, " [TOOL_CALLS]" );
362- });
359+ return parse_prefixed_json_tool_call_array (input, " [TOOL_CALLS]" );
360+ });
363361 return data;
364362}
365363
364+ static void expect_tool_parameters (const std::string & name, const json & parameters, const std::vector<std::string> & expected_properties) {
365+ if (!parameters.is_object () || !parameters.contains (" type" ) || parameters[" type" ] != " object" || !parameters.contains (" properties" ) || !parameters.contains (" required" )) {
366+ throw std::runtime_error (" Parameters of tool " + name + " must be an object w/ required properties" );
367+ }
368+ const auto & parameters_properties = parameters.at (" properties" );
369+ const auto & parameters_required = parameters.at (" required" );
370+ for (const auto & prop : expected_properties) {
371+ if (!parameters_properties.contains (prop)) {
372+ throw std::runtime_error (" Parameters of tool " + name + " is missing property: " + prop);
373+ }
374+ if (std::find (parameters_required.begin (), parameters_required.end (), json (prop)) == parameters_required.end ()) {
375+ throw std::runtime_error (" Parameters of tool " + name + " must have property marked as required: " + prop);
376+ }
377+ }
378+ if (parameters_properties.size () != expected_properties.size ()) {
379+ throw std::runtime_error (" Parameters of tool " + name + " must only have these properties:" + string_join (expected_properties, " , " ));
380+ }
381+ }
382+
366383static common_chat_data common_chat_init_llama_3_1_python_tag_tool_calls (const common_chat_template & tmpl, const struct common_chat_params & params) {
367- fprintf (stderr, " [%s]\n " , __func__);
368- // TODO: get from request body.
369- auto builtin_tools = json {" wolfram_alpha" , " brave_search" };
384+ auto builtin_tools = json::array ();
370385 common_chat_data data;
371-
372386 data.grammar_lazy = params.tool_choice != " required" ;
373387 data.grammar = build_grammar ([&](const common_grammar_builder & builder) {
374388 std::vector<std::string> tool_rules;
375389
390+ auto handle_builtin_tool = [&](const std::string & name, const json & parameters) {
391+ if (name == " wolfram_alpha" ) { // https://github.com/meta-llama/llama-stack/blob/main/llama_stack/providers/remote/tool_runtime/wolfram_alpha/wolfram_alpha.py
392+ expect_tool_parameters (name, parameters, {" query" });
393+ } else if (name == " web_search" || name == " brave_search" ) { // https://github.com/meta-llama/llama-stack/blob/main/llama_stack/providers/remote/tool_runtime/brave_search/brave_search.py
394+ expect_tool_parameters (name, parameters, {" query" });
395+ } else if (name == " python" || name == " code_interpreter" ) { // https://github.com/meta-llama/llama-stack/blob/main/llama_stack/providers/inline/tool_runtime/code_interpreter/code_interpreter.py
396+ expect_tool_parameters (name, parameters, {" code" });
397+ } else {
398+ return false ;
399+ }
400+
401+ std::vector<std::string> kvs;
402+ for (const auto & [key, value] : parameters.at (" properties" ).items ()) {
403+ kvs.push_back (" \" " + key + " =\" " + builder.add_schema (name + " -args-" + key, value));
404+ }
405+
406+ tool_rules.push_back (
407+ builder.add_rule (
408+ name + " -call" ,
409+ " \" <|python_tag|>" + name + " .call(\" " + string_join (kvs, " \" , \" " ) + " \" )\" " ));
410+ builtin_tools.push_back (name);
411+
412+ return true ;
413+ };
414+
376415 foreach_function (params.tools , [&](const json & tool) {
377416 const auto & function = tool[" function" ];
378417 std::string name = function[" name" ];
379418 auto parameters = function[" parameters" ];
419+
420+ // https://github.com/meta-llama/llama-stack/tree/main/llama_stack/providers/remote/tool_runtime
421+ if (handle_builtin_tool (name, parameters)) {
422+ return ;
423+ }
380424 builder.resolve_refs (parameters);
381425 tool_rules.push_back (
382426 builder.add_rule (
@@ -388,30 +432,42 @@ static common_chat_data common_chat_init_llama_3_1_python_tag_tool_calls(const c
388432 " \" }\" " ));
389433 data.grammar_triggers .push_back ({" {\" name\" : \" " + name + " \" " , /* .at_start = */ true });
390434 });
391- tool_rules.push_back (builder.add_rule (" builtin-tool-call" , " \" <|python_tag|>\" .*" ));
392- data.grammar_triggers .push_back ({" <|python_tag|>" , /* .at_start = */ false });
435+ if (!builtin_tools.empty ()) {
436+ data.grammar_triggers .push_back ({" <|python_tag|>" , /* .at_start = */ false });
437+ }
393438 builder.add_rule (" root" , string_join (tool_rules, " | " ));
394439 }, grammar_options);
395440 data.additional_stops .push_back (" <|eom_id|>" );
396441 data.prompt = tmpl.apply (params.messages , params.tools .empty () ? json () : params.tools , params.add_generation_prompt , {
397- {" builtin_tools" , builtin_tools},
442+ {" tools_in_user_message" , false },
443+ {" builtin_tools" , builtin_tools.empty () ? json () : builtin_tools},
398444 });
399445 data.format = " llama 3.1 tool calls" ;
400446 data.parser = std::make_unique<monolithic_chat_parser>([params](const std::string & input) -> common_chat_msg {
401447 static std::regex function_regex (" \\ {(?:\" type\" : \" function\" , |[\\ s\\ n\\ r]*)\" name\" : \" ([^\" ]+)\" , \" parameters\" : " );
402448 static std::regex close_regex (" \\ }" );
403- static std::regex builtin_call_regex (" <\\ |python_tag\\ |>([^.(]+)\((.*)\)" );
449+ static std::regex builtin_call_regex (" <\\ |python_tag\\ |>([^.(]+)\\ .call \\ ((.*)\ \ )" );
404450
405451 std::smatch match;
406452 if (std::regex_match (input, match, builtin_call_regex)) {
407- auto arguments = json::parse (" [" + match[2 ].str () + " ]" );
453+ auto name = match[1 ].str ();
454+ auto raw_args = match[2 ].str ();
455+
456+ // TODO: if/when builtin tools start accepting more than 1 argument, use parse_json for real parsing.
457+ auto it_eq = raw_args.find (' =' );
458+ auto arg_name = raw_args.substr (0 , it_eq);
459+ auto arg_value_str = raw_args.substr (it_eq + 1 );
460+ auto arg_value = json::parse (arg_value_str);
461+
408462 return {
409463 /* .role = */ " assistant" ,
410464 /* .content = */ match.prefix ().str (),
411465 /* .tool_calls = */ {
412466 {
413467 /* .name = */ match[1 ],
414- /* .arguments = */ arguments.dump (),
468+ /* .arguments = */ (json {
469+ {arg_name, arg_value},
470+ }).dump (),
415471 /* .id = */ " " ,
416472 },
417473 },
@@ -423,7 +479,6 @@ static common_chat_data common_chat_init_llama_3_1_python_tag_tool_calls(const c
423479}
424480
425481static common_chat_data common_chat_init_llama_3_2_tool_calls (const common_chat_template & tmpl, const struct common_chat_params & params) {
426- fprintf (stderr, " [%s]\n " , __func__);
427482 common_chat_data data;
428483
429484 data.grammar_lazy = params.tool_choice != " required" ;
@@ -462,7 +517,6 @@ static common_chat_data common_chat_init_llama_3_2_tool_calls(const common_chat_
462517}
463518
464519static common_chat_data common_chat_init_deepseek_r1_tool_call (const common_chat_template & tmpl, const struct common_chat_params & params) {
465- fprintf (stderr, " [%s]\n " , __func__);
466520 common_chat_data data;
467521 data.grammar_lazy = params.tool_choice != " required" ;
468522 data.grammar = build_grammar ([&](const common_grammar_builder & builder) {
@@ -490,7 +544,6 @@ static common_chat_data common_chat_init_deepseek_r1_tool_call(const common_chat
490544}
491545
492546static common_chat_data common_chat_init_firefunction_v2_tool_call (const common_chat_template & tmpl, const struct common_chat_params & params) {
493- fprintf (stderr, " [%s]\n " , __func__);
494547 common_chat_data data;
495548 data.grammar_lazy = params.tool_choice != " required" ;
496549 data.grammar = build_grammar ([&](const common_grammar_builder & builder) {
@@ -529,7 +582,6 @@ static common_chat_data common_chat_init_firefunction_v2_tool_call(const common_
529582}
530583
531584static common_chat_data common_chat_init_functionary_v3_2_tool_call (const common_chat_template & tmpl, const struct common_chat_params & params) {
532- fprintf (stderr, " [%s]\n " , __func__);
533585 // >>>all\nlet's call functions>>>fn1\n{"arg1": 1...}\n>>>fn2\n{"arg1": 1...}...
534586 // Using ">>>f1\n", ">>>f2\n"... as trigger words for the grammar
535587 common_chat_data data;
@@ -574,7 +626,6 @@ static common_chat_data common_chat_init_functionary_v3_2_tool_call(const common
574626}
575627
576628static common_chat_data common_chat_init_functionary_v3_llama_3_1_tool_call (const common_chat_template & tmpl, const struct common_chat_params & params) {
577- fprintf (stderr, " [%s]\n " , __func__);
578629 // ./tests/chat/templates/meetkai-functionary-medium-v3.1.jinja
579630 // https://github.com/MeetKai/functionary/blob/main/tests/prompt_test_v3-llama3.1.txt
580631 common_chat_data data;
@@ -651,7 +702,6 @@ static common_chat_data common_chat_init_functionary_v3_llama_3_1_tool_call(cons
651702}
652703
653704static common_chat_data common_chat_init_hermes_2_pro_tool_call (const common_chat_template & tmpl, const struct common_chat_params & params) {
654- fprintf (stderr, " [%s]\n " , __func__);
655705 common_chat_data data;
656706 // (content)?(<tool_call>{"name": "foo", "arguments": {"a": 1}}</tool_call>)*
657707 data.grammar_lazy = params.tool_choice != " required" ;
@@ -705,9 +755,11 @@ static common_chat_data common_chat_init_hermes_2_pro_tool_call(const common_cha
705755 if (!parse_json (it, end, call)) {
706756 throw std::runtime_error (" Failed to parse json tool call" );
707757 }
758+ const auto & arguments = call[" arguments" ];
708759 result.tool_calls .push_back ({
709760 call[" name" ],
710- call[" arguments" ].dump (),
761+ arguments.dump (),
762+ // arguments.is_string() ? arguments.get<std::string>() : arguments.dump(),
711763 /* id= */ " " ,
712764 });
713765 rit = {it, end, middle_pattern};
@@ -734,7 +786,6 @@ static common_chat_data common_chat_init_hermes_2_pro_tool_call(const common_cha
734786}
735787
736788static common_chat_data common_chat_init_without_tools (const common_chat_template & tmpl, const struct common_chat_params & params) {
737- fprintf (stderr, " [%s]\n " , __func__);
738789 common_chat_data data;
739790 data.prompt = tmpl.apply (params.messages , params.tools .empty () ? json () : params.tools , params.add_generation_prompt );
740791 data.format = " content-only" ;
0 commit comments