Skip to content

Commit 7e017cf

Browse files
committed
server : add Hermes-3 tool call support
1 parent 7ea8d80 commit 7e017cf

File tree

6 files changed

+240
-37
lines changed

6 files changed

+240
-37
lines changed

common/common.cpp

Lines changed: 17 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -2253,6 +2253,10 @@ bool string_parse_kv_override(const char * data, std::vector<llama_model_kv_over
22532253
return true;
22542254
}
22552255

2256+
bool string_contains(std::string haystack, std::string needle) {
2257+
return haystack.find(needle) != std::string::npos;
2258+
}
2259+
22562260
//
22572261
// Filesystem utils
22582262
//
@@ -3186,6 +3190,19 @@ std::string llama_chat_format_example(const struct llama_model * model,
31863190
return llama_chat_apply_template(model, tmpl, msgs, true);
31873191
}
31883192

3193+
std::string llama_get_chat_template(const struct llama_model * model) {
3194+
std::string template_key = "tokenizer.chat_template";
3195+
// call with NULL buffer to get the total size of the string
3196+
int32_t res = llama_model_meta_val_str(model, template_key.c_str(), NULL, 0);
3197+
if (res < 0) {
3198+
return "";
3199+
} else {
3200+
std::vector<char> model_template(res, 0);
3201+
llama_model_meta_val_str(model, template_key.c_str(), model_template.data(), model_template.size());
3202+
return std::string(model_template.data(), model_template.size());
3203+
}
3204+
}
3205+
31893206
//
31903207
// KV cache utils
31913208
//

common/common.h

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -320,6 +320,8 @@ static std::vector<T> string_split(const std::string & str, char delim) {
320320
bool string_parse_kv_override(const char * data, std::vector<llama_model_kv_override> & overrides);
321321
void string_process_escapes(std::string & input);
322322

323+
bool string_contains(std::string haystack, std::string needle);
324+
323325
//
324326
// Filesystem utils
325327
//
@@ -428,6 +430,10 @@ std::string llama_chat_format_single(const struct llama_model * model,
428430
std::string llama_chat_format_example(const struct llama_model * model,
429431
const std::string & tmpl);
430432

433+
// Returns the chat template stored inside the model
434+
// (empty string if model does not have built-in chat template)
435+
std::string llama_get_chat_template(const struct llama_model * model);
436+
431437
//
432438
// KV cache utils
433439
//

examples/server/server.cpp

Lines changed: 37 additions & 16 deletions
Original file line numberDiff line numberDiff line change
@@ -4,6 +4,7 @@
44
#include "json-schema-to-grammar.h"
55
#include "llama.h"
66
#include "grammar-parser.h"
7+
#include "tool-call.hpp"
78

89
#ifndef NDEBUG
910
// crash the server in debug mode, otherwise send an http 500 error
@@ -157,6 +158,7 @@ struct server_slot {
157158
std::string generated_text;
158159
std::vector<llama_token> cache_tokens;
159160
std::vector<completion_token_output> generated_token_probs;
161+
enum llama_response_state response_state = LLAMA_RESPONSE_STATE_UNKNOWN;
160162

161163
bool infill = false;
162164
bool embedding = false;
@@ -207,6 +209,7 @@ struct server_slot {
207209
infill = false;
208210
ga_i = 0;
209211
n_past_se = 0;
212+
response_state = LLAMA_RESPONSE_STATE_UNKNOWN;
210213

211214
generated_token_probs.clear();
212215
}
@@ -625,6 +628,7 @@ struct server_context {
625628
llama_model * model = nullptr;
626629
llama_context * ctx = nullptr;
627630
std::vector<llama_lora_adapter_container> lora_adapters;
631+
llama_tool_format tool_format = LLAMA_TOOL_FORMAT_NOT_SUPPORTED;
628632

629633
gpt_params params;
630634

@@ -1217,7 +1221,13 @@ struct server_context {
12171221
break;
12181222
}
12191223

1220-
if (!incomplete) {
1224+
if (slot.response_state == LLAMA_RESPONSE_STATE_UNKNOWN) {
1225+
slot.response_state = check_response_state(tool_format, slot.generated_text);
1226+
}
1227+
1228+
// if response is tool call, we cannot stream it
1229+
// instead, we wait for the full response, then extract JSON
1230+
if (!incomplete && slot.response_state == LLAMA_RESPONSE_STATE_TEXT) {
12211231
size_t pos = std::min(slot.n_sent_text, slot.generated_text.size());
12221232

12231233
const std::string str_test = slot.generated_text.substr(pos);
@@ -1247,9 +1257,7 @@ struct server_context {
12471257
if (slot.params.stream) {
12481258
send_partial_response(slot, result);
12491259
}
1250-
}
1251-
1252-
if (incomplete) {
1260+
} else {
12531261
slot.has_next_token = true;
12541262
}
12551263

@@ -1396,6 +1404,10 @@ struct server_context {
13961404
{"multimodal", false}
13971405
};
13981406

1407+
if (slot.response_state == LLAMA_RESPONSE_STATE_TOOL_CALL) {
1408+
res.data["tool_calls"] = parse_tool_response(tool_format, tkn.text_to_send);
1409+
}
1410+
13991411
if (slot.sparams.n_probs > 0) {
14001412
const std::vector<llama_token> to_send_toks = llama_tokenize(ctx, tkn.text_to_send, false);
14011413
const size_t probs_pos = std::min(slot.n_sent_token_probs, slot.generated_token_probs.size());
@@ -1444,6 +1456,10 @@ struct server_context {
14441456
{"timings", slot.get_formated_timings()}
14451457
};
14461458

1459+
if (slot.response_state == LLAMA_RESPONSE_STATE_TOOL_CALL) {
1460+
res.data["tool_calls"] = parse_tool_response(tool_format, slot.generated_text);
1461+
}
1462+
14471463
if (slot.sparams.n_probs > 0) {
14481464
std::vector<completion_token_output> probs;
14491465
if (!slot.params.stream && slot.stopped_word) {
@@ -2937,19 +2953,14 @@ int main(int argc, char ** argv) {
29372953
};
29382954

29392955
const auto handle_props = [&ctx_server](const httplib::Request &, httplib::Response & res) {
2940-
std::string template_key = "tokenizer.chat_template", curr_tmpl;
2941-
int32_t tlen = llama_model_meta_val_str(ctx_server.model, template_key.c_str(), nullptr, 0);
2942-
if (tlen > 0) {
2943-
std::vector<char> curr_tmpl_buf(tlen + 1, 0);
2944-
if (llama_model_meta_val_str(ctx_server.model, template_key.c_str(), curr_tmpl_buf.data(), curr_tmpl_buf.size()) == tlen) {
2945-
curr_tmpl = std::string(curr_tmpl_buf.data(), tlen);
2946-
}
2947-
}
2956+
std::string chat_tmpl = ctx_server.params.chat_template.empty()
2957+
? llama_get_chat_template(ctx_server.model)
2958+
: ctx_server.params.chat_template;
29482959
json data = {
29492960
{ "system_prompt", ctx_server.system_prompt.c_str() },
29502961
{ "default_generation_settings", ctx_server.default_generation_settings_for_props },
29512962
{ "total_slots", ctx_server.params.n_parallel },
2952-
{ "chat_template", curr_tmpl.c_str() }
2963+
{ "chat_template", chat_tmpl },
29532964
};
29542965

29552966
res.set_content(data.dump(), MIMETYPE_JSON);
@@ -3056,7 +3067,13 @@ int main(int argc, char ** argv) {
30563067
res_error(res, format_error_response("This server does not support chat completions. Start it without `--embeddings`", ERROR_TYPE_NOT_SUPPORTED));
30573068
return;
30583069
}
3059-
json data = oaicompat_completion_params_parse(ctx_server.model, json::parse(req.body), params.chat_template);
3070+
json body = json::parse(req.body);
3071+
3072+
if (body.contains("tools") && ctx_server.tool_format != LLAMA_TOOL_FORMAT_NOT_SUPPORTED) {
3073+
body["prompt"] = format_chat_with_tool(ctx_server.tool_format, body.at("messages"), body.at("tools"));
3074+
}
3075+
3076+
json data = oaicompat_completion_params_parse(ctx_server.model, body, params.chat_template);
30603077

30613078
const int id_task = ctx_server.queue_tasks.get_new_id();
30623079

@@ -3423,11 +3440,15 @@ int main(int argc, char ** argv) {
34233440
}
34243441
}
34253442

3443+
// decide if we can enable tool calls
3444+
ctx_server.tool_format = get_tool_format(ctx_server.ctx);
3445+
34263446
// print sample chat example to make it clear which template is used
34273447
{
34283448
LOG_INFO("chat template", {
3429-
{"chat_example", llama_chat_format_example(ctx_server.model, params.chat_template)},
3430-
{"built_in", params.chat_template.empty()},
3449+
{"chat_example", llama_chat_format_example(ctx_server.model, params.chat_template)},
3450+
{"built_in", params.chat_template.empty()},
3451+
{"tool_call_enabled", ctx_server.tool_format != LLAMA_TOOL_FORMAT_NOT_SUPPORTED },
34313452
});
34323453
}
34333454

examples/server/tool-call.hpp

Lines changed: 114 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,114 @@
1+
#pragma once
2+
3+
#include "llama.h"
4+
#include "common.h"
5+
#include "utils.hpp"
6+
7+
// Change JSON_ASSERT from assert() to GGML_ASSERT:
8+
#define JSON_ASSERT GGML_ASSERT
9+
#include "json.hpp"
10+
11+
#include <string>
12+
#include <vector>
13+
#include <sstream>
14+
15+
using json = nlohmann::ordered_json;
16+
17+
enum llama_tool_format {
18+
LLAMA_TOOL_FORMAT_NOT_SUPPORTED,
19+
LLAMA_TOOL_FORMAT_HERMES_3,
20+
};
21+
22+
enum llama_response_state {
23+
LLAMA_RESPONSE_STATE_UNKNOWN,
24+
LLAMA_RESPONSE_STATE_TEXT,
25+
LLAMA_RESPONSE_STATE_TOOL_CALL,
26+
};
27+
28+
// get the tool call format for the loaded model
29+
// this function does linear search, so do not call it repeatedly
30+
inline enum llama_tool_format get_tool_format(const struct llama_context * ctx) {
31+
auto model = llama_get_model(ctx);
32+
auto has_token = [&](std::string piece) {
33+
for (int i = 0; i < llama_n_vocab(model); i++) {
34+
const std::string token_str = llama_token_to_piece(ctx, i, true);
35+
if (token_str == piece) {
36+
return true;
37+
}
38+
}
39+
return false;
40+
};
41+
if (has_token("<|im_start|>") && has_token("<tool_call>")) {
42+
return LLAMA_TOOL_FORMAT_HERMES_3;
43+
}
44+
return LLAMA_TOOL_FORMAT_NOT_SUPPORTED;
45+
}
46+
47+
inline std::string format_chat_with_tool(enum llama_tool_format format, const std::vector<json> & messages, json tools) {
48+
if (!tools.is_array()) {
49+
throw std::runtime_error("tools must be an array");
50+
}
51+
std::stringstream ss;
52+
auto chat = parse_chat_messages(messages);
53+
if (format == LLAMA_TOOL_FORMAT_HERMES_3) {
54+
ss << "<|im_start|>system\n\n";
55+
ss << "You are a function calling AI model. You are provided with function signatures within <tools></tools> XML tags. You may call one or more functions to assist with the user query. Don't make assumptions about what values to plug into functions. Here are the available tools: <tools>\n\n";
56+
for (auto tool : tools) {
57+
ss << tool.dump(1, '\t') << "\n\n";
58+
}
59+
ss << "</tools> Use the following pydantic model json schema for each tool call you will make: {\"properties\": {\"arguments\": {\"title\": \"Arguments\", \"type\": \"object\"}, \"name\": {\"title\": \"Name\", \"type\": \"string\"}}, \"required\": [\"arguments\", \"name\"], \"title\": \"FunctionCall\", \"type\": \"object\"} For each function call return a json object with function name and arguments within <tool_call></tool_call> XML tags as follows:\n";
60+
ss << "<tool_call>\n";
61+
ss << "{\"arguments\": <args-dict>, \"name\": <function-name>}\n";
62+
ss << "</tool_call><|im_end|>\n";
63+
for (auto & message : chat) {
64+
std::string role(message.role);
65+
if (role == "system") {
66+
continue; // for optimal performance, we skip user-defined system message
67+
}
68+
ss << "<|im_start|>" << role << "\n\n";
69+
if (role == "tool") {
70+
ss << "<tool_response>\n" << string_strip(message.content) << "\n</tool_response>\n";
71+
} else {
72+
ss << string_strip(message.content) << "<|im_end|>\n";
73+
}
74+
}
75+
ss << "<|im_start|>assistant\n\n";
76+
} else {
77+
throw std::runtime_error("tool_call is not supported by this model");
78+
}
79+
LOG_VERBOSE("format_chat_with_tool", {{"text", ss.str()}});
80+
return ss.str();
81+
}
82+
83+
// check if the response is text or tool_call
84+
// if it is tool_call, we may have to disable streaming, because we must parse the whole JSON response
85+
inline enum llama_response_state check_response_state(enum llama_tool_format format, const std::string & generated_text) {
86+
if (format == LLAMA_TOOL_FORMAT_NOT_SUPPORTED) {
87+
return LLAMA_RESPONSE_STATE_TEXT;
88+
} else if (format == LLAMA_TOOL_FORMAT_HERMES_3 && generated_text.rfind("<tool_call>", 0) == 0) {
89+
return LLAMA_RESPONSE_STATE_TOOL_CALL;
90+
}
91+
return LLAMA_RESPONSE_STATE_TEXT;
92+
}
93+
94+
// convert model's response to OAI format
95+
inline json parse_tool_response(enum llama_tool_format format, const std::string & generated_text) {
96+
if (format == LLAMA_TOOL_FORMAT_NOT_SUPPORTED) {
97+
return json{};
98+
} else if (format == LLAMA_TOOL_FORMAT_HERMES_3) {
99+
std::string tmp(generated_text);
100+
string_replace_all(tmp, "<tool_call>", "");
101+
string_replace_all(tmp, "</tool_call>", "");
102+
json tool = json::parse(tmp);
103+
std::vector<json> tool_calls = {json{
104+
{"id", tool.at("name")},
105+
{"type", "function"},
106+
{"function", {
107+
{"name", tool.at("name")},
108+
{"arguments", tool.at("arguments").dump()}, // OAI requires this to be JSON-stringified
109+
}},
110+
}};
111+
return tool_calls;
112+
}
113+
return generated_text;
114+
}

0 commit comments

Comments
 (0)