@@ -837,37 +837,50 @@ static void add_message(const char * role, const std::string & text, LlamaData &
837837 llama_data.messages .push_back ({ role, llama_data.msg_strs .back ().c_str () });
838838}
839839
840+ // Function to handle Jinja template application
841+ static int handle_jinja_template (const common_chat_template & tmpl, LlamaData & llama_data, const bool append) {
842+ json messages = json::array ();
843+ for (const auto & msg : llama_data.messages ) {
844+ messages.push_back ({
845+ { " role" , msg.role },
846+ { " content" , msg.content },
847+ });
848+ }
849+
850+ try {
851+ minja::chat_template_inputs tmpl_inputs;
852+ tmpl_inputs.messages = messages;
853+ tmpl_inputs.add_generation_prompt = append;
854+
855+ minja::chat_template_options tmpl_opts;
856+ tmpl_opts.use_bos_token = false ;
857+ tmpl_opts.use_eos_token = false ;
858+
859+ auto result = tmpl.apply (tmpl_inputs, tmpl_opts);
860+ llama_data.fmtted .resize (result.size () + 1 );
861+ memcpy (llama_data.fmtted .data (), result.c_str (), result.size () + 1 );
862+ return result.size ();
863+ } catch (const std::exception & e) {
864+ printe (" failed to render the chat template: %s\n " , e.what ());
865+ }
866+
867+ return -1 ;
868+ }
869+
840870// Function to apply the chat template and resize `formatted` if needed
841871static int apply_chat_template (const common_chat_template & tmpl, LlamaData & llama_data, const bool append, bool use_jinja) {
842872 if (use_jinja) {
843- json messages = json::array ();
844- for (const auto & msg : llama_data.messages ) {
845- messages.push_back ({
846- {" role" , msg.role },
847- {" content" , msg.content },
848- });
849- }
850- try {
851- minja::chat_template_inputs tmpl_inputs;
852- tmpl_inputs.messages = messages;
853- tmpl_inputs.add_generation_prompt = append;
854-
855- minja::chat_template_options tmpl_opts;
856- tmpl_opts.use_bos_token = false ;
857- tmpl_opts.use_eos_token = false ;
858-
859- auto result = tmpl.apply (tmpl_inputs, tmpl_opts);
860- llama_data.fmtted .resize (result.size () + 1 );
861- memcpy (llama_data.fmtted .data (), result.c_str (), result.size () + 1 );
862- return result.size ();
863- } catch (const std::exception & e) {
864- printe (" failed to render the chat template: %s\n " , e.what ());
865- return -1 ;
866- }
873+ return handle_jinja_template (tmpl, llama_data, append);
867874 }
875+
868876 int result = llama_chat_apply_template (
869877 tmpl.source ().c_str (), llama_data.messages .data (), llama_data.messages .size (), append,
870878 append ? llama_data.fmtted .data () : nullptr , append ? llama_data.fmtted .size () : 0 );
879+ // If llama_chat_apply_template fails to apply template, fallback to using jinja
880+ if (result < 0 ) {
881+ return handle_jinja_template (tmpl, llama_data, append);
882+ }
883+
871884 if (append && result > static_cast <int >(llama_data.fmtted .size ())) {
872885 llama_data.fmtted .resize (result);
873886 result = llama_chat_apply_template (tmpl.source ().c_str (), llama_data.messages .data (),
0 commit comments