Skip to content

Commit a726ada

Browse files
committed
Copy sampler parameters from chat template
1 parent a30111b commit a726ada

File tree

3 files changed

+65
-10
lines changed

3 files changed

+65
-10
lines changed

common/common.cpp

Lines changed: 49 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -1830,12 +1830,51 @@ bool common_chat_verify_template(const std::string & tmpl, bool use_jinja) {
18301830
return res >= 0;
18311831
}
18321832

1833+
static void copy_chat_params(const common_chat_params & src, common_chat_sampling_updater * update_sparams)
1834+
{
1835+
GGML_ASSERT(update_sparams && update_sparams->sparams && update_sparams->vocab);
1836+
1837+
auto & dst = *update_sparams->sparams;
1838+
auto vocab = update_sparams->vocab;
1839+
1840+
dst.grammar = src.grammar;
1841+
dst.grammar_lazy = src.grammar_lazy;
1842+
1843+
for (const auto & trigger : src.grammar_triggers) {
1844+
auto ids = common_tokenize(vocab, trigger.word, false, true);
1845+
1846+
if (ids.size() == 1) {
1847+
LOG_DBG("Grammar trigger token: %d (`%s`)\n", ids[0], trigger.word.c_str());
1848+
dst.grammar_trigger_tokens.push_back(ids[0]);
1849+
dst.preserved_tokens.insert(ids[0]);
1850+
continue;
1851+
}
1852+
LOG_DBG("Grammar trigger word: `%s`\n", trigger.word.c_str());
1853+
dst.grammar_trigger_words.push_back(trigger);
1854+
}
1855+
1856+
for (const auto & preserved : src.preserved_tokens) {
1857+
auto ids = common_tokenize(vocab, preserved, false, true);
1858+
if (ids.size() == 1) {
1859+
LOG_DBG("Preserved token: %d\n", ids[0]);
1860+
dst.preserved_tokens.insert(ids[0]);
1861+
1862+
} else {
1863+
// This may happen when using a tool call style meant for a model
1864+
// with special tokens to preserve on a model without said tokens.
1865+
LOG_WRN("Not preserved because more than 1 token (wrong chat template override?): %s\n",
1866+
preserved.c_str());
1867+
}
1868+
}
1869+
}
1870+
18331871
std::string common_chat_apply_template(
18341872
const common_chat_templates & tmpl,
18351873
const std::vector<common_chat_msg> & msgs,
18361874
bool add_ass,
18371875
bool use_jinja,
1838-
const common_params_tools & tools)
1876+
const common_params_tools & tools,
1877+
common_chat_sampling_updater * update_sparams)
18391878
{
18401879
const auto & tmpl_selected =
18411880
tools.tools() && tmpl.template_tool_use ? *tmpl.template_tool_use : *tmpl.template_default;
@@ -1865,7 +1904,11 @@ std::string common_chat_apply_template(
18651904

18661905
inputs.messages = messages;
18671906
inputs.add_generation_prompt = add_ass;
1868-
return common_chat_params_init(tmpl_selected, inputs).prompt;
1907+
auto chat_params = common_chat_params_init(tmpl_selected, inputs);
1908+
if (update_sparams) {
1909+
copy_chat_params(chat_params, update_sparams);
1910+
}
1911+
return chat_params.prompt;
18691912
}
18701913

18711914
int alloc_size = 0;
@@ -1903,11 +1946,12 @@ std::string common_chat_format_single(
19031946
const common_chat_msg & new_msg,
19041947
bool add_ass,
19051948
bool use_jinja,
1906-
const common_params_tools & tools)
1949+
const common_params_tools & tools,
1950+
common_chat_sampling_updater * update_sparams)
19071951
{
19081952
std::ostringstream ss;
19091953
auto fmt_past_msg = past_msg.empty() ? ""
1910-
: common_chat_apply_template(tmpl, past_msg, false, use_jinja, tools);
1954+
: common_chat_apply_template(tmpl, past_msg, false, use_jinja, tools, update_sparams);
19111955

19121956
std::vector<common_chat_msg> chat_new(past_msg);
19131957
// if the past_msg ends with a newline, we must preserve it in the formatted version
@@ -1916,7 +1960,7 @@ std::string common_chat_format_single(
19161960
};
19171961
// format chat with new_msg
19181962
chat_new.push_back(new_msg);
1919-
auto fmt_new_msg = common_chat_apply_template(tmpl, chat_new, add_ass, use_jinja, tools);
1963+
auto fmt_new_msg = common_chat_apply_template(tmpl, chat_new, add_ass, use_jinja, tools, update_sparams);
19201964
// get the diff part
19211965
ss << fmt_new_msg.substr(fmt_past_msg.size(), fmt_new_msg.size() - fmt_past_msg.size());
19221966
return ss.str();

common/common.h

Lines changed: 9 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -671,6 +671,11 @@ struct common_chat_templates {
671671
std::unique_ptr<common_chat_template> template_tool_use;
672672
};
673673

674+
struct common_chat_sampling_updater {
675+
common_params_sampling * sparams;
676+
const llama_vocab * vocab;
677+
};
678+
674679
// CPP wrapper for llama_chat_apply_template
675680
// If the built-in template is not supported, we default to chatml
676681
// If the custom "tmpl" is not supported, we throw an error
@@ -679,7 +684,8 @@ std::string common_chat_apply_template(
679684
const std::vector<common_chat_msg> & chat,
680685
bool add_ass,
681686
bool use_jinja,
682-
const common_params_tools & tools = common_params_tools());
687+
const common_params_tools & tools = common_params_tools(),
688+
common_chat_sampling_updater * update_sparams = nullptr);
683689

684690
// Format single message, while taking into account the position of that message in chat history
685691
std::string common_chat_format_single(
@@ -688,7 +694,8 @@ std::string common_chat_format_single(
688694
const common_chat_msg & new_msg,
689695
bool add_ass,
690696
bool use_jinja,
691-
const common_params_tools & tools = common_params_tools());
697+
const common_params_tools & tools = common_params_tools(),
698+
common_chat_sampling_updater * update_sparams = nullptr);
692699

693700
// Returns an example of formatted chat
694701
std::string common_chat_format_example(

examples/main/main.cpp

Lines changed: 7 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -263,14 +263,18 @@ int main(int argc, char ** argv) {
263263

264264
std::vector<llama_token> embd_inp;
265265

266-
auto chat_add_and_format = [&chat_msgs, &chat_templates](
266+
auto chat_add_and_format = [&chat_msgs, &chat_templates, &sparams, vocab](
267267
const std::string & role, const std::string & content,
268268
const common_params_tools & tools = common_params_tools())
269269
{
270+
bool add_ass = (role == "user");
271+
270272
common_chat_msg new_msg{role, content, {}};
271273

272-
auto formatted = common_chat_format_single(chat_templates, chat_msgs,
273-
new_msg, role == "user", g_params->use_jinja, tools);
274+
common_chat_sampling_updater updater{&sparams, vocab};
275+
auto formatted =
276+
common_chat_format_single(chat_templates, chat_msgs, new_msg, add_ass, g_params->use_jinja,
277+
tools, &updater);
274278

275279
chat_msgs.push_back({role, content, {}});
276280
LOG_DBG("formatted: '%s'\n", formatted.c_str());

0 commit comments

Comments
 (0)