Skip to content

Commit a6afb27

Browse files
author
ochafik
committed
Update common_chat_format_example to use minja template wrapper
1 parent c04c50e commit a6afb27

File tree

4 files changed

+17
-8
lines changed

4 files changed

+17
-8
lines changed

common/common.cpp

Lines changed: 11 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -1811,15 +1811,23 @@ std::string common_chat_format_single(const struct llama_model * model,
18111811
return ss.str();
18121812
}
18131813

1814-
std::string common_chat_format_example(const struct llama_model * model,
1815-
const std::string & tmpl) {
1814+
std::string common_chat_format_example(const struct llama_model * model, const minja::chat_template & tmpl, bool use_jinja) {
18161815
std::vector<common_chat_msg> msgs = {
18171816
{"system", "You are a helpful assistant"},
18181817
{"user", "Hello"},
18191818
{"assistant", "Hi there"},
18201819
{"user", "How are you?"},
18211820
};
1822-
return common_chat_apply_template(model, tmpl, msgs, true);
1821+
const auto add_generation_prompt = true;
1822+
if (use_jinja) {
1823+
auto messages = json::array();
1824+
for (const auto & msg : msgs) {
1825+
messages.push_back({{"role", msg.role}, {"content", msg.content}});
1826+
}
1827+
return tmpl.apply(messages, /* tools= */ json(), add_generation_prompt);
1828+
} else {
1829+
return common_chat_apply_template(model, tmpl.source(), msgs, add_generation_prompt);
1830+
}
18231831
}
18241832

18251833
llama_chat_templates llama_chat_templates_from_model(const struct llama_model * model, const std::string & chat_template_override)

common/common.h

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -619,7 +619,7 @@ std::string common_chat_format_single(const struct llama_model * model,
619619

620620
// Returns an example of formatted chat
621621
std::string common_chat_format_example(const struct llama_model * model,
622-
const std::string & tmpl);
622+
const minja::chat_template & tmpl, bool use_jinja);
623623

624624

625625
struct llama_chat_templates {

examples/main/main.cpp

Lines changed: 3 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -165,6 +165,7 @@ int main(int argc, char ** argv) {
165165
}
166166

167167
const llama_vocab * vocab = llama_model_get_vocab(model);
168+
auto chat_templates = llama_chat_templates_from_model(model, params.chat_template);
168169

169170
LOG_INF("%s: llama threadpool init, n_threads = %d\n", __func__, (int) params.cpuparams.n_threads);
170171

@@ -207,7 +208,7 @@ int main(int argc, char ** argv) {
207208
}
208209

209210
// auto enable conversation mode if chat template is available
210-
const bool has_chat_template = !common_get_builtin_chat_template(model).empty() || !params.chat_template.empty();
211+
const bool has_chat_template = !chat_templates.default_template.source().empty();
211212
if (params.conversation_mode == COMMON_CONVERSATION_MODE_AUTO) {
212213
if (has_chat_template) {
213214
LOG_INF("%s: chat template is available, enabling conversation mode (disable it with -no-cnv)\n", __func__);
@@ -225,7 +226,7 @@ int main(int argc, char ** argv) {
225226
// print chat template example in conversation mode
226227
if (params.conversation_mode) {
227228
if (params.enable_chat_template) {
228-
LOG_INF("%s: chat template example:\n%s\n", __func__, common_chat_format_example(model, params.chat_template).c_str());
229+
LOG_INF("%s: chat template example:\n%s\n", __func__, common_chat_format_example(model, chat_templates.default_template, params.use_jinja).c_str());
229230
} else {
230231
LOG_INF("%s: in-suffix/prefix is specified, chat template will be disabled\n", __func__);
231232
}

examples/server/server.cpp

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -4287,8 +4287,8 @@ int main(int argc, char ** argv) {
42874287

42884288
// print sample chat example to make it clear which template is used
42894289
LOG_INF("%s: chat template, chat_template: %s, example_format: '%s'\n", __func__,
4290-
params.chat_template.empty() ? "(built-in)" : params.chat_template.c_str(),
4291-
common_chat_format_example(ctx_server.model, params.chat_template).c_str());
4290+
get_chat_templates().default_template.source().c_str(),
4291+
common_chat_format_example(ctx_server.model, get_chat_templates().default_template, ctx_server.params_base.use_jinja).c_str());
42924292

42934293
ctx_server.queue_tasks.on_new_task(std::bind(
42944294
&server_context::process_single_task, &ctx_server, std::placeholders::_1));

0 commit comments

Comments
 (0)