Skip to content

Commit 4706bdb

Browse files
author
ochafik
committed
tool-call: support Functionary v3 vs. v3-llama3.1 variants
1 parent 41103c0 commit 4706bdb

File tree

2 files changed

+93
-7
lines changed

2 files changed

+93
-7
lines changed

common/tool-call.cpp

Lines changed: 67 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -12,11 +12,18 @@
1212

1313
using json = nlohmann::ordered_json;
1414

15-
static bool needs_functionary_3_2_tool_call(const std::string & chat_template) {
15+
// https://github.com/MeetKai/functionary/blob/main/tests/prompt_test_v3.llama3.txt
16+
static bool needs_functionary_v3_tool_call(const std::string & chat_template) {
1617
return chat_template.find("<|start_header_id|>") != std::string::npos
1718
&& chat_template.find(">>>all") != std::string::npos;
1819
}
1920

21+
// https://github.com/MeetKai/functionary/blob/main/tests/prompt_test_v3-llama3.1.txt
22+
static bool needs_functionary_v3_llama_3_1_tool_call(const std::string & chat_template) {
23+
return chat_template.find("<|start_header_id|>") != std::string::npos
24+
&& chat_template.find("<function=") != std::string::npos;
25+
}
26+
2027
static bool needs_llama_3_1_tool_call(const std::string & chat_template) {
2128
return chat_template.find("<|start_header_id|>") != std::string::npos
2229
&& chat_template.find("<|python_tag|>") != std::string::npos;
@@ -148,8 +155,42 @@ static llama_tool_calls parse_llama_3_1_tool_calls(const json & tools, const std
148155
return {input, {}};
149156
}
150157

158+
static llama_tool_calls parse_functionary_v3_llama_3_1_tool_calls(const std::string& input) {
159+
static std::regex function_regex(R"(<function=(\w+)>)");
160+
static std::regex close_regex(R"(</function>)");
161+
std::smatch match;
151162

152-
static llama_tool_calls parse_functionary_3_2_tool_calls(const std::string& input) {
163+
llama_tool_calls result;
164+
auto end = input.end();
165+
auto it = input.begin();
166+
167+
while (it != end) {
168+
std::sregex_iterator rend;
169+
std::sregex_iterator rit(it, end, function_regex);
170+
if (rit == rend) {
171+
result.content += std::string(it, end);
172+
break;
173+
}
174+
175+
result.content += std::string(it, rit->prefix().second);
176+
it = rit->suffix().first;
177+
178+
auto name = rit->str(1);
179+
180+
json arguments;
181+
if (!parse_json(it, end, arguments)) {
182+
throw std::runtime_error("Failed to parse json tool call arguments");
183+
}
184+
if (!std::regex_search(it, end, match, close_regex)) {
185+
throw std::runtime_error("Malformed input, missing closing pattern");
186+
}
187+
it = match.suffix().first;
188+
result.tool_calls.push_back({name, arguments.dump()});
189+
}
190+
return result;
191+
}
192+
193+
static llama_tool_calls parse_functionary_v3_tool_calls(const std::string& input) {
153194
static std::regex python_tag_regex(R"(>>>(\w+)\n((?!>>>)[\s\S\n]*))");
154195
std::smatch match;
155196
llama_tool_calls result;
@@ -172,8 +213,10 @@ llama_tool_calls parse_tool_calls(const json & tools, const std::string & chat_t
172213
return parse_hermes_tool_calls(input);
173214
} else if (needs_llama_3_1_tool_call(chat_template)) {
174215
return parse_llama_3_1_tool_calls(tools, input);
175-
} else if (needs_functionary_3_2_tool_call(chat_template)) {
176-
return parse_functionary_3_2_tool_calls(input);
216+
} else if (needs_functionary_v3_tool_call(chat_template)) {
217+
return parse_functionary_v3_tool_calls(input);
218+
} else if (needs_functionary_v3_llama_3_1_tool_call(chat_template)) {
219+
return parse_functionary_v3_llama_3_1_tool_calls(input);
177220
} else {
178221
throw std::runtime_error("Unsupported chat template for tool calls");
179222
}
@@ -187,7 +230,7 @@ llama_tool_call_handler llama_tool_call_handler_init(
187230
{
188231
llama_tool_call_handler handler;
189232

190-
if (needs_functionary_3_2_tool_call(chat_template)) {
233+
if (needs_functionary_v3_tool_call(chat_template)) {
191234
// MeetKaiFunctionary_3_2
192235
// >>>all\nlet's call functions>>>fn1\n{"arg1": 1...}\n>>>fn2\n{"arg1": 1...}...
193236
// Using ">>>f1\n", ">>>f2\n"... as trigger words for the grammar
@@ -208,6 +251,25 @@ llama_tool_call_handler llama_tool_call_handler_init(
208251
builder.add_rule("root", parallel_tool_calls ? "(" + tool_call + ")+" : tool_call);
209252
});
210253
// handler.parser = parse_functionary_3_2_tool_calls;
254+
} else if (needs_functionary_v3_llama_3_1_tool_call(chat_template)) {
255+
// https://github.com/MeetKai/functionary/blob/main/tests/prompt_test_v3-llama3.1.txt
256+
handler.grammar = build_grammar([&](const llama_grammar_builder & builder) {
257+
std::vector<std::string> tool_rules;
258+
for (size_t i = 0, n = tools.size(); i < n; i++) {
259+
auto & tool = tools[i];
260+
const auto & function = tool["function"];
261+
std::string name = function["name"];
262+
auto parameters = function["parameters"];
263+
auto tool_rule = builder.add_rule(name + "-call", "\"<function=" + name + ">\" " + builder.add_schema(name + "-args", parameters) + " \"</function>\"");
264+
tool_rules.push_back(tool_rule);
265+
}
266+
auto tool_call = builder.add_rule("tool_call", join(tool_rules.begin(), tool_rules.end(), " | ")) + " space";
267+
builder.add_rule("root", parallel_tool_calls ? "(" + tool_call + ")+" : tool_call);
268+
if (allow_content) {
269+
handler.grammar_trigger_words.push_back("<function=");
270+
}
271+
});
272+
// handler.parser = parse_functionary_3_2_tool_calls;
211273
} else if (needs_hermes_pro_tool_call(chat_template)) {
212274
// NousResearchHermesPro_2
213275
// (content)?(<tool_call>{"name": "foo", "arguments": {"a": 1}}</tool_call>)*

tests/test-tool-call.cpp

Lines changed: 26 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -21,6 +21,7 @@ static void assert_equals(const std::string & expected, const std::string & actu
2121
*/
2222

2323
static void test_parse_tool_call(const json & tools, const std::string & chat_template, const std::string & input, const std::string & expected_content, const json & expected_tool_calls) {
24+
std::cout << "# Testing: " << input << std::endl << std::flush;
2425
auto result = parse_tool_calls(tools, chat_template, input);
2526
assert_equals(expected_content, result.content);
2627
auto tool_calls = json::array();
@@ -71,8 +72,8 @@ int main() {
7172
}}
7273
}});
7374

74-
std::string functionary_3_2_like_tmpl = "Functionary 3.2 template should have <|start_header_id|> and then some >>>all inside it";
75-
test_parse_tool_call(tools, functionary_3_2_like_tmpl,
75+
std::string functionary_v3_like_tmpl = "Functionary 3.2 template should have <|start_header_id|> and then some >>>all inside it";
76+
test_parse_tool_call(tools, functionary_v3_like_tmpl,
7677
">>>ipython\nprint('Hello, world!')",
7778
"",
7879
json {{
@@ -84,6 +85,29 @@ int main() {
8485
}}
8586
}});
8687

88+
std::string functionary_v3_llama_3_1_like_tmpl = "Functionary 3.2 template for llama 3.1 should have <|start_header_id|> and then some <function=foo>{...}</function> inside it";
89+
test_parse_tool_call(tools, functionary_v3_llama_3_1_like_tmpl,
90+
"Hell<function=foo>{\"arg1\": 1}</function>o, world<function=bar>{\"arg2\": 2}</function>!",
91+
"Hello, world!",
92+
json {
93+
{
94+
{"function", {
95+
{"name", "foo"},
96+
{"arguments", (json {
97+
{"arg1", 1}
98+
}).dump()}
99+
}}
100+
},
101+
{
102+
{"function", {
103+
{"name", "bar"},
104+
{"arguments", (json {
105+
{"arg2", 2}
106+
}).dump()}
107+
}}
108+
},
109+
});
110+
87111
std::string llama_3_1_like_tmpl = "Llama 3.1 template should have <|start_header_id|> and <|python_tag|> inside it";
88112
test_parse_tool_call(tools, llama_3_1_like_tmpl,
89113
"<|python_tag|>this could be anything",

0 commit comments

Comments
 (0)