Skip to content

Commit 3972945

Browse files
author
Olivier Chafik
committed
common_tool_call rename
1 parent ef61a4c commit 3972945

File tree

5 files changed

+95
-95
lines changed

5 files changed

+95
-95
lines changed

common/tool-call.cpp

Lines changed: 50 additions & 50 deletions
Original file line numberDiff line numberDiff line change
@@ -47,34 +47,34 @@ static json normalize_tools(const json & tools) {
4747
return results;
4848
}
4949

50-
std::string llama_tool_call_style_name(llama_tool_call_style style) {
50+
std::string common_tool_call_style_name(common_tool_call_style style) {
5151
switch (style) {
52-
case llama_tool_call_style::None:
52+
case common_tool_call_style::None:
5353
return "None";
54-
case llama_tool_call_style::Generic:
54+
case common_tool_call_style::Generic:
5555
return "Generic";
56-
case llama_tool_call_style::Llama31:
56+
case common_tool_call_style::Llama31:
5757
return "Llama-3.1";
58-
case llama_tool_call_style::Llama32:
58+
case common_tool_call_style::Llama32:
5959
return "Llama-3.2";
60-
case llama_tool_call_style::FunctionaryV3Llama3:
60+
case common_tool_call_style::FunctionaryV3Llama3:
6161
return "FunctionaryV3Llama3";
62-
case llama_tool_call_style::FunctionaryV3Llama31:
62+
case common_tool_call_style::FunctionaryV3Llama31:
6363
return "FunctionaryV3Llama3.1";
64-
case llama_tool_call_style::Hermes2Pro:
64+
case common_tool_call_style::Hermes2Pro:
6565
return "Hermes2Pro";
66-
case llama_tool_call_style::CommandRPlus:
66+
case common_tool_call_style::CommandRPlus:
6767
return "CommandRPlus";
68-
case llama_tool_call_style::MistralNemo:
68+
case common_tool_call_style::MistralNemo:
6969
return "MistralNemo";
70-
case llama_tool_call_style::FirefunctionV2:
70+
case common_tool_call_style::FirefunctionV2:
7171
return "FirefunctionV2";
7272
default:
7373
return "Unknown";
7474
}
7575
}
7676

77-
llama_tool_call_style llama_tool_call_style_detect(const common_chat_template & chat_template) {
77+
common_tool_call_style common_tool_call_style_detect(const common_chat_template & chat_template) {
7878
const auto & src = chat_template.source();
7979

8080
if (src.find("<tool_call>") != std::string::npos) {
@@ -150,10 +150,10 @@ static bool parse_json(std::string::const_iterator & it, const std::string::cons
150150
* Takes a prefix regex that must have 1 group to capture the function name, a closing suffix, and expects json parameters in between.
151151
* Aggregates the prefix, suffix and in-between text into the content.
152152
*/
153-
static llama_tool_calls parse_json_tool_calls(const json & tools, const std::string& input, const std::regex & function_regex, const std::regex & close_regex, bool check_names) {
153+
static common_tool_calls parse_json_tool_calls(const json & tools, const std::string& input, const std::regex & function_regex, const std::regex & close_regex, bool check_names) {
154154
std::smatch match;
155155

156-
llama_tool_calls result;
156+
common_tool_calls result;
157157
auto end = input.end();
158158
auto it = input.begin();
159159

@@ -202,7 +202,7 @@ static llama_tool_calls parse_json_tool_calls(const json & tools, const std::str
202202
return result;
203203
}
204204

205-
static llama_tool_calls parse_hermes_tool_calls(const std::string& input) {
205+
static common_tool_calls parse_hermes_tool_calls(const std::string& input) {
206206
try {
207207
std::regex start_pattern(R"([\n\s]*<tool_call>)");
208208
std::regex middle_pattern(R"([\n\s]*</tool_call>[\n\s]*<tool_call>)");
@@ -215,7 +215,7 @@ static llama_tool_calls parse_hermes_tool_calls(const std::string& input) {
215215
return {input, {}};
216216
}
217217

218-
llama_tool_calls result;
218+
common_tool_calls result;
219219
result.content = rit->prefix();
220220

221221
auto it = rit->suffix().first;
@@ -246,7 +246,7 @@ static llama_tool_calls parse_hermes_tool_calls(const std::string& input) {
246246
}
247247
}
248248

249-
static llama_tool_calls parse_llama_3_tool_calls(const json & tools, const std::string& input, bool allow_python_tag) {
249+
static common_tool_calls parse_llama_3_tool_calls(const json & tools, const std::string& input, bool allow_python_tag) {
250250
if (allow_python_tag) {
251251
static std::regex python_tag_regex(R"(<\|python_tag\|>([\s\S\n]*)$)");
252252
std::smatch match;
@@ -268,7 +268,7 @@ static llama_tool_calls parse_llama_3_tool_calls(const json & tools, const std::
268268
return parse_json_tool_calls(tools, input, function_regex, close_regex, /* check_names= */ true);
269269
}
270270

271-
static llama_tool_calls parse_functionary_v3_llama_3_1_tool_calls(const json & tools, const std::string& input) {
271+
static common_tool_calls parse_functionary_v3_llama_3_1_tool_calls(const json & tools, const std::string& input) {
272272
// This version of Functionary still supports the llama 3.1 tool call format for the python tool.
273273
static std::regex python_tag_regex(R"(<\|python_tag\|>([\s\S\n]*)$)");
274274
std::smatch match;
@@ -289,15 +289,15 @@ static llama_tool_calls parse_functionary_v3_llama_3_1_tool_calls(const json & t
289289
return parse_json_tool_calls(tools, input, function_regex, close_regex, /* check_names= */ false);
290290
}
291291

292-
static llama_tool_calls parse_functionary_v3_tool_calls(const json & tools, const std::string& input) {
292+
static common_tool_calls parse_functionary_v3_tool_calls(const json & tools, const std::string& input) {
293293
static std::regex function_regex(R"((?:>>>)?(\w+)\n)");
294294
static std::regex close_regex(R"($|(?=>>>))");
295295
return parse_json_tool_calls(tools, input, function_regex, close_regex, /* check_names= */ true);
296296
}
297297

298-
static llama_tool_calls parse_generic_tool_calls(const std::string& input) {
298+
static common_tool_calls parse_generic_tool_calls(const std::string& input) {
299299
json data = json::parse(input);
300-
llama_tool_calls result;
300+
common_tool_calls result;
301301
if (data.contains("tool_calls")) {
302302
for (const auto & tool_call : data["tool_calls"]) {
303303
result.tool_calls.push_back({
@@ -319,11 +319,11 @@ static llama_tool_calls parse_generic_tool_calls(const std::string& input) {
319319
return result;
320320
}
321321

322-
static llama_tool_calls parse_prefixed_json_tool_call_array(const std::string& input, const std::string & prefix, size_t rstrip_prefix = 0) {
322+
static common_tool_calls parse_prefixed_json_tool_call_array(const std::string& input, const std::string & prefix, size_t rstrip_prefix = 0) {
323323
auto content_end = input.find(prefix);
324324
size_t tc_start = std::string::npos;
325325

326-
llama_tool_calls result;
326+
common_tool_calls result;
327327
const auto process_tool_calls = [&](const json & tool_calls) {
328328
for (const auto & tool_call : tool_calls) {
329329
const auto & arguments = tool_call["arguments"];
@@ -345,34 +345,34 @@ static llama_tool_calls parse_prefixed_json_tool_call_array(const std::string& i
345345
return result;
346346
}
347347

348-
static llama_tool_calls parse_mistral_nemo_tool_calls(const std::string& input) {
348+
static common_tool_calls parse_mistral_nemo_tool_calls(const std::string& input) {
349349
return parse_prefixed_json_tool_call_array(input, "[TOOL_CALLS]");
350350
}
351351

352-
static llama_tool_calls parse_firefunction_v2_tool_calls(const std::string& input) {
352+
static common_tool_calls parse_firefunction_v2_tool_calls(const std::string& input) {
353353
return parse_prefixed_json_tool_call_array(input, " functools[", /* rstrip_prefix= */ 1);
354354
}
355355

356-
llama_tool_calls parse_tool_calls(llama_tool_call_style style, const json & tools, const std::string& input) {
357-
fprintf(stderr, "# parse_tool_calls(%s):\n\n%s\n\n", llama_tool_call_style_name(style).c_str(), input.c_str());
356+
common_tool_calls parse_tool_calls(common_tool_call_style style, const json & tools, const std::string& input) {
357+
fprintf(stderr, "# parse_tool_calls(%s):\n\n%s\n\n", common_tool_call_style_name(style).c_str(), input.c_str());
358358
switch (style) {
359-
case llama_tool_call_style::None:
359+
case common_tool_call_style::None:
360360
return {input, {}};
361-
case llama_tool_call_style::Generic:
361+
case common_tool_call_style::Generic:
362362
return parse_generic_tool_calls(input);
363-
case llama_tool_call_style::Llama31:
363+
case common_tool_call_style::Llama31:
364364
return parse_llama_3_tool_calls(tools, input, /* parse_llama_3_tool_calls= */ true);
365-
case llama_tool_call_style::Llama32:
365+
case common_tool_call_style::Llama32:
366366
return parse_llama_3_tool_calls(tools, input, /* parse_llama_3_tool_calls= */ false);
367-
case llama_tool_call_style::FunctionaryV3Llama3:
367+
case common_tool_call_style::FunctionaryV3Llama3:
368368
return parse_functionary_v3_tool_calls(tools, input);
369-
case llama_tool_call_style::FunctionaryV3Llama31:
369+
case common_tool_call_style::FunctionaryV3Llama31:
370370
return parse_functionary_v3_llama_3_1_tool_calls(tools, input);
371-
case llama_tool_call_style::Hermes2Pro:
371+
case common_tool_call_style::Hermes2Pro:
372372
return parse_hermes_tool_calls(input);
373-
case llama_tool_call_style::MistralNemo:
373+
case common_tool_call_style::MistralNemo:
374374
return parse_mistral_nemo_tool_calls(input);
375-
case llama_tool_call_style::FirefunctionV2:
375+
case common_tool_call_style::FirefunctionV2:
376376
return parse_firefunction_v2_tool_calls(input);
377377
default:
378378
throw std::runtime_error("Unsupported tool call style");
@@ -397,23 +397,23 @@ static nlohmann::ordered_json add_system(const nlohmann::ordered_json & messages
397397
return messages_with_system;
398398
}
399399

400-
llama_tool_call_handler llama_tool_call_handler_init(
401-
llama_tool_call_style style,
400+
common_tool_call_handler common_tool_call_handler_init(
401+
common_tool_call_style style,
402402
const common_chat_template & tmpl,
403403
bool allow_content,
404404
const nlohmann::ordered_json & parallel_tool_calls,
405405
const nlohmann::ordered_json & messages,
406406
const nlohmann::ordered_json & tools,
407407
const nlohmann::ordered_json & json_schema)
408408
{
409-
llama_tool_call_handler handler;
409+
common_tool_call_handler handler;
410410
auto parallel = parallel_tool_calls.is_null() ? tmpl.supports_parallel_tool_calls() : parallel_tool_calls.get<bool>();
411411

412412
switch (style) {
413-
case llama_tool_call_style::None:
413+
case common_tool_call_style::None:
414414
handler.prompt = tmpl.apply(messages, tools, /* add_generation_prompt= */ true);
415415
break;
416-
case llama_tool_call_style::Generic: {
416+
case common_tool_call_style::Generic: {
417417
auto actual_tools = normalize_tools(tools);
418418
auto tool_call_schemas = json::array();
419419
for (const auto & tool : actual_tools) {
@@ -493,7 +493,7 @@ llama_tool_call_handler llama_tool_call_handler_init(
493493
handler.prompt = tmpl.apply(tweaked_messages, actual_tools.empty() ? json() : actual_tools, /* add_generation_prompt= */ true);
494494
break;
495495
}
496-
case llama_tool_call_style::MistralNemo: {
496+
case common_tool_call_style::MistralNemo: {
497497
auto actual_tools = normalize_tools(tools);
498498
handler.grammar = build_grammar([&](const llama_grammar_builder & builder) {
499499
auto schemas = json::array();
@@ -534,7 +534,7 @@ llama_tool_call_handler llama_tool_call_handler_init(
534534
handler.prompt = tmpl.apply(messages, actual_tools.empty() ? json() : actual_tools, /* add_generation_prompt= */ true);
535535
break;
536536
}
537-
case llama_tool_call_style::FirefunctionV2: {
537+
case common_tool_call_style::FirefunctionV2: {
538538
auto actual_tools = normalize_tools(tools);
539539
handler.grammar = build_grammar([&](const llama_grammar_builder & builder) {
540540
auto schemas = json::array();
@@ -568,8 +568,8 @@ llama_tool_call_handler llama_tool_call_handler_init(
568568
handler.prompt = tmpl.apply(messages, actual_tools.empty() ? json() : actual_tools, /* add_generation_prompt= */ true);
569569
break;
570570
}
571-
case llama_tool_call_style::Llama31:
572-
case llama_tool_call_style::Llama32: {
571+
case common_tool_call_style::Llama31:
572+
case common_tool_call_style::Llama32: {
573573
auto builtin_tools = json {"wolfram_alpha", "brave_search"};
574574
for (const auto & tool : tools) {
575575
if (!tool.contains("type")) {
@@ -582,13 +582,13 @@ llama_tool_call_handler llama_tool_call_handler_init(
582582
}
583583
auto actual_tools = normalize_tools(tools);
584584

585-
auto uses_python_tag = style == llama_tool_call_style::Llama31;
585+
auto uses_python_tag = style == common_tool_call_style::Llama31;
586586

587587
// Technically we should only trigger on `"\n{\"name\": \"" + name + "\""` for each tool name,
588588
// but Llama-3.2-3B (and 1B) struggles to output valid tool calls so we're "guiding" it strongly as soon
589589
// as it seems to be outputting some JSON.
590590
// TODO: make this conditional on a very small model (e.g. 1B / 3B).
591-
auto eagerly_match_any_json = style == llama_tool_call_style::Llama32;
591+
auto eagerly_match_any_json = style == common_tool_call_style::Llama32;
592592

593593
handler.grammar = build_grammar([&](const llama_grammar_builder & builder) {
594594
std::vector<std::string> tool_rules;
@@ -639,7 +639,7 @@ llama_tool_call_handler llama_tool_call_handler_init(
639639
});
640640
break;
641641
}
642-
case llama_tool_call_style::FunctionaryV3Llama3: {
642+
case common_tool_call_style::FunctionaryV3Llama3: {
643643
// >>>all\nlet's call functions>>>fn1\n{"arg1": 1...}\n>>>fn2\n{"arg1": 1...}...
644644
// Using ">>>f1\n", ">>>f2\n"... as trigger words for the grammar
645645
auto actual_tools = normalize_tools(tools);
@@ -670,7 +670,7 @@ llama_tool_call_handler llama_tool_call_handler_init(
670670
// handler.parser = parse_functionary_3_2_tool_calls;
671671
break;
672672
}
673-
case llama_tool_call_style::FunctionaryV3Llama31: {
673+
case common_tool_call_style::FunctionaryV3Llama31: {
674674
// ./tests/chat/templates/meetkai-functionary-medium-v3.1.jinja
675675
// https://github.com/MeetKai/functionary/blob/main/tests/prompt_test_v3-llama3.1.txt
676676
// TODO: handle tool {type: code_interpreter} as python
@@ -700,7 +700,7 @@ llama_tool_call_handler llama_tool_call_handler_init(
700700
// handler.parser = parse_functionary_3_2_tool_calls;
701701
break;
702702
}
703-
case llama_tool_call_style::Hermes2Pro: {
703+
case common_tool_call_style::Hermes2Pro: {
704704
// NousResearchHermesPro_2
705705
// (content)?(<tool_call>{"name": "foo", "arguments": {"a": 1}}</tool_call>)*
706706
auto actual_tools = normalize_tools(tools);

common/tool-call.h

Lines changed: 10 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -7,7 +7,7 @@
77
#define JSON_ASSERT GGML_ASSERT
88
#include "json.hpp"
99

10-
enum llama_tool_call_style {
10+
enum common_tool_call_style {
1111
UnknownToolCallStyle,
1212
None,
1313
Generic,
@@ -21,32 +21,32 @@ enum llama_tool_call_style {
2121
FirefunctionV2,
2222
};
2323

24-
struct llama_tool_call {
24+
struct common_tool_call {
2525
std::string name;
2626
std::string arguments;
2727
std::string id;
2828
};
2929

30-
struct llama_tool_calls {
30+
struct common_tool_calls {
3131
std::string content;
32-
std::vector<llama_tool_call> tool_calls;
32+
std::vector<common_tool_call> tool_calls;
3333
};
3434

35-
struct llama_tool_call_handler {
35+
struct common_tool_call_handler {
3636
std::string prompt;
3737
std::string grammar;
3838
std::vector<std::string> grammar_triggers;
3939
std::vector<std::string> additional_stops;
4040
};
4141

42-
std::string llama_tool_call_style_name(llama_tool_call_style style);
42+
std::string common_tool_call_style_name(common_tool_call_style style);
4343

44-
llama_tool_call_style llama_tool_call_style_detect(const common_chat_template & chat_template);
44+
common_tool_call_style common_tool_call_style_detect(const common_chat_template & chat_template);
4545

46-
llama_tool_calls parse_tool_calls(llama_tool_call_style style, const nlohmann::ordered_json & tools, const std::string& input);
46+
common_tool_calls parse_tool_calls(common_tool_call_style style, const nlohmann::ordered_json & tools, const std::string& input);
4747

48-
llama_tool_call_handler llama_tool_call_handler_init(
49-
llama_tool_call_style style,
48+
common_tool_call_handler common_tool_call_handler_init(
49+
common_tool_call_style style,
5050
const common_chat_template & tmpl,
5151
bool allow_content,
5252
const nlohmann::ordered_json & parallel_tool_calls,

examples/server/server.cpp

Lines changed: 7 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -118,7 +118,7 @@ struct slot_params {
118118
std::string oaicompat_model;
119119
std::string oaicompat_cmpl_id;
120120
json oaicompat_tools;
121-
llama_tool_call_style oaicompat_tool_call_style = llama_tool_call_style::None;
121+
common_tool_call_style oaicompat_tool_call_style = common_tool_call_style::None;
122122

123123
json to_json() const {
124124
std::vector<std::string> samplers;
@@ -589,7 +589,7 @@ struct server_task_result_cmpl_final : server_task_result {
589589
std::string oaicompat_model;
590590
std::string oaicompat_cmpl_id;
591591
json oaicompat_tools;
592-
llama_tool_call_style oaicompat_tool_call_style = llama_tool_call_style::None;
592+
common_tool_call_style oaicompat_tool_call_style = common_tool_call_style::None;
593593

594594
virtual int get_index() override {
595595
return index;
@@ -687,10 +687,10 @@ struct server_task_result_cmpl_final : server_task_result {
687687
finish_reason = "stop";
688688
}
689689

690-
llama_tool_calls parsed_tool_calls;
690+
common_tool_calls parsed_tool_calls;
691691
json tool_calls;
692692
json message_content;
693-
if (oaicompat_tool_call_style != llama_tool_call_style::None && !oaicompat_tools.is_null()) {
693+
if (oaicompat_tool_call_style != common_tool_call_style::None && !oaicompat_tools.is_null()) {
694694
parsed_tool_calls = parse_tool_calls(oaicompat_tool_call_style, oaicompat_tools, content);
695695
if (!parsed_tool_calls.tool_calls.empty()) {
696696
finish_reason = "tool_calls";
@@ -3772,7 +3772,7 @@ int main(int argc, char ** argv) {
37723772
std::function<bool()> is_connection_closed,
37733773
httplib::Response & res,
37743774
oaicompat_type oaicompat,
3775-
llama_tool_call_style tool_call_style = llama_tool_call_style::None) {
3775+
common_tool_call_style tool_call_style = common_tool_call_style::None) {
37763776
GGML_ASSERT(type == SERVER_TASK_TYPE_COMPLETION || type == SERVER_TASK_TYPE_INFILL);
37773777

37783778
if (ctx_server.params_base.embedding) {
@@ -3979,8 +3979,8 @@ int main(int argc, char ** argv) {
39793979

39803980
auto body = json::parse(req.body);
39813981
const auto & chat_template = body.contains("tools") && ctx_server.chat_templates.template_tool_use ? *ctx_server.chat_templates.template_tool_use : *ctx_server.chat_templates.template_default;
3982-
auto tool_call_style = llama_tool_call_style_detect(chat_template);
3983-
LOG_INF("Tool call style: %s\n", llama_tool_call_style_name(tool_call_style).c_str());
3982+
auto tool_call_style = common_tool_call_style_detect(chat_template);
3983+
LOG_INF("Tool call style: %s\n", common_tool_call_style_name(tool_call_style).c_str());
39843984

39853985
json data = oaicompat_completion_params_parse(body, chat_template, tool_call_style, params.use_jinja);
39863986

0 commit comments

Comments
 (0)