Skip to content

Commit ac1fc31

Browse files
committed
WIP: use common_chat_parse for toolcall
1 parent 46766c1 commit ac1fc31

File tree

1 file changed

+37
-44
lines changed

1 file changed

+37
-44
lines changed

examples/main/main.cpp

Lines changed: 37 additions & 44 deletions
Original file line numberDiff line numberDiff line change
@@ -143,39 +143,37 @@ class chat_formatter {
143143
};
144144

145145
#ifdef LLAMA_USE_TOOLCALL
146-
static bool call_tool(common_params & params, const std::string & assistant_msg,
147-
llama_context * ctx, toolcall::client::ptr tc_client, std::vector<llama_token> & embd_inp)
146+
static bool call_tool(const std::string & assistant_msg, llama_context * ctx, toolcall::client::ptr tc_client, std::vector<llama_token> & embd_inp)
148147
{
149-
auto should_use_toolcall = [&params, tc_client] (const std::string & asst_msg) {
150-
if (! params.use_jinja || tc_client == nullptr) {
151-
return false;
152-
}
153-
try {
154-
nlohmann::json j = nlohmann::json::parse(asst_msg);
155-
return (j.contains("name") && j.contains("parameters"));
156-
157-
} catch (const nlohmann::json::exception & err) {
158-
return false;
159-
}
160-
};
161-
162-
if (should_use_toolcall(assistant_msg)) {
163-
toolcall::result_set res = tc_client->call(assistant_msg);
164-
if (! res.empty()) {
165-
std::string toolcall_result_str;
166-
for (const auto & r : res) {
167-
toolcall_result_str += ("\n" + r.data); // Although more complex results can be
168-
// returned (resources, images, etc.),
169-
// for now simply append the data. Later
170-
// on support for specific models may
171-
// allow for unpacking Base64 data.
148+
bool tool_was_called = false;
149+
common_chat_msg msg = common_chat_parse(assistant_msg, COMMON_CHAT_FORMAT_GENERIC);
150+
if (! msg.tool_calls.empty()) {
151+
for (const auto & tc : msg.tool_calls) {
152+
nlohmann::json tc_oai_json {
153+
{"type", "function"},
154+
{"function", {
155+
{"name", tc.name},
156+
{"arguments", tc.arguments},
157+
}},
158+
{"id", tc.id},
159+
};
160+
toolcall::result_set res = tc_client->call(tc_oai_json);
161+
if (! res.empty()) {
162+
std::string toolcall_result_str;
163+
for (const auto & r : res) {
164+
toolcall_result_str += ("\n" + r.data); // Although more complex results can be
165+
// returned (resources, images, etc.),
166+
// for now simply append the data. Later
167+
// on support for specific models may
168+
// allow for unpacking Base64 data.
169+
}
170+
auto toolcall_result_tok = common_tokenize(ctx, toolcall_result_str, false, true);
171+
embd_inp.insert(embd_inp.end(), toolcall_result_tok.begin(), toolcall_result_tok.end());
172172
}
173-
auto toolcall_result_tok = common_tokenize(ctx, toolcall_result_str, false, true);
174-
embd_inp.insert(embd_inp.end(), toolcall_result_tok.begin(), toolcall_result_tok.end());
173+
tool_was_called = true;
175174
}
176-
return true;
177175
}
178-
return false;
176+
return tool_was_called;
179177
}
180178
#endif
181179

@@ -923,13 +921,8 @@ int main(int argc, char ** argv) {
923921

924922
if (params.enable_chat_template) {
925923
chat_add_and_format("assistant", assistant_ss.str(), true);
926-
#ifdef LLAMA_USE_TOOLCALL
927-
is_interacting = ! call_tool(params, assistant_ss.str(), ctx, tc_client, embd_inp);
928-
LOG("\n");
929-
#else
930924
is_interacting = true;
931925
LOG("\n");
932-
#endif
933926
}
934927
}
935928
}
@@ -945,6 +938,16 @@ int main(int argc, char ** argv) {
945938
}
946939
}
947940

941+
#ifdef LLAMA_USE_TOOLCALL
942+
if ((tc_client && n_past > 0) && (waiting_for_first_input || is_interacting)) {
943+
size_t last_len = embd_inp.size();
944+
bool was_toolcall = call_tool(assistant_ss.str(), ctx, tc_client, embd_inp);
945+
if (was_toolcall && last_len < embd_inp.size()) {
946+
LOG("%s", common_token_to_piece(ctx, embd_inp[last_len]).c_str());
947+
}
948+
}
949+
#endif
950+
948951
if ((n_past > 0 || waiting_for_first_input) && is_interacting) {
949952
LOG_DBG("waiting for user input\n");
950953

@@ -1049,16 +1052,6 @@ int main(int argc, char ** argv) {
10491052
}
10501053
}
10511054

1052-
#ifdef LLAMA_USE_TOOLCALL
1053-
if (params.single_turn) {
1054-
size_t last_len = embd_inp.size();
1055-
bool was_toolcall = call_tool(params, assistant_ss.str(), ctx, tc_client, embd_inp);
1056-
if (was_toolcall && last_len < embd_inp.size()) {
1057-
LOG("%s", common_token_to_piece(ctx, embd_inp[last_len]).c_str());
1058-
}
1059-
}
1060-
#endif
1061-
10621055
// end of generation
10631056
if (!embd.empty() && llama_vocab_is_eog(vocab, embd.back()) && !(params.interactive)) {
10641057
LOG(" [end of text]\n");

0 commit comments

Comments
 (0)