Skip to content

Commit d23abdc

Browse files
ericcurtinengelmi
andcommitted
When llama_chat_apply_template doesn't work
Try minja. With granite-code if we fall back to jinja on failure, it's fine. Co-authored-by: Michael Engel <[email protected]> Signed-off-by: Eric Curtin <[email protected]>
1 parent d774ab3 commit d23abdc

File tree

2 files changed

+70
-43
lines changed

2 files changed

+70
-43
lines changed

examples/run/run.cpp

Lines changed: 37 additions & 24 deletions
Original file line numberDiff line numberDiff line change
@@ -837,37 +837,50 @@ static void add_message(const char * role, const std::string & text, LlamaData &
837837
llama_data.messages.push_back({ role, llama_data.msg_strs.back().c_str() });
838838
}
839839

840+
// Function to handle Jinja template application
841+
static int handle_jinja_template(const common_chat_template & tmpl, LlamaData & llama_data, const bool append) {
842+
json messages = json::array();
843+
for (const auto & msg : llama_data.messages) {
844+
messages.push_back({
845+
{ "role", msg.role },
846+
{ "content", msg.content },
847+
});
848+
}
849+
850+
try {
851+
minja::chat_template_inputs tmpl_inputs;
852+
tmpl_inputs.messages = messages;
853+
tmpl_inputs.add_generation_prompt = append;
854+
855+
minja::chat_template_options tmpl_opts;
856+
tmpl_opts.use_bos_token = false;
857+
tmpl_opts.use_eos_token = false;
858+
859+
auto result = tmpl.apply(tmpl_inputs, tmpl_opts);
860+
llama_data.fmtted.resize(result.size() + 1);
861+
memcpy(llama_data.fmtted.data(), result.c_str(), result.size() + 1);
862+
return result.size();
863+
} catch (const std::exception & e) {
864+
printe("failed to render the chat template: %s\n", e.what());
865+
}
866+
867+
return -1;
868+
}
869+
840870
// Function to apply the chat template and resize `formatted` if needed
841871
static int apply_chat_template(const common_chat_template & tmpl, LlamaData & llama_data, const bool append, bool use_jinja) {
842872
if (use_jinja) {
843-
json messages = json::array();
844-
for (const auto & msg : llama_data.messages) {
845-
messages.push_back({
846-
{"role", msg.role},
847-
{"content", msg.content},
848-
});
849-
}
850-
try {
851-
minja::chat_template_inputs tmpl_inputs;
852-
tmpl_inputs.messages = messages;
853-
tmpl_inputs.add_generation_prompt = append;
854-
855-
minja::chat_template_options tmpl_opts;
856-
tmpl_opts.use_bos_token = false;
857-
tmpl_opts.use_eos_token = false;
858-
859-
auto result = tmpl.apply(tmpl_inputs, tmpl_opts);
860-
llama_data.fmtted.resize(result.size() + 1);
861-
memcpy(llama_data.fmtted.data(), result.c_str(), result.size() + 1);
862-
return result.size();
863-
} catch (const std::exception & e) {
864-
printe("failed to render the chat template: %s\n", e.what());
865-
return -1;
866-
}
873+
return handle_jinja_template(tmpl, llama_data, append);
867874
}
875+
868876
int result = llama_chat_apply_template(
869877
tmpl.source().c_str(), llama_data.messages.data(), llama_data.messages.size(), append,
870878
append ? llama_data.fmtted.data() : nullptr, append ? llama_data.fmtted.size() : 0);
879+
// If llama_chat_apply_template fails to apply template, fallback to using jinja
880+
if (result < 0) {
881+
return handle_jinja_template(tmpl, llama_data, append);
882+
}
883+
871884
if (append && result > static_cast<int>(llama_data.fmtted.size())) {
872885
llama_data.fmtted.resize(result);
873886
result = llama_chat_apply_template(tmpl.source().c_str(), llama_data.messages.data(),

examples/server/server.cpp

Lines changed: 33 additions & 19 deletions
Original file line numberDiff line numberDiff line change
@@ -1895,30 +1895,44 @@ struct server_context {
18951895
return true;
18961896
}
18971897

1898+
bool apply_jinja_templates() const {
1899+
auto templates = common_chat_templates_from_model(model, "");
1900+
common_chat_inputs inputs;
1901+
inputs.messages = json::array({
1902+
{
1903+
{ "role", "user" },
1904+
{ "content", "test" },
1905+
}
1906+
});
1907+
GGML_ASSERT(templates.template_default);
1908+
try {
1909+
common_chat_params_init(*templates.template_default, inputs);
1910+
if (templates.template_tool_use) {
1911+
common_chat_params_init(*templates.template_tool_use, inputs);
1912+
}
1913+
1914+
return true;
1915+
} catch (const std::exception & e) {
1916+
SRV_ERR("failed to apply template: %s\n", e.what());
1917+
1918+
return false;
1919+
}
1920+
}
1921+
18981922
bool validate_builtin_chat_template(bool use_jinja) const {
1899-
llama_chat_message chat[] = {{"user", "test"}};
1923+
llama_chat_message chat[] = {
1924+
{ "user", "test" }
1925+
};
19001926

19011927
if (use_jinja) {
1902-
auto templates = common_chat_templates_from_model(model, "");
1903-
common_chat_inputs inputs;
1904-
inputs.messages = json::array({{
1905-
{"role", "user"},
1906-
{"content", "test"},
1907-
}});
1908-
GGML_ASSERT(templates.template_default);
1909-
try {
1910-
common_chat_params_init(*templates.template_default, inputs);
1911-
if (templates.template_tool_use) {
1912-
common_chat_params_init(*templates.template_tool_use, inputs);
1913-
}
1914-
return true;
1915-
} catch (const std::exception & e) {
1916-
SRV_ERR("failed to apply template: %s\n", e.what());
1917-
return false;
1918-
}
1928+
return apply_jinja_templates();
19191929
} else {
1920-
const char * tmpl = llama_model_chat_template(model, /* name */ nullptr);
1930+
const char * tmpl = llama_model_chat_template(model, /* name */ nullptr);
19211931
const int32_t chat_res = llama_chat_apply_template(tmpl, chat, 1, true, nullptr, 0);
1932+
if (chat_res < 0) {
1933+
return apply_jinja_templates();
1934+
}
1935+
19221936
return chat_res > 0;
19231937
}
19241938
}

0 commit comments

Comments
 (0)