Skip to content

Commit 850e043

Browse files
committed
Refactor tool/call response
1 parent ce5c46c commit 850e043

File tree

3 files changed

+49
-55
lines changed

3 files changed

+49
-55
lines changed

examples/main/main.cpp

Lines changed: 18 additions & 28 deletions
Original file line numberDiff line numberDiff line change
@@ -106,10 +106,26 @@ class chat_formatter {
106106
#endif
107107

108108
std::string operator () (const std::string & role, const std::string & content, [[maybe_unused]] bool use_toolcalls = false) {
109+
109110
common_chat_msg new_msg;
110111
new_msg.role = role;
111112
new_msg.content = content;
112113

114+
#ifdef LLAMA_USE_TOOLCALL
115+
if (params_.use_jinja && use_toolcalls) {
116+
if (tc_handler_ != nullptr) {
117+
if (nlohmann::json::accept(content)) { // Need a better way to know this is for a toolcall
118+
toolcall::result_set res = tc_handler_->call(content);
119+
std::string new_content;
120+
for (const auto & r : res) {
121+
new_content += (r.data + "\n");
122+
}
123+
new_msg.content = new_content; // TODO: this is not wiring correctly into the prompt
124+
}
125+
}
126+
}
127+
#endif
128+
113129
common_chat_params cparams;
114130
common_chat_templates_inputs cinputs;
115131
#ifdef LLAMA_USE_TOOLCALL
@@ -126,21 +142,8 @@ class chat_formatter {
126142
chat_msgs_.push_back(new_msg);
127143
LOG_DBG("formatted: '%s'\n", formatted.c_str());
128144

129-
#ifdef LLAMA_USE_TOOLCALL
130-
if (params_.use_jinja && use_toolcalls) {
131-
common_chat_grammar_to_sampler(&cparams, vocab_, &params_.sampling);
132-
if (tc_handler_ != nullptr) {
133-
if (nlohmann::json::accept(formatted)) { // May need a better way to ensure
134-
std::string response; // this is intended for a tool-call.
135-
tc_handler_->call(formatted, response);
136-
return std::string(response);
145+
common_chat_grammar_to_sampler(&cparams, vocab_, &params_.sampling);
137146

138-
} else {
139-
return formatted;
140-
}
141-
}
142-
}
143-
#endif
144147
return formatted;
145148
}
146149

@@ -855,22 +858,9 @@ int main(int argc, char ** argv) {
855858
}
856859

857860
if (params.enable_chat_template) {
858-
#ifdef LLAMA_USE_TOOLCALL
859-
auto output = chat_add_and_format("assistant", assistant_ss.str(), true);
860-
if (tc_handler == nullptr || tc_handler->last_action() != toolcall::ACCEPT) {
861-
is_interacting = true;
862-
LOG("\n");
863-
864-
} else {
865-
LOG_DBG("tokenizing toolcall response");
866-
auto response = common_tokenize(ctx, output, false, true);
867-
embd_inp.insert(embd_inp.end(), response.begin(), response.end());
868-
}
869-
#else
870-
chat_add_and_format("assistant", assistant_ss.str());
861+
chat_add_and_format("assistant", assistant_ss.str(), true);
871862
is_interacting = true;
872863
LOG("\n");
873-
#endif
874864
}
875865
}
876866
}

toolcall/handler.cpp

Lines changed: 16 additions & 15 deletions
Original file line numberDiff line numberDiff line change
@@ -5,7 +5,7 @@
55
#include <stdexcept>
66

77
#ifdef LLAMA_USE_CURL
8-
# include "mcp_sse_transport.h"
8+
# include "mcp_sse_transport.h"
99
#endif
1010

1111
#include "mcp_stdio_transport.h"
@@ -39,19 +39,14 @@ bool toolcall::handler::tool_list_dirty() const {
3939
return impl_->tool_list_dirty();
4040
}
4141

42-
toolcall::action toolcall::handler::call(const std::string & request, std::string & response) {
43-
last_action_ = impl_->call(request, response);
44-
return last_action_;
42+
toolcall::result_set toolcall::handler::call(const std::string & request) {
43+
return impl_->call(request);
4544
}
4645

4746
const std::string & toolcall::handler::tool_choice() const {
4847
return impl_->tool_choice();
4948
}
5049

51-
toolcall::action toolcall::handler::last_action() const {
52-
return last_action_;
53-
}
54-
5550
void toolcall::handler::initialize() {
5651
impl_->initialize();
5752
}
@@ -194,26 +189,32 @@ static mcp::tools_call_request tools_call_request_from_local_json(nlohmann::json
194189
return mcp::tools_call_request(id, j["name"], args);
195190
}
196191

197-
static std::string tools_call_response_to_local_json(const mcp::tools_call_response & resp) {
198-
return resp.toJson().dump(-1); // The AI will figure it out?
192+
static toolcall::result_set tools_call_response_to_result(const mcp::tools_call_response & resp) {
193+
toolcall::result_set result;
194+
for (const auto & res : resp.tool_result()) {
195+
result.push_back(toolcall::result{
196+
res.type, res.value, res.mime_type.value_or("text/plain"), res.uri, resp.tool_error()
197+
});
198+
}
199+
return std::move(result);
199200
}
200201

201-
toolcall::action toolcall::mcp_impl::call(const std::string & request, std::string & response) {
202+
toolcall::result_set toolcall::mcp_impl::call(const std::string & request) {
202203
using on_response = toolcall::callback<mcp::tools_call_response>;
203204

204205
if (transport_ == nullptr) {
205-
return toolcall::DEFER;
206+
return toolcall::result_set();
206207
}
207208
std::unique_lock<std::mutex> lock(tools_mutex_);
208209

209-
response.clear();
210+
toolcall::result_set response;
210211
on_response set_response = [this, &response] (const mcp::tools_call_response & resp) {
211212
std::unique_lock<std::mutex> lock(tools_mutex_);
212-
response = tools_call_response_to_local_json(resp);
213+
response = tools_call_response_to_result(resp);
213214
tools_populating_.notify_one();
214215
};
215216
transport_->send(tools_call_request_from_local_json(next_id_++, request), set_response);
216217
tools_populating_.wait_for(lock, std::chrono::seconds(15), [&response] { return ! response.empty(); });
217218

218-
return toolcall::ACCEPT;
219+
return response;
219220
}

toolcall/toolcall-handler.h

Lines changed: 15 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -10,32 +10,34 @@
1010

1111
namespace toolcall
1212
{
13-
enum action {
14-
ACCEPT,
15-
PENDING,
16-
DEFER
13+
struct result {
14+
std::string type;
15+
std::string data;
16+
std::string mime_type;
17+
std::optional<std::string> uri;
18+
bool error;
1719
};
1820

21+
using result_set = std::vector<result>;
22+
1923
class handler_impl;
2024
class handler {
2125
public:
2226
using ptr = std::shared_ptr<handler>;
2327

2428
handler(std::unique_ptr<handler_impl> impl) : impl_(std::move(impl)) {}
2529

26-
action call(const std::string & request, std::string & response);
30+
result_set call(const std::string & request);
2731

2832
std::string tool_list();
2933
bool tool_list_dirty() const;
3034

3135
const std::string & tool_choice() const;
32-
action last_action() const;
3336

3437
void initialize();
3538

3639
private:
3740
std::unique_ptr<handler_impl> impl_;
38-
action last_action_;
3941
};
4042

4143
std::shared_ptr<toolcall::handler> create_handler(const toolcall::params & params);
@@ -53,7 +55,7 @@ namespace toolcall
5355
return tool_list_dirty_;
5456
}
5557

56-
virtual action call(const std::string & request, std::string & response) = 0;
58+
virtual result_set call(const std::string & request) = 0;
5759

5860
const std::string & tool_choice() const { return tool_choice_; }
5961

@@ -74,9 +76,10 @@ namespace toolcall
7476
return tools_;
7577
}
7678

77-
virtual action call(const std::string & request, std::string & response) override {
78-
response = request;
79-
return toolcall::DEFER;
79+
virtual result_set call(const std::string & request) override {
80+
return {
81+
{"text", request, "text/plain", std::nullopt, false}
82+
};
8083
}
8184

8285
private:
@@ -90,7 +93,7 @@ namespace toolcall
9093
mcp_impl(std::vector<std::string> argv, std::string tool_choice);
9194

9295
virtual std::string tool_list() override;
93-
virtual action call(const std::string & request, std::string & response) override;
96+
virtual result_set call(const std::string & request) override;
9497

9598
virtual void initialize() override;
9699

0 commit comments

Comments
 (0)