Skip to content

Commit 7c0adb5

Browse files
ericcurtinengelmi
andcommitted
When llama_chat_apply_template doesn't work
Try minja. With granite-code if we fall back to jinja on failure, it's fine. Co-authored-by: Michael Engel <[email protected]> Signed-off-by: Eric Curtin <[email protected]>
1 parent d774ab3 commit 7c0adb5

File tree

1 file changed

+37
-24
lines changed

1 file changed

+37
-24
lines changed

examples/run/run.cpp

Lines changed: 37 additions & 24 deletions
Original file line numberDiff line numberDiff line change
@@ -837,37 +837,50 @@ static void add_message(const char * role, const std::string & text, LlamaData &
837837
llama_data.messages.push_back({ role, llama_data.msg_strs.back().c_str() });
838838
}
839839

840+
// Function to handle Jinja template application
841+
static int handle_jinja_template(const common_chat_template & tmpl, LlamaData & llama_data, const bool append) {
842+
json messages = json::array();
843+
for (const auto & msg : llama_data.messages) {
844+
messages.push_back({
845+
{ "role", msg.role },
846+
{ "content", msg.content },
847+
});
848+
}
849+
850+
try {
851+
minja::chat_template_inputs tmpl_inputs;
852+
tmpl_inputs.messages = messages;
853+
tmpl_inputs.add_generation_prompt = append;
854+
855+
minja::chat_template_options tmpl_opts;
856+
tmpl_opts.use_bos_token = false;
857+
tmpl_opts.use_eos_token = false;
858+
859+
auto result = tmpl.apply(tmpl_inputs, tmpl_opts);
860+
llama_data.fmtted.resize(result.size() + 1);
861+
memcpy(llama_data.fmtted.data(), result.c_str(), result.size() + 1);
862+
return result.size();
863+
} catch (const std::exception & e) {
864+
printe("failed to render the chat template: %s\n", e.what());
865+
}
866+
867+
return -1;
868+
}
869+
840870
// Function to apply the chat template and resize `formatted` if needed
841871
static int apply_chat_template(const common_chat_template & tmpl, LlamaData & llama_data, const bool append, bool use_jinja) {
842872
if (use_jinja) {
843-
json messages = json::array();
844-
for (const auto & msg : llama_data.messages) {
845-
messages.push_back({
846-
{"role", msg.role},
847-
{"content", msg.content},
848-
});
849-
}
850-
try {
851-
minja::chat_template_inputs tmpl_inputs;
852-
tmpl_inputs.messages = messages;
853-
tmpl_inputs.add_generation_prompt = append;
854-
855-
minja::chat_template_options tmpl_opts;
856-
tmpl_opts.use_bos_token = false;
857-
tmpl_opts.use_eos_token = false;
858-
859-
auto result = tmpl.apply(tmpl_inputs, tmpl_opts);
860-
llama_data.fmtted.resize(result.size() + 1);
861-
memcpy(llama_data.fmtted.data(), result.c_str(), result.size() + 1);
862-
return result.size();
863-
} catch (const std::exception & e) {
864-
printe("failed to render the chat template: %s\n", e.what());
865-
return -1;
866-
}
873+
return handle_jinja_template(tmpl, llama_data, append);
867874
}
875+
868876
int result = llama_chat_apply_template(
869877
tmpl.source().c_str(), llama_data.messages.data(), llama_data.messages.size(), append,
870878
append ? llama_data.fmtted.data() : nullptr, append ? llama_data.fmtted.size() : 0);
879+
// If llama_chat_apply_template fails to apply template, fallback to using jinja
880+
if (result < 0) {
881+
return handle_jinja_template(tmpl, llama_data, append);
882+
}
883+
871884
if (append && result > static_cast<int>(llama_data.fmtted.size())) {
872885
llama_data.fmtted.resize(result);
873886
result = llama_chat_apply_template(tmpl.source().c_str(), llama_data.messages.data(),

0 commit comments

Comments
 (0)