Skip to content

Commit 66eff76

Browse files
committed
Move tool-call invocation into main loop
1 parent 0b52627 commit 66eff76

File tree

1 file changed

+21
-16
lines changed

1 file changed

+21
-16
lines changed

examples/main/main.cpp

Lines changed: 21 additions & 16 deletions
Original file line numberDiff line numberDiff line change
@@ -111,21 +111,6 @@ class chat_formatter {
111111
new_msg.role = role;
112112
new_msg.content = content;
113113

114-
#ifdef LLAMA_USE_TOOLCALL
115-
if (params_.use_jinja && use_toolcalls) {
116-
if (tc_handler_ != nullptr) {
117-
if (nlohmann::json::accept(content)) { // Need a better way to know this is for a toolcall
118-
toolcall::result_set res = tc_handler_->call(content);
119-
std::string new_content;
120-
for (const auto & r : res) {
121-
new_content += (r.data + "\n");
122-
}
123-
new_msg.content = new_content; // TODO: this is not wiring correctly into the prompt
124-
}
125-
}
126-
}
127-
#endif
128-
129114
common_chat_params cparams;
130115
common_chat_templates_inputs cinputs;
131116
#ifdef LLAMA_USE_TOOLCALL
@@ -142,8 +127,9 @@ class chat_formatter {
142127
chat_msgs_.push_back(new_msg);
143128
LOG_DBG("formatted: '%s'\n", formatted.c_str());
144129

130+
#ifdef LLAMA_USE_TOOLCALL
145131
common_chat_grammar_to_sampler(&cparams, vocab_, &params_.sampling);
146-
132+
#endif
147133
return formatted;
148134
}
149135

@@ -859,8 +845,27 @@ int main(int argc, char ** argv) {
859845

860846
if (params.enable_chat_template) {
861847
chat_add_and_format("assistant", assistant_ss.str(), true);
848+
#ifdef LLAMA_USE_TOOLCALL
849+
if (! params.use_jinja || tc_handler == nullptr || ! nlohmann::json::accept(assistant_ss.str())) {
850+
is_interacting = true;
851+
LOG("\n");
852+
853+
} else {
854+
toolcall::result_set res = tc_handler->call(assistant_ss.str());
855+
if (! res.empty()) {
856+
std::string toolcall_result_str;
857+
for (const auto & r : res) {
858+
toolcall_result_str += (r.data + "\n");
859+
}
860+
auto toolcall_result_tok = common_tokenize(ctx, toolcall_result_str, false, true);
861+
embd_inp.insert(embd_inp.end(), toolcall_result_tok.begin(), toolcall_result_tok.end());
862+
}
863+
}
864+
#else
865+
862866
is_interacting = true;
863867
LOG("\n");
868+
#endif
864869
}
865870
}
866871
}

0 commit comments

Comments
 (0)