@@ -1768,42 +1768,18 @@ std::string common_detokenize(const struct llama_vocab * vocab, const std::vecto
17681768 return text;
17691769}
17701770
1771- //
1772- // Chat template utils
1773- //
1774-
1775- bool common_chat_verify_template (const std::string & tmpl, bool use_jinja) {
1776- if (use_jinja) {
1777- try {
1778- auto chat_template = common_chat_template (tmpl, " <s>" , " </s>" );
1779- common_chat_inputs inputs;
1780- inputs.messages = json::array ({{
1781- {" role" , " user" },
1782- {" content" , " test" },
1783- }});
1784- common_chat_params_init (chat_template, inputs);
1785- return true ;
1786- } catch (const std::exception & e) {
1787- LOG_ERR (" %s: failed to apply template: %s\n " , __func__, e.what ());
1788- return false ;
1789- }
1790- }
1791- llama_chat_message chat[] = {{" user" , " test" }};
1792- const int res = llama_chat_apply_template (tmpl.c_str (), chat, 1 , true , nullptr , 0 );
1793- return res >= 0 ;
1794- }
1795-
1796- static void copy_chat_params (const common_chat_params & src, toolcall::sampling_updater * update_sparams)
1771+ void common_chat_grammar_to_sampler (const common_chat_params * src,
1772+ const llama_vocab * vocab,
1773+ common_params_sampling * sparams)
17971774{
1798- GGML_ASSERT (update_sparams && update_sparams-> sparams && update_sparams-> vocab );
1775+ GGML_ASSERT (src && vocab && sparams );
17991776
1800- auto & dst = *update_sparams->sparams ;
1801- auto vocab = update_sparams->vocab ;
1777+ auto & dst = *sparams;
18021778
1803- dst.grammar = src. grammar ;
1804- dst.grammar_lazy = src. grammar_lazy ;
1779+ dst.grammar = src-> grammar ;
1780+ dst.grammar_lazy = src-> grammar_lazy ;
18051781
1806- for (const auto & trigger : src. grammar_triggers ) {
1782+ for (const auto & trigger : src-> grammar_triggers ) {
18071783 auto ids = common_tokenize (vocab, trigger.word , false , true );
18081784
18091785 if (ids.size () == 1 ) {
@@ -1816,7 +1792,7 @@ static void copy_chat_params(const common_chat_params & src, toolcall::sampling_
18161792 dst.grammar_trigger_words .push_back (trigger);
18171793 }
18181794
1819- for (const auto & preserved : src. preserved_tokens ) {
1795+ for (const auto & preserved : src-> preserved_tokens ) {
18201796 auto ids = common_tokenize (vocab, preserved, false , true );
18211797 if (ids.size () == 1 ) {
18221798 LOG_DBG (" Preserved token: %d\n " , ids[0 ]);
@@ -1831,19 +1807,45 @@ static void copy_chat_params(const common_chat_params & src, toolcall::sampling_
18311807 }
18321808}
18331809
1810+
1811+ //
1812+ // Chat template utils
1813+ //
1814+
1815+ bool common_chat_verify_template (const std::string & tmpl, bool use_jinja) {
1816+ if (use_jinja) {
1817+ try {
1818+ auto chat_template = common_chat_template (tmpl, " <s>" , " </s>" );
1819+ common_chat_inputs inputs;
1820+ inputs.messages = json::array ({{
1821+ {" role" , " user" },
1822+ {" content" , " test" },
1823+ }});
1824+ common_chat_params_init (chat_template, inputs);
1825+ return true ;
1826+ } catch (const std::exception & e) {
1827+ LOG_ERR (" %s: failed to apply template: %s\n " , __func__, e.what ());
1828+ return false ;
1829+ }
1830+ }
1831+ llama_chat_message chat[] = {{" user" , " test" }};
1832+ const int res = llama_chat_apply_template (tmpl.c_str (), chat, 1 , true , nullptr , 0 );
1833+ return res >= 0 ;
1834+ }
1835+
18341836std::string common_chat_apply_template (
18351837 const common_chat_templates & tmpl,
18361838 const std::vector<common_chat_msg> & msgs,
18371839 bool add_ass,
18381840 bool use_jinja,
1839- toolcall::handler::ptr handler ,
1840- toolcall::sampling_updater * update_sparams )
1841+ const common_chat_inputs * inputs_ ,
1842+ common_chat_params * out_params )
18411843{
1842- bool use_tool_template = ( use_jinja && handler != nullptr ) && tmpl.template_tool_use ;
1844+ bool use_tool_template = use_jinja && tmpl.template_tool_use ;
18431845 const auto & tmpl_selected = use_tool_template ? *tmpl.template_tool_use : *tmpl.template_default ;
18441846
18451847 if (use_jinja) {
1846- common_chat_inputs inputs;
1848+ common_chat_inputs inputs = inputs_ ? *inputs_ : common_chat_inputs () ;
18471849
18481850 auto messages = json::array ();
18491851 for (const auto & msg : msgs) {
@@ -1852,35 +1854,11 @@ std::string common_chat_apply_template(
18521854 inputs.messages = messages;
18531855 inputs.add_generation_prompt = add_ass;
18541856
1855- if (handler != nullptr ) {
1856- auto choice = handler->tool_choice ();
1857- if (std::holds_alternative<std::string>(choice)) {
1858- inputs.tool_choice = std::get<std::string>(choice);
1859-
1860- } else {
1861- auto choice_ptr = std::get<toolcall::json_ptr>(choice);
1862- if (choice_ptr != nullptr ) {
1863- inputs.tool_choice = *choice_ptr;
1864- }
1865- }
1866-
1867- inputs.tools = handler->tool_list ();
1868- }
1869-
18701857 auto chat_params = common_chat_params_init (tmpl_selected, inputs);
1871- if (update_sparams) {
1872- copy_chat_params (chat_params, update_sparams);
1873- }
1874-
1875- auto prompt = chat_params.prompt ;
1876- if (handler != nullptr ) {
1877- json response;
1878- handler->call (prompt, response);
1879- return response; // Caller will determine what to do based upon last_action
1880-
1881- } else {
1882- return prompt;
1858+ if (out_params != nullptr ) {
1859+ *out_params = chat_params;
18831860 }
1861+ return chat_params.prompt ;
18841862 }
18851863
18861864 int alloc_size = 0 ;
@@ -1918,12 +1896,12 @@ std::string common_chat_format_single(
19181896 const common_chat_msg & new_msg,
19191897 bool add_ass,
19201898 bool use_jinja,
1921- toolcall::handler::ptr handler ,
1922- toolcall::sampling_updater * update_sparams )
1899+ const common_chat_inputs * inputs ,
1900+ common_chat_params * out_params )
19231901{
19241902 std::ostringstream ss;
19251903 auto fmt_past_msg = past_msg.empty () ? " "
1926- : common_chat_apply_template (tmpl, past_msg, false , use_jinja, handler, update_sparams );
1904+ : common_chat_apply_template (tmpl, past_msg, false , use_jinja, inputs );
19271905
19281906 std::vector<common_chat_msg> chat_new (past_msg);
19291907 // if the past_msg ends with a newline, we must preserve it in the formatted version
@@ -1932,7 +1910,7 @@ std::string common_chat_format_single(
19321910 };
19331911 // format chat with new_msg
19341912 chat_new.push_back (new_msg);
1935- auto fmt_new_msg = common_chat_apply_template (tmpl, chat_new, add_ass, use_jinja, handler, update_sparams );
1913+ auto fmt_new_msg = common_chat_apply_template (tmpl, chat_new, add_ass, use_jinja, inputs, out_params );
19361914 // get the diff part
19371915 ss << fmt_new_msg.substr (fmt_past_msg.size (), fmt_new_msg.size () - fmt_past_msg.size ());
19381916 return ss.str ();
0 commit comments