Skip to content

Commit d3830ad

Browse files
committed
Tool calls support improvements (support null content in messages, handle correctly previous messages that contains tool calls, improvements in prompt format)
1 parent d25cd7f commit d3830ad

File tree

3 files changed

+55
-13
lines changed

3 files changed

+55
-13
lines changed

common/common.h

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -407,6 +407,10 @@ std::string llama_detokenize(
407407
struct llama_chat_msg {
408408
std::string role;
409409
std::string content;
410+
std::string tool_calls;
411+
412+
llama_chat_msg(std::string role, std::string content, std::string tool_calls = "")
413+
: role(role), content(content), tool_calls(tool_calls) {}
410414
};
411415

412416
// Check if the template supplied via "--chat-template" is supported or not. Returns true if it's valid

examples/server/tool-call.hpp

Lines changed: 15 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -51,11 +51,9 @@ inline std::string format_chat_with_tool(enum llama_tool_format format, const st
5151
std::stringstream ss;
5252
auto chat = parse_chat_messages(messages);
5353
if (format == LLAMA_TOOL_FORMAT_HERMES_3) {
54-
ss << "<|im_start|>system\n\n";
55-
ss << "You are a function calling AI model. You are provided with function signatures within <tools></tools> XML tags. You may call one or more functions to assist with the user query. Don't make assumptions about what values to plug into functions. Here are the available tools: <tools>\n\n";
56-
for (auto tool : tools) {
57-
ss << tool.dump(1, '\t') << "\n\n";
58-
}
54+
ss << "<|im_start|>system\n";
55+
ss << "You are a function calling AI model. You are provided with function signatures within <tools></tools> XML tags. You may call one or more functions to assist with the user query. Don't make assumptions about what values to plug into functions. Here are the available tools: <tools> ";
56+
ss << tools.dump() << " ";
5957
ss << "</tools> Use the following pydantic model json schema for each tool call you will make: {\"properties\": {\"arguments\": {\"title\": \"Arguments\", \"type\": \"object\"}, \"name\": {\"title\": \"Name\", \"type\": \"string\"}}, \"required\": [\"arguments\", \"name\"], \"title\": \"FunctionCall\", \"type\": \"object\"} For each function call return a json object with function name and arguments within <tool_call></tool_call> XML tags as follows:\n";
6058
ss << "<tool_call>\n";
6159
ss << "{\"arguments\": <args-dict>, \"name\": <function-name>}\n";
@@ -65,14 +63,22 @@ inline std::string format_chat_with_tool(enum llama_tool_format format, const st
6563
if (role == "system") {
6664
continue; // for optimal performance, we skip user-defined system message
6765
}
68-
ss << "<|im_start|>" << role << "\n\n";
66+
ss << "<|im_start|>" << role << "\n";
6967
if (role == "tool") {
7068
ss << "<tool_response>\n" << string_strip(message.content) << "\n</tool_response>\n";
71-
} else {
72-
ss << string_strip(message.content) << "<|im_end|>\n";
7369
}
70+
else if (role == "user") {
71+
ss << string_strip(message.content);
72+
}
73+
else if (role == "assistant") {
74+
ss << string_strip(message.content);
75+
if (!message.tool_calls.empty()) {
76+
ss << "<tool_call>\n" << string_strip(message.tool_calls) << "\n</tool_call>";
77+
}
78+
}
79+
ss << "<|im_end|>\n";
7480
}
75-
ss << "<|im_start|>assistant\n\n";
81+
ss << "<|im_start|>assistant\n";
7682
} else {
7783
throw std::runtime_error("tool_call is not supported by this model");
7884
}

examples/server/utils.hpp

Lines changed: 36 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -124,24 +124,56 @@ inline std::vector<llama_chat_msg> parse_chat_messages(const std::vector<json> &
124124

125125
std::string role = json_value(curr_msg, "role", std::string(""));
126126

127+
bool content_found = false;
128+
bool valid_content_found = false;
127129
std::string content;
128130
if (curr_msg.contains("content")) {
131+
content_found = true;
129132
if (curr_msg["content"].is_string()) {
133+
valid_content_found = true;
130134
content = curr_msg["content"].get<std::string>();
131135
} else if (curr_msg["content"].is_array()) {
136+
valid_content_found = true;
132137
for (const auto & part : curr_msg["content"]) {
133138
if (part.contains("text")) {
134139
content += "\n" + part["text"].get<std::string>();
135140
}
136141
}
137-
} else {
138-
throw std::runtime_error("Invalid 'content' type (ref: https://github.com/ggerganov/llama.cpp/issues/8367)");
139142
}
140-
} else {
143+
}
144+
145+
std::string tool_calls;
146+
if (role == "assistant") {
147+
const std::string tool_calls_field_name = "tool_calls";
148+
std::string tool_calls_parsed_text;
149+
if (curr_msg.contains(tool_calls_field_name)) {
150+
content_found = true;
151+
if (curr_msg[tool_calls_field_name].is_array()) {
152+
std::size_t tool_calls_count = curr_msg[tool_calls_field_name].size();
153+
if (tool_calls_count > 1) {
154+
throw std::runtime_error("Parallel tool calls are not supported yet");
155+
}
156+
else if (tool_calls_count == 1) {
157+
valid_content_found = true;
158+
json tool_call_function(curr_msg[tool_calls_field_name][0]["function"]);
159+
if (tool_call_function["arguments"].is_string()) {
160+
tool_calls = "{\"arguments\": " + tool_call_function["arguments"].get<std::string>() +
161+
", \"name\": \"" + tool_call_function["name"].get<std::string>() + "\"}";
162+
}
163+
else tool_calls = tool_call_function["function"].dump();
164+
}
165+
}
166+
}
167+
}
168+
169+
if (!content_found) {
141170
throw std::runtime_error("Missing 'content' (ref: https://github.com/ggerganov/llama.cpp/issues/8367)");
142171
}
172+
else if (!valid_content_found) {
173+
throw std::runtime_error("Invalid 'content' type (ref: https://github.com/ggerganov/llama.cpp/issues/8367)");
174+
}
143175

144-
chat.push_back({role, content});
176+
chat.push_back({role, content, tool_calls});
145177
}
146178
return chat;
147179
}

0 commit comments

Comments
 (0)