99#include " json.hpp"
1010#include " json-schema-to-grammar.h"
1111#include " llama.h"
12+ #include " chat-template.h"
1213
1314#include < algorithm>
1415#include < cinttypes>
@@ -1511,6 +1512,20 @@ std::string llama_detokenize(llama_context * ctx, const std::vector<llama_token>
15111512//
15121513
15131514bool llama_chat_verify_template (const std::string & tmpl, bool use_jinja) {
1515+ if (use_jinja) {
1516+ try {
1517+ auto chat_template = llama_chat_template (tmpl, " <s>" , " </s>" );
1518+ chat_template.apply ({{
1519+ {" role" , " user" },
1520+ {" content" , " test" },
1521+ }}, json (), true );
1522+ return true ;
1523+ } catch (const std::exception & e) {
1524+ LOG_ERR (" %s: failed to apply template: %s\n " , __func__, e.what ());
1525+ return false ;
1526+ }
1527+ }
1528+
15141529 llama_chat_message chat[] = {{" user" , " test" }};
15151530 int res = llama_chat_apply_template (
15161531 nullptr ,
@@ -1519,22 +1534,14 @@ bool llama_chat_verify_template(const std::string & tmpl, bool use_jinja) {
15191534 1 ,
15201535 /* add_ass= */ true ,
15211536 /* buffer= */ nullptr ,
1522- /* length= */ 0 ,
1523- use_jinja,
1524- /* tools= */ nullptr ,
1525- " <s>" ,
1526- " </s>" );
1537+ /* length= */ 0 );
15271538 return res >= 0 ;
15281539}
15291540
15301541std::string llama_chat_apply_template (const struct llama_model * model,
15311542 const std::string & tmpl,
15321543 const std::vector<llama_chat_msg> & msgs,
1533- bool add_ass,
1534- bool use_jinja,
1535- const char * tools,
1536- const char * bos_token,
1537- const char * eos_token) {
1544+ bool add_ass) {
15381545 int alloc_size = 0 ;
15391546 bool fallback = false ; // indicate if we must fallback to default chatml
15401547 std::vector<llama_chat_message> chat;
@@ -1547,7 +1554,7 @@ std::string llama_chat_apply_template(const struct llama_model * model,
15471554 std::vector<char > buf (alloc_size);
15481555
15491556 // run the first time to get the total output length
1550- int32_t res = llama_chat_apply_template (model, ptr_tmpl, chat.data (), chat.size (), add_ass, buf.data (), buf.size (), use_jinja, tools, bos_token, eos_token );
1557+ int32_t res = llama_chat_apply_template (model, ptr_tmpl, chat.data (), chat.size (), add_ass, buf.data (), buf.size ());
15511558
15521559 // error: chat template is not supported
15531560 if (res < 0 ) {
@@ -1557,7 +1564,7 @@ std::string llama_chat_apply_template(const struct llama_model * model,
15571564 throw std::runtime_error (" this custom template is not supported" );
15581565 } else {
15591566 // If the built-in template is not supported, we default to chatml
1560- res = llama_chat_apply_template (nullptr , " chatml" , chat.data (), chat.size (), add_ass, buf.data (), buf.size (), use_jinja, tools, bos_token, eos_token );
1567+ res = llama_chat_apply_template (nullptr , " chatml" , chat.data (), chat.size (), add_ass, buf.data (), buf.size ());
15611568 fallback = true ;
15621569 }
15631570 }
@@ -1568,7 +1575,7 @@ std::string llama_chat_apply_template(const struct llama_model * model,
15681575 res = llama_chat_apply_template (
15691576 fallback ? nullptr : model,
15701577 fallback ? " chatml" : ptr_tmpl,
1571- chat.data (), chat.size (), add_ass, buf.data (), buf.size (), use_jinja, tools, bos_token, eos_token );
1578+ chat.data (), chat.size (), add_ass, buf.data (), buf.size ());
15721579 }
15731580
15741581 std::string formatted_chat (buf.data (), res);
@@ -1579,21 +1586,17 @@ std::string llama_chat_format_single(const struct llama_model * model,
15791586 const std::string & tmpl,
15801587 const std::vector<llama_chat_msg> & past_msg,
15811588 const llama_chat_msg & new_msg,
1582- bool add_ass,
1583- bool use_jinja,
1584- const char * tools,
1585- const char * bos_token,
1586- const char * eos_token) {
1589+ bool add_ass) {
15871590 std::ostringstream ss;
1588- auto fmt_past_msg = past_msg.empty () ? " " : llama_chat_apply_template (model, tmpl, past_msg, false , use_jinja, tools, bos_token, eos_token );
1591+ auto fmt_past_msg = past_msg.empty () ? " " : llama_chat_apply_template (model, tmpl, past_msg, false );
15891592 std::vector<llama_chat_msg> chat_new (past_msg);
15901593 // if the past_msg ends with a newline, we must preserve it in the formatted version
15911594 if (add_ass && !fmt_past_msg.empty () && fmt_past_msg.back () == ' \n ' ) {
15921595 ss << " \n " ;
15931596 };
15941597 // format chat with new_msg
15951598 chat_new.push_back (new_msg);
1596- auto fmt_new_msg = llama_chat_apply_template (model, tmpl, chat_new, add_ass, use_jinja, tools, bos_token, eos_token );
1599+ auto fmt_new_msg = llama_chat_apply_template (model, tmpl, chat_new, add_ass);
15971600 // get the diff part
15981601 ss << fmt_new_msg.substr (fmt_past_msg.size (), fmt_new_msg.size () - fmt_past_msg.size ());
15991602 return ss.str ();
0 commit comments