|
1 | 1 |
|
| 2 | +#include <iostream> |
| 3 | +#include <sstream> |
2 | 4 | #include "mcp_sse_transport.hpp" |
| 5 | +#include <log.h> |
| 6 | +#include <chrono> |
| 7 | + |
| 8 | +toolcall::mcp_sse_transport::~mcp_sse_transport() { |
| 9 | + if (endpoint_headers_) { |
| 10 | + curl_slist_free_all(endpoint_headers_); |
| 11 | + } |
| 12 | + if (endpoint_) { |
| 13 | + curl_easy_cleanup(endpoint_); |
| 14 | + } |
| 15 | +} |
3 | 16 |
|
4 | 17 | toolcall::mcp_sse_transport::mcp_sse_transport(std::string server_uri) |
5 | | - : server_uri_(std::move(server_uri)) |
| 18 | + : server_uri_(std::move(server_uri)), |
| 19 | + running_(false), |
| 20 | + sse_thread_(), |
| 21 | + endpoint_(nullptr), |
| 22 | + endpoint_headers_(nullptr), |
| 23 | + endpoint_errbuf_(CURL_ERROR_SIZE), |
| 24 | + event_{"", "", ""}, |
| 25 | + sse_buffer_(""), |
| 26 | + sse_cursor_(0), |
| 27 | + sse_last_id_(""), |
| 28 | + initializing_mutex_(), |
| 29 | + initializing_() |
6 | 30 | { |
| 31 | + curl_global_init(CURL_GLOBAL_DEFAULT); |
7 | 32 | } |
8 | 33 |
|
9 | 34 | void toolcall::mcp_sse_transport::start() { |
| 35 | + if (running_) return; |
| 36 | + running_ = true; |
| 37 | + |
| 38 | + std::unique_lock<std::mutex> lock(initializing_mutex_); |
| 39 | + sse_thread_ = std::thread(&toolcall::mcp_sse_transport::sse_run, this); |
| 40 | + initializing_.wait(lock); |
| 41 | + |
| 42 | + if (endpoint_ == nullptr) { |
| 43 | + running_ = false; |
| 44 | + LOG_ERR("SSE: Connection to \"%s\" failed", server_uri_.c_str()); |
| 45 | + throw std::runtime_error("Connection to \"" + server_uri_ + "\" failed"); |
| 46 | + } |
10 | 47 | } |
11 | 48 |
|
12 | 49 | void toolcall::mcp_sse_transport::stop() { |
| 50 | + running_ = false; |
| 51 | +} |
| 52 | + |
| 53 | +bool toolcall::mcp_sse_transport::send(const std::string & request_json) { |
| 54 | + if (! running_ || endpoint_ == nullptr) { |
| 55 | + return false; |
| 56 | + } |
| 57 | + |
| 58 | + curl_easy_setopt(endpoint_, CURLOPT_POSTFIELDS, request_json.c_str()); |
| 59 | + |
| 60 | + CURLcode code = curl_easy_perform(endpoint_); |
| 61 | + if (code != CURLE_OK) { |
| 62 | + size_t len = strlen(&endpoint_errbuf_[0]); |
| 63 | + LOG_ERR("%s", (len > 0 ? &endpoint_errbuf_[0] : curl_easy_strerror(code))); |
| 64 | + return false; |
| 65 | + } |
| 66 | + return true; |
| 67 | +} |
| 68 | + |
| 69 | +static size_t sse_callback(char * data, size_t size, size_t nmemb, void * clientp) { |
| 70 | + auto transport = static_cast<toolcall::mcp_sse_transport*>(clientp); |
| 71 | + size_t len = size * nmemb; |
| 72 | + return transport->sse_read(data, len); |
13 | 73 | } |
14 | 74 |
|
15 | | -bool toolcall::mcp_sse_transport::send(const mcp::message_variant & /*request*/) { |
16 | | - return false; |
| 75 | +void toolcall::mcp_sse_transport::parse_field_value(std::string field, std::string value) { |
| 76 | + if (field == "event") { |
| 77 | + // Set the event type buffer to field value. |
| 78 | + event_.type = std::move(value); |
| 79 | + |
| 80 | + } else if (field == "data") { |
| 81 | + // Append the field value to the data buffer, |
| 82 | + // then append a single U+000A LINE FEED (LF) |
| 83 | + // character to the data buffer. |
| 84 | + value += '\n'; |
| 85 | + event_.data.insert(event_.data.end(), value.begin(), value.end()); |
| 86 | + |
| 87 | + } else if (field == "id") { |
| 88 | + // If the field value does not contain U+0000 NULL, |
| 89 | + // then set the last event ID buffer to the field value. |
| 90 | + // Otherwise, ignore the field. |
| 91 | + if (! value.empty()) { |
| 92 | + event_.id = std::move(value); |
| 93 | + } |
| 94 | + |
| 95 | + } else if (field == "retry") { |
| 96 | + // If the field value consists of only ASCII digits, |
| 97 | + // then interpret the field value as an integer in base |
| 98 | + // ten, and set the event stream's reconnection time to |
| 99 | + // that integer. Otherwise, ignore the field. |
| 100 | + |
| 101 | + LOG_INF("SSE: Retry field is not currently implemented"); |
| 102 | + |
| 103 | + } else { |
| 104 | + LOG_WRN("SSE: Unsupported field \"%s\" received", field.c_str()); |
| 105 | + } |
| 106 | +} |
| 107 | + |
| 108 | +void toolcall::mcp_sse_transport::on_endpoint_event() { |
| 109 | + endpoint_ = curl_easy_init(); |
| 110 | + if (! endpoint_) { |
| 111 | + LOG_ERR("SSE: Failed to create endpoint handle"); |
| 112 | + running_ = false; |
| 113 | + return; |
| 114 | + } |
| 115 | + |
| 116 | + curl_easy_setopt(endpoint_, CURLOPT_URL, event_.data.c_str()); |
| 117 | + |
| 118 | + endpoint_headers_ = |
| 119 | + curl_slist_append(endpoint_headers_, "Content-Type: application/json"); |
| 120 | + curl_slist_append(endpoint_headers_, "Connection: keep-alive"); |
| 121 | + curl_easy_setopt(endpoint_, CURLOPT_HTTPHEADER, endpoint_headers_); |
| 122 | + curl_easy_setopt(endpoint_, CURLOPT_ERRORBUFFER, &endpoint_errbuf_[0]); |
| 123 | + |
| 124 | + // Later calls to send will reuse the endpoint_ handle |
| 125 | +} |
| 126 | + |
| 127 | +void toolcall::mcp_sse_transport::on_message_event() { |
| 128 | + mcp::message_variant message; |
| 129 | + if (mcp::create_message(event_.data, message)) { |
| 130 | + notify_if<mcp::initialize_response>(message); |
| 131 | + notify_if<mcp::tools_list_response>(message); |
| 132 | + } |
| 133 | +} |
| 134 | + |
| 135 | +size_t toolcall::mcp_sse_transport::sse_read(const char * data, size_t len) { |
| 136 | + sse_buffer_.insert(sse_buffer_.end(), data, data + len); |
| 137 | + |
| 138 | + for (; sse_cursor_ < sse_buffer_.length(); ++sse_cursor_) { |
| 139 | + if (sse_buffer_[sse_cursor_] == '\r' || sse_buffer_[sse_cursor_] == '\n') { |
| 140 | + auto last = sse_buffer_.begin() + sse_cursor_; |
| 141 | + |
| 142 | + std::string line(sse_buffer_.begin(), last); |
| 143 | + if (line.empty()) { // Dispatch event |
| 144 | + if (event_.type == "endpoint") { |
| 145 | + on_endpoint_event(); |
| 146 | + |
| 147 | + } else if(event_.type == "message") { |
| 148 | + on_message_event(); |
| 149 | + |
| 150 | + } else { |
| 151 | + LOG_WRN("SSE: Unsupported event \"%s\" received", event_.type.c_str()); |
| 152 | + } |
| 153 | + |
| 154 | + sse_last_id_ = event_.id; |
| 155 | + event_ = {"", "", ""}; |
| 156 | + |
| 157 | + } else if(line[0] != ':') { // : denotes a comment |
| 158 | + // Set field/value |
| 159 | + auto sep_index = line.find(':'); |
| 160 | + if (sep_index != std::string::npos) { |
| 161 | + auto sep_i = line.begin() + sep_index; |
| 162 | + |
| 163 | + std::string field (line.begin(), sep_i); |
| 164 | + std::string value (sep_i + 1, line.end()); |
| 165 | + |
| 166 | + parse_field_value(std::move(field), std::move(value)); |
| 167 | + } |
| 168 | + } |
| 169 | + |
| 170 | + if (last++ != sse_buffer_.end()) { // Consume line-end |
| 171 | + if (*last == '\n') { |
| 172 | + last ++; // In the CRLF case consume one more |
| 173 | + } |
| 174 | + sse_buffer_ = std::string(last, sse_buffer_.end()); |
| 175 | + |
| 176 | + } else { |
| 177 | + sse_buffer_.clear(); |
| 178 | + } |
| 179 | + sse_cursor_ = 0; // Prepare to scan for next line-end |
| 180 | + } |
| 181 | + } |
| 182 | + return len; |
| 183 | +} |
| 184 | + |
| 185 | +void toolcall::mcp_sse_transport::sse_run() { |
| 186 | + std::unique_lock<std::mutex> lock(initializing_mutex_); |
| 187 | + char errbuf[CURL_ERROR_SIZE]; |
| 188 | + size_t errlen; |
| 189 | + CURLMcode mcode; |
| 190 | + int num_handles; |
| 191 | + struct CURLMsg * m; |
| 192 | + int msgs_in_queue = 0; |
| 193 | + CURLM * async_handle = nullptr; |
| 194 | + struct curl_slist * headers = nullptr; |
| 195 | + CURL * sse = nullptr; |
| 196 | + |
| 197 | + sse = curl_easy_init(); |
| 198 | + if (! sse) { |
| 199 | + LOG_ERR("SSE: Failed to initialize handle"); |
| 200 | + goto cleanup; |
| 201 | + } |
| 202 | + |
| 203 | + headers = curl_slist_append(headers, "Connection: keep-alive"); |
| 204 | + |
| 205 | + curl_easy_setopt(sse, CURLOPT_HTTPHEADER, headers); |
| 206 | + curl_easy_setopt(sse, CURLOPT_ERRORBUFFER, errbuf); |
| 207 | + curl_easy_setopt(sse, CURLOPT_URL, server_uri_.c_str()); |
| 208 | + curl_easy_setopt(sse, CURLOPT_TCP_KEEPALIVE, 1L); |
| 209 | + curl_easy_setopt(sse, CURLOPT_WRITEFUNCTION, sse_callback); |
| 210 | + curl_easy_setopt(sse, CURLOPT_WRITEDATA, this); |
| 211 | + |
| 212 | + async_handle = curl_multi_init(); |
| 213 | + if (! async_handle) { |
| 214 | + LOG_ERR("SSE: Failed to initialize async handle"); |
| 215 | + goto cleanup; |
| 216 | + } |
| 217 | + curl_multi_add_handle(async_handle, sse); |
| 218 | + |
| 219 | + do { |
| 220 | + std::this_thread::sleep_for(std::chrono::milliseconds(50)); |
| 221 | + |
| 222 | + mcode = curl_multi_perform(async_handle, &num_handles); |
| 223 | + if (mcode != CURLM_OK) { |
| 224 | + LOG_ERR("SSE: %s", curl_multi_strerror(mcode)); |
| 225 | + break; |
| 226 | + } |
| 227 | + while ((m = curl_multi_info_read(async_handle, &msgs_in_queue)) != nullptr) { |
| 228 | + if (m->msg == CURLMSG_DONE) { |
| 229 | + if (m->data.result != CURLE_OK) { |
| 230 | + errlen = strlen(errbuf); |
| 231 | + if (errlen) { |
| 232 | + LOG_ERR("SSE: %s", errbuf); |
| 233 | + |
| 234 | + } else { |
| 235 | + LOG_ERR("SSE: %s", curl_easy_strerror(m->data.result)); |
| 236 | + } |
| 237 | + running_ = false; |
| 238 | + break; |
| 239 | + } |
| 240 | + } |
| 241 | + } |
| 242 | + if (endpoint_ && lock.owns_lock()) { // TODO: timeout if endpoint not received |
| 243 | + lock.unlock(); |
| 244 | + initializing_.notify_one(); |
| 245 | + } |
| 246 | + |
| 247 | + } while (running_); |
| 248 | + |
| 249 | + cleanup: |
| 250 | + if (headers) { |
| 251 | + curl_slist_free_all(headers); |
| 252 | + } |
| 253 | + if (async_handle) { |
| 254 | + curl_multi_remove_handle(async_handle, sse); |
| 255 | + curl_multi_cleanup(async_handle); |
| 256 | + } |
| 257 | + if (sse) { |
| 258 | + curl_easy_cleanup(sse); |
| 259 | + } |
17 | 260 | } |
0 commit comments