Skip to content

Commit b75d062

Browse files
author
ochafik
committed
Refactor common_chat_* functions to accept minja template + use_jinja option
1 parent 3ed670b commit b75d062

File tree

7 files changed

+82
-80
lines changed

7 files changed

+82
-80
lines changed

common/common.cpp

Lines changed: 35 additions & 42 deletions
Original file line numberDiff line numberDiff line change
@@ -74,6 +74,15 @@
7474
#endif
7575
#define LLAMA_CURL_MAX_URL_LENGTH 2084 // Maximum URL Length in Chrome: 2083
7676

77+
const char * LLAMA_CHATML_TEMPLATE = R"(
78+
{%- for message in messages -%}
79+
{{- "<|im_start|>" + message.role + "\n" + message.content + "<|im_end|>\n" -}}
80+
{%- endfor -%}
81+
{%- if add_generation_prompt -%}
82+
{{- "<|im_start|>assistant\n" -}}
83+
{%- endif -%}
84+
)";
85+
7786
//
7887
// CURL utils
7988
//
@@ -1748,86 +1757,77 @@ bool common_chat_verify_template(const std::string & tmpl, bool use_jinja) {
17481757
return res >= 0;
17491758
}
17501759

1751-
std::string common_chat_apply_template(const struct llama_model * model,
1752-
const std::string & tmpl,
1760+
std::string common_chat_apply_template(
1761+
const llama_chat_template & tmpl,
17531762
const std::vector<common_chat_msg> & msgs,
1754-
bool add_ass) {
1763+
bool add_ass,
1764+
bool use_jinja) {
1765+
if (use_jinja) {
1766+
auto messages = json::array();
1767+
for (const auto & msg : msgs) {
1768+
messages.push_back({{"role", msg.role}, {"content", msg.content}});
1769+
}
1770+
return tmpl.apply(messages, /* tools= */ json(), add_ass);
1771+
}
1772+
17551773
int alloc_size = 0;
1756-
bool fallback = false; // indicate if we must fallback to default chatml
17571774
std::vector<llama_chat_message> chat;
17581775
for (const auto & msg : msgs) {
17591776
chat.push_back({msg.role.c_str(), msg.content.c_str()});
17601777
alloc_size += (msg.role.size() + msg.content.size()) * 1.25;
17611778
}
17621779

1763-
const char * ptr_tmpl = tmpl.empty() ? llama_model_chat_template(model, /* name */ nullptr) : tmpl.c_str();
17641780
std::vector<char> buf(alloc_size);
17651781

17661782
// run the first time to get the total output length
1767-
int32_t res = llama_chat_apply_template(ptr_tmpl, chat.data(), chat.size(), add_ass, buf.data(), buf.size());
1783+
int32_t res = llama_chat_apply_template(tmpl.source().c_str(), chat.data(), chat.size(), add_ass, buf.data(), buf.size());
17681784

17691785
// error: chat template is not supported
17701786
if (res < 0) {
1771-
if (ptr_tmpl != nullptr) {
1772-
// if the custom "tmpl" is not supported, we throw an error
1773-
// this is a bit redundant (for good), since we're not sure if user validated the custom template with llama_chat_verify_template()
1774-
throw std::runtime_error("this custom template is not supported");
1775-
}
1776-
1777-
// If the built-in template is not supported, we default to chatml
1778-
res = llama_chat_apply_template("chatml", chat.data(), chat.size(), add_ass, buf.data(), buf.size());
1779-
fallback = true;
1787+
// if the custom "tmpl" is not supported, we throw an error
1788+
// this is a bit redundant (for good), since we're not sure if user validated the custom template with llama_chat_verify_template()
1789+
throw std::runtime_error("this custom template is not supported");
17801790
}
17811791

17821792
// if it turns out that our buffer is too small, we resize it
17831793
if ((size_t) res > buf.size()) {
17841794
buf.resize(res);
1785-
res = llama_chat_apply_template(
1786-
fallback ? "chatml" : ptr_tmpl,
1787-
chat.data(), chat.size(), add_ass, buf.data(), buf.size());
1795+
res = llama_chat_apply_template(tmpl.source().c_str(), chat.data(), chat.size(), add_ass, buf.data(), buf.size());
17881796
}
17891797

17901798
std::string formatted_chat(buf.data(), res);
17911799
return formatted_chat;
17921800
}
17931801

1794-
std::string common_chat_format_single(const struct llama_model * model,
1795-
const std::string & tmpl,
1802+
std::string common_chat_format_single(
1803+
const llama_chat_template & tmpl,
17961804
const std::vector<common_chat_msg> & past_msg,
17971805
const common_chat_msg & new_msg,
1798-
bool add_ass) {
1806+
bool add_ass,
1807+
bool use_jinja) {
17991808
std::ostringstream ss;
1800-
auto fmt_past_msg = past_msg.empty() ? "" : common_chat_apply_template(model, tmpl, past_msg, false);
1809+
auto fmt_past_msg = past_msg.empty() ? "" : common_chat_apply_template(tmpl, past_msg, false, use_jinja);
18011810
std::vector<common_chat_msg> chat_new(past_msg);
18021811
// if the past_msg ends with a newline, we must preserve it in the formatted version
18031812
if (add_ass && !fmt_past_msg.empty() && fmt_past_msg.back() == '\n') {
18041813
ss << "\n";
18051814
};
18061815
// format chat with new_msg
18071816
chat_new.push_back(new_msg);
1808-
auto fmt_new_msg = common_chat_apply_template(model, tmpl, chat_new, add_ass);
1817+
auto fmt_new_msg = common_chat_apply_template(tmpl, chat_new, add_ass, use_jinja);
18091818
// get the diff part
18101819
ss << fmt_new_msg.substr(fmt_past_msg.size(), fmt_new_msg.size() - fmt_past_msg.size());
18111820
return ss.str();
18121821
}
18131822

1814-
std::string common_chat_format_example(const struct llama_model * model, const minja::chat_template & tmpl, bool use_jinja) {
1823+
std::string common_chat_format_example(const llama_chat_template & tmpl, bool use_jinja) {
18151824
std::vector<common_chat_msg> msgs = {
18161825
{"system", "You are a helpful assistant"},
18171826
{"user", "Hello"},
18181827
{"assistant", "Hi there"},
18191828
{"user", "How are you?"},
18201829
};
1821-
const auto add_generation_prompt = true;
1822-
if (use_jinja) {
1823-
auto messages = json::array();
1824-
for (const auto & msg : msgs) {
1825-
messages.push_back({{"role", msg.role}, {"content", msg.content}});
1826-
}
1827-
return tmpl.apply(messages, /* tools= */ json(), add_generation_prompt);
1828-
} else {
1829-
return common_chat_apply_template(model, tmpl.source(), msgs, add_generation_prompt);
1830-
}
1830+
return common_chat_apply_template(tmpl, msgs, true, use_jinja);
18311831
}
18321832

18331833
llama_chat_templates llama_chat_templates_from_model(const struct llama_model * model, const std::string & chat_template_override)
@@ -1847,14 +1847,7 @@ llama_chat_templates llama_chat_templates_from_model(const struct llama_model *
18471847
if (!tool_use_template_src.empty()) {
18481848
default_template_src = tool_use_template_src;
18491849
} else {
1850-
default_template_src = R"(
1851-
{%- for message in messages -%}
1852-
{{- "<|im_start|>" + message.role + "\n" + message.content + "<|im_end|>\n" -}}
1853-
{%- endfor -%}
1854-
{%- if add_generation_prompt -%}
1855-
{{- "<|im_start|>assistant\n" -}}
1856-
{%- endif -%}
1857-
)";
1850+
default_template_src = LLAMA_CHATML_TEMPLATE;
18581851
}
18591852
}
18601853
return {

common/common.h

Lines changed: 16 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -26,6 +26,8 @@
2626

2727
#define DEFAULT_MODEL_PATH "models/7B/ggml-model-f16.gguf"
2828

29+
extern const char * LLAMA_CHATML_TEMPLATE;
30+
2931
struct common_adapter_lora_info {
3032
std::string path;
3133
float scale;
@@ -602,29 +604,32 @@ struct common_chat_msg {
602604
// Check if the template supplied via "--chat-template" is supported or not. Returns true if it's valid
603605
bool common_chat_verify_template(const std::string & tmpl, bool use_jinja);
604606

607+
typedef minja::chat_template llama_chat_template;
608+
605609
// CPP wrapper for llama_chat_apply_template
606610
// If the built-in template is not supported, we default to chatml
607611
// If the custom "tmpl" is not supported, we throw an error
608-
std::string common_chat_apply_template(const struct llama_model * model,
609-
const std::string & tmpl,
612+
std::string common_chat_apply_template(
613+
const llama_chat_template & tmpl,
610614
const std::vector<common_chat_msg> & chat,
611-
bool add_ass);
615+
bool add_ass,
616+
bool use_jinja);
612617

613618
// Format single message, while taking into account the position of that message in chat history
614-
std::string common_chat_format_single(const struct llama_model * model,
615-
const std::string & tmpl,
619+
std::string common_chat_format_single(
620+
const llama_chat_template & tmpl,
616621
const std::vector<common_chat_msg> & past_msg,
617622
const common_chat_msg & new_msg,
618-
bool add_ass);
623+
bool add_ass,
624+
bool use_jinja);
619625

620626
// Returns an example of formatted chat
621-
std::string common_chat_format_example(const struct llama_model * model,
622-
const minja::chat_template & tmpl, bool use_jinja);
623-
627+
std::string common_chat_format_example(
628+
const llama_chat_template & tmpl, bool use_jinja);
624629

625630
struct llama_chat_templates {
626-
minja::chat_template default_template;
627-
std::optional<minja::chat_template> tool_use_template;
631+
llama_chat_template default_template;
632+
std::optional<llama_chat_template> tool_use_template;
628633
};
629634

630635
llama_chat_templates llama_chat_templates_from_model(const struct llama_model * model, const std::string & chat_template_override);

examples/main/main.cpp

Lines changed: 12 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -84,14 +84,6 @@ static void sigint_handler(int signo) {
8484
}
8585
#endif
8686

87-
static std::string chat_add_and_format(struct llama_model * model, std::vector<common_chat_msg> & chat_msgs, const std::string & role, const std::string & content) {
88-
common_chat_msg new_msg{role, content};
89-
auto formatted = common_chat_format_single(model, g_params->chat_template, chat_msgs, new_msg, role == "user");
90-
chat_msgs.push_back({role, content});
91-
LOG_DBG("formatted: '%s'\n", formatted.c_str());
92-
return formatted;
93-
}
94-
9587
int main(int argc, char ** argv) {
9688
common_params params;
9789
g_params = &params;
@@ -226,7 +218,7 @@ int main(int argc, char ** argv) {
226218
// print chat template example in conversation mode
227219
if (params.conversation_mode) {
228220
if (params.enable_chat_template) {
229-
LOG_INF("%s: chat template example:\n%s\n", __func__, common_chat_format_example(model, chat_templates.default_template, params.use_jinja).c_str());
221+
LOG_INF("%s: chat template example:\n%s\n", __func__, common_chat_format_example(chat_templates.default_template, params.use_jinja).c_str());
230222
} else {
231223
LOG_INF("%s: in-suffix/prefix is specified, chat template will be disabled\n", __func__);
232224
}
@@ -270,10 +262,18 @@ int main(int argc, char ** argv) {
270262

271263
std::vector<llama_token> embd_inp;
272264

265+
auto chat_add_and_format = [&chat_msgs, &chat_templates](const std::string & role, const std::string & content) {
266+
common_chat_msg new_msg{role, content};
267+
auto formatted = common_chat_format_single(chat_templates.default_template, chat_msgs, new_msg, role == "user", g_params->use_jinja);
268+
chat_msgs.push_back({role, content});
269+
LOG_DBG("formatted: '%s'\n", formatted.c_str());
270+
return formatted;
271+
};
272+
273273
{
274274
auto prompt = (params.conversation_mode && params.enable_chat_template)
275275
// format the system prompt in conversation mode (fallback to default if empty)
276-
? chat_add_and_format(model, chat_msgs, "system", params.prompt.empty() ? DEFAULT_SYSTEM_MESSAGE : params.prompt)
276+
? chat_add_and_format("system", params.prompt.empty() ? DEFAULT_SYSTEM_MESSAGE : params.prompt)
277277
// otherwise use the prompt as is
278278
: params.prompt;
279279
if (params.interactive_first || !params.prompt.empty() || session_tokens.empty()) {
@@ -780,7 +780,7 @@ int main(int argc, char ** argv) {
780780
}
781781

782782
if (params.enable_chat_template) {
783-
chat_add_and_format(model, chat_msgs, "assistant", assistant_ss.str());
783+
chat_add_and_format("assistant", assistant_ss.str());
784784
}
785785
is_interacting = true;
786786
LOG("\n");
@@ -845,7 +845,7 @@ int main(int argc, char ** argv) {
845845

846846
bool format_chat = params.conversation_mode && params.enable_chat_template;
847847
std::string user_inp = format_chat
848-
? chat_add_and_format(model, chat_msgs, "user", std::move(buffer))
848+
? chat_add_and_format("user", std::move(buffer))
849849
: std::move(buffer);
850850
// TODO: one inconvenient of current chat template implementation is that we can't distinguish between user input and special tokens (prefix/postfix)
851851
const auto line_pfx = common_tokenize(ctx, params.input_prefix, false, true);

examples/run/run.cpp

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -714,7 +714,7 @@ static void add_message(const char * role, const std::string & text, LlamaData &
714714
}
715715

716716
// Function to apply the chat template and resize `formatted` if needed
717-
static int apply_chat_template(const minja::chat_template & tmpl, LlamaData & llama_data, const bool append, bool use_jinja) {
717+
static int apply_chat_template(const llama_chat_template & tmpl, LlamaData & llama_data, const bool append, bool use_jinja) {
718718
if (use_jinja) {
719719
json messages = json::array();
720720
for (const auto & msg : llama_data.messages) {
@@ -868,7 +868,7 @@ static int generate_response(LlamaData & llama_data, const std::string & prompt,
868868
}
869869

870870
// Helper function to apply the chat template and handle errors
871-
static int apply_chat_template_with_error_handling(const minja::chat_template & tmpl, LlamaData & llama_data, const bool append, int & output_length, bool use_jinja) {
871+
static int apply_chat_template_with_error_handling(const llama_chat_template & tmpl, LlamaData & llama_data, const bool append, int & output_length, bool use_jinja) {
872872
const int new_len = apply_chat_template(tmpl, llama_data, append, use_jinja);
873873
if (new_len < 0) {
874874
printe("failed to apply the chat template\n");

examples/server/server.cpp

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -3869,7 +3869,7 @@ int main(int argc, char ** argv) {
38693869
auto body = json::parse(req.body);
38703870
const auto & templates = get_chat_templates();
38713871
const auto & chat_template = body.contains("tools") && templates.tool_use_template ? *templates.tool_use_template : templates.default_template;
3872-
json data = oaicompat_completion_params_parse(ctx_server.model, body, chat_template, params.use_jinja);
3872+
json data = oaicompat_completion_params_parse(body, chat_template, params.use_jinja);
38733873

38743874
return handle_completions_impl(
38753875
SERVER_TASK_TYPE_COMPLETION,
@@ -4288,7 +4288,7 @@ int main(int argc, char ** argv) {
42884288
// print sample chat example to make it clear which template is used
42894289
LOG_INF("%s: chat template, chat_template: %s, example_format: '%s'\n", __func__,
42904290
get_chat_templates().default_template.source().c_str(),
4291-
common_chat_format_example(ctx_server.model, get_chat_templates().default_template, ctx_server.params_base.use_jinja).c_str());
4291+
common_chat_format_example(get_chat_templates().default_template, ctx_server.params_base.use_jinja).c_str());
42924292

42934293
ctx_server.queue_tasks.on_new_task(std::bind(
42944294
&server_context::process_single_task, &ctx_server, std::placeholders::_1));

examples/server/utils.hpp

Lines changed: 4 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -351,7 +351,7 @@ static llama_tokens format_infill(
351351
}
352352

353353
// Format given chat. If tmpl is empty, we take the template from model metadata
354-
inline std::string format_chat(const struct llama_model * model, const std::string & tmpl, const std::vector<json> & messages) {
354+
inline std::string format_chat(const llama_chat_template & tmpl, const std::vector<json> & messages) {
355355
std::vector<common_chat_msg> chat;
356356

357357
for (size_t i = 0; i < messages.size(); ++i) {
@@ -379,7 +379,7 @@ inline std::string format_chat(const struct llama_model * model, const std::stri
379379
chat.push_back({role, content});
380380
}
381381

382-
const auto formatted_chat = common_chat_apply_template(model, tmpl, chat, true);
382+
const auto formatted_chat = common_chat_apply_template(tmpl, chat, true, /* use_jinja= */ false);
383383
LOG_DBG("formatted_chat: '%s'\n", formatted_chat.c_str());
384384

385385
return formatted_chat;
@@ -579,9 +579,8 @@ static json oaicompat_completion_params_parse(const json & body) {
579579
}
580580

581581
static json oaicompat_completion_params_parse(
582-
const struct llama_model * model,
583582
const json & body, /* openai api json semantics */
584-
const minja::chat_template & tmpl,
583+
const llama_chat_template & tmpl,
585584
bool use_jinja)
586585
{
587586
json llama_params;
@@ -622,7 +621,7 @@ static json oaicompat_completion_params_parse(
622621
if (use_jinja) {
623622
llama_params["prompt"] = tmpl.apply(body.at("messages"), tools, /* add_generation_prompt= */ true);
624623
} else {
625-
llama_params["prompt"] = format_chat(model, tmpl.source(), body.at("messages"));
624+
llama_params["prompt"] = format_chat(tmpl, body.at("messages"));
626625
}
627626

628627
// Handle "n" field

tests/test-chat-template.cpp

Lines changed: 11 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -8,6 +8,7 @@
88
#include "llama.h"
99
#include "common.h"
1010
#include "chat-template.hpp"
11+
#include "llama-chat.h"
1112

1213
int main(void) {
1314
std::vector<llama_chat_message> conversation {
@@ -319,9 +320,10 @@ int main(void) {
319320
std::vector<common_chat_msg> chat2;
320321
common_chat_msg sys_msg{"system", "You are a helpful assistant"};
321322

322-
auto fmt_sys = [&](std::string tmpl) {
323-
auto output = common_chat_format_single(nullptr, tmpl, chat2, sys_msg, false);
324-
printf("fmt_sys(%s) : %s\n", tmpl.c_str(), output.c_str());
323+
auto fmt_sys = [&](std::string tmpl_str) {
324+
minja::chat_template tmpl(tmpl_str, "", "");
325+
auto output = common_chat_format_single(tmpl, chat2, sys_msg, false, /* use_jinja= */ false);
326+
printf("fmt_sys(%s) : %s\n", tmpl_str.c_str(), output.c_str());
325327
printf("-------------------------\n");
326328
return output;
327329
};
@@ -345,9 +347,10 @@ int main(void) {
345347
chat2.push_back({"assistant", "I am assistant"});
346348
common_chat_msg new_msg{"user", "How are you"};
347349

348-
auto fmt_single = [&](std::string tmpl) {
349-
auto output = common_chat_format_single(nullptr, tmpl, chat2, new_msg, true);
350-
printf("fmt_single(%s) : %s\n", tmpl.c_str(), output.c_str());
350+
auto fmt_single = [&](std::string tmpl_str) {
351+
minja::chat_template tmpl(tmpl_str, "", "");
352+
auto output = common_chat_format_single(tmpl, chat2, new_msg, true, /* use_jinja= */ false);
353+
printf("fmt_single(%s) : %s\n", tmpl_str.c_str(), output.c_str());
351354
printf("-------------------------\n");
352355
return output;
353356
};
@@ -362,5 +365,7 @@ int main(void) {
362365
assert(fmt_single("llama3") == "<|start_header_id|>user<|end_header_id|>\n\nHow are you<|eot_id|><|start_header_id|>assistant<|end_header_id|>\n\n");
363366
assert(fmt_single("gigachat") == "user<|role_sep|>How are you<|message_sep|>available functions<|role_sep|>[]<|message_sep|>assistant<|role_sep|>");
364367

368+
assert(llm_chat_detect_template(LLAMA_CHATML_TEMPLATE) == LLM_CHAT_TEMPLATE_CHATML);
369+
365370
return 0;
366371
}

0 commit comments

Comments
 (0)