1212
1313using json = nlohmann::ordered_json;
1414
15- static bool needs_functionary_3_2_tool_call (const std::string & chat_template) {
15+ // https://github.com/MeetKai/functionary/blob/main/tests/prompt_test_v3.llama3.txt
16+ static bool needs_functionary_v3_tool_call (const std::string & chat_template) {
1617 return chat_template.find (" <|start_header_id|>" ) != std::string::npos
1718 && chat_template.find (" >>>all" ) != std::string::npos;
1819}
1920
21+ // https://github.com/MeetKai/functionary/blob/main/tests/prompt_test_v3-llama3.1.txt
22+ static bool needs_functionary_v3_llama_3_1_tool_call (const std::string & chat_template) {
23+ return chat_template.find (" <|start_header_id|>" ) != std::string::npos
24+ && chat_template.find (" <function=" ) != std::string::npos;
25+ }
26+
2027static bool needs_llama_3_1_tool_call (const std::string & chat_template) {
2128 return chat_template.find (" <|start_header_id|>" ) != std::string::npos
2229 && chat_template.find (" <|python_tag|>" ) != std::string::npos;
@@ -148,8 +155,42 @@ static llama_tool_calls parse_llama_3_1_tool_calls(const json & tools, const std
148155 return {input, {}};
149156}
150157
158+ static llama_tool_calls parse_functionary_v3_llama_3_1_tool_calls (const std::string& input) {
159+ static std::regex function_regex (R"( <function=(\w+)>)" );
160+ static std::regex close_regex (R"( </function>)" );
161+ std::smatch match;
151162
152- static llama_tool_calls parse_functionary_3_2_tool_calls (const std::string& input) {
163+ llama_tool_calls result;
164+ auto end = input.end ();
165+ auto it = input.begin ();
166+
167+ while (it != end) {
168+ std::sregex_iterator rend;
169+ std::sregex_iterator rit (it, end, function_regex);
170+ if (rit == rend) {
171+ result.content += std::string (it, end);
172+ break ;
173+ }
174+
175+ result.content += std::string (it, rit->prefix ().second );
176+ it = rit->suffix ().first ;
177+
178+ auto name = rit->str (1 );
179+
180+ json arguments;
181+ if (!parse_json (it, end, arguments)) {
182+ throw std::runtime_error (" Failed to parse json tool call arguments" );
183+ }
184+ if (!std::regex_search (it, end, match, close_regex)) {
185+ throw std::runtime_error (" Malformed input, missing closing pattern" );
186+ }
187+ it = match.suffix ().first ;
188+ result.tool_calls .push_back ({name, arguments.dump ()});
189+ }
190+ return result;
191+ }
192+
193+ static llama_tool_calls parse_functionary_v3_tool_calls (const std::string& input) {
153194 static std::regex python_tag_regex (R"( >>>(\w+)\n((?!>>>)[\s\S\n]*))" );
154195 std::smatch match;
155196 llama_tool_calls result;
@@ -172,8 +213,10 @@ llama_tool_calls parse_tool_calls(const json & tools, const std::string & chat_t
172213 return parse_hermes_tool_calls (input);
173214 } else if (needs_llama_3_1_tool_call (chat_template)) {
174215 return parse_llama_3_1_tool_calls (tools, input);
175- } else if (needs_functionary_3_2_tool_call (chat_template)) {
176- return parse_functionary_3_2_tool_calls (input);
216+ } else if (needs_functionary_v3_tool_call (chat_template)) {
217+ return parse_functionary_v3_tool_calls (input);
218+ } else if (needs_functionary_v3_llama_3_1_tool_call (chat_template)) {
219+ return parse_functionary_v3_llama_3_1_tool_calls (input);
177220 } else {
178221 throw std::runtime_error (" Unsupported chat template for tool calls" );
179222 }
@@ -187,7 +230,7 @@ llama_tool_call_handler llama_tool_call_handler_init(
187230{
188231 llama_tool_call_handler handler;
189232
190- if (needs_functionary_3_2_tool_call (chat_template)) {
233+ if (needs_functionary_v3_tool_call (chat_template)) {
191234 // MeetKaiFunctionary_3_2
192235 // >>>all\nlet's call functions>>>fn1\n{"arg1": 1...}\n>>>fn2\n{"arg1": 1...}...
193236 // Using ">>>f1\n", ">>>f2\n"... as trigger words for the grammar
@@ -208,6 +251,25 @@ llama_tool_call_handler llama_tool_call_handler_init(
208251 builder.add_rule (" root" , parallel_tool_calls ? " (" + tool_call + " )+" : tool_call);
209252 });
210253 // handler.parser = parse_functionary_3_2_tool_calls;
254+ } else if (needs_functionary_v3_llama_3_1_tool_call (chat_template)) {
255+ // https://github.com/MeetKai/functionary/blob/main/tests/prompt_test_v3-llama3.1.txt
256+ handler.grammar = build_grammar ([&](const llama_grammar_builder & builder) {
257+ std::vector<std::string> tool_rules;
258+ for (size_t i = 0 , n = tools.size (); i < n; i++) {
259+ auto & tool = tools[i];
260+ const auto & function = tool[" function" ];
261+ std::string name = function[" name" ];
262+ auto parameters = function[" parameters" ];
263+ auto tool_rule = builder.add_rule (name + " -call" , " \" <function=" + name + " >\" " + builder.add_schema (name + " -args" , parameters) + " \" </function>\" " );
264+ tool_rules.push_back (tool_rule);
265+ }
266+ auto tool_call = builder.add_rule (" tool_call" , join (tool_rules.begin (), tool_rules.end (), " | " )) + " space" ;
267+ builder.add_rule (" root" , parallel_tool_calls ? " (" + tool_call + " )+" : tool_call);
268+ if (allow_content) {
269+ handler.grammar_trigger_words .push_back (" <function=" );
270+ }
271+ });
272+ // handler.parser = parse_functionary_3_2_tool_calls;
211273 } else if (needs_hermes_pro_tool_call (chat_template)) {
212274 // NousResearchHermesPro_2
213275 // (content)?(<tool_call>{"name": "foo", "arguments": {"a": 1}}</tool_call>)*
0 commit comments