@@ -1830,12 +1830,51 @@ bool common_chat_verify_template(const std::string & tmpl, bool use_jinja) {
18301830 return res >= 0 ;
18311831}
18321832
1833+ static void copy_chat_params (const common_chat_params & src, common_chat_sampling_updater * update_sparams)
1834+ {
1835+ GGML_ASSERT (update_sparams && update_sparams->sparams && update_sparams->vocab );
1836+
1837+ auto & dst = *update_sparams->sparams ;
1838+ auto vocab = update_sparams->vocab ;
1839+
1840+ dst.grammar = src.grammar ;
1841+ dst.grammar_lazy = src.grammar_lazy ;
1842+
1843+ for (const auto & trigger : src.grammar_triggers ) {
1844+ auto ids = common_tokenize (vocab, trigger.word , false , true );
1845+
1846+ if (ids.size () == 1 ) {
1847+ LOG_DBG (" Grammar trigger token: %d (`%s`)\n " , ids[0 ], trigger.word .c_str ());
1848+ dst.grammar_trigger_tokens .push_back (ids[0 ]);
1849+ dst.preserved_tokens .insert (ids[0 ]);
1850+ continue ;
1851+ }
1852+ LOG_DBG (" Grammar trigger word: `%s`\n " , trigger.word .c_str ());
1853+ dst.grammar_trigger_words .push_back (trigger);
1854+ }
1855+
1856+ for (const auto & preserved : src.preserved_tokens ) {
1857+ auto ids = common_tokenize (vocab, preserved, false , true );
1858+ if (ids.size () == 1 ) {
1859+ LOG_DBG (" Preserved token: %d\n " , ids[0 ]);
1860+ dst.preserved_tokens .insert (ids[0 ]);
1861+
1862+ } else {
1863+ // This may happen when using a tool call style meant for a model
1864+ // with special tokens to preserve on a model without said tokens.
1865+ LOG_WRN (" Not preserved because more than 1 token (wrong chat template override?): %s\n " ,
1866+ preserved.c_str ());
1867+ }
1868+ }
1869+ }
1870+
18331871std::string common_chat_apply_template (
18341872 const common_chat_templates & tmpl,
18351873 const std::vector<common_chat_msg> & msgs,
18361874 bool add_ass,
18371875 bool use_jinja,
1838- const common_params_tools & tools)
1876+ const common_params_tools & tools,
1877+ common_chat_sampling_updater * update_sparams)
18391878{
18401879 const auto & tmpl_selected =
18411880 tools.tools () && tmpl.template_tool_use ? *tmpl.template_tool_use : *tmpl.template_default ;
@@ -1865,7 +1904,11 @@ std::string common_chat_apply_template(
18651904
18661905 inputs.messages = messages;
18671906 inputs.add_generation_prompt = add_ass;
1868- return common_chat_params_init (tmpl_selected, inputs).prompt ;
1907+ auto chat_params = common_chat_params_init (tmpl_selected, inputs);
1908+ if (update_sparams) {
1909+ copy_chat_params (chat_params, update_sparams);
1910+ }
1911+ return chat_params.prompt ;
18691912 }
18701913
18711914 int alloc_size = 0 ;
@@ -1903,11 +1946,12 @@ std::string common_chat_format_single(
19031946 const common_chat_msg & new_msg,
19041947 bool add_ass,
19051948 bool use_jinja,
1906- const common_params_tools & tools)
1949+ const common_params_tools & tools,
1950+ common_chat_sampling_updater * update_sparams)
19071951{
19081952 std::ostringstream ss;
19091953 auto fmt_past_msg = past_msg.empty () ? " "
1910- : common_chat_apply_template (tmpl, past_msg, false , use_jinja, tools);
1954+ : common_chat_apply_template (tmpl, past_msg, false , use_jinja, tools, update_sparams );
19111955
19121956 std::vector<common_chat_msg> chat_new (past_msg);
19131957 // if the past_msg ends with a newline, we must preserve it in the formatted version
@@ -1916,7 +1960,7 @@ std::string common_chat_format_single(
19161960 };
19171961 // format chat with new_msg
19181962 chat_new.push_back (new_msg);
1919- auto fmt_new_msg = common_chat_apply_template (tmpl, chat_new, add_ass, use_jinja, tools);
1963+ auto fmt_new_msg = common_chat_apply_template (tmpl, chat_new, add_ass, use_jinja, tools, update_sparams );
19201964 // get the diff part
19211965 ss << fmt_new_msg.substr (fmt_past_msg.size (), fmt_new_msg.size () - fmt_past_msg.size ());
19221966 return ss.str ();
0 commit comments