@@ -98,9 +98,10 @@ class chat_formatter {
9898 std::vector<common_chat_msg> & chat_msgs,
9999 struct common_chat_templates * chat_templates,
100100 const llama_vocab * vocab,
101- toolcall::client::ptr tc_client)
101+ toolcall::client::ptr tc_client,
102+ common_chat_format chat_format)
102103
103- : params_(params), chat_msgs_(chat_msgs), chat_templates_(chat_templates), vocab_(vocab), tc_client_(tc_client) {}
104+ : params_(params), chat_msgs_(chat_msgs), chat_templates_(chat_templates), vocab_(vocab), tc_client_(tc_client), chat_format_(chat_format) {}
104105#endif
105106
106107 std::string operator () (const std::string & role, const std::string & content, [[maybe_unused]] bool use_toolcalls = false ) {
@@ -126,6 +127,7 @@ class chat_formatter {
126127 LOG_DBG (" formatted: '%s'\n " , formatted.c_str ());
127128
128129#ifdef LLAMA_USE_TOOLCALL
130+ if (chat_format_) *chat_format_ = cparams.format ;
129131 common_chat_grammar_to_sampler (&cparams, vocab_, ¶ms_.sampling );
130132#endif
131133 return formatted;
@@ -139,14 +141,16 @@ class chat_formatter {
139141#ifdef LLAMA_USE_TOOLCALL
140142 const llama_vocab * vocab_;
141143 toolcall::client::ptr tc_client_;
144+ common_chat_format * chat_format_;
142145#endif
143146};
144147
145148#ifdef LLAMA_USE_TOOLCALL
146- static bool call_tool (const std::string & assistant_msg, llama_context * ctx, toolcall::client::ptr tc_client, std::vector<llama_token> & embd_inp)
149+ static bool call_tool (common_chat_format chat_format, const std::string & assistant_msg, llama_context * ctx,
150+ toolcall::client::ptr tc_client, std::vector<llama_token> & embd_inp)
147151{
148152 bool tool_was_called = false ;
149- common_chat_msg msg = common_chat_parse (assistant_msg, COMMON_CHAT_FORMAT_GENERIC );
153+ common_chat_msg msg = common_chat_parse (assistant_msg, chat_format );
150154 if (! msg.tool_calls .empty ()) {
151155 for (const auto & tc : msg.tool_calls ) {
152156 nlohmann::json tc_oai_json {
@@ -371,7 +375,8 @@ int main(int argc, char ** argv) {
371375 if (tc_client) {
372376 tc_client->initialize ();
373377 }
374- chat_formatter chat_add_and_format (params, chat_msgs, chat_templates.get (), vocab, tc_client);
378+ common_chat_format chat_format = COMMON_CHAT_FORMAT_CONTENT_ONLY;
379+ chat_formatter chat_add_and_format (params, chat_msgs, chat_templates.get (), vocab, tc_client, &chat_format);
375380#else
376381 chat_formatter chat_add_and_format (params, chat_msgs, chat_templates.get ());
377382#endif
@@ -941,7 +946,7 @@ int main(int argc, char ** argv) {
941946#ifdef LLAMA_USE_TOOLCALL
942947 if ((tc_client && n_past > 0 ) && (waiting_for_first_input || is_interacting)) {
943948 size_t last_len = embd_inp.size ();
944- bool was_toolcall = call_tool (assistant_ss.str (), ctx, tc_client, embd_inp);
949+ bool was_toolcall = call_tool (chat_format, assistant_ss.str (), ctx, tc_client, embd_inp);
945950 if (was_toolcall && last_len < embd_inp.size ()) {
946951 LOG (" %s" , common_token_to_piece (ctx, embd_inp[last_len]).c_str ());
947952 }
0 commit comments