Skip to content

Commit a9e3404

Browse files
committed
Separate tool-call from template application
1 parent ff44762 commit a9e3404

File tree

3 files changed

+84
-81
lines changed

3 files changed

+84
-81
lines changed

common/common.cpp

Lines changed: 46 additions & 68 deletions
Original file line numberDiff line numberDiff line change
@@ -1768,42 +1768,18 @@ std::string common_detokenize(const struct llama_vocab * vocab, const std::vecto
17681768
return text;
17691769
}
17701770

1771-
//
1772-
// Chat template utils
1773-
//
1774-
1775-
bool common_chat_verify_template(const std::string & tmpl, bool use_jinja) {
1776-
if (use_jinja) {
1777-
try {
1778-
auto chat_template = common_chat_template(tmpl, "<s>", "</s>");
1779-
common_chat_inputs inputs;
1780-
inputs.messages = json::array({{
1781-
{"role", "user"},
1782-
{"content", "test"},
1783-
}});
1784-
common_chat_params_init(chat_template, inputs);
1785-
return true;
1786-
} catch (const std::exception & e) {
1787-
LOG_ERR("%s: failed to apply template: %s\n", __func__, e.what());
1788-
return false;
1789-
}
1790-
}
1791-
llama_chat_message chat[] = {{"user", "test"}};
1792-
const int res = llama_chat_apply_template(tmpl.c_str(), chat, 1, true, nullptr, 0);
1793-
return res >= 0;
1794-
}
1795-
1796-
static void copy_chat_params(const common_chat_params & src, toolcall::sampling_updater * update_sparams)
1771+
void common_chat_grammar_to_sampler(const common_chat_params * src,
1772+
const llama_vocab * vocab,
1773+
common_params_sampling * sparams)
17971774
{
1798-
GGML_ASSERT(update_sparams && update_sparams->sparams && update_sparams->vocab);
1775+
GGML_ASSERT(src && vocab && sparams);
17991776

1800-
auto & dst = *update_sparams->sparams;
1801-
auto vocab = update_sparams->vocab;
1777+
auto & dst = *sparams;
18021778

1803-
dst.grammar = src.grammar;
1804-
dst.grammar_lazy = src.grammar_lazy;
1779+
dst.grammar = src->grammar;
1780+
dst.grammar_lazy = src->grammar_lazy;
18051781

1806-
for (const auto & trigger : src.grammar_triggers) {
1782+
for (const auto & trigger : src->grammar_triggers) {
18071783
auto ids = common_tokenize(vocab, trigger.word, false, true);
18081784

18091785
if (ids.size() == 1) {
@@ -1816,7 +1792,7 @@ static void copy_chat_params(const common_chat_params & src, toolcall::sampling_
18161792
dst.grammar_trigger_words.push_back(trigger);
18171793
}
18181794

1819-
for (const auto & preserved : src.preserved_tokens) {
1795+
for (const auto & preserved : src->preserved_tokens) {
18201796
auto ids = common_tokenize(vocab, preserved, false, true);
18211797
if (ids.size() == 1) {
18221798
LOG_DBG("Preserved token: %d\n", ids[0]);
@@ -1831,19 +1807,45 @@ static void copy_chat_params(const common_chat_params & src, toolcall::sampling_
18311807
}
18321808
}
18331809

1810+
1811+
//
1812+
// Chat template utils
1813+
//
1814+
1815+
bool common_chat_verify_template(const std::string & tmpl, bool use_jinja) {
1816+
if (use_jinja) {
1817+
try {
1818+
auto chat_template = common_chat_template(tmpl, "<s>", "</s>");
1819+
common_chat_inputs inputs;
1820+
inputs.messages = json::array({{
1821+
{"role", "user"},
1822+
{"content", "test"},
1823+
}});
1824+
common_chat_params_init(chat_template, inputs);
1825+
return true;
1826+
} catch (const std::exception & e) {
1827+
LOG_ERR("%s: failed to apply template: %s\n", __func__, e.what());
1828+
return false;
1829+
}
1830+
}
1831+
llama_chat_message chat[] = {{"user", "test"}};
1832+
const int res = llama_chat_apply_template(tmpl.c_str(), chat, 1, true, nullptr, 0);
1833+
return res >= 0;
1834+
}
1835+
18341836
std::string common_chat_apply_template(
18351837
const common_chat_templates & tmpl,
18361838
const std::vector<common_chat_msg> & msgs,
18371839
bool add_ass,
18381840
bool use_jinja,
1839-
toolcall::handler::ptr handler,
1840-
toolcall::sampling_updater * update_sparams)
1841+
const common_chat_inputs * inputs_,
1842+
common_chat_params * out_params)
18411843
{
1842-
bool use_tool_template = (use_jinja && handler != nullptr) && tmpl.template_tool_use;
1844+
bool use_tool_template = use_jinja && tmpl.template_tool_use;
18431845
const auto & tmpl_selected = use_tool_template ? *tmpl.template_tool_use : *tmpl.template_default;
18441846

18451847
if (use_jinja) {
1846-
common_chat_inputs inputs;
1848+
common_chat_inputs inputs = inputs_ ? *inputs_ : common_chat_inputs();
18471849

18481850
auto messages = json::array();
18491851
for (const auto & msg : msgs) {
@@ -1852,35 +1854,11 @@ std::string common_chat_apply_template(
18521854
inputs.messages = messages;
18531855
inputs.add_generation_prompt = add_ass;
18541856

1855-
if (handler != nullptr) {
1856-
auto choice = handler->tool_choice();
1857-
if (std::holds_alternative<std::string>(choice)) {
1858-
inputs.tool_choice = std::get<std::string>(choice);
1859-
1860-
} else {
1861-
auto choice_ptr = std::get<toolcall::json_ptr>(choice);
1862-
if (choice_ptr != nullptr) {
1863-
inputs.tool_choice = *choice_ptr;
1864-
}
1865-
}
1866-
1867-
inputs.tools = handler->tool_list();
1868-
}
1869-
18701857
auto chat_params = common_chat_params_init(tmpl_selected, inputs);
1871-
if (update_sparams) {
1872-
copy_chat_params(chat_params, update_sparams);
1873-
}
1874-
1875-
auto prompt = chat_params.prompt;
1876-
if (handler != nullptr) {
1877-
json response;
1878-
handler->call(prompt, response);
1879-
return response; // Caller will determine what to do based upon last_action
1880-
1881-
} else {
1882-
return prompt;
1858+
if (out_params != nullptr) {
1859+
*out_params = chat_params;
18831860
}
1861+
return chat_params.prompt;
18841862
}
18851863

18861864
int alloc_size = 0;
@@ -1918,12 +1896,12 @@ std::string common_chat_format_single(
19181896
const common_chat_msg & new_msg,
19191897
bool add_ass,
19201898
bool use_jinja,
1921-
toolcall::handler::ptr handler,
1922-
toolcall::sampling_updater * update_sparams)
1899+
const common_chat_inputs * inputs,
1900+
common_chat_params * out_params)
19231901
{
19241902
std::ostringstream ss;
19251903
auto fmt_past_msg = past_msg.empty() ? ""
1926-
: common_chat_apply_template(tmpl, past_msg, false, use_jinja, handler, update_sparams);
1904+
: common_chat_apply_template(tmpl, past_msg, false, use_jinja, inputs);
19271905

19281906
std::vector<common_chat_msg> chat_new(past_msg);
19291907
// if the past_msg ends with a newline, we must preserve it in the formatted version
@@ -1932,7 +1910,7 @@ std::string common_chat_format_single(
19321910
};
19331911
// format chat with new_msg
19341912
chat_new.push_back(new_msg);
1935-
auto fmt_new_msg = common_chat_apply_template(tmpl, chat_new, add_ass, use_jinja, handler, update_sparams);
1913+
auto fmt_new_msg = common_chat_apply_template(tmpl, chat_new, add_ass, use_jinja, inputs, out_params);
19361914
// get the diff part
19371915
ss << fmt_new_msg.substr(fmt_past_msg.size(), fmt_new_msg.size() - fmt_past_msg.size());
19381916
return ss.str();

common/common.h

Lines changed: 11 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -618,6 +618,13 @@ std::string common_detokenize(
618618
const std::vector<llama_token> & tokens,
619619
bool special = true);
620620

621+
struct common_chat_params;
622+
struct common_chat_inputs;
623+
void common_chat_grammar_to_sampler(const common_chat_params * src,
624+
const llama_vocab * vocab,
625+
common_params_sampling * sparams);
626+
627+
621628
//
622629
// Chat template utils
623630
//
@@ -651,13 +658,6 @@ struct common_chat_templates {
651658
std::unique_ptr<common_chat_template> template_tool_use;
652659
};
653660

654-
namespace toolcall {
655-
struct sampling_updater {
656-
common_params_sampling * sparams;
657-
const llama_vocab * vocab;
658-
};
659-
}
660-
661661
// CPP wrapper for llama_chat_apply_template
662662
// If the built-in template is not supported, we default to chatml
663663
// If the custom "tmpl" is not supported, we throw an error
@@ -666,8 +666,8 @@ std::string common_chat_apply_template(
666666
const std::vector<common_chat_msg> & chat,
667667
bool add_ass,
668668
bool use_jinja,
669-
toolcall::handler::ptr handler = nullptr,
670-
toolcall::sampling_updater * update_sparams = nullptr);
669+
const common_chat_inputs * inputs = nullptr,
670+
common_chat_params * out_params = nullptr);
671671

672672
// Format single message, while taking into account the position of that message in chat history
673673
std::string common_chat_format_single(
@@ -676,8 +676,8 @@ std::string common_chat_format_single(
676676
const common_chat_msg & new_msg,
677677
bool add_ass,
678678
bool use_jinja,
679-
toolcall::handler::ptr handler = nullptr,
680-
toolcall::sampling_updater * update_sparams = nullptr);
679+
const common_chat_inputs * inputs = nullptr,
680+
common_chat_params * out_params = nullptr);
681681

682682
// Returns an example of formatted chat
683683
std::string common_chat_format_example(

examples/main/main.cpp

Lines changed: 27 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,5 @@
11
#include "arg.h"
2+
#include "chat.hpp"
23
#include "common.h"
34
#include "console.h"
45
#include "log.h"
@@ -273,13 +274,37 @@ int main(int argc, char ** argv) {
273274

274275
common_chat_msg new_msg{role, content, {}};
275276

276-
toolcall::sampling_updater updater{&sparams, vocab};
277+
common_chat_inputs cinputs;
278+
if (handler != nullptr) {
279+
auto choice = handler->tool_choice();
280+
if (std::holds_alternative<std::string>(choice)) {
281+
cinputs.tool_choice = std::get<std::string>(choice);
282+
283+
} else {
284+
auto choice_ptr = std::get<toolcall::json_ptr>(choice);
285+
if (choice_ptr != nullptr) {
286+
cinputs.tool_choice = *choice_ptr;
287+
}
288+
}
289+
cinputs.tools = handler->tool_list();
290+
}
291+
292+
common_chat_params cparams;
277293
auto formatted =
278294
common_chat_format_single(chat_templates, chat_msgs, new_msg, add_ass, g_params->use_jinja,
279-
handler, &updater);
295+
&cinputs, &cparams);
280296

281297
chat_msgs.push_back({role, content, {}});
282298
LOG_DBG("formatted: '%s'\n", formatted.c_str());
299+
300+
if (g_params->use_jinja) {
301+
common_chat_grammar_to_sampler(&cparams, vocab, &sparams);
302+
if (handler != nullptr) {
303+
json response;
304+
handler->call(formatted, response);
305+
return std::string(response);
306+
}
307+
}
283308
return formatted;
284309
};
285310

0 commit comments

Comments
 (0)