@@ -76,7 +76,7 @@ static common_chat_msg parse_json_tool_calls(const json & tools, const std::stri
7676 if (type == " function" ) {
7777 tool_names.push_back (tool[" function" ][" name" ]);
7878 } else if (type == " code_interpreter" ) {
79- tool_names.push_back (" ipython " );
79+ tool_names.push_back (" python " );
8080 }
8181 }
8282 }
@@ -171,6 +171,10 @@ class text_chat_parser : public common_chat_parser {
171171 /* .tool_calls = */ {},
172172 };
173173 }
174+
175+ std::unique_ptr<common_chat_parser> clone () const override {
176+ return std::make_unique<text_chat_parser>();
177+ }
174178};
175179
176180class monolithic_chat_parser : public common_chat_parser {
@@ -192,13 +196,48 @@ class monolithic_chat_parser : public common_chat_parser {
192196 input_buffer_.clear ();
193197 return out;
194198 }
199+
200+ std::unique_ptr<common_chat_parser> clone () const override {
201+ return std::make_unique<monolithic_chat_parser>(parse_final_);
202+ }
195203};
196204
197- static common_chat_data build_generic_tool_call_handler (const common_chat_template & tmpl, const struct common_chat_params & params) {
205+ const auto python_tool = json::parse(R"( {
206+ "type": "function",
207+ "function": {
208+ "name": "python",
209+ "description": "an ipython interpreter",
210+ "parameters": {
211+ "type": "object",
212+ "properties": {
213+ "code": {
214+ "type": "string",
215+ "description": "Python code to execute."
216+ }
217+ },
218+ "required": ["code"]
219+ }
220+ }
221+ })" );
222+
223+ static void foreach_normalized_tool (const json & tools, const std::function<void (const json &)> & fn) {
224+ for (const auto & tool : tools) {
225+ if (!tool.contains (" type" )) {
226+ continue ;
227+ }
228+ if (tool[" type" ] == " code_interpreter" ) {
229+ fn (python_tool);
230+ } else {
231+ fn (tool);
232+ }
233+ }
234+ }
235+
236+ static common_chat_data common_chat_init_generic_tool_call (const common_chat_template & tmpl, const struct common_chat_params & params) {
198237 common_chat_data data;
199238
200239 auto tool_call_schemas = json::array ();
201- for (const auto & tool : params. tools ) {
240+ foreach_normalized_tool (params. tools , [&] (const json & tool) {
202241 const auto & function = tool[" function" ];
203242 auto tool_schema = json {
204243 {" type" , " object" },
@@ -222,7 +261,7 @@ static common_chat_data build_generic_tool_call_handler(const common_chat_templa
222261 tool_schema[" required" ].push_back (" id" );
223262 }
224263 tool_call_schemas.emplace_back (tool_schema);
225- }
264+ });
226265 const auto tool_call =
227266 params.parallel_tool_calls
228267 ? json {
@@ -276,7 +315,7 @@ static common_chat_data build_generic_tool_call_handler(const common_chat_templa
276315 " Respond in JSON format, either with a request to call tools or with a response to the user's request. Here is the schema for all responses:\n\n ```json\n " + schema.dump (2 ) + " \n ```" );
277316
278317 data.prompt = tmpl.apply (tweaked_messages, params.tools .empty () ? json () : params.tools , /* add_generation_prompt= */ true );
279- data.handler = std::make_unique<monolithic_chat_parser>([&](const std::string & input) {
318+ data.parser = std::make_unique<monolithic_chat_parser>([&](const std::string & input) {
280319 json data = json::parse (input);
281320 common_chat_msg result;
282321 result.role = " assistant" ;
@@ -303,13 +342,11 @@ static common_chat_data build_generic_tool_call_handler(const common_chat_templa
303342 return data;
304343}
305344
306- static common_chat_data build_mistral_nemo_tool_call_handler (const common_chat_template & tmpl, const struct common_chat_params & params) {
345+ static common_chat_data common_chat_init_mistral_nemo_tool_call (const common_chat_template & tmpl, const struct common_chat_params & params) {
307346 common_chat_data data;
308- auto builtin_tools = json {" wolfram_alpha" , " brave_search" };
309-
310347 data.grammar = build_grammar ([&](const common_grammar_builder & builder) {
311348 auto schemas = json::array ();
312- for (const auto & tool : params. tools ) {
349+ foreach_normalized_tool (params. tools , [&] (const json & tool) {
313350 const auto & function = tool[" function" ];
314351 schemas.push_back ({
315352 {" type" , " object" },
@@ -329,7 +366,7 @@ static common_chat_data build_mistral_nemo_tool_call_handler(const common_chat_t
329366 }},
330367 {" required" , json::array ({" name" , " arguments" , " id" })},
331368 });
332- }
369+ });
333370 auto schema = json {
334371 {" type" , " array" },
335372 {" items" , schemas.size () == 1 ? schemas[0 ] : json {{" anyOf" , schemas}}},
@@ -344,24 +381,14 @@ static common_chat_data build_mistral_nemo_tool_call_handler(const common_chat_t
344381 data.grammar_triggers .push_back ({" [TOOL_CALLS]" , /* .at_start = */ true });
345382 }
346383 data.prompt = tmpl.apply (params.messages , params.tools .empty () ? json () : params.tools , /* add_generation_prompt= */ true );
347- data.handler = std::make_unique<monolithic_chat_parser>([](const std::string & input) -> common_chat_msg {
384+ data.parser = std::make_unique<monolithic_chat_parser>([](const std::string & input) -> common_chat_msg {
348385 return parse_prefixed_json_tool_call_array (input, " [TOOL_CALLS]" );
349386 });
350387 return data;
351388}
352389
353- static common_chat_data build_llama_3_tool_calls_handler (const common_chat_template & tmpl, const struct common_chat_params & params, bool uses_python_tag, bool eagerly_match_any_json) {
390+ static common_chat_data common_chat_init_llama_3_tool_calls (const common_chat_template & tmpl, const struct common_chat_params & params, bool uses_python_tag, bool eagerly_match_any_json) {
354391 auto builtin_tools = json {" wolfram_alpha" , " brave_search" };
355- for (const auto & tool : params.tools ) {
356- if (!tool.contains (" type" )) {
357- continue ;
358- }
359- if (tool[" type" ] == " code_interpreter" ) {
360- builtin_tools.push_back (" code_interpreter" );
361- break ;
362- }
363- }
364-
365392 common_chat_data data;
366393
367394 data.grammar = build_grammar ([&](const common_grammar_builder & builder) {
@@ -375,6 +402,7 @@ static common_chat_data build_llama_3_tool_calls_handler(const common_chat_templ
375402 }
376403
377404 if (tool[" type" ] == " code_interpreter" ) {
405+ builtin_tools.push_back (" code_interpreter" );
378406 has_python = true ;
379407 } else if (tool[" type" ] == " function" && tool.contains (" function" )) {
380408 const auto & function = tool[" function" ];
@@ -422,8 +450,10 @@ static common_chat_data build_llama_3_tool_calls_handler(const common_chat_templ
422450 builder.add_rule (" root" , string_join (tool_rules, " | " ));
423451 }, grammar_options);
424452 data.additional_stops .push_back (" <|eom_id|>" );
425- data.prompt = tmpl.apply (params.messages , params.tools .empty () ? json () : params.tools , /* add_generation_prompt= */ true );
426- data.handler = std::make_unique<monolithic_chat_parser>([params, uses_python_tag](const std::string & input) -> common_chat_msg {
453+ data.prompt = tmpl.apply (params.messages , params.tools .empty () ? json () : params.tools , /* add_generation_prompt= */ true , {
454+ {" builtin_tools" , builtin_tools},
455+ });
456+ data.parser = std::make_unique<monolithic_chat_parser>([params, uses_python_tag](const std::string & input) -> common_chat_msg {
427457 if (uses_python_tag) {
428458 static std::regex python_tag_regex (R"( <\|python_tag\|>([\s\S\n]*)$)" );
429459 std::smatch match;
@@ -448,11 +478,11 @@ static common_chat_data build_llama_3_tool_calls_handler(const common_chat_templ
448478 return data;
449479}
450480
451- static common_chat_data build_firefunction_v2_tool_call_handler (const common_chat_template & tmpl, const struct common_chat_params & params) {
481+ static common_chat_data common_chat_init_firefunction_v2_tool_call (const common_chat_template & tmpl, const struct common_chat_params & params) {
452482 common_chat_data data;
453483 data.grammar = build_grammar ([&](const common_grammar_builder & builder) {
454484 auto schemas = json::array ();
455- for (const auto & tool : params. tools ) {
485+ foreach_normalized_tool (params. tools , [&] (const json & tool) {
456486 const auto & function = tool[" function" ];
457487 schemas.push_back ({
458488 {" type" , " object" },
@@ -465,7 +495,7 @@ static common_chat_data build_firefunction_v2_tool_call_handler(const common_cha
465495 }},
466496 {" required" , json::array ({" name" , " arguments" , " id" })},
467497 });
468- }
498+ });
469499 auto schema = json {
470500 {" type" , " array" },
471501 {" items" , schemas.size () == 1 ? schemas[0 ] : json {{" anyOf" , schemas}}},
@@ -480,13 +510,13 @@ static common_chat_data build_firefunction_v2_tool_call_handler(const common_cha
480510 data.grammar_triggers .push_back ({" functools[" , /* .at_start = */ false });
481511 }
482512 data.prompt = tmpl.apply (params.messages , params.tools .empty () ? json () : params.tools , /* add_generation_prompt= */ true );
483- data.handler = std::make_unique<monolithic_chat_parser>([](const std::string & input) -> common_chat_msg {
513+ data.parser = std::make_unique<monolithic_chat_parser>([](const std::string & input) -> common_chat_msg {
484514 return parse_prefixed_json_tool_call_array (input, " functools[" , /* rstrip_prefix= */ 1 );
485515 });
486516 return data;
487517}
488518
489- static common_chat_data build_functionary_v3_llama_3_tool_call_handler (const common_chat_template & tmpl, const struct common_chat_params & params) {
519+ static common_chat_data common_chat_init_functionary_v3_llama_3_tool_call (const common_chat_template & tmpl, const struct common_chat_params & params) {
490520 // >>>all\nlet's call functions>>>fn1\n{"arg1": 1...}\n>>>fn2\n{"arg1": 1...}...
491521 // Using ">>>f1\n", ">>>f2\n"... as trigger words for the grammar
492522 common_chat_data data;
@@ -530,19 +560,20 @@ static common_chat_data build_functionary_v3_llama_3_tool_call_handler(const com
530560 }, grammar_options);
531561
532562 data.prompt = tmpl.apply (params.messages , params.tools .empty () ? json () : params.tools , /* add_generation_prompt= */ true );
533- data.handler = std::make_unique<monolithic_chat_parser>([params](const std::string & input) -> common_chat_msg {
563+ data.parser = std::make_unique<monolithic_chat_parser>([params](const std::string & input) -> common_chat_msg {
534564 static std::regex function_regex (R"( (?:>>>)?(\w+)\n)" );
535565 static std::regex close_regex (R"( $|(?=>>>))" );
536566 return parse_json_tool_calls (params.tools , input, function_regex, close_regex, /* check_names= */ true );
537567 });
538568 return data;
539569}
540570
541- static common_chat_data build_functionary_v3_llama_3_1_tool_call_handler (const common_chat_template & tmpl, const struct common_chat_params & params) {
571+ static common_chat_data common_chat_init_functionary_v3_llama_3_1_tool_call (const common_chat_template & tmpl, const struct common_chat_params & params) {
542572 // ./tests/chat/templates/meetkai-functionary-medium-v3.1.jinja
543573 // https://github.com/MeetKai/functionary/blob/main/tests/prompt_test_v3-llama3.1.txt
544574 // TODO: handle tool {type: code_interpreter} as python
545575 common_chat_data data;
576+ json tools = params.tools .is_null () ? params.tools : json::array ();
546577
547578 data.grammar = build_grammar ([&](const common_grammar_builder & builder) {
548579 std::vector<std::string> tool_rules;
@@ -578,7 +609,7 @@ static common_chat_data build_functionary_v3_llama_3_1_tool_call_handler(const c
578609 }, grammar_options);
579610
580611 data.prompt = tmpl.apply (params.messages , params.tools .empty () ? json () : params.tools , /* add_generation_prompt= */ true );
581- data.handler = std::make_unique<monolithic_chat_parser>([params](const std::string & input) -> common_chat_msg {
612+ data.parser = std::make_unique<monolithic_chat_parser>([params](const std::string & input) -> common_chat_msg {
582613 // This version of Functionary still supports the llama 3.1 tool call format for the python tool.
583614 static std::regex python_tag_regex (R"( <\|python_tag\|>([\s\S\n]*)$)" );
584615 std::smatch match;
@@ -602,12 +633,12 @@ static common_chat_data build_functionary_v3_llama_3_1_tool_call_handler(const c
602633 return data;
603634}
604635
605- static common_chat_data build_hermes_2_pro_tool_call_handler (const common_chat_template & tmpl, const struct common_chat_params & params) {
636+ static common_chat_data common_chat_init_hermes_2_pro_tool_call (const common_chat_template & tmpl, const struct common_chat_params & params) {
606637 common_chat_data data;
607638 // (content)?(<tool_call>{"name": "foo", "arguments": {"a": 1}}</tool_call>)*
608639 data.grammar = build_grammar ([&](const common_grammar_builder & builder) {
609640 std::vector<std::string> tool_rules;
610- for (const auto & tool : params. tools ) {
641+ foreach_normalized_tool (params. tools , [&] (const json & tool) {
611642 const auto & function = tool[" function" ];
612643 std::string name = function[" name" ];
613644 auto parameters = function[" parameters" ];
@@ -620,8 +651,7 @@ static common_chat_data build_hermes_2_pro_tool_call_handler(const common_chat_t
620651 }},
621652 {" required" , json::array ({" name" , " arguments" })},
622653 }));
623- }
624-
654+ });
625655 auto tool_call = " \" <tool_call>\" space " + builder.add_rule (" tool_call" , string_join (tool_rules, " | " )) + " \" </tool_call>\" space" ;
626656 builder.add_rule (" root" , params.parallel_tool_calls ? " (" + tool_call + " )+" : tool_call);
627657 if (params.tool_choice != " required" ) {
@@ -630,7 +660,7 @@ static common_chat_data build_hermes_2_pro_tool_call_handler(const common_chat_t
630660 }, grammar_options);
631661
632662 data.prompt = tmpl.apply (params.messages , params.tools .empty () ? json () : params.tools , /* add_generation_prompt= */ true );
633- data.handler = std::make_unique<monolithic_chat_parser>([&](const std::string & input) -> common_chat_msg {
663+ data.parser = std::make_unique<monolithic_chat_parser>([&](const std::string & input) -> common_chat_msg {
634664 try {
635665 std::regex start_pattern (R"( [\n\s]*<tool_call>)" );
636666 std::regex middle_pattern (R"( [\n\s]*</tool_call>[\n\s]*<tool_call>)" );
@@ -677,24 +707,40 @@ static common_chat_data build_hermes_2_pro_tool_call_handler(const common_chat_t
677707 return data;
678708}
679709
710+ static common_chat_data common_chat_init_without_tools (const common_chat_template & tmpl, const struct common_chat_params & params) {
711+ common_chat_data data;
712+ data.prompt = tmpl.apply (params.messages , params.tools .empty () ? json () : params.tools , /* add_generation_prompt= */ true );
713+ data.parser = std::make_unique<text_chat_parser>();
714+ if (!params.json_schema .is_null ()) {
715+ if (!params.grammar .empty ()) {
716+ throw std::runtime_error (" Either \" json_schema\" or \" grammar\" can be specified, but not both" );
717+ }
718+ data.grammar = json_schema_to_grammar (params.json_schema );
719+ } else {
720+ data.grammar = params.grammar .empty ();
721+ }
722+ return data;
723+ }
724+
680725common_chat_data common_chat_init (const common_chat_template & tmpl, const struct common_chat_params & params) {
681726 if (params.tools .is_null ()) {
682- common_chat_data data;
683- data.prompt = tmpl.apply (params.messages , params.tools .empty () ? json () : params.tools , /* add_generation_prompt= */ true );
684- data.handler = std::make_unique<text_chat_parser>();
685- return data;
727+ return common_chat_init_without_tools (tmpl, params);
686728 }
687- const auto & src = tmpl.source ();
688729
730+ if (!params.grammar .empty ()) {
731+ throw std::runtime_error (" Cannot specify grammar with tools" );
732+ }
733+
734+ const auto & src = tmpl.source ();
689735 if (src.find (" <tool_call>" ) != std::string::npos) {
690- return build_hermes_2_pro_tool_call_handler (tmpl, params);
736+ return common_chat_init_hermes_2_pro_tool_call (tmpl, params);
691737 }
692738 if (src.find (" >>>all" ) != std::string::npos) {
693- return build_functionary_v3_llama_3_tool_call_handler (tmpl, params);
739+ return common_chat_init_functionary_v3_llama_3_tool_call (tmpl, params);
694740 }
695741 if (src.find (" <|start_header_id|>" ) != std::string::npos
696742 && src.find (" <function=" ) != std::string::npos) {
697- return build_functionary_v3_llama_3_1_tool_call_handler (tmpl, params);
743+ return common_chat_init_functionary_v3_llama_3_1_tool_call (tmpl, params);
698744 }
699745 if (src.find (" <|start_header_id|>ipython<|end_header_id|>" ) != std::string::npos) {
700746 auto uses_python_tag = src.find (" <|python_tag|>" ) != std::string::npos;
@@ -705,16 +751,16 @@ common_chat_data common_chat_init(const common_chat_template & tmpl, const struc
705751 // TODO: make this conditional on a very small model (e.g. 1B / 3B).
706752 auto eagerly_match_any_json = false ; // style == common_tool_call_style::COMMON_TOOL_CALL_STYLE_LLAMA_3_2;
707753
708- return build_llama_3_tool_calls_handler (tmpl, params, uses_python_tag, eagerly_match_any_json);
754+ return common_chat_init_llama_3_tool_calls (tmpl, params, uses_python_tag, eagerly_match_any_json);
709755 }
710756 // if (src.find("<|START_OF_TURN_TOKEN|><|CHATBOT_TOKEN|>") != std::string::npos) {
711757 // TODO: Command-R-Plus
712758 // }
713759 if (src.find (" [TOOL_CALLS]" ) != std::string::npos) {
714- return build_mistral_nemo_tool_call_handler (tmpl, params);
760+ return common_chat_init_mistral_nemo_tool_call (tmpl, params);
715761 }
716762 if (src.find (" functools[" ) != std::string::npos) {
717- return build_firefunction_v2_tool_call_handler (tmpl, params);
763+ return common_chat_init_firefunction_v2_tool_call (tmpl, params);
718764 }
719- return build_generic_tool_call_handler (tmpl, params);
765+ return common_chat_init_generic_tool_call (tmpl, params);
720766}
0 commit comments