@@ -143,39 +143,37 @@ class chat_formatter {
143143};
144144
145145#ifdef LLAMA_USE_TOOLCALL
146- static bool call_tool (common_params & params, const std::string & assistant_msg,
147- llama_context * ctx, toolcall::client::ptr tc_client, std::vector<llama_token> & embd_inp)
146+ static bool call_tool (const std::string & assistant_msg, llama_context * ctx, toolcall::client::ptr tc_client, std::vector<llama_token> & embd_inp)
148147{
149- auto should_use_toolcall = [¶ms, tc_client] (const std::string & asst_msg) {
150- if (! params.use_jinja || tc_client == nullptr ) {
151- return false ;
152- }
153- try {
154- nlohmann::json j = nlohmann::json::parse (asst_msg);
155- return (j.contains (" name" ) && j.contains (" parameters" ));
156-
157- } catch (const nlohmann::json::exception & err) {
158- return false ;
159- }
160- };
161-
162- if (should_use_toolcall (assistant_msg)) {
163- toolcall::result_set res = tc_client->call (assistant_msg);
164- if (! res.empty ()) {
165- std::string toolcall_result_str;
166- for (const auto & r : res) {
167- toolcall_result_str += (" \n " + r.data ); // Although more complex results can be
168- // returned (resources, images, etc.),
169- // for now simply append the data. Later
170- // on support for specific models may
171- // allow for unpacking Base64 data.
148+ bool tool_was_called = false ;
149+ common_chat_msg msg = common_chat_parse (assistant_msg, COMMON_CHAT_FORMAT_GENERIC);
150+ if (! msg.tool_calls .empty ()) {
151+ for (const auto & tc : msg.tool_calls ) {
152+ nlohmann::json tc_oai_json {
153+ {" type" , " function" },
154+ {" function" , {
155+ {" name" , tc.name },
156+ {" arguments" , tc.arguments },
157+ }},
158+ {" id" , tc.id },
159+ };
160+ toolcall::result_set res = tc_client->call (tc_oai_json);
161+ if (! res.empty ()) {
162+ std::string toolcall_result_str;
163+ for (const auto & r : res) {
164+ toolcall_result_str += (" \n " + r.data ); // Although more complex results can be
165+ // returned (resources, images, etc.),
166+ // for now simply append the data. Later
167+ // on support for specific models may
168+ // allow for unpacking Base64 data.
169+ }
170+ auto toolcall_result_tok = common_tokenize (ctx, toolcall_result_str, false , true );
171+ embd_inp.insert (embd_inp.end (), toolcall_result_tok.begin (), toolcall_result_tok.end ());
172172 }
173- auto toolcall_result_tok = common_tokenize (ctx, toolcall_result_str, false , true );
174- embd_inp.insert (embd_inp.end (), toolcall_result_tok.begin (), toolcall_result_tok.end ());
173+ tool_was_called = true ;
175174 }
176- return true ;
177175 }
178- return false ;
176+ return tool_was_called ;
179177}
180178#endif
181179
@@ -923,13 +921,8 @@ int main(int argc, char ** argv) {
923921
924922 if (params.enable_chat_template ) {
925923 chat_add_and_format (" assistant" , assistant_ss.str (), true );
926- #ifdef LLAMA_USE_TOOLCALL
927- is_interacting = ! call_tool (params, assistant_ss.str (), ctx, tc_client, embd_inp);
928- LOG (" \n " );
929- #else
930924 is_interacting = true ;
931925 LOG (" \n " );
932- #endif
933926 }
934927 }
935928 }
@@ -945,6 +938,16 @@ int main(int argc, char ** argv) {
945938 }
946939 }
947940
941+ #ifdef LLAMA_USE_TOOLCALL
942+ if ((tc_client && n_past > 0 ) && (waiting_for_first_input || is_interacting)) {
943+ size_t last_len = embd_inp.size ();
944+ bool was_toolcall = call_tool (assistant_ss.str (), ctx, tc_client, embd_inp);
945+ if (was_toolcall && last_len < embd_inp.size ()) {
946+ LOG (" %s" , common_token_to_piece (ctx, embd_inp[last_len]).c_str ());
947+ }
948+ }
949+ #endif
950+
948951 if ((n_past > 0 || waiting_for_first_input) && is_interacting) {
949952 LOG_DBG (" waiting for user input\n " );
950953
@@ -1049,16 +1052,6 @@ int main(int argc, char ** argv) {
10491052 }
10501053 }
10511054
1052- #ifdef LLAMA_USE_TOOLCALL
1053- if (params.single_turn ) {
1054- size_t last_len = embd_inp.size ();
1055- bool was_toolcall = call_tool (params, assistant_ss.str (), ctx, tc_client, embd_inp);
1056- if (was_toolcall && last_len < embd_inp.size ()) {
1057- LOG (" %s" , common_token_to_piece (ctx, embd_inp[last_len]).c_str ());
1058- }
1059- }
1060- #endif
1061-
10621055 // end of generation
10631056 if (!embd.empty () && llama_vocab_is_eog (vocab, embd.back ()) && !(params.interactive )) {
10641057 LOG (" [end of text]\n " );
0 commit comments