Skip to content

Commit 3ebdb2b

Browse files
author
ochafik
committed
tool-call: support tool_use variant in llama_chat_template_from_model + drop llama_get_chat_template
1 parent 92c384a commit 3ebdb2b

File tree

4 files changed

+27
-22
lines changed

4 files changed

+27
-22
lines changed

common/common.cpp

Lines changed: 13 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -1719,12 +1719,21 @@ static std::string _llama_model_meta_val_str(const struct llama_model * model, c
17191719

17201720
minja::chat_template llama_chat_template_from_model(
17211721
const struct llama_model * model,
1722-
const char * chat_template_override)
1722+
const std::string & chat_template_override,
1723+
bool prefer_tool_use)
17231724
{
17241725
// TODO: handle "chatml"?
1725-
std::string chat_template = chat_template_override
1726-
? chat_template_override
1727-
: _llama_model_meta_val_str(model, "tokenizer.chat_template");
1726+
std::string chat_template = chat_template_override;
1727+
if (chat_template.empty()) {
1728+
if (prefer_tool_use) {
1729+
chat_template = _llama_model_meta_val_str(model, "tokenizer.chat_template.tool_use");
1730+
fprintf(stderr, "# tokenizer.chat_template.tool_use: %s\n", chat_template.c_str());
1731+
}
1732+
if (chat_template.empty()) {
1733+
chat_template = _llama_model_meta_val_str(model, "tokenizer.chat_template");
1734+
fprintf(stderr, "# tokenizer.chat_template: %s\n", chat_template.c_str());
1735+
}
1736+
}
17281737
auto bos_token = _common_token_to_piece(model, llama_token_bos(model), true);
17291738
auto eos_token = _common_token_to_piece(model, llama_token_eos(model), true);
17301739
return {std::move(chat_template), bos_token, eos_token};

common/common.h

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -529,7 +529,8 @@ std::string common_chat_format_example(const struct llama_model * model,
529529

530530
minja::chat_template llama_chat_template_from_model(
531531
const struct llama_model * model,
532-
const char * chat_template_override = nullptr);
532+
const std::string & chat_template_override = "",
533+
bool prefer_tool_use = false);
533534

534535
//
535536
// KV cache utils

examples/server/server.cpp

Lines changed: 12 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -2923,13 +2923,20 @@ int main(int argc, char ** argv) {
29232923
};
29242924

29252925
const auto handle_props = [&ctx_server, &res_ok](const httplib::Request &, httplib::Response & res) {
2926+
auto chat_template = llama_chat_template_from_model(ctx_server.model, ctx_server.params.chat_template, /* prefer_tool_use= */ false);
29262927
json data = {
29272928
{ "default_generation_settings", ctx_server.default_generation_settings_for_props },
29282929
{ "total_slots", ctx_server.params.n_parallel },
29292930
{ "bos_token", common_token_to_piece(ctx_server.ctx, llama_token_bos(ctx_server.model), true) },
29302931
{ "eos_token", common_token_to_piece(ctx_server.ctx, llama_token_eos(ctx_server.model), true) },
2931-
{ "chat_template", llama_get_chat_template(ctx_server.model) },
2932+
{ "chat_template", chat_template.source()},
29322933
};
2934+
if (ctx_server.params.use_jinja) {
2935+
auto tool_use_chat_template = llama_chat_template_from_model(ctx_server.model, ctx_server.params.chat_template, /* prefer_tool_use= */ true);
2936+
if (tool_use_chat_template.source() != chat_template.source()) {
2937+
data["chat_template_tool_use"] = tool_use_chat_template.source();
2938+
}
2939+
}
29332940

29342941
res_ok(res, data);
29352942
};
@@ -3030,13 +3037,14 @@ int main(int argc, char ** argv) {
30303037
return;
30313038
}
30323039

3033-
static auto chat_template = llama_chat_template_from_model(ctx_server.model, params.chat_template.empty() ? nullptr : params.chat_template.c_str());
3034-
static auto tool_call_style = llama_tool_call_style_detect(chat_template);
3040+
auto body = json::parse(req.body);
3041+
auto chat_template = llama_chat_template_from_model(ctx_server.model, params.chat_template, /* prefer_tool_use= */ body.contains("tools"));
3042+
auto tool_call_style = llama_tool_call_style_detect(chat_template);
30353043
LOG_INF("Tool call style: %s\n", llama_tool_call_style_name(tool_call_style).c_str());
30363044

30373045
json data;
30383046
try {
3039-
data = oaicompat_completion_params_parse(ctx_server.model, json::parse(req.body), chat_template, tool_call_style, params.use_jinja);
3047+
data = oaicompat_completion_params_parse(ctx_server.model, body, chat_template, tool_call_style, params.use_jinja);
30403048
} catch (const std::exception & e) {
30413049
res_error(res, format_error_response(e.what(), ERROR_TYPE_NOT_SUPPORTED));
30423050
return;

examples/server/utils.hpp

Lines changed: 0 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -93,19 +93,6 @@ inline std::string format_chat(const struct llama_model * model, const std::stri
9393
return formatted_chat;
9494
}
9595

96-
static std::string llama_get_chat_template(const struct llama_model * model) {
97-
std::string template_key = "tokenizer.chat_template";
98-
// call with NULL buffer to get the total size of the string
99-
int32_t res = llama_model_meta_val_str(model, template_key.c_str(), NULL, 0);
100-
if (res < 0) {
101-
return "";
102-
} else {
103-
std::vector<char> model_template(res, 0);
104-
llama_model_meta_val_str(model, template_key.c_str(), model_template.data(), model_template.size());
105-
return std::string(model_template.data(), model_template.size());
106-
}
107-
}
108-
10996
//
11097
// base64 utils (TODO: move to common in the future)
11198
//

0 commit comments

Comments
 (0)