Skip to content

Commit cf7bece

Browse files
author
ochafik
committed
tool-call: factor chat template away from legacy API
1 parent d7ec84f commit cf7bece

File tree

15 files changed

+428
-396
lines changed

15 files changed

+428
-396
lines changed

Makefile

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -934,6 +934,7 @@ OBJ_LLAMA = \
934934

935935
OBJ_COMMON = \
936936
common/common.o \
937+
common/chat-template.o \
937938
common/arg.o \
938939
common/log.o \
939940
common/console.o \
@@ -1170,6 +1171,8 @@ $(LIB_LLAMA_S): \
11701171
common/common.o: \
11711172
common/common.cpp \
11721173
common/common.h \
1174+
common/chat-template.cpp \
1175+
common/chat-template.h \
11731176
common/console.h \
11741177
common/sampling.h \
11751178
common/json.hpp \
@@ -1465,6 +1468,7 @@ llama-server: \
14651468
examples/server/prompt-formats.js.hpp \
14661469
examples/server/json-schema-to-grammar.mjs.hpp \
14671470
examples/server/loading.html.hpp \
1471+
common/chat-template.h \
14681472
common/json.hpp \
14691473
common/stb_image.h \
14701474
$(OBJ_ALL)

common/CMakeLists.txt

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -54,6 +54,8 @@ add_library(${TARGET} STATIC
5454
arg.cpp
5555
arg.h
5656
base64.hpp
57+
chat-template.cpp
58+
chat-template.h
5759
common.cpp
5860
common.h
5961
console.cpp

common/chat-template.cpp

Lines changed: 118 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,118 @@
1+
#include "chat-template.h"
2+
#include "minja.hpp"
3+
#include "llama.h"
4+
5+
using json = nlohmann::ordered_json;
6+
7+
static std::string _llama_token_to_piece(const struct llama_model * model, llama_token token, bool special) {
8+
std::string piece;
9+
piece.resize(piece.capacity()); // using string internal cache, 15 bytes + '\n'
10+
const int n_chars = llama_token_to_piece(model, token, &piece[0], piece.size(), 0, special);
11+
if (n_chars < 0) {
12+
piece.resize(-n_chars);
13+
int check = llama_token_to_piece(model, token, &piece[0], piece.size(), 0, special);
14+
GGML_ASSERT(check == -n_chars);
15+
}
16+
else {
17+
piece.resize(n_chars);
18+
}
19+
20+
return piece;
21+
}
22+
23+
static std::string llama_model_meta_val_str(const struct llama_model * model, const char * key) {
24+
int32_t tlen = llama_model_meta_val_str(model, key, nullptr, 0);
25+
if (tlen > 0) {
26+
std::vector<char> curr_tmpl_buf(tlen + 1, 0);
27+
if (llama_model_meta_val_str(model, key, curr_tmpl_buf.data(), curr_tmpl_buf.size()) == tlen) {
28+
return std::string(curr_tmpl_buf.data(), tlen);
29+
}
30+
}
31+
return "";
32+
}
33+
34+
llama_chat_template llama_chat_template::from_model(
35+
const struct llama_model * model,
36+
const std::string & chat_template_override)
37+
{
38+
// TODO: handle "chatml"?
39+
auto chat_template = chat_template_override.empty()
40+
? llama_model_meta_val_str(model, "tokenizer.chat_template")
41+
: chat_template_override;
42+
auto bos_token = _llama_token_to_piece(model, llama_token_bos(model), true);
43+
auto eos_token = _llama_token_to_piece(model, llama_token_eos(model), true);
44+
return llama_chat_template(chat_template, bos_token, eos_token);
45+
}
46+
47+
std::string llama_chat_template::apply(
48+
const json & messages,
49+
const json & tools,
50+
bool add_generation_prompt) const
51+
{
52+
auto actual_messages = messages;
53+
54+
// First, "fix" messages so they have a chance to be rendered correctly by the template
55+
56+
if (_requires_object_arguments || !_supports_system_role) {
57+
std::string pending_system;
58+
auto flush_sys = [&]() {
59+
if (!pending_system.empty()) {
60+
actual_messages.push_back({
61+
{"role", "user"},
62+
{"content", pending_system},
63+
});
64+
pending_system.clear();
65+
}
66+
};
67+
for (auto & message : actual_messages) {
68+
if (!message.contains("role") || !message.contains("content")) {
69+
throw std::runtime_error("message must have 'role' and 'content' fields: " + message.dump());
70+
}
71+
std::string role = message.at("role");
72+
std::string content = message.at("content");
73+
74+
if (!_supports_system_role) {
75+
if (role == "system") {
76+
if (!pending_system.empty()) pending_system += "\n";
77+
pending_system += content;
78+
continue;
79+
} else {
80+
if (role == "user") {
81+
if (!pending_system.empty()) {
82+
message["content"] = pending_system + (content.empty() ? "" : "\n" + content);
83+
pending_system.clear();
84+
}
85+
} else {
86+
flush_sys();
87+
}
88+
}
89+
}
90+
if (_requires_object_arguments && message.contains("tool_calls")) {
91+
for (auto & tool_call : message.at("tool_calls")) {
92+
std::string arguments = tool_call.at("arguments");
93+
tool_call["arguments"] = json::parse(arguments);
94+
}
95+
}
96+
}
97+
flush_sys();
98+
}
99+
100+
auto context = minja::Context::make(json({
101+
{"messages", actual_messages},
102+
{"add_generation_prompt", add_generation_prompt},
103+
{"bos_token", _bos_token},
104+
{"eos_token", _eos_token},
105+
}));
106+
107+
if (!tools.is_null() && !tools.empty()) {
108+
auto tools_val = minja::Value(tools);
109+
context->set("tools", tools_val);
110+
}
111+
112+
auto tmpl_root = minja::Parser::parse(_chat_template, {
113+
/* .trim_blocks = */ true,
114+
/* .lstrip_blocks = */ true,
115+
/* .keep_trailing_newline = */ false,
116+
});
117+
return tmpl_root->render(context);
118+
}

common/chat-template.h

Lines changed: 64 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,64 @@
1+
#pragma once
2+
3+
#include <json.hpp>
4+
#include <string>
5+
#include <vector>
6+
7+
using json = nlohmann::ordered_json;
8+
9+
enum llama_tool_call_style {
10+
Unknown,
11+
Llama31,
12+
FunctionaryV3Llama3,
13+
FunctionaryV3Llama31,
14+
Hermes2Pro,
15+
};
16+
17+
class llama_chat_template {
18+
public:
19+
20+
private:
21+
llama_tool_call_style _tool_call_style = Unknown;
22+
bool _supports_tools = true;
23+
// Meta-Llama-3.1-8B-Instruct's template expects arguments to be an object.
24+
// Most other templates (and OpenAI's API) expect the arguments object to be stringified.
25+
bool _requires_object_arguments = false;
26+
bool _supports_system_role = true;
27+
std::string _chat_template;
28+
std::string _bos_token;
29+
std::string _eos_token;
30+
public:
31+
llama_chat_template(const std::string & chat_template, const std::string & bos_token, const std::string & eos_token)
32+
: _chat_template(chat_template), _bos_token(bos_token), _eos_token(eos_token) {
33+
34+
_supports_tools = chat_template.find("tools") != std::string::npos;
35+
_requires_object_arguments = chat_template.find("tool_call.arguments | items") != std::string::npos;
36+
_supports_system_role = chat_template.find("System role not supported") == std::string::npos;
37+
38+
if (chat_template.find("<tool_call>") != std::string::npos) {
39+
_tool_call_style = Hermes2Pro;
40+
} else if (chat_template.find(">>>all") != std::string::npos) {
41+
_tool_call_style = FunctionaryV3Llama3;
42+
} else if (chat_template.find("<|start_header_id|>") != std::string::npos) {
43+
if (chat_template.find("<function=") != std::string::npos) {
44+
_tool_call_style = FunctionaryV3Llama31;
45+
} else if (chat_template.find("<|python_tag|>") != std::string::npos) {
46+
_tool_call_style = Llama31;
47+
}
48+
}
49+
}
50+
51+
static llama_chat_template from_model(
52+
const struct llama_model * model,
53+
const std::string & chat_template_override);
54+
55+
llama_tool_call_style tool_call_style() const { return _tool_call_style; }
56+
57+
const std::string & chat_template() const { return _chat_template; }
58+
bool supports_tools() const { return _supports_tools; }
59+
60+
std::string apply(
61+
const nlohmann::ordered_json & messages,
62+
const nlohmann::ordered_json & tools,
63+
bool add_generation_prompt) const;
64+
};

common/common.cpp

Lines changed: 23 additions & 20 deletions
Original file line numberDiff line numberDiff line change
@@ -9,6 +9,7 @@
99
#include "json.hpp"
1010
#include "json-schema-to-grammar.h"
1111
#include "llama.h"
12+
#include "chat-template.h"
1213

1314
#include <algorithm>
1415
#include <cinttypes>
@@ -1511,6 +1512,20 @@ std::string llama_detokenize(llama_context * ctx, const std::vector<llama_token>
15111512
//
15121513

15131514
bool llama_chat_verify_template(const std::string & tmpl, bool use_jinja) {
1515+
if (use_jinja) {
1516+
try {
1517+
auto chat_template = llama_chat_template(tmpl, "<s>", "</s>");
1518+
chat_template.apply({{
1519+
{"role", "user"},
1520+
{"content", "test"},
1521+
}}, json(), true);
1522+
return true;
1523+
} catch (const std::exception & e) {
1524+
LOG_ERR("%s: failed to apply template: %s\n", __func__, e.what());
1525+
return false;
1526+
}
1527+
}
1528+
15141529
llama_chat_message chat[] = {{"user", "test"}};
15151530
int res = llama_chat_apply_template(
15161531
nullptr,
@@ -1519,22 +1534,14 @@ bool llama_chat_verify_template(const std::string & tmpl, bool use_jinja) {
15191534
1,
15201535
/* add_ass= */ true,
15211536
/* buffer= */ nullptr,
1522-
/* length= */ 0,
1523-
use_jinja,
1524-
/* tools= */ nullptr,
1525-
"<s>",
1526-
"</s>");
1537+
/* length= */ 0);
15271538
return res >= 0;
15281539
}
15291540

15301541
std::string llama_chat_apply_template(const struct llama_model * model,
15311542
const std::string & tmpl,
15321543
const std::vector<llama_chat_msg> & msgs,
1533-
bool add_ass,
1534-
bool use_jinja,
1535-
const char * tools,
1536-
const char * bos_token,
1537-
const char * eos_token) {
1544+
bool add_ass) {
15381545
int alloc_size = 0;
15391546
bool fallback = false; // indicate if we must fallback to default chatml
15401547
std::vector<llama_chat_message> chat;
@@ -1547,7 +1554,7 @@ std::string llama_chat_apply_template(const struct llama_model * model,
15471554
std::vector<char> buf(alloc_size);
15481555

15491556
// run the first time to get the total output length
1550-
int32_t res = llama_chat_apply_template(model, ptr_tmpl, chat.data(), chat.size(), add_ass, buf.data(), buf.size(), use_jinja, tools, bos_token, eos_token);
1557+
int32_t res = llama_chat_apply_template(model, ptr_tmpl, chat.data(), chat.size(), add_ass, buf.data(), buf.size());
15511558

15521559
// error: chat template is not supported
15531560
if (res < 0) {
@@ -1557,7 +1564,7 @@ std::string llama_chat_apply_template(const struct llama_model * model,
15571564
throw std::runtime_error("this custom template is not supported");
15581565
} else {
15591566
// If the built-in template is not supported, we default to chatml
1560-
res = llama_chat_apply_template(nullptr, "chatml", chat.data(), chat.size(), add_ass, buf.data(), buf.size(), use_jinja, tools, bos_token, eos_token);
1567+
res = llama_chat_apply_template(nullptr, "chatml", chat.data(), chat.size(), add_ass, buf.data(), buf.size());
15611568
fallback = true;
15621569
}
15631570
}
@@ -1568,7 +1575,7 @@ std::string llama_chat_apply_template(const struct llama_model * model,
15681575
res = llama_chat_apply_template(
15691576
fallback ? nullptr : model,
15701577
fallback ? "chatml" : ptr_tmpl,
1571-
chat.data(), chat.size(), add_ass, buf.data(), buf.size(), use_jinja, tools, bos_token, eos_token);
1578+
chat.data(), chat.size(), add_ass, buf.data(), buf.size());
15721579
}
15731580

15741581
std::string formatted_chat(buf.data(), res);
@@ -1579,21 +1586,17 @@ std::string llama_chat_format_single(const struct llama_model * model,
15791586
const std::string & tmpl,
15801587
const std::vector<llama_chat_msg> & past_msg,
15811588
const llama_chat_msg & new_msg,
1582-
bool add_ass,
1583-
bool use_jinja,
1584-
const char * tools,
1585-
const char * bos_token,
1586-
const char * eos_token) {
1589+
bool add_ass) {
15871590
std::ostringstream ss;
1588-
auto fmt_past_msg = past_msg.empty() ? "" : llama_chat_apply_template(model, tmpl, past_msg, false, use_jinja, tools, bos_token, eos_token);
1591+
auto fmt_past_msg = past_msg.empty() ? "" : llama_chat_apply_template(model, tmpl, past_msg, false);
15891592
std::vector<llama_chat_msg> chat_new(past_msg);
15901593
// if the past_msg ends with a newline, we must preserve it in the formatted version
15911594
if (add_ass && !fmt_past_msg.empty() && fmt_past_msg.back() == '\n') {
15921595
ss << "\n";
15931596
};
15941597
// format chat with new_msg
15951598
chat_new.push_back(new_msg);
1596-
auto fmt_new_msg = llama_chat_apply_template(model, tmpl, chat_new, add_ass, use_jinja, tools, bos_token, eos_token);
1599+
auto fmt_new_msg = llama_chat_apply_template(model, tmpl, chat_new, add_ass);
15971600
// get the diff part
15981601
ss << fmt_new_msg.substr(fmt_past_msg.size(), fmt_new_msg.size() - fmt_past_msg.size());
15991602
return ss.str();

common/common.h

Lines changed: 4 additions & 19 deletions
Original file line numberDiff line numberDiff line change
@@ -471,44 +471,29 @@ std::string llama_detokenize(
471471
// Chat template utils
472472
//
473473

474-
struct llama_chat_msg_tool_call {
475-
std::string name;
476-
std::string arguments;
477-
};
478-
479474
// same as llama_chat_message, but uses std::string and std::vector
480475
struct llama_chat_msg {
481476
std::string role;
482477
std::string content;
483-
std::string tool;
484-
std::vector<struct llama_chat_msg_tool_call> tool_calls;
485478
};
486479

487-
// Check if the template supplied via "--chat-template" is supported or not. Returns true if it's valid
488-
bool llama_chat_verify_template(const std::string & tmpl, bool use_jinja = false);
480+
// Check if the template is supported or not. Returns true if it's valid
481+
bool llama_chat_verify_template(const std::string & tmpl, bool use_jinja);
489482

490483
// CPP wrapper for llama_chat_apply_template
491484
// If the built-in template is not supported, we default to chatml
492485
// If the custom "tmpl" is not supported, we throw an error
493486
std::string llama_chat_apply_template(const struct llama_model * model,
494487
const std::string & tmpl,
495488
const std::vector<llama_chat_msg> & chat,
496-
bool add_ass,
497-
bool use_jinja = false,
498-
const char * tools = nullptr,
499-
const char * bos_token = nullptr,
500-
const char * eos_token = nullptr);
489+
bool add_ass);
501490

502491
// Format single message, while taking into account the position of that message in chat history
503492
std::string llama_chat_format_single(const struct llama_model * model,
504493
const std::string & tmpl,
505494
const std::vector<llama_chat_msg> & past_msg,
506495
const llama_chat_msg & new_msg,
507-
bool add_ass,
508-
bool use_jinja = false,
509-
const char * tools = nullptr,
510-
const char * bos_token = nullptr,
511-
const char * eos_token = nullptr);
496+
bool add_ass);
512497

513498
// Returns an example of formatted chat
514499
std::string llama_chat_format_example(const struct llama_model * model,

0 commit comments

Comments
 (0)