Skip to content

Commit ba098af

Browse files
committed
Extract toolcall format from model
1 parent ac1fc31 commit ba098af

File tree

1 file changed

+11
-6
lines changed

1 file changed

+11
-6
lines changed

examples/main/main.cpp

Lines changed: 11 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -98,9 +98,10 @@ class chat_formatter {
9898
std::vector<common_chat_msg> & chat_msgs,
9999
struct common_chat_templates * chat_templates,
100100
const llama_vocab * vocab,
101-
toolcall::client::ptr tc_client)
101+
toolcall::client::ptr tc_client,
102+
common_chat_format chat_format)
102103

103-
: params_(params), chat_msgs_(chat_msgs), chat_templates_(chat_templates), vocab_(vocab), tc_client_(tc_client) {}
104+
: params_(params), chat_msgs_(chat_msgs), chat_templates_(chat_templates), vocab_(vocab), tc_client_(tc_client), chat_format_(chat_format) {}
104105
#endif
105106

106107
std::string operator () (const std::string & role, const std::string & content, [[maybe_unused]] bool use_toolcalls = false) {
@@ -126,6 +127,7 @@ class chat_formatter {
126127
LOG_DBG("formatted: '%s'\n", formatted.c_str());
127128

128129
#ifdef LLAMA_USE_TOOLCALL
130+
if (chat_format_) *chat_format_ = cparams.format;
129131
common_chat_grammar_to_sampler(&cparams, vocab_, &params_.sampling);
130132
#endif
131133
return formatted;
@@ -139,14 +141,16 @@ class chat_formatter {
139141
#ifdef LLAMA_USE_TOOLCALL
140142
const llama_vocab * vocab_;
141143
toolcall::client::ptr tc_client_;
144+
common_chat_format * chat_format_;
142145
#endif
143146
};
144147

145148
#ifdef LLAMA_USE_TOOLCALL
146-
static bool call_tool(const std::string & assistant_msg, llama_context * ctx, toolcall::client::ptr tc_client, std::vector<llama_token> & embd_inp)
149+
static bool call_tool(common_chat_format chat_format, const std::string & assistant_msg, llama_context * ctx,
150+
toolcall::client::ptr tc_client, std::vector<llama_token> & embd_inp)
147151
{
148152
bool tool_was_called = false;
149-
common_chat_msg msg = common_chat_parse(assistant_msg, COMMON_CHAT_FORMAT_GENERIC);
153+
common_chat_msg msg = common_chat_parse(assistant_msg, chat_format);
150154
if (! msg.tool_calls.empty()) {
151155
for (const auto & tc : msg.tool_calls) {
152156
nlohmann::json tc_oai_json {
@@ -371,7 +375,8 @@ int main(int argc, char ** argv) {
371375
if (tc_client) {
372376
tc_client->initialize();
373377
}
374-
chat_formatter chat_add_and_format(params, chat_msgs, chat_templates.get(), vocab, tc_client);
378+
common_chat_format chat_format = COMMON_CHAT_FORMAT_CONTENT_ONLY;
379+
chat_formatter chat_add_and_format(params, chat_msgs, chat_templates.get(), vocab, tc_client, &chat_format);
375380
#else
376381
chat_formatter chat_add_and_format(params, chat_msgs, chat_templates.get());
377382
#endif
@@ -941,7 +946,7 @@ int main(int argc, char ** argv) {
941946
#ifdef LLAMA_USE_TOOLCALL
942947
if ((tc_client && n_past > 0) && (waiting_for_first_input || is_interacting)) {
943948
size_t last_len = embd_inp.size();
944-
bool was_toolcall = call_tool(assistant_ss.str(), ctx, tc_client, embd_inp);
949+
bool was_toolcall = call_tool(chat_format, assistant_ss.str(), ctx, tc_client, embd_inp);
945950
if (was_toolcall && last_len < embd_inp.size()) {
946951
LOG("%s", common_token_to_piece(ctx, embd_inp[last_len]).c_str());
947952
}

0 commit comments

Comments
 (0)