Skip to content

Commit 3cfc21e

Browse files
author
ochafik
committed
tool-call: basic Functionary 3.2, Llama 3.1, Hermes 2 Pro grammar generators + parsers
1 parent 26c175b commit 3cfc21e

File tree

6 files changed

+443
-1
lines changed

6 files changed

+443
-1
lines changed

Makefile

Lines changed: 13 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -55,6 +55,7 @@ TEST_TARGETS = \
5555
tests/test-grammar-parser \
5656
tests/test-json-schema-to-grammar \
5757
tests/test-minja \
58+
tests/test-tool-call \
5859
tests/test-llama-grammar \
5960
tests/test-log \
6061
tests/test-model-load-cancel \
@@ -940,7 +941,8 @@ OBJ_COMMON = \
940941
common/sampling.o \
941942
common/train.o \
942943
common/build-info.o \
943-
common/json-schema-to-grammar.o
944+
common/json-schema-to-grammar.o \
945+
common/tool-call.o
944946

945947
OBJ_ALL = $(OBJ_GGML) $(OBJ_LLAMA) $(OBJ_COMMON)
946948

@@ -1201,6 +1203,11 @@ common/json-schema-to-grammar.o: \
12011203
common/json-schema-to-grammar.h
12021204
$(CXX) $(CXXFLAGS) -c $< -o $@
12031205

1206+
common/tool-call.o: \
1207+
common/tool-call.cpp \
1208+
common/tool-call.h
1209+
$(CXX) $(CXXFLAGS) -c $< -o $@
1210+
12041211
common/train.o: \
12051212
common/train.cpp \
12061213
common/train.h
@@ -1574,6 +1581,11 @@ tests/test-antiprompts: tests/test-antiprompts.cpp \
15741581
$(CXX) $(CXXFLAGS) -Iexamples/server -c $< -o $(call GET_OBJ_FILE, $<)
15751582
$(CXX) $(CXXFLAGS) $(filter-out %.h $<,$^) $(call GET_OBJ_FILE, $<) -o $@ $(LDFLAGS)
15761583

1584+
tests/test-tool-call: tests/test-tool-call.cpp \
1585+
$(OBJ_ALL)
1586+
$(CXX) $(CXXFLAGS) -Iexamples/server -c $< -o $(call GET_OBJ_FILE, $<)
1587+
$(CXX) $(CXXFLAGS) $(filter-out %.h $<,$^) $(call GET_OBJ_FILE, $<) -o $@ $(LDFLAGS)
1588+
15771589
tests/test-minja: tests/test-minja.cpp \
15781590
$(OBJ_ALL)
15791591
$(CXX) $(CXXFLAGS) -Iexamples/server -c $< -o $(call GET_OBJ_FILE, $<)

common/CMakeLists.txt

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -67,6 +67,7 @@ add_library(${TARGET} STATIC
6767
ngram-cache.h
6868
sampling.cpp
6969
sampling.h
70+
tool-call.cpp
7071
train.cpp
7172
train.h
7273
)

common/tool-call.cpp

Lines changed: 274 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,274 @@
1+
#include "tool-call.h"
2+
#include "json-schema-to-grammar.h"
3+
#include <algorithm>
4+
#include <fstream>
5+
#include <map>
6+
#include <regex>
7+
#include <sstream>
8+
#include <string>
9+
#include <unordered_map>
10+
#include <unordered_set>
11+
#include <vector>
12+
13+
using json = nlohmann::ordered_json;
14+
15+
static bool needs_functionary_3_2_tool_call(const std::string & chat_template) {
16+
return chat_template.find("<|start_header_id|>") != std::string::npos
17+
&& chat_template.find(">>>all") != std::string::npos;
18+
}
19+
20+
static bool needs_llama_3_1_tool_call(const std::string & chat_template) {
21+
return chat_template.find("<|start_header_id|>") != std::string::npos
22+
&& chat_template.find("<|python_tag|>") != std::string::npos;
23+
}
24+
25+
static bool needs_hermes_pro_tool_call(const std::string & chat_template) {
26+
return chat_template.find("<tool_call>") != std::string::npos;
27+
}
28+
29+
static bool parse_json(std::string::const_iterator & it, const std::string::const_iterator & end, json & out) {
30+
// // https://json.nlohmann.me/features/parsing/sax_interface/
31+
struct json_error_locator : public nlohmann::json_sax<json> {
32+
std::size_t position;
33+
bool found_error;
34+
35+
bool parse_error(std::size_t position, const std::string & last_token, const json::exception & ex) override {
36+
// LOG_WARNING("JSON error (Expected)", {{"position", position}, {"last_token", last_token}, {"error", ex.what()}});
37+
this->position = position - 1;
38+
this->found_error = true;
39+
return false;
40+
}
41+
bool null() override { return true; }
42+
bool boolean(bool) override { return true; }
43+
bool number_integer(number_integer_t) override { return true; }
44+
bool number_unsigned(number_unsigned_t) override { return true; }
45+
bool number_float(number_float_t, const string_t &) override { return true; }
46+
bool string(string_t &) override { return true; }
47+
bool binary(binary_t &) override { return true; }
48+
bool start_object(std::size_t) override { return true; }
49+
bool key(string_t &) override { return true; }
50+
bool end_object() override { return true; }
51+
bool start_array(std::size_t) override { return true; }
52+
bool end_array() override { return true; }
53+
};
54+
json_error_locator err_loc;
55+
json::sax_parse(it, end, &err_loc);
56+
57+
std::string::const_iterator temptative_end;
58+
if (err_loc.found_error) {
59+
temptative_end = it + err_loc.position;
60+
} else {
61+
temptative_end = end;
62+
}
63+
std::string json_sub {it, it + err_loc.position};
64+
// LOG_WARNING("Parsing json", {{"json_sub", json_sub}});
65+
try {
66+
out = json::parse(json_sub);
67+
it = temptative_end;
68+
return true;
69+
} catch (const std::exception & e) {
70+
// LOG_WARNING("Failed to parse tool call", {{"json_sub", json_sub}, {"error", e.what()}});
71+
return false;
72+
}
73+
}
74+
75+
static llama_tool_calls parse_hermes_tool_calls(const std::string& input) {
76+
try {
77+
std::regex start_pattern(R"([\n\s]*<tool_call>)");
78+
std::regex middle_pattern(R"([\n\s]*</tool_call>[\n\s]*<tool_call>)");
79+
std::regex end_pattern(R"([\n\s]*</tool_call>[\n\s]*$)");
80+
81+
auto end = input.end();
82+
std::sregex_iterator rend;
83+
std::sregex_iterator rit(input.begin(), end, start_pattern);
84+
if (rit == rend) {
85+
return {input, {}};
86+
}
87+
88+
llama_tool_calls result;
89+
result.content = rit->prefix();
90+
91+
auto it = rit->suffix().first;
92+
while (it != end) {
93+
json call;
94+
if (!parse_json(it, end, call)) {
95+
throw std::runtime_error("Failed to parse json tool call");
96+
}
97+
result.tool_calls.push_back({
98+
call["name"],
99+
call["arguments"].dump(),
100+
});
101+
rit = {it, end, middle_pattern};
102+
if (rit != rend) {
103+
it = rit->suffix().first;
104+
} else {
105+
rit = {it, end, end_pattern};
106+
if (rit == rend) {
107+
throw std::runtime_error("Malformed input, missing </tool_call>");
108+
}
109+
break;
110+
}
111+
}
112+
return result;
113+
} catch (const std::exception & e) {
114+
return {input, {}};
115+
}
116+
}
117+
118+
static llama_tool_calls parse_llama_3_1_tool_calls(const json & tools, const std::string& input) {
119+
static std::regex python_tag_regex(R"(<\|python_tag\|>([\s\S\n]*)$)");
120+
std::smatch match;
121+
if (std::regex_search(input, match, python_tag_regex)) {
122+
return {
123+
match.prefix().str(), {
124+
{"ipython", (json {{"code", match[1].str()}}).dump()},
125+
}
126+
};
127+
}
128+
try {
129+
auto call = json::parse(input);
130+
// Only treat JSON as a tool call if it has a name attribute that matches any of the tools specified in the request.
131+
// There doesn't seem to be any better way to detect a tool call.
132+
if (call.contains("name") && call["name"].is_string()) {
133+
std::string name = call["name"];
134+
for (const auto & tool : tools) {
135+
if (tool.at("function").at("name") == name) {
136+
return {
137+
"",
138+
{
139+
{name, call["parameters"].dump()},
140+
}
141+
};
142+
}
143+
}
144+
}
145+
} catch (const std::exception & e) {
146+
// Do nothing
147+
}
148+
return {input, {}};
149+
}
150+
151+
152+
static llama_tool_calls parse_functionary_3_2_tool_calls(const std::string& input) {
153+
static std::regex python_tag_regex(R"(>>>(\w+)\n((?!>>>)[\s\S\n]*))");
154+
std::smatch match;
155+
llama_tool_calls result;
156+
std::string content;
157+
std::string in = input;
158+
while (std::regex_search(in, match, python_tag_regex)) {
159+
content += match.prefix().str();
160+
result.tool_calls.push_back({
161+
match[1].str(),
162+
(json {{"code", match[2].str()}}).dump(),
163+
});
164+
in = match.suffix().str();
165+
}
166+
result.content = content + in;
167+
return result;
168+
}
169+
170+
llama_tool_calls parse_tool_calls(const json & tools, const std::string & chat_template, const std::string& input) {
171+
if (needs_hermes_pro_tool_call(chat_template)) {
172+
return parse_hermes_tool_calls(input);
173+
} else if (needs_llama_3_1_tool_call(chat_template)) {
174+
return parse_llama_3_1_tool_calls(tools, input);
175+
} else if (needs_functionary_3_2_tool_call(chat_template)) {
176+
return parse_functionary_3_2_tool_calls(input);
177+
} else {
178+
throw std::runtime_error("Unsupported chat template for tool calls");
179+
}
180+
}
181+
182+
llama_tool_call_handler llama_tool_call_handler_init(
183+
const std::string & chat_template,
184+
bool allow_content,
185+
bool parallel_tool_calls,
186+
const nlohmann::ordered_json & tools)
187+
{
188+
llama_tool_call_handler handler;
189+
190+
if (needs_functionary_3_2_tool_call(chat_template)) {
191+
// MeetKaiFunctionary_3_2
192+
// >>>all\nlet's call functions>>>fn1\n{"arg1": 1...}\n>>>fn2\n{"arg1": 1...}...
193+
// Using ">>>f1\n", ">>>f2\n"... as trigger words for the grammar
194+
handler.grammar = build_grammar([&](const llama_grammar_builder & builder) {
195+
std::vector<std::string> tool_rules;
196+
for (size_t i = 0, n = tools.size(); i < n; i++) {
197+
auto & tool = tools[i];
198+
const auto & function = tool["function"];
199+
std::string name = function["name"];
200+
auto parameters = function["parameters"];
201+
auto tool_rule = builder.add_rule(name + "-call", "\">>>" + name + "\\n\" " + builder.add_schema(name + "-args", parameters));
202+
tool_rules.push_back(tool_rule);
203+
if (allow_content) {
204+
handler.grammar_trigger_words.push_back(">>>" + name + "\n");
205+
}
206+
}
207+
auto tool_call = builder.add_rule("tool_call", join(tool_rules.begin(), tool_rules.end(), " | ")) + " space";
208+
builder.add_rule("root", parallel_tool_calls ? "(" + tool_call + ")+" : tool_call);
209+
});
210+
// handler.parser = parse_functionary_3_2_tool_calls;
211+
} else if (needs_hermes_pro_tool_call(chat_template)) {
212+
// NousResearchHermesPro_2
213+
// (content)?(<tool_call>{"name": "foo", "arguments": {"a": 1}}</tool_call>)*
214+
handler.grammar = build_grammar([&](const llama_grammar_builder & builder) {
215+
std::vector<std::string> tool_rules;
216+
for (const auto & tool : tools) {
217+
const auto & function = tool["function"];
218+
std::string name = function["name"];
219+
auto parameters = function["parameters"];
220+
builder.resolve_refs(parameters);
221+
tool_rules.push_back(builder.add_schema(name + "-call", {
222+
{"type", "object"},
223+
{"properties", json {
224+
{"name", json {{"const", name}}},
225+
{"arguments", parameters},
226+
}},
227+
{"required", json::array({"name", "arguments"})},
228+
}));
229+
}
230+
231+
auto tool_call = "\"<tool_call>\" " + builder.add_rule("tool_call", join(tool_rules.begin(), tool_rules.end(), " | ")) + " \"</tool_call>\" space";
232+
builder.add_rule("root", parallel_tool_calls ? "(" + tool_call + ")+" : tool_call);
233+
if (allow_content) {
234+
handler.grammar_trigger_words.push_back("<tool_call>");
235+
}
236+
});
237+
} else if (needs_llama_3_1_tool_call(chat_template)) {
238+
handler.grammar = build_grammar([&](const llama_grammar_builder & builder) {
239+
static std::vector<std::string> builtin_tools {"wolfram_alpha", "brave_search"};
240+
std::vector<std::string> tool_rules;
241+
242+
for (const auto & tool : tools) {
243+
const auto & function = tool["function"];
244+
std::string name = function["name"];
245+
auto parameters = function["parameters"];
246+
builder.resolve_refs(parameters);
247+
if (name == "ipython" || std::find(builtin_tools.begin(), builtin_tools.end(), name) != builtin_tools.end()) {
248+
tool_rules.push_back(builder.add_rule("ipython-call", "\"<|python_tag|>\" .*"));
249+
if (allow_content) {
250+
handler.grammar_trigger_words.push_back("<|python_tag|>");
251+
}
252+
} else {
253+
//"<|start_header_id|>assistant<|end_header_id|>\n\n{\"name\": \"" + name + "\", " +
254+
tool_rules.push_back(
255+
builder.add_rule(
256+
name + "-call",
257+
"\"\\n{\\\"name\\\": " + name + "\\\", \\\"parameters\\\", \" " +
258+
builder.add_schema(name + "-args", parameters) +
259+
" \"}\""));
260+
if (allow_content) {
261+
handler.grammar_trigger_words.push_back("\n{\"" + name + "\"");
262+
}
263+
}
264+
}
265+
266+
builder.add_rule("root", join(tool_rules.begin(), tool_rules.end(), " | "));
267+
});
268+
handler.additional_stop_words.push_back("<|eom_id|>");
269+
} else {
270+
// TODO: generic thoughtful schema.
271+
throw std::runtime_error("Unsupported tool call style!");
272+
}
273+
return handler;
274+
}

common/tool-call.h

Lines changed: 30 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,30 @@
1+
#pragma once
2+
3+
#include "ggml.h"
4+
// Change JSON_ASSERT from assert() to GGML_ASSERT:
5+
#define JSON_ASSERT GGML_ASSERT
6+
#include "json.hpp"
7+
8+
struct llama_tool_call {
9+
std::string name;
10+
std::string arguments;
11+
};
12+
13+
struct llama_tool_calls {
14+
std::string content;
15+
std::vector<llama_tool_call> tool_calls;
16+
};
17+
18+
struct llama_tool_call_handler {
19+
std::string grammar;
20+
std::vector<std::string> grammar_trigger_words;
21+
std::vector<std::string> additional_stop_words;
22+
};
23+
24+
llama_tool_calls parse_tool_calls(const nlohmann::ordered_json & tools, const std::string & chat_template, const std::string& input);
25+
26+
llama_tool_call_handler llama_tool_call_handler_init(
27+
const std::string & chat_template,
28+
bool allow_content,
29+
bool parallel_tool_calls,
30+
const nlohmann::ordered_json & tools);

tests/CMakeLists.txt

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -124,6 +124,7 @@ llama_target_and_test(test-barrier.cpp)
124124
llama_target_and_test(test-backend-ops.cpp)
125125
llama_target_and_test(test-antiprompts.cpp)
126126
llama_target_and_test(test-minja.cpp)
127+
llama_target_and_test(test-tool-call.cpp)
127128

128129
llama_target_and_test(test-rope.cpp)
129130

0 commit comments

Comments
 (0)