Skip to content

Commit e983c9d

Browse files
author
ochafik
committed
tool-call: fix llama_chat_apply_template signature / test-chat-template
1 parent 97d0620 commit e983c9d

File tree

4 files changed

+16
-13
lines changed

4 files changed

+16
-13
lines changed

common/common.cpp

Lines changed: 7 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -1521,7 +1521,7 @@ std::string llama_chat_apply_template(const struct llama_model * model,
15211521
const std::vector<llama_chat_msg> & msgs,
15221522
bool add_ass,
15231523
bool use_jinja,
1524-
const std::string & tools,
1524+
const char * tools,
15251525
const char * bos_token,
15261526
const char * eos_token) {
15271527
int alloc_size = 0;
@@ -1536,7 +1536,7 @@ std::string llama_chat_apply_template(const struct llama_model * model,
15361536
std::vector<char> buf(alloc_size);
15371537

15381538
// run the first time to get the total output length
1539-
int32_t res = llama_chat_apply_template(model, ptr_tmpl, chat.data(), chat.size(), add_ass, buf.data(), buf.size(), use_jinja, tools.empty() ? nullptr : tools.data(), bos_token, eos_token);
1539+
int32_t res = llama_chat_apply_template(model, ptr_tmpl, chat.data(), chat.size(), add_ass, buf.data(), buf.size(), use_jinja, tools, bos_token, eos_token);
15401540

15411541
// error: chat template is not supported
15421542
if (res < 0) {
@@ -1546,7 +1546,7 @@ std::string llama_chat_apply_template(const struct llama_model * model,
15461546
throw std::runtime_error("this custom template is not supported");
15471547
} else {
15481548
// If the built-in template is not supported, we default to chatml
1549-
res = llama_chat_apply_template(nullptr, "chatml", chat.data(), chat.size(), add_ass, buf.data(), buf.size(), use_jinja, bos_token, eos_token);
1549+
res = llama_chat_apply_template(nullptr, "chatml", chat.data(), chat.size(), add_ass, buf.data(), buf.size(), use_jinja, tools, bos_token, eos_token);
15501550
fallback = true;
15511551
}
15521552
}
@@ -1557,7 +1557,7 @@ std::string llama_chat_apply_template(const struct llama_model * model,
15571557
res = llama_chat_apply_template(
15581558
fallback ? nullptr : model,
15591559
fallback ? "chatml" : ptr_tmpl,
1560-
chat.data(), chat.size(), add_ass, buf.data(), buf.size(), use_jinja, bos_token, eos_token);
1560+
chat.data(), chat.size(), add_ass, buf.data(), buf.size(), use_jinja, tools, bos_token, eos_token);
15611561
}
15621562

15631563
std::string formatted_chat(buf.data(), res);
@@ -1570,19 +1570,19 @@ std::string llama_chat_format_single(const struct llama_model * model,
15701570
const llama_chat_msg & new_msg,
15711571
bool add_ass,
15721572
bool use_jinja,
1573-
const std::string & tools,
1573+
const char * tools,
15741574
const char * bos_token,
15751575
const char * eos_token) {
15761576
std::ostringstream ss;
1577-
auto fmt_past_msg = past_msg.empty() ? "" : llama_chat_apply_template(model, tmpl, past_msg, false, use_jinja, bos_token, eos_token);
1577+
auto fmt_past_msg = past_msg.empty() ? "" : llama_chat_apply_template(model, tmpl, past_msg, false, use_jinja, tools, bos_token, eos_token);
15781578
std::vector<llama_chat_msg> chat_new(past_msg);
15791579
// if the past_msg ends with a newline, we must preserve it in the formatted version
15801580
if (add_ass && !fmt_past_msg.empty() && fmt_past_msg.back() == '\n') {
15811581
ss << "\n";
15821582
};
15831583
// format chat with new_msg
15841584
chat_new.push_back(new_msg);
1585-
auto fmt_new_msg = llama_chat_apply_template(model, tmpl, chat_new, add_ass, use_jinja, bos_token, eos_token);
1585+
auto fmt_new_msg = llama_chat_apply_template(model, tmpl, chat_new, add_ass, use_jinja, tools, bos_token, eos_token);
15861586
// get the diff part
15871587
ss << fmt_new_msg.substr(fmt_past_msg.size(), fmt_new_msg.size() - fmt_past_msg.size());
15881588
return ss.str();

common/common.h

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -493,7 +493,7 @@ std::string llama_chat_apply_template(const struct llama_model * model,
493493
const std::vector<llama_chat_msg> & chat,
494494
bool add_ass,
495495
bool use_jinja = false,
496-
const std::string & tools = "",
496+
const char * tools = nullptr,
497497
const char * bos_token = nullptr,
498498
const char * eos_token = nullptr);
499499

@@ -504,7 +504,7 @@ std::string llama_chat_format_single(const struct llama_model * model,
504504
const llama_chat_msg & new_msg,
505505
bool add_ass,
506506
bool use_jinja = false,
507-
const std::string & tools = "",
507+
const char * tools = nullptr,
508508
const char * bos_token = nullptr,
509509
const char * eos_token = nullptr);
510510

examples/server/utils.hpp

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -97,7 +97,7 @@ inline std::string format_chat(const struct llama_model * model, const std::stri
9797
chat.emplace_back(std::move(msg));
9898
}
9999

100-
const auto formatted_chat = llama_chat_apply_template(model, tmpl, chat, true, use_jinja, tools.is_null() ? "" : tools.dump());
100+
const auto formatted_chat = llama_chat_apply_template(model, tmpl, chat, true, use_jinja, tools.is_null() ? nullptr : tools.dump().c_str());
101101
LOG_DBG("formatted_chat: '%s'\n", formatted_chat.c_str());
102102

103103
return formatted_chat;

tests/test-chat-template.cpp

Lines changed: 6 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -27,6 +27,8 @@ int main(void) {
2727
{"user", "Another question"},
2828
};
2929

30+
std::string tools = "";
31+
3032
std::vector<test_template> templates {
3133
{
3234
.name = "teknium/OpenHermes-2.5-Mistral-7B",
@@ -160,7 +162,7 @@ int main(void) {
160162
int32_t res;
161163

162164
// test invalid chat template
163-
res = llama_chat_apply_template(nullptr, "INVALID TEMPLATE", conversation.data(), conversation.size(), true, formatted_chat.data(), formatted_chat.size(), false, "<|im_start|>", "<|im_end|>");
165+
res = llama_chat_apply_template(nullptr, "INVALID TEMPLATE", conversation.data(), conversation.size(), true, formatted_chat.data(), formatted_chat.size(), false, /* tools= */ nullptr, "<|im_start|>", "<|im_end|>");
164166
assert(res < 0);
165167

166168
for (auto use_jinja : std::vector<bool> { false, true }) {
@@ -182,6 +184,7 @@ int main(void) {
182184
formatted_chat.data(),
183185
formatted_chat.size(),
184186
use_jinja,
187+
tools.empty() ? nullptr : tools.c_str(),
185188
tmpl.bos.c_str(),
186189
tmpl.eos.c_str()
187190
);
@@ -210,7 +213,7 @@ int main(void) {
210213
llama_chat_msg sys_msg{"system", "You are a helpful assistant"};
211214

212215
auto fmt_sys = [&](std::string tmpl) {
213-
auto output = llama_chat_format_single(nullptr, tmpl, chat2, sys_msg, false, false, "<|im_start|>", "<|im_end|>");
216+
auto output = llama_chat_format_single(nullptr, tmpl, chat2, sys_msg, false, false, /** tools= */ "", "<|im_start|>", "<|im_end|>");
214217
printf("fmt_sys(%s) : %s\n", tmpl.c_str(), output.c_str());
215218
printf("-------------------------\n");
216219
return output;
@@ -229,7 +232,7 @@ int main(void) {
229232
llama_chat_msg new_msg{"user", "How are you"};
230233

231234
auto fmt_single = [&](std::string tmpl) {
232-
auto output = llama_chat_format_single(nullptr, tmpl, chat2, new_msg, true, false, "<|im_start|>", "<|im_end|>");
235+
auto output = llama_chat_format_single(nullptr, tmpl, chat2, new_msg, true, false, /* tools= */ nullptr, "<|im_start|>", "<|im_end|>");
233236
printf("fmt_single(%s) : %s\n", tmpl.c_str(), output.c_str());
234237
printf("-------------------------\n");
235238
return output;

0 commit comments

Comments
 (0)