@@ -47,34 +47,34 @@ static json normalize_tools(const json & tools) {
4747 return results;
4848}
4949
50- std::string llama_tool_call_style_name (llama_tool_call_style style) {
50+ std::string common_tool_call_style_name (common_tool_call_style style) {
5151 switch (style) {
52- case llama_tool_call_style ::None:
52+ case common_tool_call_style ::None:
5353 return " None" ;
54- case llama_tool_call_style ::Generic:
54+ case common_tool_call_style ::Generic:
5555 return " Generic" ;
56- case llama_tool_call_style ::Llama31:
56+ case common_tool_call_style ::Llama31:
5757 return " Llama-3.1" ;
58- case llama_tool_call_style ::Llama32:
58+ case common_tool_call_style ::Llama32:
5959 return " Llama-3.2" ;
60- case llama_tool_call_style ::FunctionaryV3Llama3:
60+ case common_tool_call_style ::FunctionaryV3Llama3:
6161 return " FunctionaryV3Llama3" ;
62- case llama_tool_call_style ::FunctionaryV3Llama31:
62+ case common_tool_call_style ::FunctionaryV3Llama31:
6363 return " FunctionaryV3Llama3.1" ;
64- case llama_tool_call_style ::Hermes2Pro:
64+ case common_tool_call_style ::Hermes2Pro:
6565 return " Hermes2Pro" ;
66- case llama_tool_call_style ::CommandRPlus:
66+ case common_tool_call_style ::CommandRPlus:
6767 return " CommandRPlus" ;
68- case llama_tool_call_style ::MistralNemo:
68+ case common_tool_call_style ::MistralNemo:
6969 return " MistralNemo" ;
70- case llama_tool_call_style ::FirefunctionV2:
70+ case common_tool_call_style ::FirefunctionV2:
7171 return " FirefunctionV2" ;
7272 default :
7373 return " Unknown" ;
7474 }
7575}
7676
77- llama_tool_call_style llama_tool_call_style_detect (const common_chat_template & chat_template) {
77+ common_tool_call_style common_tool_call_style_detect (const common_chat_template & chat_template) {
7878 const auto & src = chat_template.source ();
7979
8080 if (src.find (" <tool_call>" ) != std::string::npos) {
@@ -150,10 +150,10 @@ static bool parse_json(std::string::const_iterator & it, const std::string::cons
150150 * Takes a prefix regex that must have 1 group to capture the function name, a closing suffix, and expects json parameters in between.
151151 * Aggregates the prefix, suffix and in-between text into the content.
152152 */
153- static llama_tool_calls parse_json_tool_calls (const json & tools, const std::string& input, const std::regex & function_regex, const std::regex & close_regex, bool check_names) {
153+ static common_tool_calls parse_json_tool_calls (const json & tools, const std::string& input, const std::regex & function_regex, const std::regex & close_regex, bool check_names) {
154154 std::smatch match;
155155
156- llama_tool_calls result;
156+ common_tool_calls result;
157157 auto end = input.end ();
158158 auto it = input.begin ();
159159
@@ -202,7 +202,7 @@ static llama_tool_calls parse_json_tool_calls(const json & tools, const std::str
202202 return result;
203203}
204204
205- static llama_tool_calls parse_hermes_tool_calls (const std::string& input) {
205+ static common_tool_calls parse_hermes_tool_calls (const std::string& input) {
206206 try {
207207 std::regex start_pattern (R"( [\n\s]*<tool_call>)" );
208208 std::regex middle_pattern (R"( [\n\s]*</tool_call>[\n\s]*<tool_call>)" );
@@ -215,7 +215,7 @@ static llama_tool_calls parse_hermes_tool_calls(const std::string& input) {
215215 return {input, {}};
216216 }
217217
218- llama_tool_calls result;
218+ common_tool_calls result;
219219 result.content = rit->prefix ();
220220
221221 auto it = rit->suffix ().first ;
@@ -246,7 +246,7 @@ static llama_tool_calls parse_hermes_tool_calls(const std::string& input) {
246246 }
247247}
248248
249- static llama_tool_calls parse_llama_3_tool_calls (const json & tools, const std::string& input, bool allow_python_tag) {
249+ static common_tool_calls parse_llama_3_tool_calls (const json & tools, const std::string& input, bool allow_python_tag) {
250250 if (allow_python_tag) {
251251 static std::regex python_tag_regex (R"( <\|python_tag\|>([\s\S\n]*)$)" );
252252 std::smatch match;
@@ -268,7 +268,7 @@ static llama_tool_calls parse_llama_3_tool_calls(const json & tools, const std::
268268 return parse_json_tool_calls (tools, input, function_regex, close_regex, /* check_names= */ true );
269269}
270270
271- static llama_tool_calls parse_functionary_v3_llama_3_1_tool_calls (const json & tools, const std::string& input) {
271+ static common_tool_calls parse_functionary_v3_llama_3_1_tool_calls (const json & tools, const std::string& input) {
272272 // This version of Functionary still supports the llama 3.1 tool call format for the python tool.
273273 static std::regex python_tag_regex (R"( <\|python_tag\|>([\s\S\n]*)$)" );
274274 std::smatch match;
@@ -289,15 +289,15 @@ static llama_tool_calls parse_functionary_v3_llama_3_1_tool_calls(const json & t
289289 return parse_json_tool_calls (tools, input, function_regex, close_regex, /* check_names= */ false );
290290}
291291
292- static llama_tool_calls parse_functionary_v3_tool_calls (const json & tools, const std::string& input) {
292+ static common_tool_calls parse_functionary_v3_tool_calls (const json & tools, const std::string& input) {
293293 static std::regex function_regex (R"( (?:>>>)?(\w+)\n)" );
294294 static std::regex close_regex (R"( $|(?=>>>))" );
295295 return parse_json_tool_calls (tools, input, function_regex, close_regex, /* check_names= */ true );
296296}
297297
298- static llama_tool_calls parse_generic_tool_calls (const std::string& input) {
298+ static common_tool_calls parse_generic_tool_calls (const std::string& input) {
299299 json data = json::parse (input);
300- llama_tool_calls result;
300+ common_tool_calls result;
301301 if (data.contains (" tool_calls" )) {
302302 for (const auto & tool_call : data[" tool_calls" ]) {
303303 result.tool_calls .push_back ({
@@ -319,11 +319,11 @@ static llama_tool_calls parse_generic_tool_calls(const std::string& input) {
319319 return result;
320320}
321321
322- static llama_tool_calls parse_prefixed_json_tool_call_array (const std::string& input, const std::string & prefix, size_t rstrip_prefix = 0 ) {
322+ static common_tool_calls parse_prefixed_json_tool_call_array (const std::string& input, const std::string & prefix, size_t rstrip_prefix = 0 ) {
323323 auto content_end = input.find (prefix);
324324 size_t tc_start = std::string::npos;
325325
326- llama_tool_calls result;
326+ common_tool_calls result;
327327 const auto process_tool_calls = [&](const json & tool_calls) {
328328 for (const auto & tool_call : tool_calls) {
329329 const auto & arguments = tool_call[" arguments" ];
@@ -345,34 +345,34 @@ static llama_tool_calls parse_prefixed_json_tool_call_array(const std::string& i
345345 return result;
346346}
347347
348- static llama_tool_calls parse_mistral_nemo_tool_calls (const std::string& input) {
348+ static common_tool_calls parse_mistral_nemo_tool_calls (const std::string& input) {
349349 return parse_prefixed_json_tool_call_array (input, " [TOOL_CALLS]" );
350350}
351351
352- static llama_tool_calls parse_firefunction_v2_tool_calls (const std::string& input) {
352+ static common_tool_calls parse_firefunction_v2_tool_calls (const std::string& input) {
353353 return parse_prefixed_json_tool_call_array (input, " functools[" , /* rstrip_prefix= */ 1 );
354354}
355355
356- llama_tool_calls parse_tool_calls (llama_tool_call_style style, const json & tools, const std::string& input) {
357- fprintf (stderr, " # parse_tool_calls(%s):\n\n %s\n\n " , llama_tool_call_style_name (style).c_str (), input.c_str ());
356+ common_tool_calls parse_tool_calls (common_tool_call_style style, const json & tools, const std::string& input) {
357+ fprintf (stderr, " # parse_tool_calls(%s):\n\n %s\n\n " , common_tool_call_style_name (style).c_str (), input.c_str ());
358358 switch (style) {
359- case llama_tool_call_style ::None:
359+ case common_tool_call_style ::None:
360360 return {input, {}};
361- case llama_tool_call_style ::Generic:
361+ case common_tool_call_style ::Generic:
362362 return parse_generic_tool_calls (input);
363- case llama_tool_call_style ::Llama31:
363+ case common_tool_call_style ::Llama31:
364364 return parse_llama_3_tool_calls (tools, input, /* parse_llama_3_tool_calls= */ true );
365- case llama_tool_call_style ::Llama32:
365+ case common_tool_call_style ::Llama32:
366366 return parse_llama_3_tool_calls (tools, input, /* parse_llama_3_tool_calls= */ false );
367- case llama_tool_call_style ::FunctionaryV3Llama3:
367+ case common_tool_call_style ::FunctionaryV3Llama3:
368368 return parse_functionary_v3_tool_calls (tools, input);
369- case llama_tool_call_style ::FunctionaryV3Llama31:
369+ case common_tool_call_style ::FunctionaryV3Llama31:
370370 return parse_functionary_v3_llama_3_1_tool_calls (tools, input);
371- case llama_tool_call_style ::Hermes2Pro:
371+ case common_tool_call_style ::Hermes2Pro:
372372 return parse_hermes_tool_calls (input);
373- case llama_tool_call_style ::MistralNemo:
373+ case common_tool_call_style ::MistralNemo:
374374 return parse_mistral_nemo_tool_calls (input);
375- case llama_tool_call_style ::FirefunctionV2:
375+ case common_tool_call_style ::FirefunctionV2:
376376 return parse_firefunction_v2_tool_calls (input);
377377 default :
378378 throw std::runtime_error (" Unsupported tool call style" );
@@ -397,23 +397,23 @@ static nlohmann::ordered_json add_system(const nlohmann::ordered_json & messages
397397 return messages_with_system;
398398}
399399
400- llama_tool_call_handler llama_tool_call_handler_init (
401- llama_tool_call_style style,
400+ common_tool_call_handler common_tool_call_handler_init (
401+ common_tool_call_style style,
402402 const common_chat_template & tmpl,
403403 bool allow_content,
404404 const nlohmann::ordered_json & parallel_tool_calls,
405405 const nlohmann::ordered_json & messages,
406406 const nlohmann::ordered_json & tools,
407407 const nlohmann::ordered_json & json_schema)
408408{
409- llama_tool_call_handler handler;
409+ common_tool_call_handler handler;
410410 auto parallel = parallel_tool_calls.is_null () ? tmpl.supports_parallel_tool_calls () : parallel_tool_calls.get <bool >();
411411
412412 switch (style) {
413- case llama_tool_call_style ::None:
413+ case common_tool_call_style ::None:
414414 handler.prompt = tmpl.apply (messages, tools, /* add_generation_prompt= */ true );
415415 break ;
416- case llama_tool_call_style ::Generic: {
416+ case common_tool_call_style ::Generic: {
417417 auto actual_tools = normalize_tools (tools);
418418 auto tool_call_schemas = json::array ();
419419 for (const auto & tool : actual_tools) {
@@ -493,7 +493,7 @@ llama_tool_call_handler llama_tool_call_handler_init(
493493 handler.prompt = tmpl.apply (tweaked_messages, actual_tools.empty () ? json () : actual_tools, /* add_generation_prompt= */ true );
494494 break ;
495495 }
496- case llama_tool_call_style ::MistralNemo: {
496+ case common_tool_call_style ::MistralNemo: {
497497 auto actual_tools = normalize_tools (tools);
498498 handler.grammar = build_grammar ([&](const llama_grammar_builder & builder) {
499499 auto schemas = json::array ();
@@ -534,7 +534,7 @@ llama_tool_call_handler llama_tool_call_handler_init(
534534 handler.prompt = tmpl.apply (messages, actual_tools.empty () ? json () : actual_tools, /* add_generation_prompt= */ true );
535535 break ;
536536 }
537- case llama_tool_call_style ::FirefunctionV2: {
537+ case common_tool_call_style ::FirefunctionV2: {
538538 auto actual_tools = normalize_tools (tools);
539539 handler.grammar = build_grammar ([&](const llama_grammar_builder & builder) {
540540 auto schemas = json::array ();
@@ -568,8 +568,8 @@ llama_tool_call_handler llama_tool_call_handler_init(
568568 handler.prompt = tmpl.apply (messages, actual_tools.empty () ? json () : actual_tools, /* add_generation_prompt= */ true );
569569 break ;
570570 }
571- case llama_tool_call_style ::Llama31:
572- case llama_tool_call_style ::Llama32: {
571+ case common_tool_call_style ::Llama31:
572+ case common_tool_call_style ::Llama32: {
573573 auto builtin_tools = json {" wolfram_alpha" , " brave_search" };
574574 for (const auto & tool : tools) {
575575 if (!tool.contains (" type" )) {
@@ -582,13 +582,13 @@ llama_tool_call_handler llama_tool_call_handler_init(
582582 }
583583 auto actual_tools = normalize_tools (tools);
584584
585- auto uses_python_tag = style == llama_tool_call_style ::Llama31;
585+ auto uses_python_tag = style == common_tool_call_style ::Llama31;
586586
587587 // Technically we should only trigger on `"\n{\"name\": \"" + name + "\""` for each tool name,
588588 // but Llama-3.2-3B (and 1B) struggles to output valid tool calls so we're "guiding" it strongly as soon
589589 // as it seems to be outputting some JSON.
590590 // TODO: make this conditional on a very small model (e.g. 1B / 3B).
591- auto eagerly_match_any_json = style == llama_tool_call_style ::Llama32;
591+ auto eagerly_match_any_json = style == common_tool_call_style ::Llama32;
592592
593593 handler.grammar = build_grammar ([&](const llama_grammar_builder & builder) {
594594 std::vector<std::string> tool_rules;
@@ -639,7 +639,7 @@ llama_tool_call_handler llama_tool_call_handler_init(
639639 });
640640 break ;
641641 }
642- case llama_tool_call_style ::FunctionaryV3Llama3: {
642+ case common_tool_call_style ::FunctionaryV3Llama3: {
643643 // >>>all\nlet's call functions>>>fn1\n{"arg1": 1...}\n>>>fn2\n{"arg1": 1...}...
644644 // Using ">>>f1\n", ">>>f2\n"... as trigger words for the grammar
645645 auto actual_tools = normalize_tools (tools);
@@ -670,7 +670,7 @@ llama_tool_call_handler llama_tool_call_handler_init(
670670 // handler.parser = parse_functionary_3_2_tool_calls;
671671 break ;
672672 }
673- case llama_tool_call_style ::FunctionaryV3Llama31: {
673+ case common_tool_call_style ::FunctionaryV3Llama31: {
674674 // ./tests/chat/templates/meetkai-functionary-medium-v3.1.jinja
675675 // https://github.com/MeetKai/functionary/blob/main/tests/prompt_test_v3-llama3.1.txt
676676 // TODO: handle tool {type: code_interpreter} as python
@@ -700,7 +700,7 @@ llama_tool_call_handler llama_tool_call_handler_init(
700700 // handler.parser = parse_functionary_3_2_tool_calls;
701701 break ;
702702 }
703- case llama_tool_call_style ::Hermes2Pro: {
703+ case common_tool_call_style ::Hermes2Pro: {
704704 // NousResearchHermesPro_2
705705 // (content)?(<tool_call>{"name": "foo", "arguments": {"a": 1}}</tool_call>)*
706706 auto actual_tools = normalize_tools (tools);
0 commit comments