Skip to content

Commit 9e8b43f

Browse files
author
Olivier Chafik
committed
follow enum naming style for tool call styles
1 parent 5268ec8 commit 9e8b43f

File tree

5 files changed

+86
-86
lines changed

5 files changed

+86
-86
lines changed

common/tool-call.cpp

Lines changed: 39 additions & 39 deletions
Original file line numberDiff line numberDiff line change
@@ -49,25 +49,25 @@ static json normalize_tools(const json & tools) {
4949

5050
std::string common_tool_call_style_name(common_tool_call_style style) {
5151
switch (style) {
52-
case common_tool_call_style::None:
52+
case COMMON_TOOL_CALL_STYLE_NONE:
5353
return "None";
54-
case common_tool_call_style::Generic:
54+
case COMMON_TOOL_CALL_STYLE_GENERIC:
5555
return "Generic";
56-
case common_tool_call_style::Llama31:
56+
case COMMON_TOOL_CALL_STYLE_LLAMA_3_1:
5757
return "Llama-3.1";
58-
case common_tool_call_style::Llama32:
58+
case COMMON_TOOL_CALL_STYLE_LLAMA_3_2:
5959
return "Llama-3.2";
60-
case common_tool_call_style::FunctionaryV3Llama3:
60+
case COMMON_TOOL_CALL_STYLE_FUNCTIONARY_V3_LLAMA_3:
6161
return "FunctionaryV3Llama3";
62-
case common_tool_call_style::FunctionaryV3Llama31:
62+
case COMMON_TOOL_CALL_STYLE_FUNCTIONARY_V3_LLAMA_3_1:
6363
return "FunctionaryV3Llama3.1";
64-
case common_tool_call_style::Hermes2Pro:
64+
case COMMON_TOOL_CALL_STYLE_HERMES_2_PRO:
6565
return "Hermes2Pro";
66-
case common_tool_call_style::CommandRPlus:
66+
case COMMON_TOOL_CALL_STYLE_COMMAND_R_PLUS:
6767
return "CommandRPlus";
68-
case common_tool_call_style::MistralNemo:
68+
case COMMON_TOOL_CALL_STYLE_MISTRAL_NEMO:
6969
return "MistralNemo";
70-
case common_tool_call_style::FirefunctionV2:
70+
case COMMON_TOOL_CALL_STYLE_FIRE_FUNCTION_V2:
7171
return "FirefunctionV2";
7272
default:
7373
return "Unknown";
@@ -78,26 +78,26 @@ common_tool_call_style common_tool_call_style_detect(const common_chat_template
7878
const auto & src = chat_template.source();
7979

8080
if (src.find("<tool_call>") != std::string::npos) {
81-
return Hermes2Pro;
81+
return COMMON_TOOL_CALL_STYLE_HERMES_2_PRO;
8282
} else if (src.find(">>>all") != std::string::npos) {
83-
return FunctionaryV3Llama3;
83+
return COMMON_TOOL_CALL_STYLE_FUNCTIONARY_V3_LLAMA_3;
8484
} else if (src.find("<|start_header_id|>") != std::string::npos
8585
&& src.find("<function=") != std::string::npos) {
86-
return FunctionaryV3Llama31;
86+
return COMMON_TOOL_CALL_STYLE_FUNCTIONARY_V3_LLAMA_3_1;
8787
} else if (src.find("<|start_header_id|>ipython<|end_header_id|>") != std::string::npos) {
8888
if (src.find("<|python_tag|>") != std::string::npos) {
89-
return Llama31;
89+
return COMMON_TOOL_CALL_STYLE_LLAMA_3_1;
9090
} else {
91-
return Llama32;
91+
return COMMON_TOOL_CALL_STYLE_LLAMA_3_2;
9292
}
9393
} else if (src.find("<|START_OF_TURN_TOKEN|><|CHATBOT_TOKEN|>") != std::string::npos) {
94-
return CommandRPlus;
94+
return COMMON_TOOL_CALL_STYLE_COMMAND_R_PLUS;
9595
} else if (src.find("[TOOL_CALLS]") != std::string::npos) {
96-
return MistralNemo;
96+
return COMMON_TOOL_CALL_STYLE_MISTRAL_NEMO;
9797
} else if (src.find(" functools[") != std::string::npos) {
98-
return FirefunctionV2;
98+
return COMMON_TOOL_CALL_STYLE_FIRE_FUNCTION_V2;
9999
} else {
100-
return Generic;
100+
return COMMON_TOOL_CALL_STYLE_GENERIC;
101101
}
102102
}
103103

@@ -356,23 +356,23 @@ static common_tool_calls parse_firefunction_v2_tool_calls(const std::string& inp
356356
common_tool_calls parse_tool_calls(common_tool_call_style style, const json & tools, const std::string& input) {
357357
fprintf(stderr, "# parse_tool_calls(%s):\n\n%s\n\n", common_tool_call_style_name(style).c_str(), input.c_str());
358358
switch (style) {
359-
case common_tool_call_style::None:
359+
case COMMON_TOOL_CALL_STYLE_NONE:
360360
return {input, {}};
361-
case common_tool_call_style::Generic:
361+
case COMMON_TOOL_CALL_STYLE_GENERIC:
362362
return parse_generic_tool_calls(input);
363-
case common_tool_call_style::Llama31:
363+
case COMMON_TOOL_CALL_STYLE_LLAMA_3_1:
364364
return parse_llama_3_tool_calls(tools, input, /* parse_llama_3_tool_calls= */ true);
365-
case common_tool_call_style::Llama32:
365+
case COMMON_TOOL_CALL_STYLE_LLAMA_3_2:
366366
return parse_llama_3_tool_calls(tools, input, /* parse_llama_3_tool_calls= */ false);
367-
case common_tool_call_style::FunctionaryV3Llama3:
367+
case COMMON_TOOL_CALL_STYLE_FUNCTIONARY_V3_LLAMA_3:
368368
return parse_functionary_v3_tool_calls(tools, input);
369-
case common_tool_call_style::FunctionaryV3Llama31:
369+
case COMMON_TOOL_CALL_STYLE_FUNCTIONARY_V3_LLAMA_3_1:
370370
return parse_functionary_v3_llama_3_1_tool_calls(tools, input);
371-
case common_tool_call_style::Hermes2Pro:
371+
case COMMON_TOOL_CALL_STYLE_HERMES_2_PRO:
372372
return parse_hermes_tool_calls(input);
373-
case common_tool_call_style::MistralNemo:
373+
case COMMON_TOOL_CALL_STYLE_MISTRAL_NEMO:
374374
return parse_mistral_nemo_tool_calls(input);
375-
case common_tool_call_style::FirefunctionV2:
375+
case COMMON_TOOL_CALL_STYLE_FIRE_FUNCTION_V2:
376376
return parse_firefunction_v2_tool_calls(input);
377377
default:
378378
throw std::runtime_error("Unsupported tool call style");
@@ -410,10 +410,10 @@ common_tool_call_handler common_tool_call_handler_init(
410410
auto parallel = parallel_tool_calls.is_null() ? tmpl.supports_parallel_tool_calls() : parallel_tool_calls.get<bool>();
411411

412412
switch (style) {
413-
case common_tool_call_style::None:
413+
case COMMON_TOOL_CALL_STYLE_NONE:
414414
handler.prompt = tmpl.apply(messages, tools, /* add_generation_prompt= */ true);
415415
break;
416-
case common_tool_call_style::Generic: {
416+
case COMMON_TOOL_CALL_STYLE_GENERIC: {
417417
auto actual_tools = normalize_tools(tools);
418418
auto tool_call_schemas = json::array();
419419
for (const auto & tool : actual_tools) {
@@ -493,7 +493,7 @@ common_tool_call_handler common_tool_call_handler_init(
493493
handler.prompt = tmpl.apply(tweaked_messages, actual_tools.empty() ? json() : actual_tools, /* add_generation_prompt= */ true);
494494
break;
495495
}
496-
case common_tool_call_style::MistralNemo: {
496+
case COMMON_TOOL_CALL_STYLE_MISTRAL_NEMO: {
497497
auto actual_tools = normalize_tools(tools);
498498
handler.grammar = build_grammar([&](const llama_grammar_builder & builder) {
499499
auto schemas = json::array();
@@ -534,7 +534,7 @@ common_tool_call_handler common_tool_call_handler_init(
534534
handler.prompt = tmpl.apply(messages, actual_tools.empty() ? json() : actual_tools, /* add_generation_prompt= */ true);
535535
break;
536536
}
537-
case common_tool_call_style::FirefunctionV2: {
537+
case COMMON_TOOL_CALL_STYLE_FIRE_FUNCTION_V2: {
538538
auto actual_tools = normalize_tools(tools);
539539
handler.grammar = build_grammar([&](const llama_grammar_builder & builder) {
540540
auto schemas = json::array();
@@ -568,8 +568,8 @@ common_tool_call_handler common_tool_call_handler_init(
568568
handler.prompt = tmpl.apply(messages, actual_tools.empty() ? json() : actual_tools, /* add_generation_prompt= */ true);
569569
break;
570570
}
571-
case common_tool_call_style::Llama31:
572-
case common_tool_call_style::Llama32: {
571+
case COMMON_TOOL_CALL_STYLE_LLAMA_3_1:
572+
case COMMON_TOOL_CALL_STYLE_LLAMA_3_2: {
573573
auto builtin_tools = json {"wolfram_alpha", "brave_search"};
574574
for (const auto & tool : tools) {
575575
if (!tool.contains("type")) {
@@ -582,13 +582,13 @@ common_tool_call_handler common_tool_call_handler_init(
582582
}
583583
auto actual_tools = normalize_tools(tools);
584584

585-
auto uses_python_tag = style == common_tool_call_style::Llama31;
585+
auto uses_python_tag = style == common_tool_call_style::COMMON_TOOL_CALL_STYLE_LLAMA_3_1;
586586

587587
// Technically we should only trigger on `"\n{\"name\": \"" + name + "\""` for each tool name,
588588
// but Llama-3.2-3B (and 1B) struggles to output valid tool calls so we're "guiding" it strongly as soon
589589
// as it seems to be outputting some JSON.
590590
// TODO: make this conditional on a very small model (e.g. 1B / 3B).
591-
auto eagerly_match_any_json = style == common_tool_call_style::Llama32;
591+
auto eagerly_match_any_json = style == common_tool_call_style::COMMON_TOOL_CALL_STYLE_LLAMA_3_2;
592592

593593
handler.grammar = build_grammar([&](const llama_grammar_builder & builder) {
594594
std::vector<std::string> tool_rules;
@@ -639,7 +639,7 @@ common_tool_call_handler common_tool_call_handler_init(
639639
});
640640
break;
641641
}
642-
case common_tool_call_style::FunctionaryV3Llama3: {
642+
case COMMON_TOOL_CALL_STYLE_FUNCTIONARY_V3_LLAMA_3: {
643643
// >>>all\nlet's call functions>>>fn1\n{"arg1": 1...}\n>>>fn2\n{"arg1": 1...}...
644644
// Using ">>>f1\n", ">>>f2\n"... as trigger words for the grammar
645645
auto actual_tools = normalize_tools(tools);
@@ -670,7 +670,7 @@ common_tool_call_handler common_tool_call_handler_init(
670670
// handler.parser = parse_functionary_3_2_tool_calls;
671671
break;
672672
}
673-
case common_tool_call_style::FunctionaryV3Llama31: {
673+
case COMMON_TOOL_CALL_STYLE_FUNCTIONARY_V3_LLAMA_3_1: {
674674
// ./tests/chat/templates/meetkai-functionary-medium-v3.1.jinja
675675
// https://github.com/MeetKai/functionary/blob/main/tests/prompt_test_v3-llama3.1.txt
676676
// TODO: handle tool {type: code_interpreter} as python
@@ -700,7 +700,7 @@ common_tool_call_handler common_tool_call_handler_init(
700700
// handler.parser = parse_functionary_3_2_tool_calls;
701701
break;
702702
}
703-
case common_tool_call_style::Hermes2Pro: {
703+
case COMMON_TOOL_CALL_STYLE_HERMES_2_PRO: {
704704
// NousResearchHermesPro_2
705705
// (content)?(<tool_call>{"name": "foo", "arguments": {"a": 1}}</tool_call>)*
706706
auto actual_tools = normalize_tools(tools);

common/tool-call.h

Lines changed: 11 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -8,17 +8,17 @@
88
#include "json.hpp"
99

1010
enum common_tool_call_style {
11-
UnknownToolCallStyle,
12-
None,
13-
Generic,
14-
Llama31,
15-
Llama32,
16-
FunctionaryV3Llama3,
17-
FunctionaryV3Llama31,
18-
Hermes2Pro,
19-
CommandRPlus,
20-
MistralNemo,
21-
FirefunctionV2,
11+
COMMON_TOOL_CALL_STYLE_UNKNOWN,
12+
COMMON_TOOL_CALL_STYLE_NONE,
13+
COMMON_TOOL_CALL_STYLE_GENERIC,
14+
COMMON_TOOL_CALL_STYLE_LLAMA_3_1,
15+
COMMON_TOOL_CALL_STYLE_LLAMA_3_2,
16+
COMMON_TOOL_CALL_STYLE_FUNCTIONARY_V3_LLAMA_3,
17+
COMMON_TOOL_CALL_STYLE_FUNCTIONARY_V3_LLAMA_3_1,
18+
COMMON_TOOL_CALL_STYLE_HERMES_2_PRO,
19+
COMMON_TOOL_CALL_STYLE_COMMAND_R_PLUS,
20+
COMMON_TOOL_CALL_STYLE_MISTRAL_NEMO,
21+
COMMON_TOOL_CALL_STYLE_FIRE_FUNCTION_V2,
2222
};
2323

2424
struct common_tool_call {

examples/server/server.cpp

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -118,7 +118,7 @@ struct slot_params {
118118
std::string oaicompat_model;
119119
std::string oaicompat_cmpl_id;
120120
json oaicompat_tools;
121-
common_tool_call_style oaicompat_tool_call_style = common_tool_call_style::None;
121+
common_tool_call_style oaicompat_tool_call_style = common_tool_call_style::COMMON_TOOL_CALL_STYLE_NONE;
122122

123123
json to_json() const {
124124
std::vector<std::string> samplers;
@@ -589,7 +589,7 @@ struct server_task_result_cmpl_final : server_task_result {
589589
std::string oaicompat_model;
590590
std::string oaicompat_cmpl_id;
591591
json oaicompat_tools;
592-
common_tool_call_style oaicompat_tool_call_style = common_tool_call_style::None;
592+
common_tool_call_style oaicompat_tool_call_style = common_tool_call_style::COMMON_TOOL_CALL_STYLE_NONE;
593593

594594
virtual int get_index() override {
595595
return index;
@@ -690,7 +690,7 @@ struct server_task_result_cmpl_final : server_task_result {
690690
common_tool_calls parsed_tool_calls;
691691
json tool_calls;
692692
json message_content;
693-
if (oaicompat_tool_call_style != common_tool_call_style::None && !oaicompat_tools.is_null()) {
693+
if (oaicompat_tool_call_style != common_tool_call_style::COMMON_TOOL_CALL_STYLE_NONE && !oaicompat_tools.is_null()) {
694694
parsed_tool_calls = parse_tool_calls(oaicompat_tool_call_style, oaicompat_tools, content);
695695
if (!parsed_tool_calls.tool_calls.empty()) {
696696
finish_reason = "tool_calls";
@@ -3772,7 +3772,7 @@ int main(int argc, char ** argv) {
37723772
std::function<bool()> is_connection_closed,
37733773
httplib::Response & res,
37743774
oaicompat_type oaicompat,
3775-
common_tool_call_style tool_call_style = common_tool_call_style::None) {
3775+
common_tool_call_style tool_call_style = common_tool_call_style::COMMON_TOOL_CALL_STYLE_NONE) {
37763776
GGML_ASSERT(type == SERVER_TASK_TYPE_COMPLETION || type == SERVER_TASK_TYPE_INFILL);
37773777

37783778
if (ctx_server.params_base.embedding) {

examples/server/utils.hpp

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -595,7 +595,7 @@ static json oaicompat_completion_params_parse(
595595
throw std::runtime_error("Cannot use tools with stream");
596596
}
597597
if (use_jinja) {
598-
if (tool_call_style == common_tool_call_style::UnknownToolCallStyle) {
598+
if (tool_call_style == common_tool_call_style::COMMON_TOOL_CALL_STYLE_UNKNOWN) {
599599
throw std::runtime_error("Chat template does not seem to support tools. Override the model template with --chat-template.");
600600
}
601601
} else {

0 commit comments

Comments
 (0)