Skip to content

Commit 3437080

Browse files
committed
Pass template group to common_chat_apply_template
1 parent 4e8beb0 commit 3437080

File tree

6 files changed

+26
-18
lines changed

6 files changed

+26
-18
lines changed

common/common.cpp

Lines changed: 12 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -1831,22 +1831,27 @@ bool common_chat_verify_template(const std::string & tmpl, bool use_jinja) {
18311831
}
18321832

18331833
std::string common_chat_apply_template(
1834-
const common_chat_template & tmpl,
1834+
const common_chat_templates & tmpl,
18351835
const std::vector<common_chat_msg> & msgs,
18361836
bool add_ass,
18371837
bool use_jinja,
18381838
const common_params_tools & tools)
18391839
{
1840+
const auto & tmpl_selected =
1841+
tools.tools() && tmpl.template_tool_use ? *tmpl.template_tool_use : *tmpl.template_default;
1842+
18401843
if (use_jinja) {
18411844
common_chat_inputs inputs;
18421845

18431846
auto messages = json::array();
18441847
for (const auto & msg : msgs) {
18451848
messages.push_back({{"role", msg.role}, {"content", msg.content}});
18461849
}
1850+
18471851
if (tools.tools() != nullptr) {
18481852
inputs.tools = *tools.tools();
18491853
}
1854+
18501855
auto choice = tools.choice();
18511856
if (std::holds_alternative<std::string>(choice)) {
18521857
inputs.tool_choice = std::get<std::string>(choice);
@@ -1857,9 +1862,10 @@ std::string common_chat_apply_template(
18571862
inputs.tool_choice = *choice_ptr;
18581863
}
18591864
}
1865+
18601866
inputs.messages = messages;
18611867
inputs.add_generation_prompt = add_ass;
1862-
return common_chat_params_init(tmpl, inputs).prompt;
1868+
return common_chat_params_init(tmpl_selected, inputs).prompt;
18631869
}
18641870

18651871
int alloc_size = 0;
@@ -1872,7 +1878,7 @@ std::string common_chat_apply_template(
18721878
std::vector<char> buf(alloc_size);
18731879

18741880
// run the first time to get the total output length
1875-
int32_t res = llama_chat_apply_template(tmpl.source().c_str(), chat.data(), chat.size(), add_ass, buf.data(), buf.size());
1881+
int32_t res = llama_chat_apply_template(tmpl_selected.source().c_str(), chat.data(), chat.size(), add_ass, buf.data(), buf.size());
18761882

18771883
// error: chat template is not supported
18781884
if (res < 0) {
@@ -1884,15 +1890,15 @@ std::string common_chat_apply_template(
18841890
// if it turns out that our buffer is too small, we resize it
18851891
if ((size_t) res > buf.size()) {
18861892
buf.resize(res);
1887-
res = llama_chat_apply_template(tmpl.source().c_str(), chat.data(), chat.size(), add_ass, buf.data(), buf.size());
1893+
res = llama_chat_apply_template(tmpl_selected.source().c_str(), chat.data(), chat.size(), add_ass, buf.data(), buf.size());
18881894
}
18891895

18901896
std::string formatted_chat(buf.data(), res);
18911897
return formatted_chat;
18921898
}
18931899

18941900
std::string common_chat_format_single(
1895-
const common_chat_template & tmpl,
1901+
const common_chat_templates & tmpl,
18961902
const std::vector<common_chat_msg> & past_msg,
18971903
const common_chat_msg & new_msg,
18981904
bool add_ass,
@@ -1916,7 +1922,7 @@ std::string common_chat_format_single(
19161922
return ss.str();
19171923
}
19181924

1919-
std::string common_chat_format_example(const common_chat_template & tmpl, bool use_jinja) {
1925+
std::string common_chat_format_example(const common_chat_templates & tmpl, bool use_jinja) {
19201926
std::vector<common_chat_msg> msgs = {
19211927
{"system", "You are a helpful assistant", {}},
19221928
{"user", "Hello", {}},

common/common.h

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -675,15 +675,15 @@ struct common_chat_templates {
675675
// If the built-in template is not supported, we default to chatml
676676
// If the custom "tmpl" is not supported, we throw an error
677677
std::string common_chat_apply_template(
678-
const common_chat_template & tmpl,
678+
const common_chat_templates & tmpl,
679679
const std::vector<common_chat_msg> & chat,
680680
bool add_ass,
681681
bool use_jinja,
682682
const common_params_tools & tools = common_params_tools());
683683

684684
// Format single message, while taking into account the position of that message in chat history
685685
std::string common_chat_format_single(
686-
const common_chat_template & tmpl,
686+
const common_chat_templates & tmpl,
687687
const std::vector<common_chat_msg> & past_msg,
688688
const common_chat_msg & new_msg,
689689
bool add_ass,
@@ -692,7 +692,7 @@ std::string common_chat_format_single(
692692

693693
// Returns an example of formatted chat
694694
std::string common_chat_format_example(
695-
const common_chat_template & tmpl, bool use_jinja);
695+
const common_chat_templates & tmpl, bool use_jinja);
696696

697697
common_chat_templates common_chat_templates_from_model(const struct llama_model * model, const std::string & chat_template_override);
698698

examples/main/main.cpp

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -219,7 +219,7 @@ int main(int argc, char ** argv) {
219219
// print chat template example in conversation mode
220220
if (params.conversation_mode) {
221221
if (params.enable_chat_template) {
222-
LOG_INF("%s: chat template example:\n%s\n", __func__, common_chat_format_example(*chat_templates.template_default, params.use_jinja).c_str());
222+
LOG_INF("%s: chat template example:\n%s\n", __func__, common_chat_format_example(chat_templates, params.use_jinja).c_str());
223223
} else {
224224
LOG_INF("%s: in-suffix/prefix is specified, chat template will be disabled\n", __func__);
225225
}
@@ -268,9 +268,9 @@ int main(int argc, char ** argv) {
268268
const common_params_tools & tools = common_params_tools())
269269
{
270270
common_chat_msg new_msg{role, content, {}};
271-
auto formatted = common_chat_format_single(
272-
*chat_templates.template_default, chat_msgs, new_msg, role == "user",
273-
g_params->use_jinja, tools);
271+
272+
auto formatted = common_chat_format_single(chat_templates, chat_msgs,
273+
new_msg, role == "user", g_params->use_jinja, tools);
274274

275275
chat_msgs.push_back({role, content, {}});
276276
LOG_DBG("formatted: '%s'\n", formatted.c_str());

examples/server/server.cpp

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -4468,7 +4468,7 @@ int main(int argc, char ** argv) {
44684468
// print sample chat example to make it clear which template is used
44694469
LOG_INF("%s: chat template, chat_template: %s, example_format: '%s'\n", __func__,
44704470
ctx_server.chat_templates.template_default->source().c_str(),
4471-
common_chat_format_example(*ctx_server.chat_templates.template_default, ctx_server.params_base.use_jinja).c_str());
4471+
common_chat_format_example(ctx_server.chat_templates, ctx_server.params_base.use_jinja).c_str());
44724472

44734473
ctx_server.queue_tasks.on_new_task([&ctx_server](const server_task & task) {
44744474
ctx_server.process_single_task(task);

examples/server/utils.hpp

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -348,7 +348,7 @@ static llama_tokens format_infill(
348348
}
349349

350350
// Format given chat. If tmpl is empty, we take the template from model metadata
351-
inline std::string format_chat(const common_chat_template & tmpl, const std::vector<json> & messages) {
351+
inline std::string format_chat(const common_chat_templates & tmpl, const std::vector<json> & messages) {
352352
std::vector<common_chat_msg> chat;
353353

354354
for (size_t i = 0; i < messages.size(); ++i) {
@@ -663,7 +663,7 @@ static json oaicompat_completion_params_parse(
663663
llama_params["stop"].push_back(stop);
664664
}
665665
} else {
666-
llama_params["prompt"] = format_chat(tmpl, body.at("messages"));
666+
llama_params["prompt"] = format_chat(chat_templates, body.at("messages"));
667667
}
668668

669669
// Handle "n" field

tests/test-chat-template.cpp

Lines changed: 4 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -339,7 +339,8 @@ int main(void) {
339339
common_chat_msg sys_msg{"system", "You are a helpful assistant", {}};
340340

341341
auto fmt_sys = [&](std::string tmpl_str) {
342-
minja::chat_template tmpl(tmpl_str, "", "");
342+
common_chat_templates tmpl;
343+
tmpl.template_default.reset(new common_chat_template(tmpl_str, "", ""));
343344
auto output = common_chat_format_single(tmpl, chat2, sys_msg, false, /* use_jinja= */ false);
344345
printf("fmt_sys(%s) : %s\n", tmpl_str.c_str(), output.c_str());
345346
printf("-------------------------\n");
@@ -366,7 +367,8 @@ int main(void) {
366367
common_chat_msg new_msg{"user", "How are you", {}};
367368

368369
auto fmt_single = [&](std::string tmpl_str) {
369-
minja::chat_template tmpl(tmpl_str, "", "");
370+
common_chat_templates tmpl;
371+
tmpl.template_default.reset(new common_chat_template(tmpl_str, "", ""));
370372
auto output = common_chat_format_single(tmpl, chat2, new_msg, true, /* use_jinja= */ false);
371373
printf("fmt_single(%s) : %s\n", tmpl_str.c_str(), output.c_str());
372374
printf("-------------------------\n");

0 commit comments

Comments
 (0)