Skip to content
Merged
Show file tree
Hide file tree
Changes from 3 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
19 changes: 9 additions & 10 deletions lldb/include/lldb/Host/JSONTransport.h
Original file line number Diff line number Diff line change
Expand Up @@ -100,22 +100,21 @@ template <typename Req, typename Resp, typename Evt> class Transport {
virtual llvm::Expected<MainLoop::ReadHandleUP>
RegisterMessageHandler(MainLoop &loop, MessageHandler &handler) = 0;

// FIXME: Refactor mcp::Server to not directly access log on the transport.
// protected:
protected:
template <typename... Ts> inline auto Logv(const char *Fmt, Ts &&...Vals) {
Log(llvm::formatv(Fmt, std::forward<Ts>(Vals)...).str());
}
virtual void Log(llvm::StringRef message) = 0;
};

/// A JSONTransport will encode and decode messages using JSON.
/// An IOTransport sends and receives messages using an IOObject.
template <typename Req, typename Resp, typename Evt>
class JSONTransport : public Transport<Req, Resp, Evt> {
class IOTransport : public Transport<Req, Resp, Evt> {
public:
using Transport<Req, Resp, Evt>::Transport;
using MessageHandler = typename Transport<Req, Resp, Evt>::MessageHandler;

JSONTransport(lldb::IOObjectSP in, lldb::IOObjectSP out)
IOTransport(lldb::IOObjectSP in, lldb::IOObjectSP out)
: m_in(in), m_out(out) {}

llvm::Error Send(const Evt &evt) override { return Write(evt); }
Expand All @@ -127,7 +126,7 @@ class JSONTransport : public Transport<Req, Resp, Evt> {
Status status;
MainLoop::ReadHandleUP read_handle = loop.RegisterReadObject(
m_in,
std::bind(&JSONTransport::OnRead, this, std::placeholders::_1,
std::bind(&IOTransport::OnRead, this, std::placeholders::_1,
std::ref(handler)),
status);
if (status.Fail()) {
Expand Down Expand Up @@ -203,9 +202,9 @@ class JSONTransport : public Transport<Req, Resp, Evt> {

/// A transport class for JSON with a HTTP header.
template <typename Req, typename Resp, typename Evt>
class HTTPDelimitedJSONTransport : public JSONTransport<Req, Resp, Evt> {
class HTTPDelimitedJSONTransport : public IOTransport<Req, Resp, Evt> {
public:
using JSONTransport<Req, Resp, Evt>::JSONTransport;
using IOTransport<Req, Resp, Evt>::IOTransport;

protected:
/// Encodes messages based on
Expand Down Expand Up @@ -270,9 +269,9 @@ class HTTPDelimitedJSONTransport : public JSONTransport<Req, Resp, Evt> {

/// A transport class for JSON RPC.
template <typename Req, typename Resp, typename Evt>
class JSONRPCTransport : public JSONTransport<Req, Resp, Evt> {
class JSONRPCTransport : public IOTransport<Req, Resp, Evt> {
public:
using JSONTransport<Req, Resp, Evt>::JSONTransport;
using IOTransport<Req, Resp, Evt>::IOTransport;

protected:
std::string Encode(const llvm::json::Value &message) override {
Expand Down
17 changes: 9 additions & 8 deletions lldb/include/lldb/Protocol/MCP/Server.h
Original file line number Diff line number Diff line change
Expand Up @@ -29,10 +29,11 @@
namespace lldb_protocol::mcp {

class Server : public MCPTransport::MessageHandler {
using ClosedCallback = llvm::unique_function<void()>;

public:
Server(std::string name, std::string version,
std::unique_ptr<MCPTransport> transport_up,
lldb_private::MainLoop &loop);
Server(std::string name, std::string version, MCPTransport &client,
LogCallback log_callback = {}, ClosedCallback closed_callback = {});
~Server() = default;

using NotificationHandler = std::function<void(const Notification &)>;
Expand All @@ -42,8 +43,6 @@ class Server : public MCPTransport::MessageHandler {
void AddNotificationHandler(llvm::StringRef method,
NotificationHandler handler);

llvm::Error Run();

protected:
ServerCapabilities GetCapabilities();

Expand Down Expand Up @@ -73,14 +72,16 @@ class Server : public MCPTransport::MessageHandler {
void OnError(llvm::Error) override;
void OnClosed() override;

void TerminateLoop();
protected:
void Log(llvm::StringRef);

private:
const std::string m_name;
const std::string m_version;

std::unique_ptr<MCPTransport> m_transport_up;
lldb_private::MainLoop &m_loop;
MCPTransport &m_client;
LogCallback m_log_callback;
ClosedCallback m_closed_callback;

llvm::StringMap<std::unique_ptr<Tool>> m_tools;
std::vector<std::unique_ptr<ResourceProvider>> m_resource_providers;
Expand Down
24 changes: 17 additions & 7 deletions lldb/source/Plugins/Protocol/MCP/ProtocolServerMCP.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -66,24 +66,34 @@ void ProtocolServerMCP::Extend(lldb_protocol::mcp::Server &server) const {

void ProtocolServerMCP::AcceptCallback(std::unique_ptr<Socket> socket) {
Log *log = GetLog(LLDBLog::Host);
std::string client_name = llvm::formatv("client_{0}", m_instances.size() + 1);
std::string client_name = llvm::formatv("client_{0}", ++m_client_count);
LLDB_LOG(log, "New MCP client connected: {0}", client_name);

lldb::IOObjectSP io_sp = std::move(socket);
auto transport_up = std::make_unique<lldb_protocol::mcp::Transport>(
io_sp, io_sp, [client_name](llvm::StringRef message) {
LLDB_LOG(GetLog(LLDBLog::Host), "{0}: {1}", client_name, message);
});
MCPTransport *transport_ptr = transport_up.get();
auto instance_up = std::make_unique<lldb_protocol::mcp::Server>(
std::string(kName), std::string(kVersion), std::move(transport_up),
m_loop);
std::string(kName), std::string(kVersion), *transport_up,
/*log_callback=*/
[client_name](llvm::StringRef message) {
LLDB_LOG(GetLog(LLDBLog::Host), "{0} Server: {1}", client_name,
message);
},
/*closed_callback=*/
[this, transport_ptr]() { m_instances.erase(transport_ptr); });
Extend(*instance_up);
llvm::Error error = instance_up->Run();
if (error) {
LLDB_LOG_ERROR(log, std::move(error), "Failed to run MCP server: {0}");
llvm::Expected<MainLoop::ReadHandleUP> handle =
transport_up->RegisterMessageHandler(m_loop, *instance_up);
if (!handle) {
LLDB_LOG_ERROR(log, handle.takeError(), "Failed to run MCP server: {0}");
return;
}
m_instances.push_back(std::move(instance_up));
m_instances[transport_ptr] =
std::make_tuple<ServerUP, ReadHandleUP, TransportUP>(
std::move(instance_up), std::move(*handle), std::move(transport_up));
}

llvm::Error ProtocolServerMCP::Start(ProtocolServer::Connection connection) {
Expand Down
17 changes: 14 additions & 3 deletions lldb/source/Plugins/Protocol/MCP/ProtocolServerMCP.h
Original file line number Diff line number Diff line change
Expand Up @@ -12,13 +12,21 @@
#include "lldb/Core/ProtocolServer.h"
#include "lldb/Host/MainLoop.h"
#include "lldb/Host/Socket.h"
#include "lldb/Protocol/MCP/Protocol.h"
#include "lldb/Protocol/MCP/Server.h"
#include "lldb/Protocol/MCP/Transport.h"
#include <map>
#include <memory>
#include <thread>
#include <tuple>
#include <vector>

namespace lldb_private::mcp {

class ProtocolServerMCP : public ProtocolServer {
using ReadHandleUP = MainLoopBase::ReadHandleUP;
using TransportUP = std::unique_ptr<lldb_protocol::mcp::MCPTransport>;
using ServerUP = std::unique_ptr<lldb_protocol::mcp::Server>;

public:
ProtocolServerMCP();
virtual ~ProtocolServerMCP() override;
Expand Down Expand Up @@ -52,11 +60,14 @@ class ProtocolServerMCP : public ProtocolServer {
lldb_private::MainLoop m_loop;
std::thread m_loop_thread;
std::mutex m_mutex;
uint32_t m_client_count = 0;

std::unique_ptr<Socket> m_listener;

std::vector<MainLoopBase::ReadHandleUP> m_listen_handlers;
std::vector<std::unique_ptr<lldb_protocol::mcp::Server>> m_instances;
std::vector<ReadHandleUP> m_listen_handlers;
std::map<lldb_protocol::mcp::MCPTransport *,
std::tuple<ServerUP, ReadHandleUP, TransportUP>>
m_instances;
};
} // namespace lldb_private::mcp

Expand Down
43 changes: 15 additions & 28 deletions lldb/source/Protocol/MCP/Server.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -10,7 +10,6 @@
#include "lldb/Host/File.h"
#include "lldb/Host/FileSystem.h"
#include "lldb/Host/HostInfo.h"
#include "lldb/Host/JSONTransport.h"
#include "lldb/Protocol/MCP/MCPError.h"
#include "lldb/Protocol/MCP/Protocol.h"
#include "llvm/ADT/SmallString.h"
Expand Down Expand Up @@ -109,11 +108,11 @@ Expected<std::vector<ServerInfo>> ServerInfo::Load() {
return infos;
}

Server::Server(std::string name, std::string version,
std::unique_ptr<MCPTransport> transport_up,
lldb_private::MainLoop &loop)
: m_name(std::move(name)), m_version(std::move(version)),
m_transport_up(std::move(transport_up)), m_loop(loop) {
Server::Server(std::string name, std::string version, MCPTransport &client,
LogCallback log_callback, ClosedCallback closed_callback)
: m_name(std::move(name)), m_version(std::move(version)), m_client(client),
m_log_callback(std::move(log_callback)),
m_closed_callback(std::move(closed_callback)) {
AddRequestHandlers();
}

Expand Down Expand Up @@ -287,22 +286,15 @@ ServerCapabilities Server::GetCapabilities() {
return capabilities;
}

llvm::Error Server::Run() {
auto handle = m_transport_up->RegisterMessageHandler(m_loop, *this);
if (!handle)
return handle.takeError();

lldb_private::Status status = m_loop.Run();
if (status.Fail())
return status.takeError();

return llvm::Error::success();
void Server::Log(llvm::StringRef message) {
if (m_log_callback)
m_log_callback(message);
}

void Server::Received(const Request &request) {
auto SendResponse = [this](const Response &response) {
if (llvm::Error error = m_transport_up->Send(response))
m_transport_up->Log(llvm::toString(std::move(error)));
if (llvm::Error error = m_client.Send(response))
Log(llvm::toString(std::move(error)));
};

llvm::Expected<Response> response = Handle(request);
Expand All @@ -324,24 +316,19 @@ void Server::Received(const Request &request) {
}

void Server::Received(const Response &response) {
m_transport_up->Log("unexpected MCP message: response");
Log("unexpected MCP message: response");
}

void Server::Received(const Notification &notification) {
Handle(notification);
}

void Server::OnError(llvm::Error error) {
m_transport_up->Log(llvm::toString(std::move(error)));
TerminateLoop();
Log(llvm::toString(std::move(error)));
}

void Server::OnClosed() {
m_transport_up->Log("EOF");
TerminateLoop();
}

void Server::TerminateLoop() {
m_loop.AddPendingCallback(
[](lldb_private::MainLoopBase &loop) { loop.RequestTermination(); });
Log("EOF");
if (m_closed_callback)
m_closed_callback();
}
2 changes: 1 addition & 1 deletion lldb/unittests/Host/JSONTransportTest.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -413,7 +413,7 @@ TEST_F(JSONRPCTransportTest, ReadAcrossMultipleChunks) {
// Use a string longer than the chunk size to ensure we split the message
// across the chunk boundary.
std::string long_str =
std::string(JSONTransport<Req, Resp, Evt>::kReadBufferSize * 2, 'x');
std::string(IOTransport<Req, Resp, Evt>::kReadBufferSize * 2, 'x');
Write(Req{long_str});
EXPECT_CALL(message_handler, Received(Req{long_str}));
ASSERT_THAT_ERROR(Run(), Succeeded());
Expand Down
70 changes: 45 additions & 25 deletions lldb/unittests/Protocol/ProtocolMCPServerTest.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -122,53 +122,73 @@ class ProtocolServerMCPTest : public PipePairTest {
public:
SubsystemRAII<FileSystem, HostInfo, Socket> subsystems;

std::unique_ptr<lldb_protocol::mcp::Transport> transport_up;
std::unique_ptr<TestServer> server_up;
MainLoop loop;

std::unique_ptr<lldb_protocol::mcp::Transport> from_client;
std::unique_ptr<lldb_protocol::mcp::Transport> to_client;
MainLoopBase::ReadHandleUP handles[2];

std::unique_ptr<TestServer> server_up;
MockMessageHandler<Request, Response, Notification> message_handler;

llvm::Error Write(llvm::StringRef message) {
llvm::Expected<json::Value> value = json::parse(message);
if (!value)
return value.takeError();
return transport_up->Write(*value);
return from_client->Write(*value);
}

llvm::Error Write(json::Value value) { return transport_up->Write(value); }
llvm::Error Write(json::Value value) { return from_client->Write(value); }

/// Run the transport MainLoop and return any messages received.
llvm::Error
Run(std::chrono::milliseconds timeout = std::chrono::milliseconds(200)) {
llvm::Error Run() {
loop.AddCallback([](MainLoopBase &loop) { loop.RequestTermination(); },
timeout);
auto handle = transport_up->RegisterMessageHandler(loop, message_handler);
if (!handle)
return handle.takeError();

return server_up->Run();
std::chrono::milliseconds(10));
return loop.Run().takeError();
}

void SetUp() override {
PipePairTest::SetUp();

transport_up = std::make_unique<lldb_protocol::mcp::Transport>(
from_client = std::make_unique<lldb_protocol::mcp::Transport>(
std::make_shared<NativeFile>(input.GetReadFileDescriptor(),
File::eOpenOptionReadOnly,
NativeFile::Unowned),
std::make_shared<NativeFile>(output.GetWriteFileDescriptor(),
File::eOpenOptionWriteOnly,
NativeFile::Unowned));

server_up = std::make_unique<TestServer>(
"lldb-mcp", "0.1.0",
std::make_unique<lldb_protocol::mcp::Transport>(
std::make_shared<NativeFile>(output.GetReadFileDescriptor(),
File::eOpenOptionReadOnly,
NativeFile::Unowned),
std::make_shared<NativeFile>(input.GetWriteFileDescriptor(),
File::eOpenOptionWriteOnly,
NativeFile::Unowned)),
loop);
NativeFile::Unowned),
[](StringRef message) {
// Uncomment for debugging
// llvm::errs() << "from_client: " << message << '\n';
});
to_client = std::make_unique<lldb_protocol::mcp::Transport>(
std::make_shared<NativeFile>(output.GetReadFileDescriptor(),
File::eOpenOptionReadOnly,
NativeFile::Unowned),
std::make_shared<NativeFile>(input.GetWriteFileDescriptor(),
File::eOpenOptionWriteOnly,
NativeFile::Unowned),
[](StringRef message) {
// Uncomment for debugging
// llvm::errs() << "to_client: " << message << '\n';
});

server_up = std::make_unique<TestServer>("lldb-mcp", "0.1.0", *to_client,
[](StringRef message) {
// Uncomment for debugging
// llvm::errs() << "server: " <<
// message << '\n';
});

auto maybe_from_client_handle =
from_client->RegisterMessageHandler(loop, message_handler);
EXPECT_THAT_EXPECTED(maybe_from_client_handle, Succeeded());
handles[0] = std::move(*maybe_from_client_handle);

auto maybe_to_client_handle =
to_client->RegisterMessageHandler(loop, *server_up);
EXPECT_THAT_EXPECTED(maybe_to_client_handle, Succeeded());
handles[1] = std::move(*maybe_to_client_handle);
}
};

Expand Down
Loading