@@ -49,25 +49,25 @@ static json normalize_tools(const json & tools) {
4949
5050std::string common_tool_call_style_name (common_tool_call_style style) {
5151 switch (style) {
52- case common_tool_call_style::None :
52+ case COMMON_TOOL_CALL_STYLE_NONE :
5353 return " None" ;
54- case common_tool_call_style::Generic :
54+ case COMMON_TOOL_CALL_STYLE_GENERIC :
5555 return " Generic" ;
56- case common_tool_call_style::Llama31 :
56+ case COMMON_TOOL_CALL_STYLE_LLAMA_3_1 :
5757 return " Llama-3.1" ;
58- case common_tool_call_style::Llama32 :
58+ case COMMON_TOOL_CALL_STYLE_LLAMA_3_2 :
5959 return " Llama-3.2" ;
60- case common_tool_call_style::FunctionaryV3Llama3 :
60+ case COMMON_TOOL_CALL_STYLE_FUNCTIONARY_V3_LLAMA_3 :
6161 return " FunctionaryV3Llama3" ;
62- case common_tool_call_style::FunctionaryV3Llama31 :
62+ case COMMON_TOOL_CALL_STYLE_FUNCTIONARY_V3_LLAMA_3_1 :
6363 return " FunctionaryV3Llama3.1" ;
64- case common_tool_call_style::Hermes2Pro :
64+ case COMMON_TOOL_CALL_STYLE_HERMES_2_PRO :
6565 return " Hermes2Pro" ;
66- case common_tool_call_style::CommandRPlus :
66+ case COMMON_TOOL_CALL_STYLE_COMMAND_R_PLUS :
6767 return " CommandRPlus" ;
68- case common_tool_call_style::MistralNemo :
68+ case COMMON_TOOL_CALL_STYLE_MISTRAL_NEMO :
6969 return " MistralNemo" ;
70- case common_tool_call_style::FirefunctionV2 :
70+ case COMMON_TOOL_CALL_STYLE_FIRE_FUNCTION_V2 :
7171 return " FirefunctionV2" ;
7272 default :
7373 return " Unknown" ;
@@ -78,26 +78,26 @@ common_tool_call_style common_tool_call_style_detect(const common_chat_template
7878 const auto & src = chat_template.source ();
7979
8080 if (src.find (" <tool_call>" ) != std::string::npos) {
81- return Hermes2Pro ;
81+ return COMMON_TOOL_CALL_STYLE_HERMES_2_PRO ;
8282 } else if (src.find (" >>>all" ) != std::string::npos) {
83- return FunctionaryV3Llama3 ;
83+ return COMMON_TOOL_CALL_STYLE_FUNCTIONARY_V3_LLAMA_3 ;
8484 } else if (src.find (" <|start_header_id|>" ) != std::string::npos
8585 && src.find (" <function=" ) != std::string::npos) {
86- return FunctionaryV3Llama31 ;
86+ return COMMON_TOOL_CALL_STYLE_FUNCTIONARY_V3_LLAMA_3_1 ;
8787 } else if (src.find (" <|start_header_id|>ipython<|end_header_id|>" ) != std::string::npos) {
8888 if (src.find (" <|python_tag|>" ) != std::string::npos) {
89- return Llama31 ;
89+ return COMMON_TOOL_CALL_STYLE_LLAMA_3_1 ;
9090 } else {
91- return Llama32 ;
91+ return COMMON_TOOL_CALL_STYLE_LLAMA_3_2 ;
9292 }
9393 } else if (src.find (" <|START_OF_TURN_TOKEN|><|CHATBOT_TOKEN|>" ) != std::string::npos) {
94- return CommandRPlus ;
94+ return COMMON_TOOL_CALL_STYLE_COMMAND_R_PLUS ;
9595 } else if (src.find (" [TOOL_CALLS]" ) != std::string::npos) {
96- return MistralNemo ;
96+ return COMMON_TOOL_CALL_STYLE_MISTRAL_NEMO ;
9797 } else if (src.find (" functools[" ) != std::string::npos) {
98- return FirefunctionV2 ;
98+ return COMMON_TOOL_CALL_STYLE_FIRE_FUNCTION_V2 ;
9999 } else {
100- return Generic ;
100+ return COMMON_TOOL_CALL_STYLE_GENERIC ;
101101 }
102102}
103103
@@ -356,23 +356,23 @@ static common_tool_calls parse_firefunction_v2_tool_calls(const std::string& inp
356356common_tool_calls parse_tool_calls (common_tool_call_style style, const json & tools, const std::string& input) {
357357 fprintf (stderr, " # parse_tool_calls(%s):\n\n %s\n\n " , common_tool_call_style_name (style).c_str (), input.c_str ());
358358 switch (style) {
359- case common_tool_call_style::None :
359+ case COMMON_TOOL_CALL_STYLE_NONE :
360360 return {input, {}};
361- case common_tool_call_style::Generic :
361+ case COMMON_TOOL_CALL_STYLE_GENERIC :
362362 return parse_generic_tool_calls (input);
363- case common_tool_call_style::Llama31 :
363+ case COMMON_TOOL_CALL_STYLE_LLAMA_3_1 :
364364 return parse_llama_3_tool_calls (tools, input, /* parse_llama_3_tool_calls= */ true );
365- case common_tool_call_style::Llama32 :
365+ case COMMON_TOOL_CALL_STYLE_LLAMA_3_2 :
366366 return parse_llama_3_tool_calls (tools, input, /* parse_llama_3_tool_calls= */ false );
367- case common_tool_call_style::FunctionaryV3Llama3 :
367+ case COMMON_TOOL_CALL_STYLE_FUNCTIONARY_V3_LLAMA_3 :
368368 return parse_functionary_v3_tool_calls (tools, input);
369- case common_tool_call_style::FunctionaryV3Llama31 :
369+ case COMMON_TOOL_CALL_STYLE_FUNCTIONARY_V3_LLAMA_3_1 :
370370 return parse_functionary_v3_llama_3_1_tool_calls (tools, input);
371- case common_tool_call_style::Hermes2Pro :
371+ case COMMON_TOOL_CALL_STYLE_HERMES_2_PRO :
372372 return parse_hermes_tool_calls (input);
373- case common_tool_call_style::MistralNemo :
373+ case COMMON_TOOL_CALL_STYLE_MISTRAL_NEMO :
374374 return parse_mistral_nemo_tool_calls (input);
375- case common_tool_call_style::FirefunctionV2 :
375+ case COMMON_TOOL_CALL_STYLE_FIRE_FUNCTION_V2 :
376376 return parse_firefunction_v2_tool_calls (input);
377377 default :
378378 throw std::runtime_error (" Unsupported tool call style" );
@@ -410,10 +410,10 @@ common_tool_call_handler common_tool_call_handler_init(
410410 auto parallel = parallel_tool_calls.is_null () ? tmpl.supports_parallel_tool_calls () : parallel_tool_calls.get <bool >();
411411
412412 switch (style) {
413- case common_tool_call_style::None :
413+ case COMMON_TOOL_CALL_STYLE_NONE :
414414 handler.prompt = tmpl.apply (messages, tools, /* add_generation_prompt= */ true );
415415 break ;
416- case common_tool_call_style::Generic : {
416+ case COMMON_TOOL_CALL_STYLE_GENERIC : {
417417 auto actual_tools = normalize_tools (tools);
418418 auto tool_call_schemas = json::array ();
419419 for (const auto & tool : actual_tools) {
@@ -493,7 +493,7 @@ common_tool_call_handler common_tool_call_handler_init(
493493 handler.prompt = tmpl.apply (tweaked_messages, actual_tools.empty () ? json () : actual_tools, /* add_generation_prompt= */ true );
494494 break ;
495495 }
496- case common_tool_call_style::MistralNemo : {
496+ case COMMON_TOOL_CALL_STYLE_MISTRAL_NEMO : {
497497 auto actual_tools = normalize_tools (tools);
498498 handler.grammar = build_grammar ([&](const llama_grammar_builder & builder) {
499499 auto schemas = json::array ();
@@ -534,7 +534,7 @@ common_tool_call_handler common_tool_call_handler_init(
534534 handler.prompt = tmpl.apply (messages, actual_tools.empty () ? json () : actual_tools, /* add_generation_prompt= */ true );
535535 break ;
536536 }
537- case common_tool_call_style::FirefunctionV2 : {
537+ case COMMON_TOOL_CALL_STYLE_FIRE_FUNCTION_V2 : {
538538 auto actual_tools = normalize_tools (tools);
539539 handler.grammar = build_grammar ([&](const llama_grammar_builder & builder) {
540540 auto schemas = json::array ();
@@ -568,8 +568,8 @@ common_tool_call_handler common_tool_call_handler_init(
568568 handler.prompt = tmpl.apply (messages, actual_tools.empty () ? json () : actual_tools, /* add_generation_prompt= */ true );
569569 break ;
570570 }
571- case common_tool_call_style::Llama31 :
572- case common_tool_call_style::Llama32 : {
571+ case COMMON_TOOL_CALL_STYLE_LLAMA_3_1 :
572+ case COMMON_TOOL_CALL_STYLE_LLAMA_3_2 : {
573573 auto builtin_tools = json {" wolfram_alpha" , " brave_search" };
574574 for (const auto & tool : tools) {
575575 if (!tool.contains (" type" )) {
@@ -582,13 +582,13 @@ common_tool_call_handler common_tool_call_handler_init(
582582 }
583583 auto actual_tools = normalize_tools (tools);
584584
585- auto uses_python_tag = style == common_tool_call_style::Llama31 ;
585+ auto uses_python_tag = style == common_tool_call_style::COMMON_TOOL_CALL_STYLE_LLAMA_3_1 ;
586586
587587 // Technically we should only trigger on `"\n{\"name\": \"" + name + "\""` for each tool name,
588588 // but Llama-3.2-3B (and 1B) struggles to output valid tool calls so we're "guiding" it strongly as soon
589589 // as it seems to be outputting some JSON.
590590 // TODO: make this conditional on a very small model (e.g. 1B / 3B).
591- auto eagerly_match_any_json = style == common_tool_call_style::Llama32 ;
591+ auto eagerly_match_any_json = style == common_tool_call_style::COMMON_TOOL_CALL_STYLE_LLAMA_3_2 ;
592592
593593 handler.grammar = build_grammar ([&](const llama_grammar_builder & builder) {
594594 std::vector<std::string> tool_rules;
@@ -639,7 +639,7 @@ common_tool_call_handler common_tool_call_handler_init(
639639 });
640640 break ;
641641 }
642- case common_tool_call_style::FunctionaryV3Llama3 : {
642+ case COMMON_TOOL_CALL_STYLE_FUNCTIONARY_V3_LLAMA_3 : {
643643 // >>>all\nlet's call functions>>>fn1\n{"arg1": 1...}\n>>>fn2\n{"arg1": 1...}...
644644 // Using ">>>f1\n", ">>>f2\n"... as trigger words for the grammar
645645 auto actual_tools = normalize_tools (tools);
@@ -670,7 +670,7 @@ common_tool_call_handler common_tool_call_handler_init(
670670 // handler.parser = parse_functionary_3_2_tool_calls;
671671 break ;
672672 }
673- case common_tool_call_style::FunctionaryV3Llama31 : {
673+ case COMMON_TOOL_CALL_STYLE_FUNCTIONARY_V3_LLAMA_3_1 : {
674674 // ./tests/chat/templates/meetkai-functionary-medium-v3.1.jinja
675675 // https://github.com/MeetKai/functionary/blob/main/tests/prompt_test_v3-llama3.1.txt
676676 // TODO: handle tool {type: code_interpreter} as python
@@ -700,7 +700,7 @@ common_tool_call_handler common_tool_call_handler_init(
700700 // handler.parser = parse_functionary_3_2_tool_calls;
701701 break ;
702702 }
703- case common_tool_call_style::Hermes2Pro : {
703+ case COMMON_TOOL_CALL_STYLE_HERMES_2_PRO : {
704704 // NousResearchHermesPro_2
705705 // (content)?(<tool_call>{"name": "foo", "arguments": {"a": 1}}</tool_call>)*
706706 auto actual_tools = normalize_tools (tools);
0 commit comments