@@ -1831,22 +1831,27 @@ bool common_chat_verify_template(const std::string & tmpl, bool use_jinja) {
18311831}
18321832
18331833std::string common_chat_apply_template (
1834- const common_chat_template & tmpl,
1834+ const common_chat_templates & tmpl,
18351835 const std::vector<common_chat_msg> & msgs,
18361836 bool add_ass,
18371837 bool use_jinja,
18381838 const common_params_tools & tools)
18391839{
1840+ const auto & tmpl_selected =
1841+ tools.tools () && tmpl.template_tool_use ? *tmpl.template_tool_use : *tmpl.template_default ;
1842+
18401843 if (use_jinja) {
18411844 common_chat_inputs inputs;
18421845
18431846 auto messages = json::array ();
18441847 for (const auto & msg : msgs) {
18451848 messages.push_back ({{" role" , msg.role }, {" content" , msg.content }});
18461849 }
1850+
18471851 if (tools.tools () != nullptr ) {
18481852 inputs.tools = *tools.tools ();
18491853 }
1854+
18501855 auto choice = tools.choice ();
18511856 if (std::holds_alternative<std::string>(choice)) {
18521857 inputs.tool_choice = std::get<std::string>(choice);
@@ -1857,9 +1862,10 @@ std::string common_chat_apply_template(
18571862 inputs.tool_choice = *choice_ptr;
18581863 }
18591864 }
1865+
18601866 inputs.messages = messages;
18611867 inputs.add_generation_prompt = add_ass;
1862- return common_chat_params_init (tmpl , inputs).prompt ;
1868+ return common_chat_params_init (tmpl_selected , inputs).prompt ;
18631869 }
18641870
18651871 int alloc_size = 0 ;
@@ -1872,7 +1878,7 @@ std::string common_chat_apply_template(
18721878 std::vector<char > buf (alloc_size);
18731879
18741880 // run the first time to get the total output length
1875- int32_t res = llama_chat_apply_template (tmpl .source ().c_str (), chat.data (), chat.size (), add_ass, buf.data (), buf.size ());
1881+ int32_t res = llama_chat_apply_template (tmpl_selected .source ().c_str (), chat.data (), chat.size (), add_ass, buf.data (), buf.size ());
18761882
18771883 // error: chat template is not supported
18781884 if (res < 0 ) {
@@ -1884,15 +1890,15 @@ std::string common_chat_apply_template(
18841890 // if it turns out that our buffer is too small, we resize it
18851891 if ((size_t ) res > buf.size ()) {
18861892 buf.resize (res);
1887- res = llama_chat_apply_template (tmpl .source ().c_str (), chat.data (), chat.size (), add_ass, buf.data (), buf.size ());
1893+ res = llama_chat_apply_template (tmpl_selected .source ().c_str (), chat.data (), chat.size (), add_ass, buf.data (), buf.size ());
18881894 }
18891895
18901896 std::string formatted_chat (buf.data (), res);
18911897 return formatted_chat;
18921898}
18931899
18941900std::string common_chat_format_single (
1895- const common_chat_template & tmpl,
1901+ const common_chat_templates & tmpl,
18961902 const std::vector<common_chat_msg> & past_msg,
18971903 const common_chat_msg & new_msg,
18981904 bool add_ass,
@@ -1916,7 +1922,7 @@ std::string common_chat_format_single(
19161922 return ss.str ();
19171923}
19181924
1919- std::string common_chat_format_example (const common_chat_template & tmpl, bool use_jinja) {
1925+ std::string common_chat_format_example (const common_chat_templates & tmpl, bool use_jinja) {
19201926 std::vector<common_chat_msg> msgs = {
19211927 {" system" , " You are a helpful assistant" , {}},
19221928 {" user" , " Hello" , {}},
0 commit comments