Skip to content

Commit 27f1b9b

Browse files
ericcurtinengelmi
andcommitted
When llama_chat_apply_template doesn't work
Try minja. With granite-code if we fall back to jinja on failure, it's fine. Co-authored-by: Michael Engel <[email protected]> Signed-off-by: Eric Curtin <[email protected]>
1 parent d774ab3 commit 27f1b9b

File tree

1 file changed

+29
-24
lines changed

1 file changed

+29
-24
lines changed

examples/run/run.cpp

Lines changed: 29 additions & 24 deletions
Original file line numberDiff line numberDiff line change
@@ -837,37 +837,42 @@ static void add_message(const char * role, const std::string & text, LlamaData &
837837
llama_data.messages.push_back({ role, llama_data.msg_strs.back().c_str() });
838838
}
839839

840+
// Function to handle Jinja template application
841+
static int handle_jinja_template(const common_chat_template & tmpl, LlamaData & llama_data, const bool append) {
842+
json messages = json::array();
843+
for (const auto & msg : llama_data.messages) {
844+
messages.push_back({
845+
{ "role", msg.role },
846+
{ "content", msg.content },
847+
});
848+
}
849+
850+
try {
851+
auto result = tmpl.apply(messages, /* tools= */ json(), append);
852+
llama_data.fmtted.resize(result.size() + 1);
853+
memcpy(llama_data.fmtted.data(), result.c_str(), result.size() + 1);
854+
return result.size();
855+
} catch (const std::exception & e) {
856+
printe("failed to render the chat template: %s\n", e.what());
857+
}
858+
859+
return -1;
860+
}
861+
840862
// Function to apply the chat template and resize `formatted` if needed
841863
static int apply_chat_template(const common_chat_template & tmpl, LlamaData & llama_data, const bool append, bool use_jinja) {
842864
if (use_jinja) {
843-
json messages = json::array();
844-
for (const auto & msg : llama_data.messages) {
845-
messages.push_back({
846-
{"role", msg.role},
847-
{"content", msg.content},
848-
});
849-
}
850-
try {
851-
minja::chat_template_inputs tmpl_inputs;
852-
tmpl_inputs.messages = messages;
853-
tmpl_inputs.add_generation_prompt = append;
854-
855-
minja::chat_template_options tmpl_opts;
856-
tmpl_opts.use_bos_token = false;
857-
tmpl_opts.use_eos_token = false;
858-
859-
auto result = tmpl.apply(tmpl_inputs, tmpl_opts);
860-
llama_data.fmtted.resize(result.size() + 1);
861-
memcpy(llama_data.fmtted.data(), result.c_str(), result.size() + 1);
862-
return result.size();
863-
} catch (const std::exception & e) {
864-
printe("failed to render the chat template: %s\n", e.what());
865-
return -1;
866-
}
865+
return handle_jinja_template(tmpl, llama_data, append);
867866
}
867+
868868
int result = llama_chat_apply_template(
869869
tmpl.source().c_str(), llama_data.messages.data(), llama_data.messages.size(), append,
870870
append ? llama_data.fmtted.data() : nullptr, append ? llama_data.fmtted.size() : 0);
871+
// If llama_chat_apply_template fails to apply template, fallback to using jinja
872+
if (result < 0) {
873+
return handle_jinja_template(tmpl, llama_data, append);
874+
}
875+
871876
if (append && result > static_cast<int>(llama_data.fmtted.size())) {
872877
llama_data.fmtted.resize(result);
873878
result = llama_chat_apply_template(tmpl.source().c_str(), llama_data.messages.data(),

0 commit comments

Comments
 (0)