Skip to content

Commit 3e46978

Browse files
committed
Impl. initialize and tool_list routines
1 parent 5c0b0cb commit 3e46978

File tree

5 files changed

+156
-17
lines changed

5 files changed

+156
-17
lines changed

toolcall/handler.cpp

Lines changed: 108 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -1,9 +1,11 @@
11

22
#include <json.hpp>
33
#include "toolcall-handler.h"
4+
#include <chrono>
5+
#include <stdexcept>
46

57
#ifdef LLAMA_USE_CURL
6-
# include "mcp_sse_transport.h"
8+
# include "mcp_sse_transport.h"
79
#endif
810

911
#include "mcp_stdio_transport.h"
@@ -18,10 +20,12 @@ std::shared_ptr<toolcall::handler> toolcall::create_handler(const toolcall::para
1820
if (params) {
1921
if (params.has_uri()) {
2022
#ifdef LLAMA_USE_CURL
21-
handler.reset(new toolcall::handler(std::make_unique<toolcall::mcp_impl>(tools, choice)));
23+
handler.reset(new toolcall::handler(
24+
std::make_unique<toolcall::mcp_impl>(tools, choice)));
2225
#endif
2326
} else {
24-
handler.reset(new toolcall::handler(std::make_unique<toolcall::loopback_impl>(tools, choice)));
27+
handler.reset(new toolcall::handler(
28+
std::make_unique<toolcall::loopback_impl>(tools, choice)));
2529
}
2630
}
2731
return handler;
@@ -31,6 +35,10 @@ std::string toolcall::handler::tool_list() {
3135
return impl_->tool_list();
3236
}
3337

38+
bool toolcall::handler::tool_list_dirty() const {
39+
return impl_->tool_list_dirty();
40+
}
41+
3442
toolcall::action toolcall::handler::call(const std::string & request, std::string & response) {
3543
last_action_ = impl_->call(request, response);
3644
return last_action_;
@@ -39,20 +47,33 @@ toolcall::action toolcall::handler::call(const std::string & request, std::strin
3947
const std::string & toolcall::handler::tool_choice() const {
4048
return impl_->tool_choice();
4149
}
50+
4251
toolcall::action toolcall::handler::last_action() const {
4352
return last_action_;
4453
}
4554

55+
void toolcall::handler::initialize() {
56+
impl_->initialize();
57+
}
58+
4659
#ifdef LLAMA_USE_CURL
4760
toolcall::mcp_impl::mcp_impl(std::string server_uri, std::string tool_choice)
4861
: handler_impl(tool_choice),
49-
transport_(new mcp_sse_transport(server_uri))
62+
transport_(new mcp_sse_transport(server_uri)),
63+
tools_("[]"),
64+
tools_mutex_(),
65+
tools_populating_(),
66+
next_id_(1)
5067
{
51-
transport_->start();
5268
}
5369
#else
5470
toolcall::mcp_impl::mcp_impl(std::string /*server_uri*/, std::string tool_choice)
55-
: handler_impl(tool_choice)
71+
: handler_impl(tool_choice),
72+
transport_(nullptr),
73+
tools_("[]"),
74+
tools_mutex_(),
75+
tools_populating_(),
76+
next_id_(1)
5677
{
5778
}
5879
#endif
@@ -61,15 +82,94 @@ toolcall::mcp_impl::mcp_impl(std::vector<std::string> argv, std::string tool_cho
6182
: handler_impl(tool_choice),
6283
transport_(new mcp_stdio_transport(argv))
6384
{
85+
}
86+
87+
void toolcall::mcp_impl::initialize() {
88+
using on_response = toolcall::callback<mcp::initialize_response>;
89+
using on_list_changed = toolcall::callback<mcp::tools_list_changed_notification>;
90+
91+
if (transport_ == nullptr) return;
92+
std::unique_lock<std::mutex> lock(tools_mutex_);
93+
6494
transport_->start();
95+
96+
mcp::capabilities caps;
97+
on_response set_caps = [this, &caps] (const mcp::initialize_response & resp) {
98+
std::unique_lock<std::mutex> lock(tools_mutex_);
99+
caps = resp.capabilities();
100+
tools_populating_.notify_one();
101+
};
102+
103+
transport_->subscribe(set_caps);
104+
105+
mcp::initialize_request req(next_id_++);
106+
transport_->send(req.toJson());
107+
108+
tools_populating_.wait_for(lock, std::chrono::seconds(15));
109+
transport_->unsubscribe(set_caps);
110+
111+
on_list_changed update_dirty = [this] (const mcp::tools_list_changed_notification &) {
112+
tool_list_dirty_ = true;
113+
};
114+
115+
bool has_tools = false;
116+
for (const auto & cap : caps) {
117+
if (cap.name == "tools") {
118+
has_tools = true;
119+
if (cap.listChanged) {
120+
transport_->subscribe(update_dirty);
121+
}
122+
break;
123+
}
124+
}
125+
if (! has_tools) {
126+
throw std::runtime_error("MCP server does not support toolcalls!");
127+
}
128+
}
129+
130+
static std::string tools_list_to_oai_json(const mcp::tools_list & tools) {
131+
return "[]"; // TODO
65132
}
66133

67134
std::string toolcall::mcp_impl::tool_list() {
68-
// Construct tools/list call and send to transport
69-
return "[]";// TODO
135+
using on_response = toolcall::callback<mcp::tools_list_response>;
136+
137+
if (tool_list_dirty_) {
138+
std::unique_lock<std::mutex> lock(tools_mutex_);
139+
140+
mcp::tools_list tools;
141+
on_response set_tools = [this, &tools] (const mcp::tools_list_response & resp) {
142+
std::unique_lock<std::mutex> lock(tools_mutex_);
143+
144+
tools.insert(tools.end(), resp.tools().begin(), resp.tools().end());
145+
auto cursor = resp.next_cursor();
146+
if (! cursor.empty()) {
147+
mcp::tools_list_request req(std::to_string(next_id_++), cursor);
148+
transport_->send(req.toJson());
149+
return;
150+
}
151+
tool_list_dirty_ = false;
152+
lock.unlock();
153+
tools_populating_.notify_one();
154+
};
155+
156+
transport_->subscribe(set_tools);
157+
158+
mcp::tools_list_request req(std::to_string(next_id_++));
159+
transport_->send(req.toJson());
160+
161+
tools_populating_.wait_for(lock, std::chrono::seconds(15));
162+
transport_->unsubscribe(set_tools);
163+
164+
tools_ = tools_list_to_oai_json(tools);
165+
}
166+
return tools_;
70167
}
71168

72169
toolcall::action toolcall::mcp_impl::call(const std::string & /*request*/, std::string & /*response*/) {
170+
if (transport_ == nullptr) {
171+
return toolcall::DEFER;
172+
}
73173
// Construct tool call and send to transport
74174
return toolcall::ACCEPT; // TODO
75175
}

toolcall/mcp_messages.h

Lines changed: 10 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -118,7 +118,7 @@ namespace mcp
118118

119119
class initialize_request : public request {
120120
public:
121-
initialize_request(nlohmann::json id, mcp::capabilities caps);
121+
initialize_request(nlohmann::json id, mcp::capabilities caps = mcp::capabilities{});
122122

123123
const std::string & name() const { return ClientName; }
124124
const std::string & version() const { return ClientVersion; }
@@ -206,21 +206,28 @@ namespace mcp
206206
const tools_list & tools() const { return tools_; }
207207

208208
void next_cursor(std::string next_cursor);
209-
const std::string & next_cursor() { return next_cursor_; }
209+
const std::string & next_cursor() const { return next_cursor_; }
210210

211211
private:
212212
void refreshResult();
213213
tools_list tools_;
214214
std::string next_cursor_;
215215
};
216216

217+
class tools_list_changed_notification : public notification {
218+
public:
219+
tools_list_changed_notification()
220+
: notification("notifications/tools/list_changed") {}
221+
};
222+
217223
using message_variant =
218224
std::variant<std::monostate,
219225
initialize_request,
220226
initialize_response,
221227
initialized_notification,
222228
tools_list_request,
223-
tools_list_response>;
229+
tools_list_response,
230+
tools_list_changed_notification>;
224231

225232
bool create_message(const std::string & data, message_variant & message);
226233
}

toolcall/mcp_sse_transport.cpp

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,4 @@
11

2-
#include <iostream>
3-
#include <sstream>
42
#include "mcp_sse_transport.h"
53
#include <log.h>
64
#include <chrono>
@@ -73,6 +71,8 @@ static size_t sse_callback(char * data, size_t size, size_t nmemb, void * client
7371
}
7472

7573
void toolcall::mcp_sse_transport::parse_field_value(std::string field, std::string value) {
74+
LOG_DBG("SSE: field \"%s\"; value \"%s\"", field.c_str(), value.c_str());
75+
7676
if (field == "event") {
7777
// Set the event type buffer to field value.
7878
event_.type = std::move(value);

toolcall/mcp_transport.h

Lines changed: 11 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -20,6 +20,15 @@ namespace toolcall
2020
vec.push_back(std::move(callback));
2121
}
2222

23+
template <typename T>
24+
void unsubscribe(callback<T> callback) {
25+
auto& vec = std::get<std::vector<toolcall::callback<T>>>(subscribers_);
26+
auto found = std::find(vec.begin(), vec.end(), callback);
27+
if (found != vec.end()) {
28+
vec.erase(found);
29+
}
30+
}
31+
2332
template <typename T>
2433
void notify(const T & message) const {
2534
const auto& vec = std::get<std::vector<toolcall::callback<T>>>(subscribers_);
@@ -49,8 +58,8 @@ namespace toolcall
4958
mcp::initialize_response,
5059
mcp::initialized_notification,
5160
mcp::tools_list_request,
52-
mcp::tools_list_response>
53-
{
61+
mcp::tools_list_response,
62+
mcp::tools_list_changed_notification> {
5463
public:
5564
virtual ~mcp_transport() = default;
5665
virtual void start() = 0;

toolcall/toolcall-handler.h

Lines changed: 25 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -5,6 +5,8 @@
55
#include <variant>
66
#include <memory>
77
#include <vector>
8+
#include <condition_variable>
9+
#include <mutex>
810

911
namespace toolcall
1012
{
@@ -21,11 +23,16 @@ namespace toolcall
2123

2224
handler(std::unique_ptr<handler_impl> impl) : impl_(std::move(impl)) {}
2325

24-
std::string tool_list();
2526
action call(const std::string & request, std::string & response);
27+
28+
std::string tool_list();
29+
bool tool_list_dirty() const;
30+
2631
const std::string & tool_choice() const;
2732
action last_action() const;
2833

34+
void initialize();
35+
2936
private:
3037
std::unique_ptr<handler_impl> impl_;
3138
action last_action_;
@@ -36,16 +43,25 @@ namespace toolcall
3643
class handler_impl {
3744
public:
3845
handler_impl(std::string tool_choice)
39-
: tool_choice_(std::move(tool_choice)) {}
46+
: tool_choice_(std::move(tool_choice)), tool_list_dirty_(true) {}
4047

4148
virtual ~handler_impl() = default;
49+
4250
virtual std::string tool_list() = 0;
51+
52+
virtual bool tool_list_dirty() const {
53+
return tool_list_dirty_;
54+
}
55+
4356
virtual action call(const std::string & request, std::string & response) = 0;
4457

4558
const std::string & tool_choice() const { return tool_choice_; }
4659

60+
virtual void initialize() {}
61+
4762
protected:
4863
std::string tool_choice_;
64+
bool tool_list_dirty_;
4965
};
5066

5167
class loopback_impl : public handler_impl {
@@ -54,6 +70,7 @@ namespace toolcall
5470
: handler_impl(tool_choice), tools_(std::move(tools)) {}
5571

5672
virtual std::string tool_list() override {
73+
tool_list_dirty_ = false;
5774
return tools_;
5875
}
5976

@@ -75,7 +92,13 @@ namespace toolcall
7592
virtual std::string tool_list() override;
7693
virtual action call(const std::string & request, std::string & response) override;
7794

95+
virtual void initialize() override;
96+
7897
private:
7998
std::unique_ptr<mcp_transport> transport_;
99+
std::string tools_;
100+
std::mutex tools_mutex_;
101+
std::condition_variable tools_populating_;
102+
int next_id_;
80103
};
81104
}

0 commit comments

Comments
 (0)