Skip to content

Commit 5cc4193

Browse files
authored
[lldb-mcp] Fix servers accepting more than one client. (#158357)
This fixes an issue where the MCP server would stop the main loop after the first client disconnects. This moves the MainLoop out of the Server instance and lifts the server up into the ProtocolServerMCP object instead. This allows us to register the client with the main loop used to accept and process requests.
1 parent d46715a commit 5cc4193

File tree

7 files changed

+110
-82
lines changed

7 files changed

+110
-82
lines changed

lldb/include/lldb/Host/JSONTransport.h

Lines changed: 9 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -100,22 +100,21 @@ template <typename Req, typename Resp, typename Evt> class Transport {
100100
virtual llvm::Expected<MainLoop::ReadHandleUP>
101101
RegisterMessageHandler(MainLoop &loop, MessageHandler &handler) = 0;
102102

103-
// FIXME: Refactor mcp::Server to not directly access log on the transport.
104-
// protected:
103+
protected:
105104
template <typename... Ts> inline auto Logv(const char *Fmt, Ts &&...Vals) {
106105
Log(llvm::formatv(Fmt, std::forward<Ts>(Vals)...).str());
107106
}
108107
virtual void Log(llvm::StringRef message) = 0;
109108
};
110109

111-
/// A JSONTransport will encode and decode messages using JSON.
110+
/// An IOTransport sends and receives messages using an IOObject.
112111
template <typename Req, typename Resp, typename Evt>
113-
class JSONTransport : public Transport<Req, Resp, Evt> {
112+
class IOTransport : public Transport<Req, Resp, Evt> {
114113
public:
115114
using Transport<Req, Resp, Evt>::Transport;
116115
using MessageHandler = typename Transport<Req, Resp, Evt>::MessageHandler;
117116

118-
JSONTransport(lldb::IOObjectSP in, lldb::IOObjectSP out)
117+
IOTransport(lldb::IOObjectSP in, lldb::IOObjectSP out)
119118
: m_in(in), m_out(out) {}
120119

121120
llvm::Error Send(const Evt &evt) override { return Write(evt); }
@@ -127,7 +126,7 @@ class JSONTransport : public Transport<Req, Resp, Evt> {
127126
Status status;
128127
MainLoop::ReadHandleUP read_handle = loop.RegisterReadObject(
129128
m_in,
130-
std::bind(&JSONTransport::OnRead, this, std::placeholders::_1,
129+
std::bind(&IOTransport::OnRead, this, std::placeholders::_1,
131130
std::ref(handler)),
132131
status);
133132
if (status.Fail()) {
@@ -203,9 +202,9 @@ class JSONTransport : public Transport<Req, Resp, Evt> {
203202

204203
/// A transport class for JSON with a HTTP header.
205204
template <typename Req, typename Resp, typename Evt>
206-
class HTTPDelimitedJSONTransport : public JSONTransport<Req, Resp, Evt> {
205+
class HTTPDelimitedJSONTransport : public IOTransport<Req, Resp, Evt> {
207206
public:
208-
using JSONTransport<Req, Resp, Evt>::JSONTransport;
207+
using IOTransport<Req, Resp, Evt>::IOTransport;
209208

210209
protected:
211210
/// Encodes messages based on
@@ -270,9 +269,9 @@ class HTTPDelimitedJSONTransport : public JSONTransport<Req, Resp, Evt> {
270269

271270
/// A transport class for JSON RPC.
272271
template <typename Req, typename Resp, typename Evt>
273-
class JSONRPCTransport : public JSONTransport<Req, Resp, Evt> {
272+
class JSONRPCTransport : public IOTransport<Req, Resp, Evt> {
274273
public:
275-
using JSONTransport<Req, Resp, Evt>::JSONTransport;
274+
using IOTransport<Req, Resp, Evt>::IOTransport;
276275

277276
protected:
278277
std::string Encode(const llvm::json::Value &message) override {

lldb/include/lldb/Protocol/MCP/Server.h

Lines changed: 9 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -29,10 +29,11 @@
2929
namespace lldb_protocol::mcp {
3030

3131
class Server : public MCPTransport::MessageHandler {
32+
using ClosedCallback = llvm::unique_function<void()>;
33+
3234
public:
33-
Server(std::string name, std::string version,
34-
std::unique_ptr<MCPTransport> transport_up,
35-
lldb_private::MainLoop &loop);
35+
Server(std::string name, std::string version, MCPTransport &client,
36+
LogCallback log_callback = {}, ClosedCallback closed_callback = {});
3637
~Server() = default;
3738

3839
using NotificationHandler = std::function<void(const Notification &)>;
@@ -42,8 +43,6 @@ class Server : public MCPTransport::MessageHandler {
4243
void AddNotificationHandler(llvm::StringRef method,
4344
NotificationHandler handler);
4445

45-
llvm::Error Run();
46-
4746
protected:
4847
ServerCapabilities GetCapabilities();
4948

@@ -73,14 +72,16 @@ class Server : public MCPTransport::MessageHandler {
7372
void OnError(llvm::Error) override;
7473
void OnClosed() override;
7574

76-
void TerminateLoop();
75+
protected:
76+
void Log(llvm::StringRef);
7777

7878
private:
7979
const std::string m_name;
8080
const std::string m_version;
8181

82-
std::unique_ptr<MCPTransport> m_transport_up;
83-
lldb_private::MainLoop &m_loop;
82+
MCPTransport &m_client;
83+
LogCallback m_log_callback;
84+
ClosedCallback m_closed_callback;
8485

8586
llvm::StringMap<std::unique_ptr<Tool>> m_tools;
8687
std::vector<std::unique_ptr<ResourceProvider>> m_resource_providers;

lldb/source/Plugins/Protocol/MCP/ProtocolServerMCP.cpp

Lines changed: 17 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -66,24 +66,34 @@ void ProtocolServerMCP::Extend(lldb_protocol::mcp::Server &server) const {
6666

6767
void ProtocolServerMCP::AcceptCallback(std::unique_ptr<Socket> socket) {
6868
Log *log = GetLog(LLDBLog::Host);
69-
std::string client_name = llvm::formatv("client_{0}", m_instances.size() + 1);
69+
std::string client_name = llvm::formatv("client_{0}", ++m_client_count);
7070
LLDB_LOG(log, "New MCP client connected: {0}", client_name);
7171

7272
lldb::IOObjectSP io_sp = std::move(socket);
7373
auto transport_up = std::make_unique<lldb_protocol::mcp::Transport>(
7474
io_sp, io_sp, [client_name](llvm::StringRef message) {
7575
LLDB_LOG(GetLog(LLDBLog::Host), "{0}: {1}", client_name, message);
7676
});
77+
MCPTransport *transport_ptr = transport_up.get();
7778
auto instance_up = std::make_unique<lldb_protocol::mcp::Server>(
78-
std::string(kName), std::string(kVersion), std::move(transport_up),
79-
m_loop);
79+
std::string(kName), std::string(kVersion), *transport_up,
80+
/*log_callback=*/
81+
[client_name](llvm::StringRef message) {
82+
LLDB_LOG(GetLog(LLDBLog::Host), "{0} Server: {1}", client_name,
83+
message);
84+
},
85+
/*closed_callback=*/
86+
[this, transport_ptr]() { m_instances.erase(transport_ptr); });
8087
Extend(*instance_up);
81-
llvm::Error error = instance_up->Run();
82-
if (error) {
83-
LLDB_LOG_ERROR(log, std::move(error), "Failed to run MCP server: {0}");
88+
llvm::Expected<MainLoop::ReadHandleUP> handle =
89+
transport_up->RegisterMessageHandler(m_loop, *instance_up);
90+
if (!handle) {
91+
LLDB_LOG_ERROR(log, handle.takeError(), "Failed to run MCP server: {0}");
8492
return;
8593
}
86-
m_instances.push_back(std::move(instance_up));
94+
m_instances[transport_ptr] =
95+
std::make_tuple<ServerUP, ReadHandleUP, TransportUP>(
96+
std::move(instance_up), std::move(*handle), std::move(transport_up));
8797
}
8898

8999
llvm::Error ProtocolServerMCP::Start(ProtocolServer::Connection connection) {

lldb/source/Plugins/Protocol/MCP/ProtocolServerMCP.h

Lines changed: 14 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -12,13 +12,21 @@
1212
#include "lldb/Core/ProtocolServer.h"
1313
#include "lldb/Host/MainLoop.h"
1414
#include "lldb/Host/Socket.h"
15-
#include "lldb/Protocol/MCP/Protocol.h"
1615
#include "lldb/Protocol/MCP/Server.h"
16+
#include "lldb/Protocol/MCP/Transport.h"
17+
#include <map>
18+
#include <memory>
1719
#include <thread>
20+
#include <tuple>
21+
#include <vector>
1822

1923
namespace lldb_private::mcp {
2024

2125
class ProtocolServerMCP : public ProtocolServer {
26+
using ReadHandleUP = MainLoopBase::ReadHandleUP;
27+
using TransportUP = std::unique_ptr<lldb_protocol::mcp::MCPTransport>;
28+
using ServerUP = std::unique_ptr<lldb_protocol::mcp::Server>;
29+
2230
public:
2331
ProtocolServerMCP();
2432
virtual ~ProtocolServerMCP() override;
@@ -52,11 +60,14 @@ class ProtocolServerMCP : public ProtocolServer {
5260
lldb_private::MainLoop m_loop;
5361
std::thread m_loop_thread;
5462
std::mutex m_mutex;
63+
size_t m_client_count = 0;
5564

5665
std::unique_ptr<Socket> m_listener;
5766

58-
std::vector<MainLoopBase::ReadHandleUP> m_listen_handlers;
59-
std::vector<std::unique_ptr<lldb_protocol::mcp::Server>> m_instances;
67+
std::vector<ReadHandleUP> m_listen_handlers;
68+
std::map<lldb_protocol::mcp::MCPTransport *,
69+
std::tuple<ServerUP, ReadHandleUP, TransportUP>>
70+
m_instances;
6071
};
6172
} // namespace lldb_private::mcp
6273

lldb/source/Protocol/MCP/Server.cpp

Lines changed: 15 additions & 28 deletions
Original file line numberDiff line numberDiff line change
@@ -10,7 +10,6 @@
1010
#include "lldb/Host/File.h"
1111
#include "lldb/Host/FileSystem.h"
1212
#include "lldb/Host/HostInfo.h"
13-
#include "lldb/Host/JSONTransport.h"
1413
#include "lldb/Protocol/MCP/MCPError.h"
1514
#include "lldb/Protocol/MCP/Protocol.h"
1615
#include "llvm/ADT/SmallString.h"
@@ -109,11 +108,11 @@ Expected<std::vector<ServerInfo>> ServerInfo::Load() {
109108
return infos;
110109
}
111110

112-
Server::Server(std::string name, std::string version,
113-
std::unique_ptr<MCPTransport> transport_up,
114-
lldb_private::MainLoop &loop)
115-
: m_name(std::move(name)), m_version(std::move(version)),
116-
m_transport_up(std::move(transport_up)), m_loop(loop) {
111+
Server::Server(std::string name, std::string version, MCPTransport &client,
112+
LogCallback log_callback, ClosedCallback closed_callback)
113+
: m_name(std::move(name)), m_version(std::move(version)), m_client(client),
114+
m_log_callback(std::move(log_callback)),
115+
m_closed_callback(std::move(closed_callback)) {
117116
AddRequestHandlers();
118117
}
119118

@@ -287,22 +286,15 @@ ServerCapabilities Server::GetCapabilities() {
287286
return capabilities;
288287
}
289288

290-
llvm::Error Server::Run() {
291-
auto handle = m_transport_up->RegisterMessageHandler(m_loop, *this);
292-
if (!handle)
293-
return handle.takeError();
294-
295-
lldb_private::Status status = m_loop.Run();
296-
if (status.Fail())
297-
return status.takeError();
298-
299-
return llvm::Error::success();
289+
void Server::Log(llvm::StringRef message) {
290+
if (m_log_callback)
291+
m_log_callback(message);
300292
}
301293

302294
void Server::Received(const Request &request) {
303295
auto SendResponse = [this](const Response &response) {
304-
if (llvm::Error error = m_transport_up->Send(response))
305-
m_transport_up->Log(llvm::toString(std::move(error)));
296+
if (llvm::Error error = m_client.Send(response))
297+
Log(llvm::toString(std::move(error)));
306298
};
307299

308300
llvm::Expected<Response> response = Handle(request);
@@ -324,24 +316,19 @@ void Server::Received(const Request &request) {
324316
}
325317

326318
void Server::Received(const Response &response) {
327-
m_transport_up->Log("unexpected MCP message: response");
319+
Log("unexpected MCP message: response");
328320
}
329321

330322
void Server::Received(const Notification &notification) {
331323
Handle(notification);
332324
}
333325

334326
void Server::OnError(llvm::Error error) {
335-
m_transport_up->Log(llvm::toString(std::move(error)));
336-
TerminateLoop();
327+
Log(llvm::toString(std::move(error)));
337328
}
338329

339330
void Server::OnClosed() {
340-
m_transport_up->Log("EOF");
341-
TerminateLoop();
342-
}
343-
344-
void Server::TerminateLoop() {
345-
m_loop.AddPendingCallback(
346-
[](lldb_private::MainLoopBase &loop) { loop.RequestTermination(); });
331+
Log("EOF");
332+
if (m_closed_callback)
333+
m_closed_callback();
347334
}

lldb/unittests/Host/JSONTransportTest.cpp

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -413,7 +413,7 @@ TEST_F(JSONRPCTransportTest, ReadAcrossMultipleChunks) {
413413
// Use a string longer than the chunk size to ensure we split the message
414414
// across the chunk boundary.
415415
std::string long_str =
416-
std::string(JSONTransport<Req, Resp, Evt>::kReadBufferSize * 2, 'x');
416+
std::string(IOTransport<Req, Resp, Evt>::kReadBufferSize * 2, 'x');
417417
Write(Req{long_str});
418418
EXPECT_CALL(message_handler, Received(Req{long_str}));
419419
ASSERT_THAT_ERROR(Run(), Succeeded());

lldb/unittests/Protocol/ProtocolMCPServerTest.cpp

Lines changed: 45 additions & 25 deletions
Original file line numberDiff line numberDiff line change
@@ -122,53 +122,73 @@ class ProtocolServerMCPTest : public PipePairTest {
122122
public:
123123
SubsystemRAII<FileSystem, HostInfo, Socket> subsystems;
124124

125-
std::unique_ptr<lldb_protocol::mcp::Transport> transport_up;
126-
std::unique_ptr<TestServer> server_up;
127125
MainLoop loop;
126+
127+
std::unique_ptr<lldb_protocol::mcp::Transport> from_client;
128+
std::unique_ptr<lldb_protocol::mcp::Transport> to_client;
129+
MainLoopBase::ReadHandleUP handles[2];
130+
131+
std::unique_ptr<TestServer> server_up;
128132
MockMessageHandler<Request, Response, Notification> message_handler;
129133

130134
llvm::Error Write(llvm::StringRef message) {
131135
llvm::Expected<json::Value> value = json::parse(message);
132136
if (!value)
133137
return value.takeError();
134-
return transport_up->Write(*value);
138+
return from_client->Write(*value);
135139
}
136140

137-
llvm::Error Write(json::Value value) { return transport_up->Write(value); }
141+
llvm::Error Write(json::Value value) { return from_client->Write(value); }
138142

139143
/// Run the transport MainLoop and return any messages received.
140-
llvm::Error
141-
Run(std::chrono::milliseconds timeout = std::chrono::milliseconds(200)) {
144+
llvm::Error Run() {
142145
loop.AddCallback([](MainLoopBase &loop) { loop.RequestTermination(); },
143-
timeout);
144-
auto handle = transport_up->RegisterMessageHandler(loop, message_handler);
145-
if (!handle)
146-
return handle.takeError();
147-
148-
return server_up->Run();
146+
std::chrono::milliseconds(10));
147+
return loop.Run().takeError();
149148
}
150149

151150
void SetUp() override {
152151
PipePairTest::SetUp();
153152

154-
transport_up = std::make_unique<lldb_protocol::mcp::Transport>(
153+
from_client = std::make_unique<lldb_protocol::mcp::Transport>(
155154
std::make_shared<NativeFile>(input.GetReadFileDescriptor(),
156155
File::eOpenOptionReadOnly,
157156
NativeFile::Unowned),
158157
std::make_shared<NativeFile>(output.GetWriteFileDescriptor(),
159158
File::eOpenOptionWriteOnly,
160-
NativeFile::Unowned));
161-
162-
server_up = std::make_unique<TestServer>(
163-
"lldb-mcp", "0.1.0",
164-
std::make_unique<lldb_protocol::mcp::Transport>(
165-
std::make_shared<NativeFile>(output.GetReadFileDescriptor(),
166-
File::eOpenOptionReadOnly,
167-
NativeFile::Unowned),
168-
std::make_shared<NativeFile>(input.GetWriteFileDescriptor(),
169-
File::eOpenOptionWriteOnly,
170-
NativeFile::Unowned)),
171-
loop);
159+
NativeFile::Unowned),
160+
[](StringRef message) {
161+
// Uncomment for debugging
162+
// llvm::errs() << "from_client: " << message << '\n';
163+
});
164+
to_client = std::make_unique<lldb_protocol::mcp::Transport>(
165+
std::make_shared<NativeFile>(output.GetReadFileDescriptor(),
166+
File::eOpenOptionReadOnly,
167+
NativeFile::Unowned),
168+
std::make_shared<NativeFile>(input.GetWriteFileDescriptor(),
169+
File::eOpenOptionWriteOnly,
170+
NativeFile::Unowned),
171+
[](StringRef message) {
172+
// Uncomment for debugging
173+
// llvm::errs() << "to_client: " << message << '\n';
174+
});
175+
176+
server_up = std::make_unique<TestServer>("lldb-mcp", "0.1.0", *to_client,
177+
[](StringRef message) {
178+
// Uncomment for debugging
179+
// llvm::errs() << "server: " <<
180+
// message << '\n';
181+
});
182+
183+
auto maybe_from_client_handle =
184+
from_client->RegisterMessageHandler(loop, message_handler);
185+
EXPECT_THAT_EXPECTED(maybe_from_client_handle, Succeeded());
186+
handles[0] = std::move(*maybe_from_client_handle);
187+
188+
auto maybe_to_client_handle =
189+
to_client->RegisterMessageHandler(loop, *server_up);
190+
EXPECT_THAT_EXPECTED(maybe_to_client_handle, Succeeded());
191+
handles[1] = std::move(*maybe_to_client_handle);
172192
}
173193
};
174194

0 commit comments

Comments
 (0)