Skip to content

Commit 606993d

Browse files
committed
Move transport message-dispatch to base type
1 parent 52b98d8 commit 606993d

File tree

5 files changed

+56
-76
lines changed

5 files changed

+56
-76
lines changed

toolcall/handler.cpp

Lines changed: 7 additions & 17 deletions
Original file line numberDiff line numberDiff line change
@@ -102,24 +102,20 @@ void toolcall::mcp_impl::initialize() {
102102
tools_populating_.notify_one();
103103
};
104104

105-
transport_->subscribe("set_caps", set_caps);
106-
107-
mcp::initialize_request req(next_id_++);
108-
transport_->send(req);
109-
105+
transport_->send(mcp::initialize_request(next_id_++), set_caps);
110106
tools_populating_.wait_for(lock, std::chrono::seconds(15), [&caps_received] { return caps_received; });
111-
transport_->unsubscribe<mcp::initialize_response>("set_caps");
112107

113-
on_list_changed update_dirty = [this] (const mcp::tools_list_changed_notification &) {
108+
on_list_changed update_dirty = [&update_dirty, this] (const mcp::tools_list_changed_notification &) {
114109
tool_list_dirty_ = true;
110+
transport_->subscribe("notifications/tools/list_changed", update_dirty);
115111
};
116112

117113
bool has_tools = false;
118114
for (const auto & cap : caps) {
119115
if (cap.name == "tools") {
120116
has_tools = true;
121117
if (cap.listChanged) {
122-
transport_->subscribe("update_dirty", update_dirty);
118+
transport_->subscribe("notifications/tools/list_changed", update_dirty);
123119
}
124120
break;
125121
}
@@ -167,28 +163,22 @@ std::string toolcall::mcp_impl::tool_list() {
167163
std::unique_lock<std::mutex> lock(tools_mutex_);
168164

169165
mcp::tools_list tools;
170-
on_response set_tools = [this, &tools] (const mcp::tools_list_response & resp) {
166+
on_response set_tools = [this, &tools, &set_tools] (const mcp::tools_list_response & resp) {
171167
std::unique_lock<std::mutex> lock(tools_mutex_);
172168

173169
tools.insert(tools.end(), resp.tools().begin(), resp.tools().end());
174170
auto cursor = resp.next_cursor();
175171
if (! cursor.empty()) {
176-
mcp::tools_list_request req(next_id_++, cursor);
177-
transport_->send(req);
172+
transport_->send(mcp::tools_list_request(next_id_++, cursor), set_tools);
178173
return;
179174
}
180175
tool_list_dirty_ = false;
181176
lock.unlock();
182177
tools_populating_.notify_one();
183178
};
184179

185-
transport_->subscribe("set_tools", set_tools);
186-
187-
mcp::tools_list_request req(next_id_++);
188-
transport_->send(req);
189-
180+
transport_->send(mcp::tools_list_request(next_id_++), set_tools);
190181
tools_populating_.wait_for(lock, std::chrono::seconds(15), [this] { return ! tool_list_dirty_; });
191-
transport_->unsubscribe<mcp::tools_list_response>("set_tools");
192182

193183
tools_ = tools_list_to_oai_json(tools);
194184
}

toolcall/mcp_messages.cpp

Lines changed: 9 additions & 24 deletions
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,7 @@
11
#include "mcp_messages.h"
22
#include <iostream>
33
#include <log.h>
4+
#include <stdexcept>
45

56
using json = nlohmann::json;
67

@@ -280,6 +281,14 @@ mcp::tools_list_response mcp::tools_list_response::fromJson(const nlohmann::json
280281
return tools_list_response(j["id"], std::move(tools), next_cursor);
281282
}
282283

284+
mcp::tools_list_changed_notification mcp::tools_list_changed_notification::fromJson(const nlohmann::json & j) {
285+
if (! (j.is_object() && j.contains("method") &&
286+
j["method"] == "notifications/tools/list_changed")) {
287+
throw std::invalid_argument("Invalid tools_list_changed message");
288+
}
289+
return tools_list_changed_notification();
290+
}
291+
283292
mcp::tools_call_request::tools_call_request(nlohmann::json id, std::string name, tool_arg_list args)
284293
: request(id, "tools/call"), name_(std::move(name)), args_(std::move(args))
285294
{
@@ -355,27 +364,3 @@ void mcp::tools_call_response::refreshResult() {
355364
result["content"] = content;
356365
this->result(std::move(result));
357366
}
358-
359-
static bool has_initialized_response(const nlohmann::json & data) {
360-
return data["result"].contains("capabilities");
361-
}
362-
363-
static bool has_tools_list_response(const nlohmann::json & data) {
364-
return data["result"].contains("tools");
365-
}
366-
367-
bool mcp::create_message(const std::string & data, mcp::message_variant & message) {
368-
json j = json::parse(data);
369-
370-
if (has_initialized_response(j)) {
371-
message = mcp::initialize_response::fromJson(j);
372-
373-
} else if (has_tools_list_response(j)) {
374-
message = mcp::tools_list_response::fromJson(j);
375-
376-
} else {
377-
message = std::monostate();
378-
return false;
379-
}
380-
return true;
381-
}

toolcall/mcp_messages.h

Lines changed: 2 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -220,6 +220,8 @@ namespace mcp
220220
public:
221221
tools_list_changed_notification()
222222
: notification("notifications/tools/list_changed") {}
223+
224+
static tools_list_changed_notification fromJson(const nlohmann::json & j);
223225
};
224226

225227
struct tool_arg {
@@ -270,16 +272,4 @@ namespace mcp
270272
bool error_;
271273
};
272274

273-
using message_variant =
274-
std::variant<std::monostate,
275-
initialize_request,
276-
initialize_response,
277-
initialized_notification,
278-
tools_list_request,
279-
tools_list_response,
280-
tools_list_changed_notification,
281-
tools_call_request,
282-
tools_call_response>;
283-
284-
bool create_message(const std::string & data, message_variant & message);
285275
}

toolcall/mcp_sse_transport.cpp

Lines changed: 6 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -162,10 +162,12 @@ void toolcall::mcp_sse_transport::on_endpoint_event() {
162162
}
163163

164164
void toolcall::mcp_sse_transport::on_message_event() {
165-
mcp::message_variant message;
166-
if (mcp::create_message(event_.data, message)) {
167-
notify_if<mcp::initialize_response>(message);
168-
notify_if<mcp::tools_list_response>(message);
165+
try {
166+
nlohmann::json message = nlohmann::json::parse(event_.data);
167+
notify(message);
168+
169+
} catch (const nlohmann::json::exception & err) {
170+
LOG_WRN("SSE: Invalid message \"%s\" received: \"%s\"\n", event_.data.c_str(), err.what());
169171
}
170172
}
171173

toolcall/mcp_transport.h

Lines changed: 32 additions & 19 deletions
Original file line numberDiff line numberDiff line change
@@ -31,39 +31,52 @@ namespace toolcall
3131
map.erase(key);
3232
}
3333

34-
template <typename T>
35-
void notify(const T & message) const {
36-
const auto& map =
37-
std::get<std::map<std::string, toolcall::callback<T>>>(
38-
subscribers_);
34+
void notify(const nlohmann::json & message) {
35+
std::string key;
36+
if (message.contains("id")) {
37+
key = message["id"].dump();
3938

40-
for (const auto & pair : map) {
41-
pair.second(message);
42-
}
43-
}
39+
} else if (message.contains("method")) {
40+
key = message["method"].dump();
4441

45-
template <typename T>
46-
void notify_if(const mcp::message_variant & message) {
47-
if (std::holds_alternative<T>(message)) {
48-
notify(std::get<T>(message));
42+
} else {
43+
return;
4944
}
45+
std::apply([&key, &message, this](auto&... maps) {
46+
(..., [&] {
47+
auto it = maps.find(key);
48+
if (it != maps.end()) {
49+
using callback_type = decltype(it->second);
50+
using T = typename std::decay<typename callback_type::argument_type>::type;
51+
52+
it->second(T::fromJson(message));
53+
maps.erase(it);
54+
}
55+
}());
56+
}, subscribers_);
5057
}
5158

5259
private:
5360
std::tuple<std::map<std::string, toolcall::callback<MessageTypes>>...> subscribers_;
5461
};
5562

56-
class mcp_transport : public mcp_message_observer<mcp::initialize_request,
57-
mcp::initialize_response,
58-
mcp::initialized_notification,
59-
mcp::tools_list_request,
63+
class mcp_transport : public mcp_message_observer<mcp::initialize_response,
6064
mcp::tools_list_response,
6165
mcp::tools_list_changed_notification> {
6266
public:
6367
virtual ~mcp_transport() = default;
6468

65-
template <typename T>
66-
bool send(const T & message) {
69+
template <typename Req, typename Resp>
70+
bool send(const Req & message, callback<Resp> on_response) {
71+
if (message.id().has_value()) {
72+
std::string id = message.id().value().dump();
73+
subscribe(id, on_response);
74+
}
75+
return send(message);
76+
}
77+
78+
template <typename Req>
79+
bool send(const Req & message) {
6780
nlohmann::json json = message.toJson();
6881
return send(json.dump(-1));
6982
}

0 commit comments

Comments
 (0)