Skip to content
570 changes: 542 additions & 28 deletions lldb/include/lldb/Host/JSONTransport.h

Large diffs are not rendered by default.

5 changes: 1 addition & 4 deletions lldb/include/lldb/Protocol/MCP/MCPError.h
Original file line number Diff line number Diff line change
Expand Up @@ -9,7 +9,6 @@
#ifndef LLDB_PROTOCOL_MCP_MCPERROR_H
#define LLDB_PROTOCOL_MCP_MCPERROR_H

#include "lldb/Protocol/MCP/Protocol.h"
#include "llvm/Support/Error.h"
#include <string>

Expand All @@ -26,14 +25,12 @@ class MCPError : public llvm::ErrorInfo<MCPError> {

const std::string &getMessage() const { return m_message; }

lldb_protocol::mcp::Error toProtocolError() const;

static constexpr int64_t kResourceNotFound = -32002;
static constexpr int64_t kInternalError = -32603;

private:
std::string m_message;
int64_t m_error_code;
int m_error_code;
};

class UnsupportedURI : public llvm::ErrorInfo<UnsupportedURI> {
Expand Down
12 changes: 12 additions & 0 deletions lldb/include/lldb/Protocol/MCP/Protocol.h
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,7 @@
#ifndef LLDB_PROTOCOL_MCP_PROTOCOL_H
#define LLDB_PROTOCOL_MCP_PROTOCOL_H

#include "llvm/ADT/StringRef.h"
#include "llvm/Support/JSON.h"
#include <optional>
#include <string>
Expand Down Expand Up @@ -322,6 +323,17 @@ struct CallToolResult {
llvm::json::Value toJSON(const CallToolResult &);
bool fromJSON(const llvm::json::Value &, CallToolResult &, llvm::json::Path);

lldb_protocol::mcp::Request
MakeRequest(int64_t id, llvm::StringRef method,
std::optional<llvm::json::Value> params);

} // namespace lldb_protocol::mcp

// namespace llvm::json {
// inline Value toJSON(const lldb_protocol::mcp::Void &) { return Object(); }
// inline bool fromJSON(const Value &, lldb_protocol::mcp::Void &, Path) {
// return true;
// }
// } // namespace llvm::json

#endif
72 changes: 31 additions & 41 deletions lldb/include/lldb/Protocol/MCP/Server.h
Original file line number Diff line number Diff line change
Expand Up @@ -9,7 +9,6 @@
#ifndef LLDB_PROTOCOL_MCP_SERVER_H
#define LLDB_PROTOCOL_MCP_SERVER_H

#include "lldb/Host/JSONTransport.h"
#include "lldb/Host/MainLoop.h"
#include "lldb/Protocol/MCP/Protocol.h"
#include "lldb/Protocol/MCP/Resource.h"
Expand All @@ -19,75 +18,66 @@
#include "llvm/ADT/StringMap.h"
#include "llvm/ADT/StringRef.h"
#include "llvm/Support/Error.h"
#include "llvm/Support/FormatVariadic.h"
#include "llvm/Support/JSON.h"
#include "llvm/Support/Signals.h"
#include <functional>
#include <memory>
#include <string>
#include <vector>

namespace lldb_protocol::mcp {

class Server : public MCPTransport::MessageHandler {
using ClosedCallback = llvm::unique_function<void()>;
class Server {

using MCPTransportUP = std::unique_ptr<lldb_protocol::mcp::MCPTransport>;

using ReadHandleUP = lldb_private::MainLoop::ReadHandleUP;

public:
Server(std::string name, std::string version, MCPTransport &client,
LogCallback log_callback = {}, ClosedCallback closed_callback = {});
Server(std::string name, std::string version, LogCallback log_callback = {});
~Server() = default;

using NotificationHandler = std::function<void(const Notification &)>;

void AddTool(std::unique_ptr<Tool> tool);
void AddResourceProvider(std::unique_ptr<ResourceProvider> resource_provider);
void AddNotificationHandler(llvm::StringRef method,
NotificationHandler handler);

protected:
ServerCapabilities GetCapabilities();

using RequestHandler =
std::function<llvm::Expected<Response>(const Request &)>;

void AddRequestHandlers();
llvm::Error Accept(lldb_private::MainLoop &, MCPTransportUP);

void AddRequestHandler(llvm::StringRef method, RequestHandler handler);

llvm::Expected<std::optional<Message>> HandleData(llvm::StringRef data);

llvm::Expected<Response> Handle(const Request &request);
void Handle(const Notification &notification);
protected:
MCPBinderUP Bind(MCPTransport &);

llvm::Expected<Response> InitializeHandler(const Request &);
ServerCapabilities GetCapabilities();

llvm::Expected<Response> ToolsListHandler(const Request &);
llvm::Expected<Response> ToolsCallHandler(const Request &);
llvm::Expected<InitializeResult> InitializeHandler(const InitializeParams &);

llvm::Expected<Response> ResourcesListHandler(const Request &);
llvm::Expected<Response> ResourcesReadHandler(const Request &);
llvm::Expected<ListToolsResult> ToolsListHandler();
llvm::Expected<CallToolResult> ToolsCallHandler(const CallToolParams &);

void Received(const Request &) override;
void Received(const Response &) override;
void Received(const Notification &) override;
void OnError(llvm::Error) override;
void OnClosed() override;
llvm::Expected<ListResourcesResult> ResourcesListHandler();
llvm::Expected<ReadResourceResult>
ResourcesReadHandler(const ReadResourceParams &);

protected:
void Log(llvm::StringRef);
template <typename... Ts> inline auto Logv(const char *Fmt, Ts &&...Vals) {
Log(llvm::formatv(Fmt, std::forward<Ts>(Vals)...).str());
}
void Log(llvm::StringRef message) {
if (m_log_callback)
m_log_callback(message);
}

private:
const std::string m_name;
const std::string m_version;

MCPTransport &m_client;
LogCallback m_log_callback;
ClosedCallback m_closed_callback;
struct Client {
ReadHandleUP handle;
MCPTransportUP transport;
MCPBinderUP binder;
};
std::map<MCPTransport *, Client> m_instances;

llvm::StringMap<std::unique_ptr<Tool>> m_tools;
std::vector<std::unique_ptr<ResourceProvider>> m_resource_providers;

llvm::StringMap<RequestHandler> m_request_handlers;
llvm::StringMap<NotificationHandler> m_notification_handlers;
};

class ServerInfoHandle;
Expand Down Expand Up @@ -121,7 +111,7 @@ class ServerInfoHandle {
ServerInfoHandle &operator=(const ServerInfoHandle &) = delete;
/// @}

/// Remove the file.
/// Remove the file on disk, if one is tracked.
void Remove();

private:
Expand Down
60 changes: 58 additions & 2 deletions lldb/include/lldb/Protocol/MCP/Transport.h
Original file line number Diff line number Diff line change
Expand Up @@ -10,22 +10,78 @@
#define LLDB_PROTOCOL_MCP_TRANSPORT_H

#include "lldb/Host/JSONTransport.h"
#include "lldb/Protocol/MCP/MCPError.h"
#include "lldb/Protocol/MCP/Protocol.h"
#include "lldb/lldb-forward.h"
#include "llvm/ADT/FunctionExtras.h"
#include "llvm/ADT/StringRef.h"
#include "llvm/Support/Error.h"
#include <sys/types.h>

namespace lldb_protocol::mcp {

struct ProtocolDescriptor {
using Id = int64_t;
using Req = Request;
using Resp = Response;
using Evt = Notification;

static inline Id InitialId() { return 0; }
static inline Request Make(Id id, llvm::StringRef method,
std::optional<llvm::json::Value> params) {
return Request{id, method.str(), params};
}
static inline Notification Make(llvm::StringRef method,
std::optional<llvm::json::Value> params) {
return Notification{method.str(), params};
}
static inline Response Make(Req req, llvm::Error error) {
lldb_protocol::mcp::Error protocol_error;
llvm::handleAllErrors(
std::move(error), [&](const llvm::ErrorInfoBase &err) {
std::error_code cerr = err.convertToErrorCode();
protocol_error.code =
cerr == llvm::inconvertibleErrorCode()
? lldb_protocol::mcp::eErrorCodeInternalError
: cerr.value();
protocol_error.message = err.message();
});

return Response{req.id, std::move(protocol_error)};
}
static inline Response Make(Req req,
std::optional<llvm::json::Value> result) {
return Response{req.id, std::move(result)};
}
static inline Id KeyFor(Response r) { return std::get<Id>(r.id); }
static inline llvm::StringRef KeyFor(Request r) { return r.method; }
static inline llvm::StringRef KeyFor(Notification n) { return n.method; }
static inline std::optional<llvm::json::Value> Extract(Request r) {
return r.params;
}
static inline llvm::Expected<llvm::json::Value> Extract(Response r) {
if (const lldb_protocol::mcp::Error *error =
std::get_if<lldb_protocol::mcp::Error>(&r.result))
return llvm::make_error<lldb_protocol::mcp::MCPError>(error->message,
error->code);
return std::get<llvm::json::Value>(r.result);
}
static inline std::optional<llvm::json::Value> Extract(Notification n) {
return n.params;
}
};

/// Generic transport that uses the MCP protocol.
using MCPTransport = lldb_private::Transport<Request, Response, Notification>;
using MCPTransport = lldb_private::transport::JSONTransport<ProtocolDescriptor>;
using MCPBinder = lldb_private::transport::Binder<ProtocolDescriptor>;
using MCPBinderUP = std::unique_ptr<MCPBinder>;

/// Generic logging callback, to allow the MCP server / client / transport layer
/// to be independent of the lldb log implementation.
using LogCallback = llvm::unique_function<void(llvm::StringRef message)>;

class Transport final
: public lldb_private::JSONRPCTransport<Request, Response, Notification> {
: public lldb_private::transport::JSONRPCTransport<ProtocolDescriptor> {
public:
Transport(lldb::IOObjectSP in, lldb::IOObjectSP out,
LogCallback log_callback = {});
Expand Down
26 changes: 23 additions & 3 deletions lldb/source/Host/common/JSONTransport.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -14,19 +14,39 @@
#include <string>

using namespace llvm;
using namespace lldb;
using namespace lldb_private;
using namespace lldb_private::transport;

char TransportUnhandledContentsError::ID;

TransportUnhandledContentsError::TransportUnhandledContentsError(
std::string unhandled_contents)
: m_unhandled_contents(unhandled_contents) {}

void TransportUnhandledContentsError::log(llvm::raw_ostream &OS) const {
void TransportUnhandledContentsError::log(raw_ostream &OS) const {
OS << "transport EOF with unhandled contents: '" << m_unhandled_contents
<< "'";
}
std::error_code TransportUnhandledContentsError::convertToErrorCode() const {
return std::make_error_code(std::errc::bad_message);
}

char InvalidParams::ID;

void InvalidParams::log(raw_ostream &OS) const {
OS << "invalid parameters for method '" << m_method << "': '" << m_context
<< "'";
}
std::error_code InvalidParams::convertToErrorCode() const {
return std::make_error_code(std::errc::invalid_argument);
}

char MethodNotFound::ID;

void MethodNotFound::log(raw_ostream &OS) const {
OS << "method not found: '" << m_method << "'";
}

std::error_code MethodNotFound::convertToErrorCode() const {
// JSON-RPC Method not found
return std::error_code(MethodNotFound::kErrorCode, std::generic_category());
}
52 changes: 19 additions & 33 deletions lldb/source/Plugins/Protocol/MCP/ProtocolServerMCP.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -52,11 +52,6 @@ llvm::StringRef ProtocolServerMCP::GetPluginDescriptionStatic() {
}

void ProtocolServerMCP::Extend(lldb_protocol::mcp::Server &server) const {
server.AddNotificationHandler("notifications/initialized",
[](const lldb_protocol::mcp::Notification &) {
LLDB_LOG(GetLog(LLDBLog::Host),
"MCP initialization complete");
});
server.AddTool(
std::make_unique<CommandTool>("command", "Run an lldb command."));
server.AddTool(std::make_unique<DebuggerListTool>(
Expand All @@ -74,26 +69,9 @@ void ProtocolServerMCP::AcceptCallback(std::unique_ptr<Socket> socket) {
io_sp, io_sp, [client_name](llvm::StringRef message) {
LLDB_LOG(GetLog(LLDBLog::Host), "{0}: {1}", client_name, message);
});
MCPTransport *transport_ptr = transport_up.get();
auto instance_up = std::make_unique<lldb_protocol::mcp::Server>(
std::string(kName), std::string(kVersion), *transport_up,
/*log_callback=*/
[client_name](llvm::StringRef message) {
LLDB_LOG(GetLog(LLDBLog::Host), "{0} Server: {1}", client_name,
message);
},
/*closed_callback=*/
[this, transport_ptr]() { m_instances.erase(transport_ptr); });
Extend(*instance_up);
llvm::Expected<MainLoop::ReadHandleUP> handle =
transport_up->RegisterMessageHandler(m_loop, *instance_up);
if (!handle) {
LLDB_LOG_ERROR(log, handle.takeError(), "Failed to run MCP server: {0}");
return;
}
m_instances[transport_ptr] =
std::make_tuple<ServerUP, ReadHandleUP, TransportUP>(
std::move(instance_up), std::move(*handle), std::move(transport_up));

if (auto error = m_server->Accept(m_loop, std::move(transport_up)))
LLDB_LOG_ERROR(log, std::move(error), "{0}:");
}

llvm::Error ProtocolServerMCP::Start(ProtocolServer::Connection connection) {
Expand Down Expand Up @@ -124,14 +102,21 @@ llvm::Error ProtocolServerMCP::Start(ProtocolServer::Connection connection) {
llvm::join(m_listener->GetListeningConnectionURI(), ", ");

ServerInfo info{listening_uris[0]};
llvm::Expected<ServerInfoHandle> handle = ServerInfo::Write(info);
if (!handle)
return handle.takeError();
llvm::Expected<ServerInfoHandle> server_info_handle = ServerInfo::Write(info);
if (!server_info_handle)
return server_info_handle.takeError();

m_client_count = 0;
m_server = std::make_unique<lldb_protocol::mcp::Server>(
std::string(kName), std::string(kVersion), [](StringRef message) {
LLDB_LOG(GetLog(LLDBLog::Host), "MCP Server: {0}", message);
});
Extend(*m_server);

m_running = true;
m_server_info_handle = std::move(*handle);
m_listen_handlers = std::move(*handles);
m_loop_thread = std::thread([=] {
m_server_info_handle = std::move(*server_info_handle);
m_accept_handles = std::move(*handles);
m_loop_thread = std::thread([this] {
llvm::set_thread_name("protocol-server.mcp");
m_loop.Run();
});
Expand All @@ -155,9 +140,10 @@ llvm::Error ProtocolServerMCP::Stop() {
if (m_loop_thread.joinable())
m_loop_thread.join();

m_accept_handles.clear();

m_server.reset(nullptr);
m_server_info_handle.Remove();
m_listen_handlers.clear();
m_instances.clear();

return llvm::Error::success();
}
Loading
Loading