From 48eb7de9b9fd62d9f9d655593a8c1944dc94ec68 Mon Sep 17 00:00:00 2001 From: John Harrison Date: Wed, 10 Sep 2025 10:42:56 -0700 Subject: [PATCH 1/8] [lldb] Adding A new Binding helper for JSONTransport. This adds a new Binding helper class to allow mapping of incoming and outgoing requests / events to specific handlers. This should make it easier to create new protocol implementations and allow us to create a relay in the lldb-mcp binary. --- lldb/include/lldb/Host/JSONTransport.h | 377 +++++++++++++++++- lldb/include/lldb/Protocol/MCP/Protocol.h | 8 + lldb/include/lldb/Protocol/MCP/Server.h | 73 ++-- lldb/include/lldb/Protocol/MCP/Transport.h | 77 +++- lldb/source/Host/common/JSONTransport.cpp | 10 + .../Protocol/MCP/ProtocolServerMCP.cpp | 50 +-- .../Plugins/Protocol/MCP/ProtocolServerMCP.h | 20 +- lldb/source/Protocol/MCP/Server.cpp | 207 +++------- lldb/tools/lldb-dap/DAP.h | 6 +- lldb/tools/lldb-dap/Protocol/ProtocolBase.h | 6 +- lldb/tools/lldb-dap/Transport.h | 6 +- lldb/unittests/DAP/DAPTest.cpp | 20 +- lldb/unittests/DAP/Handler/DisconnectTest.cpp | 4 +- lldb/unittests/DAP/TestBase.cpp | 42 +- lldb/unittests/DAP/TestBase.h | 122 +++--- lldb/unittests/Host/JSONTransportTest.cpp | 338 ++++++++++++---- .../Protocol/ProtocolMCPServerTest.cpp | 304 +++++++------- .../Host/JSONTransportTestUtilities.h | 96 ++++- 18 files changed, 1167 insertions(+), 599 deletions(-) diff --git a/lldb/include/lldb/Host/JSONTransport.h b/lldb/include/lldb/Host/JSONTransport.h index c73021d204258..ede8892e6fff4 100644 --- a/lldb/include/lldb/Host/JSONTransport.h +++ b/lldb/include/lldb/Host/JSONTransport.h @@ -18,6 +18,7 @@ #include "lldb/Utility/IOObject.h" #include "lldb/Utility/Status.h" #include "lldb/lldb-forward.h" +#include "llvm/ADT/FunctionExtras.h" #include "llvm/ADT/StringExtras.h" #include "llvm/ADT/StringRef.h" #include "llvm/Support/Error.h" @@ -25,8 +26,13 @@ #include "llvm/Support/FormatVariadic.h" #include "llvm/Support/JSON.h" #include "llvm/Support/raw_ostream.h" +#include +#include +#include #include #include +#include +#include #include #include @@ -50,17 +56,70 @@ class TransportUnhandledContentsError std::string m_unhandled_contents; }; +class InvalidParams : public llvm::ErrorInfo { +public: + static char ID; + + explicit InvalidParams(std::string method, std::string context) + : m_method(std::move(method)), m_context(std::move(context)) {} + + void log(llvm::raw_ostream &OS) const override; + std::error_code convertToErrorCode() const override; + +private: + std::string m_method; + std::string m_context; +}; + +// Value for tracking functions that have a void param or result. +using VoidT = std::monostate; + +template using Callback = llvm::unique_function; + +template +using Reply = typename std::conditional< + std::is_same_v == true, llvm::unique_function, + llvm::unique_function)>>::type; + +template +using OutgoingRequest = typename std::conditional< + std::is_same_v == true, + llvm::unique_function)>, + llvm::unique_function)>>::type; + +template +using OutgoingEvent = typename std::conditional< + std::is_same_v == true, llvm::unique_function, + llvm::unique_function>::type; + +template +Req make_request(Id id, llvm::StringRef method, + std::optional params = std::nullopt); +template +Resp make_response(const Req &req, llvm::Error error); +template +Resp make_response(const Req &req, llvm::json::Value result); +template +Evt make_event(llvm::StringRef method, + std::optional params = std::nullopt); +template +llvm::Expected get_result(const Resp &resp); +template Id get_id(const T &); +template llvm::StringRef get_method(const T &); +template llvm::json::Value get_params(const T &); + /// A transport is responsible for maintaining the connection to a client /// application, and reading/writing structured messages to it. /// /// Transports have limited thread safety requirements: /// - Messages will not be sent concurrently. /// - Messages MAY be sent while Run() is reading, or its callback is active. -template class Transport { +template +class JSONTransport { public: using Message = std::variant; - virtual ~Transport() = default; + virtual ~JSONTransport() = default; /// Sends an event, a message that does not require a response. virtual llvm::Error Send(const Evt &) = 0; @@ -90,8 +149,6 @@ template class Transport { virtual void OnClosed() = 0; }; - using MessageHandlerSP = std::shared_ptr; - /// RegisterMessageHandler registers the Transport with the given MainLoop and /// handles any incoming messages using the given MessageHandler. /// @@ -105,14 +162,294 @@ template class Transport { Log(llvm::formatv(Fmt, std::forward(Vals)...).str()); } virtual void Log(llvm::StringRef message) = 0; + + /// Function object to reply to a call. + /// Each instance must be called exactly once, otherwise: + /// - the bug is logged, and (in debug mode) an assert will fire + /// - if there was no reply, an error reply is sent + /// - if there were multiple replies, only the first is sent + class ReplyOnce { + std::atomic replied = {false}; + const Req req; + JSONTransport *transport; // Null when moved-from. + JSONTransport::MessageHandler *handler; // Null when moved-from. + + public: + ReplyOnce(const Req req, JSONTransport *transport, + JSONTransport::MessageHandler *handler) + : req(req), transport(transport), handler(handler) { + assert(handler); + } + ReplyOnce(ReplyOnce &&other) + : replied(other.replied.load()), req(other.req), + transport(other.transport), handler(other.handler) { + other.transport = nullptr; + other.handler = nullptr; + } + ReplyOnce &operator=(ReplyOnce &&) = delete; + ReplyOnce(const ReplyOnce &) = delete; + ReplyOnce &operator=(const ReplyOnce &) = delete; + + ~ReplyOnce() { + if (transport && handler && !replied) { + assert(false && "must reply to all calls!"); + (*this)(make_response( + req, llvm::createStringError("failed to reply"))); + } + } + + void operator()(const Resp &resp) { + assert(transport && handler && "moved-from!"); + if (replied.exchange(true)) { + assert(false && "must reply to each call only once!"); + return; + } + + if (llvm::Error error = transport->Send(resp)) + handler->OnError(std::move(error)); + } + }; + +public: + class Binder; + using BinderUP = std::unique_ptr; + + /// Binder collects a table of functions that handle calls. + /// + /// The wrapper takes care of parsing/serializing responses. + class Binder : public JSONTransport::MessageHandler { + public: + explicit Binder(JSONTransport &transport) + : m_transport(transport), m_seq(0) {} + + Binder(const Binder &) = delete; + Binder &operator=(const Binder &) = delete; + + /// Bind a handler on transport disconnect. + template + void disconnected(Fn &&fn, Args &&...args) { + m_disconnect_handler = [&, args...]() mutable { + std::invoke(std::forward(fn), std::forward(args)...); + }; + } + + /// Bind a handler on error when communicating with the transport. + template + void error(Fn &&fn, Args &&...args) { + m_error_handler = [&, args...](llvm::Error error) mutable { + std::invoke(std::forward(fn), std::forward(args)..., + std::move(error)); + }; + } + + template + static llvm::Expected parse(const llvm::json::Value &raw, + llvm::StringRef method) { + T result; + llvm::json::Path::Root root; + if (!fromJSON(raw, result, root)) { + // Dump the relevant parts of the broken message. + std::string context; + llvm::raw_string_ostream OS(context); + root.printErrorContext(raw, OS); + return llvm::make_error(method.str(), context); + } + return std::move(result); + } + + /// Bind a handler for a request. + /// e.g. `bind("peek", &ThisModule::peek, this, std::placeholders::_1);`. + /// Handler should be e.g. `Expected peek(const PeekParams&);` + /// PeekParams must be JSON parsable and PeekResult must be serializable. + template + void bind(llvm::StringLiteral method, Fn &&fn, Args &&...args) { + assert(m_request_handlers.find(method) == m_request_handlers.end() && + "request already bound"); + if constexpr (std::is_void_v || std::is_same_v) { + m_request_handlers[method] = + [fn, + args...](const Req &req, + llvm::unique_function reply) mutable { + llvm::Expected result = std::invoke( + std::forward(fn), std::forward(args)...); + if (!result) + return reply(make_response(req, result.takeError())); + reply(make_response(req, toJSON(*result))); + }; + } else { + m_request_handlers[method] = + [method, fn, + args...](const Req &req, + llvm::unique_function reply) mutable { + llvm::Expected params = + parse(get_params(req), method); + if (!params) + return reply(make_response(req, params.takeError())); + + llvm::Expected result = std::invoke( + std::forward(fn), std::forward(args)..., *params); + if (!result) + return reply(make_response(req, result.takeError())); + + reply(make_response(req, toJSON(*result))); + }; + } + } + + /// Bind a handler for a event. + /// e.g. `bind("peek", &ThisModule::peek, this);` + /// Handler should be e.g. `void peek(const PeekParams&);` + /// PeekParams must be JSON parsable. + template + void bind(llvm::StringLiteral method, Fn &&fn, Args &&...args) { + assert(m_event_handlers.find(method) == m_event_handlers.end() && + "event already bound"); + if constexpr (std::is_void_v || std::is_same_v) { + m_event_handlers[method] = [fn, args...](const Evt &) mutable { + std::invoke(std::forward(fn), std::forward(args)...); + }; + } else { + m_event_handlers[method] = [this, method, fn, + args...](const Evt &evt) mutable { + llvm::Expected params = + parse(get_params(evt), method); + if (!params) + return OnError(params.takeError()); + std::invoke(std::forward(fn), std::forward(args)..., + *params); + }; + } + } + + /// Bind a function object to be used for outgoing requests. + /// e.g. `OutgoingRequest Edit = bind("edit");` + /// Params must be JSON-serializable, Result must be parsable. + template + OutgoingRequest bind(llvm::StringLiteral method) { + if constexpr (std::is_void_v || std::is_same_v) { + return [this, method](Reply fn) { + std::scoped_lock guard(m_mutex); + Id id = ++m_seq; + Req req = make_request(id, method, std::nullopt); + m_pending_responses[id] = [fn = std::move(fn), + method](const Resp &resp) mutable { + llvm::Expected result = get_result(resp); + if (!result) + return fn(result.takeError()); + fn(parse(*result, method)); + }; + if (llvm::Error error = m_transport.Send(req)) + OnError(std::move(error)); + }; + } else { + return [this, method](const Params ¶ms, Reply fn) { + std::scoped_lock guard(m_mutex); + Id id = ++m_seq; + Req req = + make_request(id, method, llvm::json::Value(params)); + m_pending_responses[id] = [fn = std::move(fn), + method](const Resp &resp) mutable { + llvm::Expected result = get_result(resp); + if (llvm::Error err = result.takeError()) + return fn(std::move(err)); + fn(parse(*result, method)); + }; + if (llvm::Error error = m_transport.Send(req)) + OnError(std::move(error)); + }; + } + } + + /// Bind a function object to be used for outgoing events. + /// e.g. `OutgoingEvent Log = bind("log");` + /// LogParams must be JSON-serializable. + template + OutgoingEvent bind(llvm::StringLiteral method) { + if constexpr (std::is_void_v || std::is_same_v) { + return [this, method]() { + if (llvm::Error error = + m_transport.Send(make_event(method, std::nullopt))) + OnError(std::move(error)); + }; + } else { + return [this, method](const Params ¶ms) { + if (llvm::Error error = + m_transport.Send(make_event(method, toJSON(params)))) + OnError(std::move(error)); + }; + } + } + + void Received(const Evt &evt) override { + std::scoped_lock guard(m_mutex); + auto it = m_event_handlers.find(get_method(evt)); + if (it == m_event_handlers.end()) { + OnError(llvm::createStringError( + llvm::formatv("no handled for event {0}", toJSON(evt)))); + return; + } + it->second(evt); + } + + void Received(const Req &req) override { + ReplyOnce reply(req, &m_transport, this); + + std::scoped_lock guard(m_mutex); + auto it = m_request_handlers.find(get_method(req)); + if (it == m_request_handlers.end()) { + reply(make_response( + req, llvm::createStringError("method not found"))); + return; + } + + it->second(req, std::move(reply)); + } + + void Received(const Resp &resp) override { + std::scoped_lock guard(m_mutex); + auto it = m_pending_responses.find(get_id(resp)); + if (it == m_pending_responses.end()) { + OnError(llvm::createStringError( + llvm::formatv("no pending request for {0}", toJSON(resp)))); + return; + } + + it->second(resp); + m_pending_responses.erase(it); + } + + void OnError(llvm::Error err) override { + std::scoped_lock guard(m_mutex); + if (m_error_handler) + m_error_handler(std::move(err)); + } + + void OnClosed() override { + std::scoped_lock guard(m_mutex); + if (m_disconnect_handler) + m_disconnect_handler(); + } + + private: + std::recursive_mutex m_mutex; + JSONTransport &m_transport; + Id m_seq; + std::map> m_pending_responses; + llvm::StringMap)>> + m_request_handlers; + llvm::StringMap> m_event_handlers; + Callback m_disconnect_handler; + Callback m_error_handler; + }; }; /// An IOTransport sends and receives messages using an IOObject. -template -class IOTransport : public Transport { +template +class IOTransport : public JSONTransport { public: - using Transport::Transport; - using MessageHandler = typename Transport::MessageHandler; + using Message = typename JSONTransport::Message; + using MessageHandler = + typename JSONTransport::MessageHandler; IOTransport(lldb::IOObjectSP in, lldb::IOObjectSP out) : m_in(in), m_out(out) {} @@ -139,7 +476,7 @@ class IOTransport : public Transport { /// detail. static constexpr size_t kReadBufferSize = 1024; - // FIXME: Write should be protected. +protected: llvm::Error Write(const llvm::json::Value &message) { this->Logv("<-- {0}", message); std::string output = Encode(message); @@ -147,7 +484,6 @@ class IOTransport : public Transport { return m_out->Write(output.data(), bytes_written).takeError(); } -protected: virtual llvm::Expected> Parse() = 0; virtual std::string Encode(const llvm::json::Value &message) = 0; @@ -174,9 +510,8 @@ class IOTransport : public Transport { } for (const std::string &raw_message : *raw_messages) { - llvm::Expected::Message> message = - llvm::json::parse::Message>( - raw_message); + llvm::Expected message = + llvm::json::parse(raw_message); if (!message) { handler.OnError(message.takeError()); return; @@ -201,10 +536,10 @@ class IOTransport : public Transport { }; /// A transport class for JSON with a HTTP header. -template -class HTTPDelimitedJSONTransport : public IOTransport { +template +class HTTPDelimitedJSONTransport : public IOTransport { public: - using IOTransport::IOTransport; + using IOTransport::IOTransport; protected: /// Encodes messages based on @@ -230,8 +565,8 @@ class HTTPDelimitedJSONTransport : public IOTransport { for (const llvm::StringRef &header : llvm::split(headers, kHeaderSeparator)) { auto [key, value] = header.split(kHeaderFieldSeparator); - // 'Content-Length' is the only meaningful key at the moment. Others are - // ignored. + // 'Content-Length' is the only meaningful key at the moment. Others + // are ignored. if (!key.equals_insensitive(kHeaderContentLength)) continue; @@ -268,10 +603,10 @@ class HTTPDelimitedJSONTransport : public IOTransport { }; /// A transport class for JSON RPC. -template -class JSONRPCTransport : public IOTransport { +template +class JSONRPCTransport : public IOTransport { public: - using IOTransport::IOTransport; + using IOTransport::IOTransport; protected: std::string Encode(const llvm::json::Value &message) override { diff --git a/lldb/include/lldb/Protocol/MCP/Protocol.h b/lldb/include/lldb/Protocol/MCP/Protocol.h index 6e1ffcbe1f3e3..1e0816110b80a 100644 --- a/lldb/include/lldb/Protocol/MCP/Protocol.h +++ b/lldb/include/lldb/Protocol/MCP/Protocol.h @@ -14,6 +14,7 @@ #ifndef LLDB_PROTOCOL_MCP_PROTOCOL_H #define LLDB_PROTOCOL_MCP_PROTOCOL_H +#include "llvm/ADT/StringRef.h" #include "llvm/Support/JSON.h" #include #include @@ -324,4 +325,11 @@ bool fromJSON(const llvm::json::Value &, CallToolResult &, llvm::json::Path); } // namespace lldb_protocol::mcp +namespace llvm::json { +inline Value toJSON(const lldb_protocol::mcp::Void &) { return Object(); } +inline bool fromJSON(const Value &, lldb_protocol::mcp::Void &, Path) { + return true; +} +} // namespace llvm::json + #endif diff --git a/lldb/include/lldb/Protocol/MCP/Server.h b/lldb/include/lldb/Protocol/MCP/Server.h index 970980d075ea6..94dbba8d5888e 100644 --- a/lldb/include/lldb/Protocol/MCP/Server.h +++ b/lldb/include/lldb/Protocol/MCP/Server.h @@ -9,7 +9,6 @@ #ifndef LLDB_PROTOCOL_MCP_SERVER_H #define LLDB_PROTOCOL_MCP_SERVER_H -#include "lldb/Host/JSONTransport.h" #include "lldb/Host/MainLoop.h" #include "lldb/Protocol/MCP/Protocol.h" #include "lldb/Protocol/MCP/Resource.h" @@ -19,75 +18,67 @@ #include "llvm/ADT/StringMap.h" #include "llvm/ADT/StringRef.h" #include "llvm/Support/Error.h" +#include "llvm/Support/FormatVariadic.h" #include "llvm/Support/JSON.h" #include "llvm/Support/Signals.h" -#include #include #include #include namespace lldb_protocol::mcp { -class Server : public MCPTransport::MessageHandler { - using ClosedCallback = llvm::unique_function; + +class Server { + + using MCPTransportUP = std::unique_ptr; + + using ReadHandleUP = lldb_private::MainLoop::ReadHandleUP; public: - Server(std::string name, std::string version, MCPTransport &client, - LogCallback log_callback = {}, ClosedCallback closed_callback = {}); + Server(std::string name, std::string version, LogCallback log_callback = {}); ~Server() = default; - using NotificationHandler = std::function; - void AddTool(std::unique_ptr tool); void AddResourceProvider(std::unique_ptr resource_provider); - void AddNotificationHandler(llvm::StringRef method, - NotificationHandler handler); - -protected: - ServerCapabilities GetCapabilities(); - using RequestHandler = - std::function(const Request &)>; + llvm::Error Accept(lldb_private::MainLoop &, MCPTransportUP); - void AddRequestHandlers(); - - void AddRequestHandler(llvm::StringRef method, RequestHandler handler); - - llvm::Expected> HandleData(llvm::StringRef data); - - llvm::Expected Handle(const Request &request); - void Handle(const Notification ¬ification); +protected: + MCPTransport::BinderUP Bind(MCPTransport &); - llvm::Expected InitializeHandler(const Request &); + ServerCapabilities GetCapabilities(); - llvm::Expected ToolsListHandler(const Request &); - llvm::Expected ToolsCallHandler(const Request &); + llvm::Expected InitializeHandler(const InitializeParams &); - llvm::Expected ResourcesListHandler(const Request &); - llvm::Expected ResourcesReadHandler(const Request &); + llvm::Expected ToolsListHandler(); + llvm::Expected ToolsCallHandler(const CallToolParams &); - void Received(const Request &) override; - void Received(const Response &) override; - void Received(const Notification &) override; - void OnError(llvm::Error) override; - void OnClosed() override; + llvm::Expected ResourcesListHandler(); + llvm::Expected + ResourcesReadHandler(const ReadResourceParams &); -protected: - void Log(llvm::StringRef); + template inline auto Logv(const char *Fmt, Ts &&...Vals) { + Log(llvm::formatv(Fmt, std::forward(Vals)...).str()); + } + void Log(llvm::StringRef message) { + if (m_log_callback) + m_log_callback(message); + } private: const std::string m_name; const std::string m_version; - MCPTransport &m_client; LogCallback m_log_callback; - ClosedCallback m_closed_callback; + struct Client { + ReadHandleUP handle; + MCPTransportUP transport; + MCPTransport::BinderUP binder; + }; + std::map m_instances; llvm::StringMap> m_tools; std::vector> m_resource_providers; - - llvm::StringMap m_request_handlers; - llvm::StringMap m_notification_handlers; }; class ServerInfoHandle; @@ -121,7 +112,7 @@ class ServerInfoHandle { ServerInfoHandle &operator=(const ServerInfoHandle &) = delete; /// @} - /// Remove the file. + /// Remove the file on disk, if one is tracked. void Remove(); private: diff --git a/lldb/include/lldb/Protocol/MCP/Transport.h b/lldb/include/lldb/Protocol/MCP/Transport.h index 47c2ccfc44dfe..55b2e8fa0a7f2 100644 --- a/lldb/include/lldb/Protocol/MCP/Transport.h +++ b/lldb/include/lldb/Protocol/MCP/Transport.h @@ -10,22 +10,95 @@ #define LLDB_PROTOCOL_MCP_TRANSPORT_H #include "lldb/Host/JSONTransport.h" +#include "lldb/Protocol/MCP/MCPError.h" #include "lldb/Protocol/MCP/Protocol.h" #include "lldb/lldb-forward.h" #include "llvm/ADT/FunctionExtras.h" #include "llvm/ADT/StringRef.h" +#include "llvm/Support/Error.h" + +namespace lldb_private { +/// Specializations of the JSONTransport protocol functions for MCP. +/// @{ +template <> +inline lldb_protocol::mcp::Request +make_request(int64_t id, llvm::StringRef method, + std::optional params) { + return lldb_protocol::mcp::Request{id, method.str(), params}; +} +template <> +inline lldb_protocol::mcp::Response +make_response(const lldb_protocol::mcp::Request &req, llvm::Error error) { + lldb_protocol::mcp::Error protocol_error; + llvm::handleAllErrors( + std::move(error), + [&](const lldb_protocol::mcp::MCPError &err) { + protocol_error = err.toProtocolError(); + }, + [&](const llvm::ErrorInfoBase &err) { + protocol_error.code = lldb_protocol::mcp::MCPError::kInternalError; + protocol_error.message = err.message(); + }); + + return lldb_protocol::mcp::Response{req.id, std::move(protocol_error)}; +} +template <> +inline lldb_protocol::mcp::Response +make_response(const lldb_protocol::mcp::Request &req, + llvm::json::Value result) { + return lldb_protocol::mcp::Response{req.id, std::move(result)}; +} +template <> +inline lldb_protocol::mcp::Notification +make_event(llvm::StringRef method, std::optional params) { + return lldb_protocol::mcp::Notification{method.str(), params}; +} +template <> +inline llvm::Expected +get_result(const lldb_protocol::mcp::Response &resp) { + if (const lldb_protocol::mcp::Error *error = + std::get_if(&resp.result)) + return llvm::make_error(error->message, + error->code); + return std::get(resp.result); +} +template <> inline int64_t get_id(const lldb_protocol::mcp::Response &resp) { + return std::get(resp.id); +} +template <> +inline llvm::StringRef get_method(const lldb_protocol::mcp::Request &req) { + return req.method; +} +template <> +inline llvm::StringRef get_method(const lldb_protocol::mcp::Notification &evt) { + return evt.method; +} +template <> +inline llvm::json::Value get_params(const lldb_protocol::mcp::Request &req) { + return req.params; +} +template <> +inline llvm::json::Value +get_params(const lldb_protocol::mcp::Notification &evt) { + return evt.params; +} +/// @} + +} // end namespace lldb_private namespace lldb_protocol::mcp { /// Generic transport that uses the MCP protocol. -using MCPTransport = lldb_private::Transport; +using MCPTransport = + lldb_private::JSONTransport; /// Generic logging callback, to allow the MCP server / client / transport layer /// to be independent of the lldb log implementation. using LogCallback = llvm::unique_function; class Transport final - : public lldb_private::JSONRPCTransport { + : public lldb_private::JSONRPCTransport { public: Transport(lldb::IOObjectSP in, lldb::IOObjectSP out, LogCallback log_callback = {}); diff --git a/lldb/source/Host/common/JSONTransport.cpp b/lldb/source/Host/common/JSONTransport.cpp index c4b42eafc85d3..f809ef478c8f7 100644 --- a/lldb/source/Host/common/JSONTransport.cpp +++ b/lldb/source/Host/common/JSONTransport.cpp @@ -30,3 +30,13 @@ void TransportUnhandledContentsError::log(llvm::raw_ostream &OS) const { std::error_code TransportUnhandledContentsError::convertToErrorCode() const { return std::make_error_code(std::errc::bad_message); } + +char InvalidParams::ID; + +void InvalidParams::log(llvm::raw_ostream &OS) const { + OS << "invalid parameters for method '" << m_method << "': '" << m_context + << "'"; +} +std::error_code InvalidParams::convertToErrorCode() const { + return std::make_error_code(std::errc::invalid_argument); +} diff --git a/lldb/source/Plugins/Protocol/MCP/ProtocolServerMCP.cpp b/lldb/source/Plugins/Protocol/MCP/ProtocolServerMCP.cpp index d7293fc28c524..7500bb1515e01 100644 --- a/lldb/source/Plugins/Protocol/MCP/ProtocolServerMCP.cpp +++ b/lldb/source/Plugins/Protocol/MCP/ProtocolServerMCP.cpp @@ -52,11 +52,6 @@ llvm::StringRef ProtocolServerMCP::GetPluginDescriptionStatic() { } void ProtocolServerMCP::Extend(lldb_protocol::mcp::Server &server) const { - server.AddNotificationHandler("notifications/initialized", - [](const lldb_protocol::mcp::Notification &) { - LLDB_LOG(GetLog(LLDBLog::Host), - "MCP initialization complete"); - }); server.AddTool( std::make_unique("command", "Run an lldb command.")); server.AddTool(std::make_unique( @@ -74,26 +69,9 @@ void ProtocolServerMCP::AcceptCallback(std::unique_ptr socket) { io_sp, io_sp, [client_name](llvm::StringRef message) { LLDB_LOG(GetLog(LLDBLog::Host), "{0}: {1}", client_name, message); }); - MCPTransport *transport_ptr = transport_up.get(); - auto instance_up = std::make_unique( - std::string(kName), std::string(kVersion), *transport_up, - /*log_callback=*/ - [client_name](llvm::StringRef message) { - LLDB_LOG(GetLog(LLDBLog::Host), "{0} Server: {1}", client_name, - message); - }, - /*closed_callback=*/ - [this, transport_ptr]() { m_instances.erase(transport_ptr); }); - Extend(*instance_up); - llvm::Expected handle = - transport_up->RegisterMessageHandler(m_loop, *instance_up); - if (!handle) { - LLDB_LOG_ERROR(log, handle.takeError(), "Failed to run MCP server: {0}"); - return; - } - m_instances[transport_ptr] = - std::make_tuple( - std::move(instance_up), std::move(*handle), std::move(transport_up)); + + if (auto error = m_server->Accept(m_loop, std::move(transport_up))) + LLDB_LOG_ERROR(log, std::move(error), "{0}:"); } llvm::Error ProtocolServerMCP::Start(ProtocolServer::Connection connection) { @@ -124,13 +102,20 @@ llvm::Error ProtocolServerMCP::Start(ProtocolServer::Connection connection) { llvm::join(m_listener->GetListeningConnectionURI(), ", "); ServerInfo info{listening_uris[0]}; - llvm::Expected handle = ServerInfo::Write(info); - if (!handle) - return handle.takeError(); + llvm::Expected server_info_handle = ServerInfo::Write(info); + if (!server_info_handle) + return server_info_handle.takeError(); + + m_client_count = 0; + m_server = std::make_unique( + std::string(kName), std::string(kVersion), [](StringRef message) { + LLDB_LOG(GetLog(LLDBLog::Host), "MCP Server: {0}", message); + }); + Extend(*m_server); m_running = true; - m_server_info_handle = std::move(*handle); - m_listen_handlers = std::move(*handles); + m_server_info_handle = std::move(*server_info_handle); + m_accept_handles = std::move(*handles); m_loop_thread = std::thread([=] { llvm::set_thread_name("protocol-server.mcp"); m_loop.Run(); @@ -155,9 +140,10 @@ llvm::Error ProtocolServerMCP::Stop() { if (m_loop_thread.joinable()) m_loop_thread.join(); + m_accept_handles.clear(); + + m_server.reset(nullptr); m_server_info_handle.Remove(); - m_listen_handlers.clear(); - m_instances.clear(); return llvm::Error::success(); } diff --git a/lldb/source/Plugins/Protocol/MCP/ProtocolServerMCP.h b/lldb/source/Plugins/Protocol/MCP/ProtocolServerMCP.h index b325a3681bccb..e0f2a6ccea1f5 100644 --- a/lldb/source/Plugins/Protocol/MCP/ProtocolServerMCP.h +++ b/lldb/source/Plugins/Protocol/MCP/ProtocolServerMCP.h @@ -23,16 +23,17 @@ namespace lldb_private::mcp { class ProtocolServerMCP : public ProtocolServer { - using ReadHandleUP = MainLoopBase::ReadHandleUP; - using TransportUP = std::unique_ptr; + using ServerUP = std::unique_ptr; + using ReadHandleUP = MainLoop::ReadHandleUP; + public: ProtocolServerMCP(); - virtual ~ProtocolServerMCP() override; + ~ProtocolServerMCP() override; - virtual llvm::Error Start(ProtocolServer::Connection connection) override; - virtual llvm::Error Stop() override; + llvm::Error Start(ProtocolServer::Connection connection) override; + llvm::Error Stop() override; static void Initialize(); static void Terminate(); @@ -56,19 +57,18 @@ class ProtocolServerMCP : public ProtocolServer { bool m_running = false; - lldb_protocol::mcp::ServerInfoHandle m_server_info_handle; lldb_private::MainLoop m_loop; std::thread m_loop_thread; std::mutex m_mutex; size_t m_client_count = 0; std::unique_ptr m_listener; + std::vector m_accept_handles; - std::vector m_listen_handlers; - std::map> - m_instances; + ServerUP m_server; + lldb_protocol::mcp::ServerInfoHandle m_server_info_handle; }; + } // namespace lldb_private::mcp #endif diff --git a/lldb/source/Protocol/MCP/Server.cpp b/lldb/source/Protocol/MCP/Server.cpp index 19030a3a4e5d6..79a2925c6c320 100644 --- a/lldb/source/Protocol/MCP/Server.cpp +++ b/lldb/source/Protocol/MCP/Server.cpp @@ -12,6 +12,7 @@ #include "lldb/Host/HostInfo.h" #include "lldb/Protocol/MCP/MCPError.h" #include "lldb/Protocol/MCP/Protocol.h" +#include "lldb/Protocol/MCP/Transport.h" #include "llvm/ADT/SmallString.h" #include "llvm/Support/FileSystem.h" #include "llvm/Support/JSON.h" @@ -108,48 +109,9 @@ Expected> ServerInfo::Load() { return infos; } -Server::Server(std::string name, std::string version, MCPTransport &client, - LogCallback log_callback, ClosedCallback closed_callback) - : m_name(std::move(name)), m_version(std::move(version)), m_client(client), - m_log_callback(std::move(log_callback)), - m_closed_callback(std::move(closed_callback)) { - AddRequestHandlers(); -} - -void Server::AddRequestHandlers() { - AddRequestHandler("initialize", std::bind(&Server::InitializeHandler, this, - std::placeholders::_1)); - AddRequestHandler("tools/list", std::bind(&Server::ToolsListHandler, this, - std::placeholders::_1)); - AddRequestHandler("tools/call", std::bind(&Server::ToolsCallHandler, this, - std::placeholders::_1)); - AddRequestHandler("resources/list", std::bind(&Server::ResourcesListHandler, - this, std::placeholders::_1)); - AddRequestHandler("resources/read", std::bind(&Server::ResourcesReadHandler, - this, std::placeholders::_1)); -} - -llvm::Expected Server::Handle(const Request &request) { - auto it = m_request_handlers.find(request.method); - if (it != m_request_handlers.end()) { - llvm::Expected response = it->second(request); - if (!response) - return response; - response->id = request.id; - return *response; - } - - return llvm::make_error( - llvm::formatv("no handler for request: {0}", request.method).str()); -} - -void Server::Handle(const Notification ¬ification) { - auto it = m_notification_handlers.find(notification.method); - if (it != m_notification_handlers.end()) { - it->second(notification); - return; - } -} +Server::Server(std::string name, std::string version, LogCallback log_callback) + : m_name(std::move(name)), m_version(std::move(version)), + m_log_callback(std::move(log_callback)) {} void Server::AddTool(std::unique_ptr tool) { if (!tool) @@ -164,48 +126,62 @@ void Server::AddResourceProvider( m_resource_providers.push_back(std::move(resource_provider)); } -void Server::AddRequestHandler(llvm::StringRef method, RequestHandler handler) { - m_request_handlers[method] = std::move(handler); -} - -void Server::AddNotificationHandler(llvm::StringRef method, - NotificationHandler handler) { - m_notification_handlers[method] = std::move(handler); -} - -llvm::Expected Server::InitializeHandler(const Request &request) { - Response response; +MCPTransport::BinderUP Server::Bind(MCPTransport &transport) { + MCPTransport::BinderUP binder = + std::make_unique(transport); + binder->bind( + "initialize", &Server::InitializeHandler, this); + binder->bind("tools/list", &Server::ToolsListHandler, + this); + binder->bind("tools/call", + &Server::ToolsCallHandler, this); + binder->bind("resources/list", + &Server::ResourcesListHandler, this); + binder->bind( + "resources/read", &Server::ResourcesReadHandler, this); + binder->bind("notifications/initialized", + [this]() { Log("MCP initialization complete"); }); + return binder; +} + +llvm::Error Server::Accept(MainLoop &loop, MCPTransportUP transport) { + MCPTransport::BinderUP binder = Bind(*transport); + MCPTransport *transport_ptr = transport.get(); + binder->disconnected([this, transport_ptr]() { + assert(m_instances.find(transport_ptr) != m_instances.end() && + "Client not found in m_instances"); + m_instances.erase(transport_ptr); + }); + + auto handle = transport->RegisterMessageHandler(loop, *binder); + if (!handle) + return handle.takeError(); + + m_instances[transport_ptr] = + Client{std::move(*handle), std::move(transport), std::move(binder)}; + return llvm::Error::success(); +} + +Expected +Server::InitializeHandler(const InitializeParams &request) { InitializeResult result; result.protocolVersion = mcp::kProtocolVersion; result.capabilities = GetCapabilities(); result.serverInfo.name = m_name; result.serverInfo.version = m_version; - response.result = std::move(result); - return response; + return result; } -llvm::Expected Server::ToolsListHandler(const Request &request) { - Response response; - +llvm::Expected Server::ToolsListHandler() { ListToolsResult result; for (const auto &tool : m_tools) result.tools.emplace_back(tool.second->GetDefinition()); - response.result = std::move(result); - - return response; + return result; } -llvm::Expected Server::ToolsCallHandler(const Request &request) { - Response response; - - if (!request.params) - return llvm::createStringError("no tool parameters"); - CallToolParams params; - json::Path::Root root("params"); - if (!fromJSON(request.params, params, root)) - return root.getError(); - +llvm::Expected +Server::ToolsCallHandler(const CallToolParams ¶ms) { llvm::StringRef tool_name = params.name; if (tool_name.empty()) return llvm::createStringError("no tool name"); @@ -222,113 +198,50 @@ llvm::Expected Server::ToolsCallHandler(const Request &request) { if (!text_result) return text_result.takeError(); - response.result = toJSON(*text_result); - - return response; + return text_result; } -llvm::Expected Server::ResourcesListHandler(const Request &request) { - Response response; - +llvm::Expected Server::ResourcesListHandler() { ListResourcesResult result; for (std::unique_ptr &resource_provider_up : m_resource_providers) for (const Resource &resource : resource_provider_up->GetResources()) result.resources.push_back(resource); - response.result = std::move(result); - - return response; + return result; } -llvm::Expected Server::ResourcesReadHandler(const Request &request) { - Response response; - - if (!request.params) - return llvm::createStringError("no resource parameters"); - - ReadResourceParams params; - json::Path::Root root("params"); - if (!fromJSON(request.params, params, root)) - return root.getError(); - - llvm::StringRef uri_str = params.uri; +Expected +Server::ResourcesReadHandler(const ReadResourceParams ¶ms) { + StringRef uri_str = params.uri; if (uri_str.empty()) - return llvm::createStringError("no resource uri"); + return createStringError("no resource uri"); for (std::unique_ptr &resource_provider_up : m_resource_providers) { - llvm::Expected result = + Expected result = resource_provider_up->ReadResource(uri_str); if (result.errorIsA()) { - llvm::consumeError(result.takeError()); + consumeError(result.takeError()); continue; } if (!result) return result.takeError(); - Response response; - response.result = std::move(*result); - return response; + return *result; } return make_error( - llvm::formatv("no resource handler for uri: {0}", uri_str).str(), + formatv("no resource handler for uri: {0}", uri_str).str(), MCPError::kResourceNotFound); } ServerCapabilities Server::GetCapabilities() { lldb_protocol::mcp::ServerCapabilities capabilities; capabilities.supportsToolsList = true; + capabilities.supportsResourcesList = true; // FIXME: Support sending notifications when a debugger/target are // added/removed. - capabilities.supportsResourcesList = false; + capabilities.supportsResourcesSubscribe = false; return capabilities; } - -void Server::Log(llvm::StringRef message) { - if (m_log_callback) - m_log_callback(message); -} - -void Server::Received(const Request &request) { - auto SendResponse = [this](const Response &response) { - if (llvm::Error error = m_client.Send(response)) - Log(llvm::toString(std::move(error))); - }; - - llvm::Expected response = Handle(request); - if (response) - return SendResponse(*response); - - lldb_protocol::mcp::Error protocol_error; - llvm::handleAllErrors( - response.takeError(), - [&](const MCPError &err) { protocol_error = err.toProtocolError(); }, - [&](const llvm::ErrorInfoBase &err) { - protocol_error.code = MCPError::kInternalError; - protocol_error.message = err.message(); - }); - Response error_response; - error_response.id = request.id; - error_response.result = std::move(protocol_error); - SendResponse(error_response); -} - -void Server::Received(const Response &response) { - Log("unexpected MCP message: response"); -} - -void Server::Received(const Notification ¬ification) { - Handle(notification); -} - -void Server::OnError(llvm::Error error) { - Log(llvm::toString(std::move(error))); -} - -void Server::OnClosed() { - Log("EOF"); - if (m_closed_callback) - m_closed_callback(); -} diff --git a/lldb/tools/lldb-dap/DAP.h b/lldb/tools/lldb-dap/DAP.h index 71681fd4b51ed..0c921e5b72d74 100644 --- a/lldb/tools/lldb-dap/DAP.h +++ b/lldb/tools/lldb-dap/DAP.h @@ -79,10 +79,10 @@ enum DAPBroadcasterBits { enum class ReplMode { Variable = 0, Command, Auto }; using DAPTransport = - lldb_private::Transport; + lldb_private::JSONTransport; -struct DAP final : private DAPTransport::MessageHandler { +struct DAP final : public DAPTransport::MessageHandler { /// Path to the lldb-dap binary itself. static llvm::StringRef debug_adapter_path; diff --git a/lldb/tools/lldb-dap/Protocol/ProtocolBase.h b/lldb/tools/lldb-dap/Protocol/ProtocolBase.h index 0a9ef538a7398..92e41b1dbf595 100644 --- a/lldb/tools/lldb-dap/Protocol/ProtocolBase.h +++ b/lldb/tools/lldb-dap/Protocol/ProtocolBase.h @@ -30,6 +30,8 @@ namespace lldb_dap::protocol { // MARK: Base Protocol +using Id = int64_t; + /// A client or debug adapter initiated request. struct Request { /// Sequence number of the message (also known as message ID). The `seq` for @@ -39,7 +41,7 @@ struct Request { /// associate requests with their corresponding responses. For protocol /// messages of type `request` the sequence number can be used to cancel the /// request. - int64_t seq; + Id seq; /// The command to execute. std::string command; @@ -76,7 +78,7 @@ enum ResponseMessage : unsigned { /// Response for a request. struct Response { /// Sequence number of the corresponding request. - int64_t request_seq; + Id request_seq; /// The command requested. std::string command; diff --git a/lldb/tools/lldb-dap/Transport.h b/lldb/tools/lldb-dap/Transport.h index 4a9dd76c2303e..6462c155eb9af 100644 --- a/lldb/tools/lldb-dap/Transport.h +++ b/lldb/tools/lldb-dap/Transport.h @@ -24,9 +24,9 @@ namespace lldb_dap { /// A transport class that performs the Debug Adapter Protocol communication /// with the client. -class Transport final - : public lldb_private::HTTPDelimitedJSONTransport< - protocol::Request, protocol::Response, protocol::Event> { +class Transport final : public lldb_private::HTTPDelimitedJSONTransport< + protocol::Id, protocol::Request, protocol::Response, + protocol::Event> { public: Transport(llvm::StringRef client_name, lldb_dap::Log *log, lldb::IOObjectSP input, lldb::IOObjectSP output); diff --git a/lldb/unittests/DAP/DAPTest.cpp b/lldb/unittests/DAP/DAPTest.cpp index 2090fe6896d6b..4fd6cd546e6fa 100644 --- a/lldb/unittests/DAP/DAPTest.cpp +++ b/lldb/unittests/DAP/DAPTest.cpp @@ -9,13 +9,10 @@ #include "DAP.h" #include "Protocol/ProtocolBase.h" #include "TestBase.h" -#include "llvm/Testing/Support/Error.h" #include "gmock/gmock.h" #include "gtest/gtest.h" #include -using namespace llvm; -using namespace lldb; using namespace lldb_dap; using namespace lldb_dap_tests; using namespace lldb_dap::protocol; @@ -24,18 +21,7 @@ using namespace testing; class DAPTest : public TransportBase {}; TEST_F(DAPTest, SendProtocolMessages) { - DAP dap{ - /*log=*/nullptr, - /*default_repl_mode=*/ReplMode::Auto, - /*pre_init_commands=*/{}, - /*no_lldbinit=*/false, - /*client_name=*/"test_client", - /*transport=*/*transport, - /*loop=*/loop, - }; - dap.Send(Event{/*event=*/"my-event", /*body=*/std::nullopt}); - loop.AddPendingCallback( - [](lldb_private::MainLoopBase &loop) { loop.RequestTermination(); }); - EXPECT_CALL(client, Received(IsEvent("my-event", std::nullopt))); - ASSERT_THAT_ERROR(dap.Loop(), llvm::Succeeded()); + dap->Send(Event{/*event=*/"my-event", /*body=*/std::nullopt}); + EXPECT_CALL(client, Received(IsEvent("my-event"))); + Run(); } diff --git a/lldb/unittests/DAP/Handler/DisconnectTest.cpp b/lldb/unittests/DAP/Handler/DisconnectTest.cpp index c6ff1f90b01d5..88d6e9a69eca3 100644 --- a/lldb/unittests/DAP/Handler/DisconnectTest.cpp +++ b/lldb/unittests/DAP/Handler/DisconnectTest.cpp @@ -31,7 +31,7 @@ TEST_F(DisconnectRequestHandlerTest, DisconnectTriggersTerminated) { DisconnectRequestHandler handler(*dap); ASSERT_THAT_ERROR(handler.Run(std::nullopt), Succeeded()); EXPECT_CALL(client, Received(IsEvent("terminated", _))); - RunOnce(); + Run(); } TEST_F(DisconnectRequestHandlerTest, DisconnectTriggersTerminateCommands) { @@ -53,5 +53,5 @@ TEST_F(DisconnectRequestHandlerTest, DisconnectTriggersTerminateCommands) { EXPECT_CALL(client, Received(Output("(lldb) script print(2)\n"))); EXPECT_CALL(client, Received(Output("Running terminateCommands:\n"))); EXPECT_CALL(client, Received(IsEvent("terminated", _))); - RunOnce(); + Run(); } diff --git a/lldb/unittests/DAP/TestBase.cpp b/lldb/unittests/DAP/TestBase.cpp index ba7baf2103799..3721e09d8b699 100644 --- a/lldb/unittests/DAP/TestBase.cpp +++ b/lldb/unittests/DAP/TestBase.cpp @@ -32,23 +32,9 @@ using lldb_private::FileSystem; using lldb_private::MainLoop; using lldb_private::Pipe; -Expected -TestTransport::RegisterMessageHandler(MainLoop &loop, MessageHandler &handler) { - Expected dummy_file = FileSystem::Instance().Open( - FileSpec(FileSystem::DEV_NULL), File::eOpenOptionReadWrite); - if (!dummy_file) - return dummy_file.takeError(); - m_dummy_file = std::move(*dummy_file); - lldb_private::Status status; - auto handle = loop.RegisterReadObject( - m_dummy_file, [](lldb_private::MainLoopBase &) {}, status); - if (status.Fail()) - return status.takeError(); - return handle; -} +void TransportBase::SetUp() { + std::tie(to_client, to_server) = TestDAPTransport::createPair(); -void DAPTestBase::SetUp() { - TransportBase::SetUp(); std::error_code EC; log = std::make_unique("-", EC); dap = std::make_unique( @@ -57,16 +43,30 @@ void DAPTestBase::SetUp() { /*pre_init_commands=*/std::vector(), /*no_lldbinit=*/false, /*client_name=*/"test_client", - /*transport=*/*transport, /*loop=*/loop); + /*transport=*/*to_client, /*loop=*/loop); + + auto server_handle = to_server->RegisterMessageHandler(loop, *dap.get()); + EXPECT_THAT_EXPECTED(server_handle, Succeeded()); + handles[0] = std::move(*server_handle); + + auto client_handle = to_client->RegisterMessageHandler(loop, client); + EXPECT_THAT_EXPECTED(client_handle, Succeeded()); + handles[1] = std::move(*client_handle); } +void TransportBase::Run() { + loop.AddPendingCallback( + [](lldb_private::MainLoopBase &loop) { loop.RequestTermination(); }); + EXPECT_THAT_ERROR(loop.Run().takeError(), llvm::Succeeded()); +} + +void DAPTestBase::SetUp() { TransportBase::SetUp(); } + void DAPTestBase::TearDown() { - if (core) { + if (core) ASSERT_THAT_ERROR(core->discard(), Succeeded()); - } - if (binary) { + if (binary) ASSERT_THAT_ERROR(binary->discard(), Succeeded()); - } } void DAPTestBase::SetUpTestSuite() { diff --git a/lldb/unittests/DAP/TestBase.h b/lldb/unittests/DAP/TestBase.h index c19eead4e37e7..aaeab3b3d2cd9 100644 --- a/lldb/unittests/DAP/TestBase.h +++ b/lldb/unittests/DAP/TestBase.h @@ -7,6 +7,7 @@ //===----------------------------------------------------------------------===// #include "DAP.h" +#include "DAPLog.h" #include "Protocol/ProtocolBase.h" #include "TestingSupport/Host/JSONTransportTestUtilities.h" #include "TestingSupport/SubsystemRAII.h" @@ -14,66 +15,41 @@ #include "lldb/Host/HostInfo.h" #include "lldb/Host/MainLoop.h" #include "lldb/Host/MainLoopBase.h" -#include "lldb/lldb-forward.h" #include "llvm/ADT/StringRef.h" -#include "llvm/Support/Error.h" #include "llvm/Support/FileSystem.h" #include "llvm/Support/JSON.h" -#include "llvm/Testing/Support/Error.h" #include "gmock/gmock.h" #include "gtest/gtest.h" #include +#include + +/// Helpers for gtest printing. +namespace lldb_dap::protocol { + +inline void PrintTo(const Request &req, std::ostream *os) { + *os << llvm::formatv("{0}", toJSON(req)).str(); +} + +inline void PrintTo(const Response &resp, std::ostream *os) { + *os << llvm::formatv("{0}", toJSON(resp)).str(); +} + +inline void PrintTo(const Event &evt, std::ostream *os) { + *os << llvm::formatv("{0}", toJSON(evt)).str(); +} + +inline void PrintTo(const Message &message, std::ostream *os) { + return std::visit([os](auto &&message) { return PrintTo(message, os); }, + message); +} + +} // namespace lldb_dap::protocol namespace lldb_dap_tests { -class TestTransport final - : public lldb_private::Transport { -public: - using Message = lldb_private::Transport::Message; - - TestTransport(lldb_private::MainLoop &loop, MessageHandler &handler) - : m_loop(loop), m_handler(handler) {} - - llvm::Error Send(const lldb_dap::protocol::Event &e) override { - m_loop.AddPendingCallback([this, e](lldb_private::MainLoopBase &) { - this->m_handler.Received(e); - }); - return llvm::Error::success(); - } - - llvm::Error Send(const lldb_dap::protocol::Request &r) override { - m_loop.AddPendingCallback([this, r](lldb_private::MainLoopBase &) { - this->m_handler.Received(r); - }); - return llvm::Error::success(); - } - - llvm::Error Send(const lldb_dap::protocol::Response &r) override { - m_loop.AddPendingCallback([this, r](lldb_private::MainLoopBase &) { - this->m_handler.Received(r); - }); - return llvm::Error::success(); - } - - llvm::Expected - RegisterMessageHandler(lldb_private::MainLoop &loop, - MessageHandler &handler) override; - - void Log(llvm::StringRef message) override { - log_messages.emplace_back(message); - } - - std::vector log_messages; - -private: - lldb_private::MainLoop &m_loop; - MessageHandler &m_handler; - lldb::FileSP m_dummy_file; -}; +using TestDAPTransport = + TestTransport; /// A base class for tests that need transport configured for communicating DAP /// messages. @@ -82,22 +58,38 @@ class TransportBase : public testing::Test { lldb_private::SubsystemRAII subsystems; lldb_private::MainLoop loop; - std::unique_ptr transport; - MockMessageHandler + lldb_private::MainLoop::ReadHandleUP handles[2]; + + std::unique_ptr log; + + std::unique_ptr to_client; + MockMessageHandler client; - void SetUp() override { - transport = std::make_unique(loop, client); - } + std::unique_ptr to_server; + std::unique_ptr dap; + + void SetUp() override; + + void Run(); }; /// A matcher for a DAP event. -template +template inline testing::Matcher -IsEvent(const M1 &m1, const M2 &m2) { - return testing::AllOf(testing::Field(&lldb_dap::protocol::Event::event, m1), - testing::Field(&lldb_dap::protocol::Event::body, m2)); +IsEvent(const EventMatcher &event_matcher, const BodyMatcher &body_matcher) { + return testing::AllOf( + testing::Field(&lldb_dap::protocol::Event::event, event_matcher), + testing::Field(&lldb_dap::protocol::Event::body, body_matcher)); +} + +template +inline testing::Matcher +IsEvent(const EventMatcher &event_matcher) { + return testing::AllOf( + testing::Field(&lldb_dap::protocol::Event::event, event_matcher), + testing::Field(&lldb_dap::protocol::Event::body, std::nullopt)); } /// Matches an "output" event. @@ -110,8 +102,6 @@ inline auto Output(llvm::StringRef o, llvm::StringRef cat = "console") { /// A base class for tests that interact with a `lldb_dap::DAP` instance. class DAPTestBase : public TransportBase { protected: - std::unique_ptr log; - std::unique_ptr dap; std::optional core; std::optional binary; @@ -126,12 +116,6 @@ class DAPTestBase : public TransportBase { bool GetDebuggerSupportsTarget(llvm::StringRef platform); void CreateDebugger(); void LoadCore(); - - void RunOnce() { - loop.AddPendingCallback( - [](lldb_private::MainLoopBase &loop) { loop.RequestTermination(); }); - ASSERT_THAT_ERROR(dap->Loop(), llvm::Succeeded()); - } }; } // namespace lldb_dap_tests diff --git a/lldb/unittests/Host/JSONTransportTest.cpp b/lldb/unittests/Host/JSONTransportTest.cpp index 3a36bf21f07ff..0228e13b61909 100644 --- a/lldb/unittests/Host/JSONTransportTest.cpp +++ b/lldb/unittests/Host/JSONTransportTest.cpp @@ -9,6 +9,7 @@ #include "lldb/Host/JSONTransport.h" #include "TestingSupport/Host/JSONTransportTestUtilities.h" #include "TestingSupport/Host/PipeTestUtilities.h" +#include "TestingSupport/SubsystemRAII.h" #include "lldb/Host/File.h" #include "lldb/Host/MainLoop.h" #include "lldb/Host/MainLoopBase.h" @@ -25,6 +26,7 @@ #include #include #include +#include #include using namespace llvm; @@ -32,20 +34,35 @@ using namespace lldb_private; using testing::_; using testing::HasSubstr; using testing::InSequence; +using testing::Ref; + +namespace llvm::json { +static bool fromJSON(const Value &V, Value &T, Path P) { + T = V; + return true; +} +} // namespace llvm::json namespace { namespace test_protocol { struct Req { + int id = 0; std::string name; + std::optional params; }; -json::Value toJSON(const Req &T) { return json::Object{{"req", T.name}}; } +json::Value toJSON(const Req &T) { + return json::Object{{"name", T.name}, {"id", T.id}, {"params", T.params}}; +} bool fromJSON(const json::Value &V, Req &T, json::Path P) { json::ObjectMapper O(V, P); - return O && O.map("req", T.name); + return O && O.map("name", T.name) && O.map("id", T.id) && + O.map("params", T.params); +} +bool operator==(const Req &a, const Req &b) { + return a.name == b.name && a.id == b.id && a.params == b.params; } -bool operator==(const Req &a, const Req &b) { return a.name == b.name; } inline llvm::raw_ostream &operator<<(llvm::raw_ostream &OS, const Req &V) { OS << toJSON(V); return OS; @@ -58,14 +75,19 @@ void PrintTo(const Req &message, std::ostream *os) { } struct Resp { - std::string name; + int id = 0; + std::optional result; }; -json::Value toJSON(const Resp &T) { return json::Object{{"resp", T.name}}; } +json::Value toJSON(const Resp &T) { + return json::Object{{"id", T.id}, {"result", T.result}}; +} bool fromJSON(const json::Value &V, Resp &T, json::Path P) { json::ObjectMapper O(V, P); - return O && O.map("resp", T.name); + return O && O.map("id", T.id) && O.map("result", T.result); +} +bool operator==(const Resp &a, const Resp &b) { + return a.id == b.id && a.result == b.result; } -bool operator==(const Resp &a, const Resp &b) { return a.name == b.name; } inline llvm::raw_ostream &operator<<(llvm::raw_ostream &OS, const Resp &V) { OS << toJSON(V); return OS; @@ -79,11 +101,14 @@ void PrintTo(const Resp &message, std::ostream *os) { struct Evt { std::string name; + std::optional params; }; -json::Value toJSON(const Evt &T) { return json::Object{{"evt", T.name}}; } +json::Value toJSON(const Evt &T) { + return json::Object{{"name", T.name}, {"params", T.params}}; +} bool fromJSON(const json::Value &V, Evt &T, json::Path P) { json::ObjectMapper O(V, P); - return O && O.map("evt", T.name); + return O && O.map("name", T.name) && O.map("params", T.params); } bool operator==(const Evt &a, const Evt &b) { return a.name == b.name; } inline llvm::raw_ostream &operator<<(llvm::raw_ostream &OS, const Evt &V) { @@ -107,41 +132,64 @@ bool fromJSON(const json::Value &V, Message &msg, json::Path P) { P.report("expected object"); return false; } - if (O->get("req")) { - Req R; - if (!fromJSON(V, R, P)) + + if (O->find("id") == O->end()) { + Evt E; + if (!fromJSON(V, E, P)) return false; - msg = std::move(R); + msg = std::move(E); return true; } - if (O->get("resp")) { - Resp R; + + if (O->get("name")) { + Req R; if (!fromJSON(V, R, P)) return false; msg = std::move(R); return true; } - if (O->get("evt")) { - Evt E; - if (!fromJSON(V, E, P)) - return false; - msg = std::move(E); - return true; - } - P.report("unknown message type"); - return false; + Resp R; + if (!fromJSON(V, R, P)) + return false; + + msg = std::move(R); + return true; } -} // namespace test_protocol +struct MyFnParams { + int a = 0; + int b = 0; +}; +json::Value toJSON(const MyFnParams &T) { + return json::Object{{"a", T.a}, {"b", T.b}}; +} +bool fromJSON(const json::Value &V, MyFnParams &T, json::Path P) { + json::ObjectMapper O(V, P); + return O && O.map("a", T.a) && O.map("b", T.b); +} -template -class JSONTransportTest : public PipePairTest { +struct MyFnResult { + int c = 0; +}; +json::Value toJSON(const MyFnResult &T) { return json::Object{{"c", T.c}}; } +bool fromJSON(const json::Value &V, MyFnResult &T, json::Path P) { + json::ObjectMapper O(V, P); + return O && O.map("c", T.c); +} + +using Transport = TestTransport; +using MessageHandler = MockMessageHandler; + +} // namespace test_protocol +template class JSONTransportTest : public PipePairTest { protected: - MockMessageHandler message_handler; + SubsystemRAII subsystems; + + test_protocol::MessageHandler message_handler; std::unique_ptr transport; MainLoop loop; @@ -191,8 +239,8 @@ class JSONTransportTest : public PipePairTest { }; class TestHTTPDelimitedJSONTransport final - : public HTTPDelimitedJSONTransport { + : public HTTPDelimitedJSONTransport< + int, test_protocol::Req, test_protocol::Resp, test_protocol::Evt> { public: using HTTPDelimitedJSONTransport::HTTPDelimitedJSONTransport; @@ -204,9 +252,7 @@ class TestHTTPDelimitedJSONTransport final }; class HTTPDelimitedJSONTransportTest - : public JSONTransportTest { + : public JSONTransportTest { public: using JSONTransportTest::JSONTransportTest; @@ -222,7 +268,7 @@ class HTTPDelimitedJSONTransportTest }; class TestJSONRPCTransport final - : public JSONRPCTransport { public: using JSONRPCTransport::JSONRPCTransport; @@ -234,9 +280,7 @@ class TestJSONRPCTransport final std::vector log_messages; }; -class JSONRPCTransportTest - : public JSONTransportTest { +class JSONRPCTransportTest : public JSONTransportTest { public: using JSONTransportTest::JSONTransportTest; @@ -248,8 +292,71 @@ class JSONRPCTransportTest } }; +class TestTransportBinder : public testing::Test { +protected: + SubsystemRAII subsystems; + + std::unique_ptr to_remote; + std::unique_ptr from_remote; + std::unique_ptr binder; + test_protocol::MessageHandler remote; + MainLoop loop; + + void SetUp() override { + std::tie(to_remote, from_remote) = test_protocol::Transport::createPair(); + binder = std::make_unique(*to_remote); + + auto binder_handle = to_remote->RegisterMessageHandler(loop, remote); + EXPECT_THAT_EXPECTED(binder_handle, Succeeded()); + + auto remote_handle = from_remote->RegisterMessageHandler(loop, *binder); + EXPECT_THAT_EXPECTED(remote_handle, Succeeded()); + } + + void Run() { + loop.AddPendingCallback([](auto &loop) { loop.RequestTermination(); }); + EXPECT_THAT_ERROR(loop.Run().takeError(), Succeeded()); + } +}; + } // namespace +namespace lldb_private { +using namespace test_protocol; +template <> +inline test_protocol::Req make_request(int id, llvm::StringRef method, + std::optional params) { + return test_protocol::Req{id, method.str(), params}; +} +template <> inline Resp make_response(const Req &req, llvm::Error error) { + llvm::consumeError(std::move(error)); + return Resp{req.id, std::nullopt}; +} +template <> inline Resp make_response(const Req &req, json::Value result) { + return Resp{req.id, std::move(result)}; +} +template <> +inline Evt make_event(llvm::StringRef method, + std::optional params) { + return Evt{method.str(), params}; +} + +template <> inline llvm::Expected get_result(const Resp &resp) { + return resp.result; +} + +template <> inline int get_id(const Resp &resp) { return resp.id; } +template <> inline llvm::StringRef get_method(const Req &req) { + return req.name; +} +template <> inline llvm::StringRef get_method(const Evt &evt) { + return evt.name; +} +template <> inline json::Value get_params(const Req &req) { return req.params; } +template <> inline json::Value get_params(const Evt &evt) { return evt.params; } + +} // namespace lldb_private + // Failing on Windows, see https://github.com/llvm/llvm-project/issues/153446. #ifndef _WIN32 using namespace test_protocol; @@ -269,35 +376,47 @@ TEST_F(HTTPDelimitedJSONTransportTest, MalformedRequests) { } TEST_F(HTTPDelimitedJSONTransportTest, Read) { - Write(Req{"foo"}); - EXPECT_CALL(message_handler, Received(Req{"foo"})); + Write(Req{6, "foo", std::nullopt}); + EXPECT_CALL(message_handler, Received(Req{6, "foo", std::nullopt})); ASSERT_THAT_ERROR(Run(), Succeeded()); } TEST_F(HTTPDelimitedJSONTransportTest, ReadMultipleMessagesInSingleWrite) { InSequence seq; - Write(Message{Req{"one"}}, Message{Evt{"two"}}, Message{Resp{"three"}}); - EXPECT_CALL(message_handler, Received(Req{"one"})); - EXPECT_CALL(message_handler, Received(Evt{"two"})); - EXPECT_CALL(message_handler, Received(Resp{"three"})); + Write( + Message{ + Req{6, "one", std::nullopt}, + }, + Message{ + Evt{"two", std::nullopt}, + }, + Message{ + Resp{2, std::nullopt}, + }); + EXPECT_CALL(message_handler, Received(Req{6, "one", std::nullopt})); + EXPECT_CALL(message_handler, Received(Evt{"two", std::nullopt})); + EXPECT_CALL(message_handler, Received(Resp{2, std::nullopt})); ASSERT_THAT_ERROR(Run(), Succeeded()); } TEST_F(HTTPDelimitedJSONTransportTest, ReadAcrossMultipleChunks) { std::string long_str = std::string( - HTTPDelimitedJSONTransport::kReadBufferSize * 2, 'x'); - Write(Req{long_str}); - EXPECT_CALL(message_handler, Received(Req{long_str})); + HTTPDelimitedJSONTransport::kReadBufferSize * + 2, + 'x'); + Write(Req{5, long_str, std::nullopt}); + EXPECT_CALL(message_handler, Received(Req{5, long_str, std::nullopt})); ASSERT_THAT_ERROR(Run(), Succeeded()); } TEST_F(HTTPDelimitedJSONTransportTest, ReadPartialMessage) { - std::string message = Encode(Req{"foo"}); + std::string message = Encode(Req{5, "foo", std::nullopt}); auto split_at = message.size() / 2; std::string part1 = message.substr(0, split_at); std::string part2 = message.substr(split_at); - EXPECT_CALL(message_handler, Received(Req{"foo"})); + EXPECT_CALL(message_handler, Received(Req{5, "foo", std::nullopt})); ASSERT_THAT_EXPECTED(input.Write(part1.data(), part1.size()), Succeeded()); loop.AddPendingCallback( @@ -309,12 +428,12 @@ TEST_F(HTTPDelimitedJSONTransportTest, ReadPartialMessage) { } TEST_F(HTTPDelimitedJSONTransportTest, ReadWithZeroByteWrites) { - std::string message = Encode(Req{"foo"}); + std::string message = Encode(Req{6, "foo", std::nullopt}); auto split_at = message.size() / 2; std::string part1 = message.substr(0, split_at); std::string part2 = message.substr(split_at); - EXPECT_CALL(message_handler, Received(Req{"foo"})); + EXPECT_CALL(message_handler, Received(Req{6, "foo", std::nullopt})); ASSERT_THAT_EXPECTED(input.Write(part1.data(), part1.size()), Succeeded()); @@ -366,20 +485,21 @@ TEST_F(HTTPDelimitedJSONTransportTest, InvalidTransport) { } TEST_F(HTTPDelimitedJSONTransportTest, Write) { - ASSERT_THAT_ERROR(transport->Send(Req{"foo"}), Succeeded()); - ASSERT_THAT_ERROR(transport->Send(Resp{"bar"}), Succeeded()); - ASSERT_THAT_ERROR(transport->Send(Evt{"baz"}), Succeeded()); + ASSERT_THAT_ERROR(transport->Send(Req{7, "foo", std::nullopt}), Succeeded()); + ASSERT_THAT_ERROR(transport->Send(Resp{5, "bar"}), Succeeded()); + ASSERT_THAT_ERROR(transport->Send(Evt{"baz", std::nullopt}), Succeeded()); output.CloseWriteFileDescriptor(); char buf[1024]; Expected bytes_read = output.Read(buf, sizeof(buf), std::chrono::milliseconds(1)); ASSERT_THAT_EXPECTED(bytes_read, Succeeded()); - ASSERT_EQ(StringRef(buf, *bytes_read), StringRef("Content-Length: 13\r\n\r\n" - R"({"req":"foo"})" - "Content-Length: 14\r\n\r\n" - R"({"resp":"bar"})" - "Content-Length: 13\r\n\r\n" - R"({"evt":"baz"})")); + ASSERT_EQ(StringRef(buf, *bytes_read), + StringRef("Content-Length: 35\r\n\r\n" + R"({"id":7,"name":"foo","params":null})" + "Content-Length: 23\r\n\r\n" + R"({"id":5,"result":"bar"})" + "Content-Length: 28\r\n\r\n" + R"({"name":"baz","params":null})")); } TEST_F(JSONRPCTransportTest, MalformedRequests) { @@ -395,17 +515,18 @@ TEST_F(JSONRPCTransportTest, MalformedRequests) { } TEST_F(JSONRPCTransportTest, Read) { - Write(Message{Req{"foo"}}); - EXPECT_CALL(message_handler, Received(Req{"foo"})); + Write(Message{Req{1, "foo", std::nullopt}}); + EXPECT_CALL(message_handler, Received(Req{1, "foo", std::nullopt})); ASSERT_THAT_ERROR(Run(), Succeeded()); } TEST_F(JSONRPCTransportTest, ReadMultipleMessagesInSingleWrite) { InSequence seq; - Write(Message{Req{"one"}}, Message{Evt{"two"}}, Message{Resp{"three"}}); - EXPECT_CALL(message_handler, Received(Req{"one"})); - EXPECT_CALL(message_handler, Received(Evt{"two"})); - EXPECT_CALL(message_handler, Received(Resp{"three"})); + Write(Message{Req{1, "one", std::nullopt}}, Message{Evt{"two", std::nullopt}}, + Message{Resp{3, "three"}}); + EXPECT_CALL(message_handler, Received(Req{1, "one", std::nullopt})); + EXPECT_CALL(message_handler, Received(Evt{"two", std::nullopt})); + EXPECT_CALL(message_handler, Received(Resp{3, "three"})); ASSERT_THAT_ERROR(Run(), Succeeded()); } @@ -413,19 +534,22 @@ TEST_F(JSONRPCTransportTest, ReadAcrossMultipleChunks) { // Use a string longer than the chunk size to ensure we split the message // across the chunk boundary. std::string long_str = - std::string(IOTransport::kReadBufferSize * 2, 'x'); - Write(Req{long_str}); - EXPECT_CALL(message_handler, Received(Req{long_str})); + std::string(IOTransport::kReadBufferSize * + 2, + 'x'); + Write(Req{42, long_str, std::nullopt}); + EXPECT_CALL(message_handler, Received(Req{42, long_str, std::nullopt})); ASSERT_THAT_ERROR(Run(), Succeeded()); } TEST_F(JSONRPCTransportTest, ReadPartialMessage) { - std::string message = R"({"req": "foo"})" + std::string message = R"({"id":42,"name":"foo","params":null})" "\n"; std::string part1 = message.substr(0, 7); std::string part2 = message.substr(7); - EXPECT_CALL(message_handler, Received(Req{"foo"})); + EXPECT_CALL(message_handler, Received(Req{42, "foo", std::nullopt})); ASSERT_THAT_EXPECTED(input.Write(part1.data(), part1.size()), Succeeded()); loop.AddPendingCallback( @@ -455,20 +579,21 @@ TEST_F(JSONRPCTransportTest, ReaderWithUnhandledData) { } TEST_F(JSONRPCTransportTest, Write) { - ASSERT_THAT_ERROR(transport->Send(Req{"foo"}), Succeeded()); - ASSERT_THAT_ERROR(transport->Send(Resp{"bar"}), Succeeded()); - ASSERT_THAT_ERROR(transport->Send(Evt{"baz"}), Succeeded()); + ASSERT_THAT_ERROR(transport->Send(Req{11, "foo", std::nullopt}), Succeeded()); + ASSERT_THAT_ERROR(transport->Send(Resp{14, "bar"}), Succeeded()); + ASSERT_THAT_ERROR(transport->Send(Evt{"baz", std::nullopt}), Succeeded()); output.CloseWriteFileDescriptor(); char buf[1024]; Expected bytes_read = output.Read(buf, sizeof(buf), std::chrono::milliseconds(1)); ASSERT_THAT_EXPECTED(bytes_read, Succeeded()); - ASSERT_EQ(StringRef(buf, *bytes_read), StringRef(R"({"req":"foo"})" - "\n" - R"({"resp":"bar"})" - "\n" - R"({"evt":"baz"})" - "\n")); + ASSERT_EQ(StringRef(buf, *bytes_read), + StringRef(R"({"id":11,"name":"foo","params":null})" + "\n" + R"({"id":14,"result":"bar"})" + "\n" + R"({"name":"baz","params":null})" + "\n")); } TEST_F(JSONRPCTransportTest, InvalidTransport) { @@ -477,4 +602,59 @@ TEST_F(JSONRPCTransportTest, InvalidTransport) { FailedWithMessage("IO object is not valid.")); } +// Out-bound binding request handler. +TEST_F(TestTransportBinder, OutBoundRequests) { + auto addFn = binder->bind("add"); + addFn(MyFnParams{1, 2}, [](Expected result) { + EXPECT_THAT_EXPECTED(result, Succeeded()); + EXPECT_EQ(result->c, 3); + }); + EXPECT_CALL(remote, Received(Req{1, "add", MyFnParams{1, 2}})); + // Queue a reply that will be sent during 'Run'. + EXPECT_THAT_ERROR(from_remote->Send(Resp{1, toJSON(MyFnResult{3})}), + Succeeded()); + Run(); +} + +// In-bound binding request handler. +TEST_F(TestTransportBinder, InBoundRequests) { + binder->bind( + "add", + [](const int captured_param, + const MyFnParams ¶ms) -> Expected { + return MyFnResult{params.a + params.b + captured_param}; + }, + 2); + EXPECT_THAT_ERROR(from_remote->Send(Req{2, "add", MyFnParams{3, 4}}), + Succeeded()); + EXPECT_CALL(remote, Received(Resp{2, MyFnResult{9}})); + Run(); +} + +// Out-bound binding event handler. +TEST_F(TestTransportBinder, OutBoundEvents) { + auto emitEvent = binder->bind("evt"); + emitEvent(MyFnParams{1, 2}); + EXPECT_CALL(remote, Received(Evt{"evt", MyFnParams{1, 2}})); + Run(); +} + +// In-bound binding event handler. +TEST_F(TestTransportBinder, InBoundEvents) { + bool called = false; + binder->bind( + "evt", + [&](const int captured_arg, const MyFnParams ¶ms) { + EXPECT_EQ(captured_arg, 42); + EXPECT_EQ(params.a, 3); + EXPECT_EQ(params.b, 4); + called = true; + }, + 42); + EXPECT_THAT_ERROR(from_remote->Send(Evt{"evt", MyFnParams{3, 4}}), + Succeeded()); + Run(); + EXPECT_TRUE(called); +} + #endif diff --git a/lldb/unittests/Protocol/ProtocolMCPServerTest.cpp b/lldb/unittests/Protocol/ProtocolMCPServerTest.cpp index f3ca4cfc01788..0958af87a9402 100644 --- a/lldb/unittests/Protocol/ProtocolMCPServerTest.cpp +++ b/lldb/unittests/Protocol/ProtocolMCPServerTest.cpp @@ -6,9 +6,8 @@ // //===----------------------------------------------------------------------===// -#include "ProtocolMCPTestUtilities.h" +#include "ProtocolMCPTestUtilities.h" // IWYU pragma: keep #include "TestingSupport/Host/JSONTransportTestUtilities.h" -#include "TestingSupport/Host/PipeTestUtilities.h" #include "TestingSupport/SubsystemRAII.h" #include "lldb/Host/FileSystem.h" #include "lldb/Host/HostInfo.h" @@ -28,20 +27,21 @@ #include "llvm/Testing/Support/Error.h" #include "gmock/gmock.h" #include "gtest/gtest.h" -#include -#include +#include +#include +#include using namespace llvm; using namespace lldb; using namespace lldb_private; using namespace lldb_protocol::mcp; +using testing::_; namespace { -class TestServer : public Server { -public: - using Server::Server; -}; +template Response make_response(T &&result, Id id = 1) { + return Response{id, std::forward(result)}; +} /// Test tool that returns it argument as text. class TestTool : public Tool { @@ -118,195 +118,209 @@ class FailTool : public Tool { } }; -class ProtocolServerMCPTest : public PipePairTest { +class TestServer : public Server { +public: + using Server::Bind; + using Server::Server; +}; + +using Transport = TestTransport; + +class ProtocolServerMCPTest : public testing::Test { public: SubsystemRAII subsystems; MainLoop loop; + lldb_private::MainLoop::ReadHandleUP handles[2]; - std::unique_ptr from_client; - std::unique_ptr to_client; - MainLoopBase::ReadHandleUP handles[2]; - + std::unique_ptr to_server; + Transport::BinderUP binder; std::unique_ptr server_up; - MockMessageHandler message_handler; - llvm::Error Write(llvm::StringRef message) { - llvm::Expected value = json::parse(message); - if (!value) - return value.takeError(); - return from_client->Write(*value); - } + std::unique_ptr to_client; + MockMessageHandler client; - llvm::Error Write(json::Value value) { return from_client->Write(value); } + std::vector logged_messages; - /// Run the transport MainLoop and return any messages received. - llvm::Error Run() { - loop.AddCallback([](MainLoopBase &loop) { loop.RequestTermination(); }, - std::chrono::milliseconds(10)); - return loop.Run().takeError(); + /// Runs the MainLoop a single time, executing any pending callbacks. + void Run() { + loop.AddPendingCallback( + [](MainLoopBase &loop) { loop.RequestTermination(); }); + EXPECT_THAT_ERROR(loop.Run().takeError(), Succeeded()); } void SetUp() override { - PipePairTest::SetUp(); - - from_client = std::make_unique( - std::make_shared(input.GetReadFileDescriptor(), - File::eOpenOptionReadOnly, - NativeFile::Unowned), - std::make_shared(output.GetWriteFileDescriptor(), - File::eOpenOptionWriteOnly, - NativeFile::Unowned), - [](StringRef message) { - // Uncomment for debugging - // llvm::errs() << "from_client: " << message << '\n'; - }); - to_client = std::make_unique( - std::make_shared(output.GetReadFileDescriptor(), - File::eOpenOptionReadOnly, - NativeFile::Unowned), - std::make_shared(input.GetWriteFileDescriptor(), - File::eOpenOptionWriteOnly, - NativeFile::Unowned), - [](StringRef message) { - // Uncomment for debugging - // llvm::errs() << "to_client: " << message << '\n'; - }); - - server_up = std::make_unique("lldb-mcp", "0.1.0", *to_client, - [](StringRef message) { - // Uncomment for debugging - // llvm::errs() << "server: " << - // message << '\n'; - }); - - auto maybe_from_client_handle = - from_client->RegisterMessageHandler(loop, message_handler); - EXPECT_THAT_EXPECTED(maybe_from_client_handle, Succeeded()); - handles[0] = std::move(*maybe_from_client_handle); - - auto maybe_to_client_handle = - to_client->RegisterMessageHandler(loop, *server_up); - EXPECT_THAT_EXPECTED(maybe_to_client_handle, Succeeded()); - handles[1] = std::move(*maybe_to_client_handle); + std::tie(to_client, to_server) = Transport::createPair(); + + server_up = std::make_unique( + "lldb-mcp", "0.1.0", + [this](StringRef msg) { logged_messages.push_back(msg.str()); }); + binder = server_up->Bind(*to_client); + auto server_handle = to_server->RegisterMessageHandler(loop, *binder); + EXPECT_THAT_EXPECTED(server_handle, Succeeded()); + binder->error([](llvm::Error error) { + llvm::errs() << formatv("Server transport error: {0}", error); + }); + handles[0] = std::move(*server_handle); + + auto client_handle = to_client->RegisterMessageHandler(loop, client); + EXPECT_THAT_EXPECTED(client_handle, Succeeded()); + handles[1] = std::move(*client_handle); + } + + template + Expected Call(StringRef method, const Params ¶ms) { + std::promise promised_result; + Request req = make_request( + /*id=*/1, method, toJSON(params)); + EXPECT_THAT_ERROR(to_server->Send(req), Succeeded()); + EXPECT_CALL(client, Received(testing::An())) + .WillOnce( + [&](const Response &resp) { promised_result.set_value(resp); }); + Run(); + Response resp = promised_result.get_future().get(); + return toJSON(resp); + } + + template + Expected + Capture(llvm::unique_function)> &fn) { + std::promise> promised_result; + fn([&promised_result](llvm::Expected result) { + promised_result.set_value(std::move(result)); + }); + Run(); + llvm::Expected result = promised_result.get_future().get(); + if (!result) + return result.takeError(); + return toJSON(*result); + } + + template + Expected + Capture(llvm::unique_function)> &fn, + const Params ¶ms) { + std::promise> promised_result; + fn(params, [&promised_result](llvm::Expected result) { + promised_result.set_value(std::move(result)); + }); + Run(); + llvm::Expected result = promised_result.get_future().get(); + if (!result) + return result.takeError(); + return toJSON(*result); } }; template -Request make_request(StringLiteral method, T &¶ms, Id id = 1) { - return Request{id, method.str(), toJSON(std::forward(params))}; -} - -template Response make_response(T &&result, Id id = 1) { - return Response{id, std::forward(result)}; +inline testing::internal::EqMatcher HasJSON(T x) { + return testing::internal::EqMatcher(toJSON(x)); } } // namespace TEST_F(ProtocolServerMCPTest, Initialization) { - Request request = make_request( - "initialize", InitializeParams{/*protocolVersion=*/"2024-11-05", - /*capabilities=*/{}, - /*clientInfo=*/{"lldb-unit", "0.1.0"}}); - Response response = make_response( - InitializeResult{/*protocolVersion=*/"2024-11-05", - /*capabilities=*/{/*supportsToolsList=*/true}, - /*serverInfo=*/{"lldb-mcp", "0.1.0"}}); - - ASSERT_THAT_ERROR(Write(request), Succeeded()); - EXPECT_CALL(message_handler, Received(response)); - EXPECT_THAT_ERROR(Run(), Succeeded()); + EXPECT_THAT_EXPECTED( + (Call( + "initialize", + InitializeParams{/*protocolVersion=*/"2024-11-05", + /*capabilities=*/{}, + /*clientInfo=*/{"lldb-unit", "0.1.0"}})), + HasValue(make_response( + InitializeResult{/*protocolVersion=*/"2024-11-05", + /*capabilities=*/ + { + /*supportsToolsList=*/true, + /*supportsResourcesList=*/true, + }, + /*serverInfo=*/{"lldb-mcp", "0.1.0"}}))); } TEST_F(ProtocolServerMCPTest, ToolsList) { server_up->AddTool(std::make_unique("test", "test tool")); - Request request = make_request("tools/list", Void{}, /*id=*/"one"); - ToolDefinition test_tool; test_tool.name = "test"; test_tool.description = "test tool"; test_tool.inputSchema = json::Object{{"type", "object"}}; - Response response = make_response(ListToolsResult{{test_tool}}, /*id=*/"one"); - - ASSERT_THAT_ERROR(Write(request), llvm::Succeeded()); - EXPECT_CALL(message_handler, Received(response)); - EXPECT_THAT_ERROR(Run(), Succeeded()); + EXPECT_THAT_EXPECTED(Call("tools/list", Void{}), + HasValue(make_response(ListToolsResult{{test_tool}}))); } TEST_F(ProtocolServerMCPTest, ResourcesList) { server_up->AddResourceProvider(std::make_unique()); - Request request = make_request("resources/list", Void{}); - Response response = make_response(ListResourcesResult{ - {{/*uri=*/"lldb://foo/bar", /*name=*/"name", - /*description=*/"description", /*mimeType=*/"application/json"}}}); - - ASSERT_THAT_ERROR(Write(request), llvm::Succeeded()); - EXPECT_CALL(message_handler, Received(response)); - EXPECT_THAT_ERROR(Run(), Succeeded()); + EXPECT_THAT_EXPECTED(Call("resources/list", Void{}), + HasValue(make_response(ListResourcesResult{{ + { + /*uri=*/"lldb://foo/bar", + /*name=*/"name", + /*description=*/"description", + /*mimeType=*/"application/json", + }, + }}))); } TEST_F(ProtocolServerMCPTest, ToolsCall) { server_up->AddTool(std::make_unique("test", "test tool")); - Request request = make_request( - "tools/call", CallToolParams{/*name=*/"test", /*arguments=*/json::Object{ - {"arguments", "foo"}, - {"debugger_id", 0}, - }}); - Response response = make_response(CallToolResult{{{/*text=*/"foo"}}}); - - ASSERT_THAT_ERROR(Write(request), llvm::Succeeded()); - EXPECT_CALL(message_handler, Received(response)); - EXPECT_THAT_ERROR(Run(), Succeeded()); + EXPECT_THAT_EXPECTED( + (Call("tools/call", + CallToolParams{ + /*name=*/"test", + /*arguments=*/ + json::Object{ + {"arguments", "foo"}, + {"debugger_id", 0}, + }, + })), + HasValue(make_response(CallToolResult{{{/*text=*/"foo"}}}))); } TEST_F(ProtocolServerMCPTest, ToolsCallError) { server_up->AddTool(std::make_unique("error", "error tool")); - Request request = make_request( - "tools/call", CallToolParams{/*name=*/"error", /*arguments=*/json::Object{ - {"arguments", "foo"}, - {"debugger_id", 0}, - }}); - Response response = - make_response(lldb_protocol::mcp::Error{eErrorCodeInternalError, - /*message=*/"error"}); - - ASSERT_THAT_ERROR(Write(request), llvm::Succeeded()); - EXPECT_CALL(message_handler, Received(response)); - EXPECT_THAT_ERROR(Run(), Succeeded()); + EXPECT_THAT_EXPECTED((Call( + "tools/call", CallToolParams{ + /*name=*/"error", + /*arguments=*/ + json::Object{ + {"arguments", "foo"}, + {"debugger_id", 0}, + }, + })), + HasValue(make_response(lldb_protocol::mcp::Error{ + eErrorCodeInternalError, "error"}))); } TEST_F(ProtocolServerMCPTest, ToolsCallFail) { server_up->AddTool(std::make_unique("fail", "fail tool")); - Request request = make_request( - "tools/call", CallToolParams{/*name=*/"fail", /*arguments=*/json::Object{ - {"arguments", "foo"}, - {"debugger_id", 0}, - }}); - Response response = - make_response(CallToolResult{{{/*text=*/"failed"}}, /*isError=*/true}); - - ASSERT_THAT_ERROR(Write(request), llvm::Succeeded()); - EXPECT_CALL(message_handler, Received(response)); - EXPECT_THAT_ERROR(Run(), Succeeded()); + EXPECT_THAT_EXPECTED((Call( + "tools/call", CallToolParams{ + /*name=*/"fail", + /*arguments=*/ + json::Object{ + {"arguments", "foo"}, + {"debugger_id", 0}, + }, + })), + HasValue(make_response(CallToolResult{ + {{/*text=*/"failed"}}, + /*isError=*/true, + }))); } TEST_F(ProtocolServerMCPTest, NotificationInitialized) { - bool handler_called = false; - std::condition_variable cv; - - server_up->AddNotificationHandler( - "notifications/initialized", - [&](const Notification ¬ification) { handler_called = true; }); - llvm::StringLiteral request = - R"json({"method":"notifications/initialized","jsonrpc":"2.0"})json"; - - ASSERT_THAT_ERROR(Write(request), llvm::Succeeded()); - EXPECT_THAT_ERROR(Run(), Succeeded()); - EXPECT_TRUE(handler_called); + EXPECT_THAT_ERROR(to_server->Send(lldb_protocol::mcp::Notification{ + "notifications/initialized", + std::nullopt, + }), + Succeeded()); + Run(); + EXPECT_THAT(logged_messages, + testing::Contains("MCP initialization complete")); } diff --git a/lldb/unittests/TestingSupport/Host/JSONTransportTestUtilities.h b/lldb/unittests/TestingSupport/Host/JSONTransportTestUtilities.h index 5a9eb8e59f2b6..4dbcd614e400b 100644 --- a/lldb/unittests/TestingSupport/Host/JSONTransportTestUtilities.h +++ b/lldb/unittests/TestingSupport/Host/JSONTransportTestUtilities.h @@ -6,19 +6,105 @@ // //===----------------------------------------------------------------------===// -#ifndef LLDB_UNITTESTS_TESTINGSUPPORT_HOST_NATIVEPROCESSTESTUTILS_H -#define LLDB_UNITTESTS_TESTINGSUPPORT_HOST_NATIVEPROCESSTESTUTILS_H +#ifndef LLDB_UNITTESTS_TESTINGSUPPORT_HOST_JSONTRANSPORTTESTUTILITIES_H +#define LLDB_UNITTESTS_TESTINGSUPPORT_HOST_JSONTRANSPORTTESTUTILITIES_H +#include "lldb/Host/FileSystem.h" #include "lldb/Host/JSONTransport.h" +#include "lldb/Host/MainLoop.h" +#include "lldb/Utility/FileSpec.h" +#include "llvm/Support/raw_ostream.h" +#include "llvm/Testing/Support/Error.h" #include "gmock/gmock.h" +#include "gtest/gtest.h" +#include +#include +#include -template +template +class TestTransport final + : public lldb_private::JSONTransport { +public: + using MessageHandler = + typename lldb_private::JSONTransport::MessageHandler; + + static std::pair>, + std::unique_ptr>> + createPair() { + std::unique_ptr> transports[2] = { + std::make_unique>(), + std::make_unique>()}; + return std::make_pair(std::move(transports[0]), std::move(transports[1])); + } + + explicit TestTransport() { + llvm::Expected dummy_file = + lldb_private::FileSystem::Instance().Open( + lldb_private::FileSpec(lldb_private::FileSystem::DEV_NULL), + lldb_private::File::eOpenOptionReadWrite); + EXPECT_THAT_EXPECTED(dummy_file, llvm::Succeeded()); + m_dummy_file = std::move(*dummy_file); + } + + llvm::Error Send(const Evt &evt) override { + EXPECT_TRUE(m_loop && m_handler) + << "Send called before RegisterMessageHandler"; + m_loop->AddPendingCallback([this, evt](lldb_private::MainLoopBase &) { + m_handler->Received(evt); + }); + return llvm::Error::success(); + } + + llvm::Error Send(const Req &req) override { + EXPECT_TRUE(m_loop && m_handler) + << "Send called before RegisterMessageHandler"; + m_loop->AddPendingCallback([this, req](lldb_private::MainLoopBase &) { + m_handler->Received(req); + }); + return llvm::Error::success(); + } + + llvm::Error Send(const Resp &resp) override { + EXPECT_TRUE(m_loop && m_handler) + << "Send called before RegisterMessageHandler"; + m_loop->AddPendingCallback([this, resp](lldb_private::MainLoopBase &) { + m_handler->Received(resp); + }); + return llvm::Error::success(); + } + + llvm::Expected + RegisterMessageHandler(lldb_private::MainLoop &loop, + MessageHandler &handler) override { + if (!m_loop) + m_loop = &loop; + if (!m_handler) + m_handler = &handler; + lldb_private::Status status; + auto handle = loop.RegisterReadObject( + m_dummy_file, [](lldb_private::MainLoopBase &) {}, status); + if (status.Fail()) + return status.takeError(); + return handle; + } + +protected: + void Log(llvm::StringRef message) override {}; + +private: + lldb_private::MainLoop *m_loop = nullptr; + MessageHandler *m_handler = nullptr; + // Dummy file for registering with the MainLoop. + lldb::FileSP m_dummy_file = nullptr; +}; + +template class MockMessageHandler final - : public lldb_private::Transport::MessageHandler { + : public lldb_private::JSONTransport::MessageHandler { public: - MOCK_METHOD(void, Received, (const Evt &), (override)); MOCK_METHOD(void, Received, (const Req &), (override)); MOCK_METHOD(void, Received, (const Resp &), (override)); + MOCK_METHOD(void, Received, (const Evt &), (override)); MOCK_METHOD(void, OnError, (llvm::Error), (override)); MOCK_METHOD(void, OnClosed, (), (override)); }; From d861156354ac5f4fd2827d3f0b622ab26bb121a1 Mon Sep 17 00:00:00 2001 From: John Harrison Date: Wed, 17 Sep 2025 15:47:14 -0700 Subject: [PATCH 2/8] Adding some comments and adjusting the naming of the helpers. --- lldb/include/lldb/Host/JSONTransport.h | 89 +++++++++++-------- lldb/include/lldb/Protocol/MCP/Transport.h | 23 +++-- lldb/unittests/Host/JSONTransportTest.cpp | 28 +++--- .../Protocol/ProtocolMCPServerTest.cpp | 2 +- 4 files changed, 79 insertions(+), 63 deletions(-) diff --git a/lldb/include/lldb/Host/JSONTransport.h b/lldb/include/lldb/Host/JSONTransport.h index ede8892e6fff4..3b6ab1c3050fb 100644 --- a/lldb/include/lldb/Host/JSONTransport.h +++ b/lldb/include/lldb/Host/JSONTransport.h @@ -76,37 +76,55 @@ using VoidT = std::monostate; template using Callback = llvm::unique_function; +/// A handler for the response to an outgoing request. template using Reply = typename std::conditional< - std::is_same_v == true, llvm::unique_function, + std::is_same_v || std::is_void_v, + llvm::unique_function, llvm::unique_function)>>::type; +/// A function to send an outgoing request and receive a response. template using OutgoingRequest = typename std::conditional< - std::is_same_v == true, + std::is_same_v || std::is_void_v, llvm::unique_function)>, llvm::unique_function)>>::type; +/// A function to send an outgoing event. template using OutgoingEvent = typename std::conditional< - std::is_same_v == true, llvm::unique_function, + std::is_same_v || std::is_void_v, + llvm::unique_function, llvm::unique_function>::type; +/// Creates a request with the given id, method, and optional params. template -Req make_request(Id id, llvm::StringRef method, - std::optional params = std::nullopt); +Req MakeRequest(Id, llvm::StringRef, std::optional); + +/// Creates an error response for a given request. template -Resp make_response(const Req &req, llvm::Error error); +Resp MakeResponse(const Req &, llvm::Error); + +/// Creates a success response for a given request. template -Resp make_response(const Req &req, llvm::json::Value result); +Resp MakeResponse(const Req &, llvm::json::Value); + +/// Creates an event. template -Evt make_event(llvm::StringRef method, - std::optional params = std::nullopt); +Evt MakeEvent(llvm::StringRef, std::optional); + +/// Extracts the result value from a response. template -llvm::Expected get_result(const Resp &resp); -template Id get_id(const T &); -template llvm::StringRef get_method(const T &); -template llvm::json::Value get_params(const T &); +llvm::Expected GetResult(const Resp &); + +/// Extracts the id from a response. +template Id GetId(const Resp &); + +/// Extracts the method from a request or event. +template llvm::StringRef GetMethod(const T &); + +/// Extracts the parameters from a request or event. +template llvm::json::Value GetParams(const T &); /// A transport is responsible for maintaining the connection to a client /// application, and reading/writing structured messages to it. @@ -193,7 +211,7 @@ class JSONTransport { ~ReplyOnce() { if (transport && handler && !replied) { assert(false && "must reply to all calls!"); - (*this)(make_response( + (*this)(MakeResponse( req, llvm::createStringError("failed to reply"))); } } @@ -257,9 +275,9 @@ class JSONTransport { return std::move(result); } - /// Bind a handler for a request. - /// e.g. `bind("peek", &ThisModule::peek, this, std::placeholders::_1);`. - /// Handler should be e.g. `Expected peek(const PeekParams&);` + /// Bind a handler for an incoming request. + /// e.g. `bind("peek", &ThisModule::peek, this);`. + /// Handler should be e.g. `Expected peek(const PeekParams&);` /// PeekParams must be JSON parsable and PeekResult must be serializable. template void bind(llvm::StringLiteral method, Fn &&fn, Args &&...args) { @@ -273,8 +291,8 @@ class JSONTransport { llvm::Expected result = std::invoke( std::forward(fn), std::forward(args)...); if (!result) - return reply(make_response(req, result.takeError())); - reply(make_response(req, toJSON(*result))); + return reply(MakeResponse(req, result.takeError())); + reply(MakeResponse(req, toJSON(*result))); }; } else { m_request_handlers[method] = @@ -282,21 +300,21 @@ class JSONTransport { args...](const Req &req, llvm::unique_function reply) mutable { llvm::Expected params = - parse(get_params(req), method); + parse(GetParams(req), method); if (!params) - return reply(make_response(req, params.takeError())); + return reply(MakeResponse(req, params.takeError())); llvm::Expected result = std::invoke( std::forward(fn), std::forward(args)..., *params); if (!result) - return reply(make_response(req, result.takeError())); + return reply(MakeResponse(req, result.takeError())); - reply(make_response(req, toJSON(*result))); + reply(MakeResponse(req, toJSON(*result))); }; } } - /// Bind a handler for a event. + /// Bind a handler for an incoming event. /// e.g. `bind("peek", &ThisModule::peek, this);` /// Handler should be e.g. `void peek(const PeekParams&);` /// PeekParams must be JSON parsable. @@ -312,7 +330,7 @@ class JSONTransport { m_event_handlers[method] = [this, method, fn, args...](const Evt &evt) mutable { llvm::Expected params = - parse(get_params(evt), method); + parse(GetParams(evt), method); if (!params) return OnError(params.takeError()); std::invoke(std::forward(fn), std::forward(args)..., @@ -330,10 +348,10 @@ class JSONTransport { return [this, method](Reply fn) { std::scoped_lock guard(m_mutex); Id id = ++m_seq; - Req req = make_request(id, method, std::nullopt); + Req req = MakeRequest(id, method, std::nullopt); m_pending_responses[id] = [fn = std::move(fn), method](const Resp &resp) mutable { - llvm::Expected result = get_result(resp); + llvm::Expected result = GetResult(resp); if (!result) return fn(result.takeError()); fn(parse(*result, method)); @@ -345,11 +363,10 @@ class JSONTransport { return [this, method](const Params ¶ms, Reply fn) { std::scoped_lock guard(m_mutex); Id id = ++m_seq; - Req req = - make_request(id, method, llvm::json::Value(params)); + Req req = MakeRequest(id, method, llvm::json::Value(params)); m_pending_responses[id] = [fn = std::move(fn), method](const Resp &resp) mutable { - llvm::Expected result = get_result(resp); + llvm::Expected result = GetResult(resp); if (llvm::Error err = result.takeError()) return fn(std::move(err)); fn(parse(*result, method)); @@ -368,13 +385,13 @@ class JSONTransport { if constexpr (std::is_void_v || std::is_same_v) { return [this, method]() { if (llvm::Error error = - m_transport.Send(make_event(method, std::nullopt))) + m_transport.Send(MakeEvent(method, std::nullopt))) OnError(std::move(error)); }; } else { return [this, method](const Params ¶ms) { if (llvm::Error error = - m_transport.Send(make_event(method, toJSON(params)))) + m_transport.Send(MakeEvent(method, toJSON(params)))) OnError(std::move(error)); }; } @@ -382,7 +399,7 @@ class JSONTransport { void Received(const Evt &evt) override { std::scoped_lock guard(m_mutex); - auto it = m_event_handlers.find(get_method(evt)); + auto it = m_event_handlers.find(GetMethod(evt)); if (it == m_event_handlers.end()) { OnError(llvm::createStringError( llvm::formatv("no handled for event {0}", toJSON(evt)))); @@ -395,9 +412,9 @@ class JSONTransport { ReplyOnce reply(req, &m_transport, this); std::scoped_lock guard(m_mutex); - auto it = m_request_handlers.find(get_method(req)); + auto it = m_request_handlers.find(GetMethod(req)); if (it == m_request_handlers.end()) { - reply(make_response( + reply(MakeResponse( req, llvm::createStringError("method not found"))); return; } @@ -407,7 +424,7 @@ class JSONTransport { void Received(const Resp &resp) override { std::scoped_lock guard(m_mutex); - auto it = m_pending_responses.find(get_id(resp)); + auto it = m_pending_responses.find(GetId(resp)); if (it == m_pending_responses.end()) { OnError(llvm::createStringError( llvm::formatv("no pending request for {0}", toJSON(resp)))); diff --git a/lldb/include/lldb/Protocol/MCP/Transport.h b/lldb/include/lldb/Protocol/MCP/Transport.h index 55b2e8fa0a7f2..2b275b0354fbe 100644 --- a/lldb/include/lldb/Protocol/MCP/Transport.h +++ b/lldb/include/lldb/Protocol/MCP/Transport.h @@ -22,13 +22,13 @@ namespace lldb_private { /// @{ template <> inline lldb_protocol::mcp::Request -make_request(int64_t id, llvm::StringRef method, - std::optional params) { +MakeRequest(int64_t id, llvm::StringRef method, + std::optional params) { return lldb_protocol::mcp::Request{id, method.str(), params}; } template <> inline lldb_protocol::mcp::Response -make_response(const lldb_protocol::mcp::Request &req, llvm::Error error) { +MakeResponse(const lldb_protocol::mcp::Request &req, llvm::Error error) { lldb_protocol::mcp::Error protocol_error; llvm::handleAllErrors( std::move(error), @@ -44,42 +44,41 @@ make_response(const lldb_protocol::mcp::Request &req, llvm::Error error) { } template <> inline lldb_protocol::mcp::Response -make_response(const lldb_protocol::mcp::Request &req, - llvm::json::Value result) { +MakeResponse(const lldb_protocol::mcp::Request &req, llvm::json::Value result) { return lldb_protocol::mcp::Response{req.id, std::move(result)}; } template <> inline lldb_protocol::mcp::Notification -make_event(llvm::StringRef method, std::optional params) { +MakeEvent(llvm::StringRef method, std::optional params) { return lldb_protocol::mcp::Notification{method.str(), params}; } template <> inline llvm::Expected -get_result(const lldb_protocol::mcp::Response &resp) { +GetResult(const lldb_protocol::mcp::Response &resp) { if (const lldb_protocol::mcp::Error *error = std::get_if(&resp.result)) return llvm::make_error(error->message, error->code); return std::get(resp.result); } -template <> inline int64_t get_id(const lldb_protocol::mcp::Response &resp) { +template <> inline int64_t GetId(const lldb_protocol::mcp::Response &resp) { return std::get(resp.id); } template <> -inline llvm::StringRef get_method(const lldb_protocol::mcp::Request &req) { +inline llvm::StringRef GetMethod(const lldb_protocol::mcp::Request &req) { return req.method; } template <> -inline llvm::StringRef get_method(const lldb_protocol::mcp::Notification &evt) { +inline llvm::StringRef GetMethod(const lldb_protocol::mcp::Notification &evt) { return evt.method; } template <> -inline llvm::json::Value get_params(const lldb_protocol::mcp::Request &req) { +inline llvm::json::Value GetParams(const lldb_protocol::mcp::Request &req) { return req.params; } template <> inline llvm::json::Value -get_params(const lldb_protocol::mcp::Notification &evt) { +GetParams(const lldb_protocol::mcp::Notification &evt) { return evt.params; } /// @} diff --git a/lldb/unittests/Host/JSONTransportTest.cpp b/lldb/unittests/Host/JSONTransportTest.cpp index 0228e13b61909..6477edd5eb666 100644 --- a/lldb/unittests/Host/JSONTransportTest.cpp +++ b/lldb/unittests/Host/JSONTransportTest.cpp @@ -322,38 +322,38 @@ class TestTransportBinder : public testing::Test { } // namespace namespace lldb_private { + using namespace test_protocol; + template <> -inline test_protocol::Req make_request(int id, llvm::StringRef method, - std::optional params) { +inline test_protocol::Req MakeRequest(int id, llvm::StringRef method, + std::optional params) { return test_protocol::Req{id, method.str(), params}; } -template <> inline Resp make_response(const Req &req, llvm::Error error) { +template <> inline Resp MakeResponse(const Req &req, llvm::Error error) { llvm::consumeError(std::move(error)); return Resp{req.id, std::nullopt}; } -template <> inline Resp make_response(const Req &req, json::Value result) { +template <> inline Resp MakeResponse(const Req &req, json::Value result) { return Resp{req.id, std::move(result)}; } template <> -inline Evt make_event(llvm::StringRef method, - std::optional params) { +inline Evt MakeEvent(llvm::StringRef method, + std::optional params) { return Evt{method.str(), params}; } - -template <> inline llvm::Expected get_result(const Resp &resp) { +template <> inline llvm::Expected GetResult(const Resp &resp) { return resp.result; } - -template <> inline int get_id(const Resp &resp) { return resp.id; } -template <> inline llvm::StringRef get_method(const Req &req) { +template <> inline int GetId(const Resp &resp) { return resp.id; } +template <> inline llvm::StringRef GetMethod(const Req &req) { return req.name; } -template <> inline llvm::StringRef get_method(const Evt &evt) { +template <> inline llvm::StringRef GetMethod(const Evt &evt) { return evt.name; } -template <> inline json::Value get_params(const Req &req) { return req.params; } -template <> inline json::Value get_params(const Evt &evt) { return evt.params; } +template <> inline json::Value GetParams(const Req &req) { return req.params; } +template <> inline json::Value GetParams(const Evt &evt) { return evt.params; } } // namespace lldb_private diff --git a/lldb/unittests/Protocol/ProtocolMCPServerTest.cpp b/lldb/unittests/Protocol/ProtocolMCPServerTest.cpp index 0958af87a9402..5523b21552400 100644 --- a/lldb/unittests/Protocol/ProtocolMCPServerTest.cpp +++ b/lldb/unittests/Protocol/ProtocolMCPServerTest.cpp @@ -173,7 +173,7 @@ class ProtocolServerMCPTest : public testing::Test { template Expected Call(StringRef method, const Params ¶ms) { std::promise promised_result; - Request req = make_request( + Request req = MakeRequest( /*id=*/1, method, toJSON(params)); EXPECT_THAT_ERROR(to_server->Send(req), Succeeded()); EXPECT_CALL(client, Received(testing::An())) From a2a1c034d3d0768a02c01b74889f13132bdc8e54 Mon Sep 17 00:00:00 2001 From: John Harrison Date: Thu, 18 Sep 2025 18:05:34 -0700 Subject: [PATCH 3/8] Trying to improve the organization of the JSONTransport file. Moved the Binder below the rest of the transport classes and moved some of the implementations below the class definition to improve readability of the Binder interface without seeing a lot of template details. --- lldb/include/lldb/Host/JSONTransport.h | 799 ++++++++++-------- lldb/include/lldb/Protocol/MCP/MCPError.h | 5 +- lldb/include/lldb/Protocol/MCP/Protocol.h | 16 +- lldb/include/lldb/Protocol/MCP/Transport.h | 27 +- lldb/source/Host/common/JSONTransport.cpp | 18 +- lldb/source/Protocol/MCP/MCPError.cpp | 9 +- lldb/source/Protocol/MCP/Server.cpp | 17 +- lldb/tools/lldb-dap/DAP.h | 4 +- lldb/tools/lldb-dap/Transport.h | 7 +- lldb/unittests/Host/JSONTransportTest.cpp | 186 +++- .../Protocol/ProtocolMCPServerTest.cpp | 9 +- .../Host/JSONTransportTestUtilities.h | 8 +- 12 files changed, 686 insertions(+), 419 deletions(-) diff --git a/lldb/include/lldb/Host/JSONTransport.h b/lldb/include/lldb/Host/JSONTransport.h index 3b6ab1c3050fb..ca207ceddf6f0 100644 --- a/lldb/include/lldb/Host/JSONTransport.h +++ b/lldb/include/lldb/Host/JSONTransport.h @@ -36,8 +36,10 @@ #include #include -namespace lldb_private { +namespace lldb_private::transport { +/// An error to indicate that the transport reached EOF but there were still +/// unhandled contents in the read buffer. class TransportUnhandledContentsError : public llvm::ErrorInfo { public: @@ -56,6 +58,8 @@ class TransportUnhandledContentsError std::string m_unhandled_contents; }; +/// An error to indicate that the parameters of a Req, Resp or Evt could not be +/// deserialized. class InvalidParams : public llvm::ErrorInfo { public: static char ID; @@ -71,65 +75,26 @@ class InvalidParams : public llvm::ErrorInfo { std::string m_context; }; -// Value for tracking functions that have a void param or result. -using VoidT = std::monostate; - -template using Callback = llvm::unique_function; - -/// A handler for the response to an outgoing request. -template -using Reply = typename std::conditional< - std::is_same_v || std::is_void_v, - llvm::unique_function, - llvm::unique_function)>>::type; - -/// A function to send an outgoing request and receive a response. -template -using OutgoingRequest = typename std::conditional< - std::is_same_v || std::is_void_v, - llvm::unique_function)>, - llvm::unique_function)>>::type; - -/// A function to send an outgoing event. -template -using OutgoingEvent = typename std::conditional< - std::is_same_v || std::is_void_v, - llvm::unique_function, - llvm::unique_function>::type; - -/// Creates a request with the given id, method, and optional params. -template -Req MakeRequest(Id, llvm::StringRef, std::optional); - -/// Creates an error response for a given request. -template -Resp MakeResponse(const Req &, llvm::Error); - -/// Creates a success response for a given request. -template -Resp MakeResponse(const Req &, llvm::json::Value); +/// An error to indicate that no handler was registered for a given method. +class MethodNotFound : public llvm::ErrorInfo { +public: + static char ID; -/// Creates an event. -template -Evt MakeEvent(llvm::StringRef, std::optional); + static constexpr int kErrorCode = -32601; -/// Extracts the result value from a response. -template -llvm::Expected GetResult(const Resp &); + explicit MethodNotFound(std::string method) : m_method(std::move(method)) {} -/// Extracts the id from a response. -template Id GetId(const Resp &); - -/// Extracts the method from a request or event. -template llvm::StringRef GetMethod(const T &); + void log(llvm::raw_ostream &OS) const override; + std::error_code convertToErrorCode() const override; -/// Extracts the parameters from a request or event. -template llvm::json::Value GetParams(const T &); +private: + std::string m_method; +}; /// A transport is responsible for maintaining the connection to a client /// application, and reading/writing structured messages to it. /// -/// Transports have limited thread safety requirements: +/// JSONTransport have limited thread safety requirements: /// - Messages will not be sent concurrently. /// - Messages MAY be sent while Run() is reading, or its callback is active. template @@ -137,6 +102,9 @@ class JSONTransport { public: using Message = std::variant; + class Binder; + using BinderUP = std::unique_ptr; + virtual ~JSONTransport() = default; /// Sends an event, a message that does not require a response. @@ -146,7 +114,8 @@ class JSONTransport { /// Sends a response to a specific request. virtual llvm::Error Send(const Resp &) = 0; - /// Implemented to handle incoming messages. (See Run() below). + /// Implemented to handle incoming messages. (See `RegisterMessageHandler()` + /// below). class MessageHandler { public: virtual ~MessageHandler() = default; @@ -180,284 +149,6 @@ class JSONTransport { Log(llvm::formatv(Fmt, std::forward(Vals)...).str()); } virtual void Log(llvm::StringRef message) = 0; - - /// Function object to reply to a call. - /// Each instance must be called exactly once, otherwise: - /// - the bug is logged, and (in debug mode) an assert will fire - /// - if there was no reply, an error reply is sent - /// - if there were multiple replies, only the first is sent - class ReplyOnce { - std::atomic replied = {false}; - const Req req; - JSONTransport *transport; // Null when moved-from. - JSONTransport::MessageHandler *handler; // Null when moved-from. - - public: - ReplyOnce(const Req req, JSONTransport *transport, - JSONTransport::MessageHandler *handler) - : req(req), transport(transport), handler(handler) { - assert(handler); - } - ReplyOnce(ReplyOnce &&other) - : replied(other.replied.load()), req(other.req), - transport(other.transport), handler(other.handler) { - other.transport = nullptr; - other.handler = nullptr; - } - ReplyOnce &operator=(ReplyOnce &&) = delete; - ReplyOnce(const ReplyOnce &) = delete; - ReplyOnce &operator=(const ReplyOnce &) = delete; - - ~ReplyOnce() { - if (transport && handler && !replied) { - assert(false && "must reply to all calls!"); - (*this)(MakeResponse( - req, llvm::createStringError("failed to reply"))); - } - } - - void operator()(const Resp &resp) { - assert(transport && handler && "moved-from!"); - if (replied.exchange(true)) { - assert(false && "must reply to each call only once!"); - return; - } - - if (llvm::Error error = transport->Send(resp)) - handler->OnError(std::move(error)); - } - }; - -public: - class Binder; - using BinderUP = std::unique_ptr; - - /// Binder collects a table of functions that handle calls. - /// - /// The wrapper takes care of parsing/serializing responses. - class Binder : public JSONTransport::MessageHandler { - public: - explicit Binder(JSONTransport &transport) - : m_transport(transport), m_seq(0) {} - - Binder(const Binder &) = delete; - Binder &operator=(const Binder &) = delete; - - /// Bind a handler on transport disconnect. - template - void disconnected(Fn &&fn, Args &&...args) { - m_disconnect_handler = [&, args...]() mutable { - std::invoke(std::forward(fn), std::forward(args)...); - }; - } - - /// Bind a handler on error when communicating with the transport. - template - void error(Fn &&fn, Args &&...args) { - m_error_handler = [&, args...](llvm::Error error) mutable { - std::invoke(std::forward(fn), std::forward(args)..., - std::move(error)); - }; - } - - template - static llvm::Expected parse(const llvm::json::Value &raw, - llvm::StringRef method) { - T result; - llvm::json::Path::Root root; - if (!fromJSON(raw, result, root)) { - // Dump the relevant parts of the broken message. - std::string context; - llvm::raw_string_ostream OS(context); - root.printErrorContext(raw, OS); - return llvm::make_error(method.str(), context); - } - return std::move(result); - } - - /// Bind a handler for an incoming request. - /// e.g. `bind("peek", &ThisModule::peek, this);`. - /// Handler should be e.g. `Expected peek(const PeekParams&);` - /// PeekParams must be JSON parsable and PeekResult must be serializable. - template - void bind(llvm::StringLiteral method, Fn &&fn, Args &&...args) { - assert(m_request_handlers.find(method) == m_request_handlers.end() && - "request already bound"); - if constexpr (std::is_void_v || std::is_same_v) { - m_request_handlers[method] = - [fn, - args...](const Req &req, - llvm::unique_function reply) mutable { - llvm::Expected result = std::invoke( - std::forward(fn), std::forward(args)...); - if (!result) - return reply(MakeResponse(req, result.takeError())); - reply(MakeResponse(req, toJSON(*result))); - }; - } else { - m_request_handlers[method] = - [method, fn, - args...](const Req &req, - llvm::unique_function reply) mutable { - llvm::Expected params = - parse(GetParams(req), method); - if (!params) - return reply(MakeResponse(req, params.takeError())); - - llvm::Expected result = std::invoke( - std::forward(fn), std::forward(args)..., *params); - if (!result) - return reply(MakeResponse(req, result.takeError())); - - reply(MakeResponse(req, toJSON(*result))); - }; - } - } - - /// Bind a handler for an incoming event. - /// e.g. `bind("peek", &ThisModule::peek, this);` - /// Handler should be e.g. `void peek(const PeekParams&);` - /// PeekParams must be JSON parsable. - template - void bind(llvm::StringLiteral method, Fn &&fn, Args &&...args) { - assert(m_event_handlers.find(method) == m_event_handlers.end() && - "event already bound"); - if constexpr (std::is_void_v || std::is_same_v) { - m_event_handlers[method] = [fn, args...](const Evt &) mutable { - std::invoke(std::forward(fn), std::forward(args)...); - }; - } else { - m_event_handlers[method] = [this, method, fn, - args...](const Evt &evt) mutable { - llvm::Expected params = - parse(GetParams(evt), method); - if (!params) - return OnError(params.takeError()); - std::invoke(std::forward(fn), std::forward(args)..., - *params); - }; - } - } - - /// Bind a function object to be used for outgoing requests. - /// e.g. `OutgoingRequest Edit = bind("edit");` - /// Params must be JSON-serializable, Result must be parsable. - template - OutgoingRequest bind(llvm::StringLiteral method) { - if constexpr (std::is_void_v || std::is_same_v) { - return [this, method](Reply fn) { - std::scoped_lock guard(m_mutex); - Id id = ++m_seq; - Req req = MakeRequest(id, method, std::nullopt); - m_pending_responses[id] = [fn = std::move(fn), - method](const Resp &resp) mutable { - llvm::Expected result = GetResult(resp); - if (!result) - return fn(result.takeError()); - fn(parse(*result, method)); - }; - if (llvm::Error error = m_transport.Send(req)) - OnError(std::move(error)); - }; - } else { - return [this, method](const Params ¶ms, Reply fn) { - std::scoped_lock guard(m_mutex); - Id id = ++m_seq; - Req req = MakeRequest(id, method, llvm::json::Value(params)); - m_pending_responses[id] = [fn = std::move(fn), - method](const Resp &resp) mutable { - llvm::Expected result = GetResult(resp); - if (llvm::Error err = result.takeError()) - return fn(std::move(err)); - fn(parse(*result, method)); - }; - if (llvm::Error error = m_transport.Send(req)) - OnError(std::move(error)); - }; - } - } - - /// Bind a function object to be used for outgoing events. - /// e.g. `OutgoingEvent Log = bind("log");` - /// LogParams must be JSON-serializable. - template - OutgoingEvent bind(llvm::StringLiteral method) { - if constexpr (std::is_void_v || std::is_same_v) { - return [this, method]() { - if (llvm::Error error = - m_transport.Send(MakeEvent(method, std::nullopt))) - OnError(std::move(error)); - }; - } else { - return [this, method](const Params ¶ms) { - if (llvm::Error error = - m_transport.Send(MakeEvent(method, toJSON(params)))) - OnError(std::move(error)); - }; - } - } - - void Received(const Evt &evt) override { - std::scoped_lock guard(m_mutex); - auto it = m_event_handlers.find(GetMethod(evt)); - if (it == m_event_handlers.end()) { - OnError(llvm::createStringError( - llvm::formatv("no handled for event {0}", toJSON(evt)))); - return; - } - it->second(evt); - } - - void Received(const Req &req) override { - ReplyOnce reply(req, &m_transport, this); - - std::scoped_lock guard(m_mutex); - auto it = m_request_handlers.find(GetMethod(req)); - if (it == m_request_handlers.end()) { - reply(MakeResponse( - req, llvm::createStringError("method not found"))); - return; - } - - it->second(req, std::move(reply)); - } - - void Received(const Resp &resp) override { - std::scoped_lock guard(m_mutex); - auto it = m_pending_responses.find(GetId(resp)); - if (it == m_pending_responses.end()) { - OnError(llvm::createStringError( - llvm::formatv("no pending request for {0}", toJSON(resp)))); - return; - } - - it->second(resp); - m_pending_responses.erase(it); - } - - void OnError(llvm::Error err) override { - std::scoped_lock guard(m_mutex); - if (m_error_handler) - m_error_handler(std::move(err)); - } - - void OnClosed() override { - std::scoped_lock guard(m_mutex); - if (m_disconnect_handler) - m_disconnect_handler(); - } - - private: - std::recursive_mutex m_mutex; - JSONTransport &m_transport; - Id m_seq; - std::map> m_pending_responses; - llvm::StringMap)>> - m_request_handlers; - llvm::StringMap> m_event_handlers; - Callback m_disconnect_handler; - Callback m_error_handler; - }; }; /// An IOTransport sends and receives messages using an IOObject. @@ -649,6 +340,450 @@ class JSONRPCTransport : public IOTransport { static constexpr llvm::StringLiteral kMessageSeparator = "\n"; }; -} // namespace lldb_private +/// A handler for the response to an outgoing request. +template +using Reply = + std::conditional_t, + llvm::unique_function, + llvm::unique_function)>>; + +namespace detail { +template struct request_t final { + using type = llvm::unique_function)>; +}; +template struct request_t final { + using type = llvm::unique_function)>; +}; +template struct event_t final { + using type = llvm::unique_function; +}; +template <> struct event_t final { + using type = llvm::unique_function; +}; +} // namespace detail + +template +using OutgoingRequest = typename detail::request_t::type; + +/// A function to send an outgoing event. +template using OutgoingEvent = typename detail::event_t

::type; + +/// Creates a request with the given id, method, and optional params. +template +Req MakeRequest(Id, llvm::StringRef, std::optional); + +/// Creates an error response for a given request. +template +Resp MakeResponse(const Req &, llvm::Error); + +/// Creates a success response for a given request. +template +Resp MakeResponse(const Req &, llvm::json::Value); + +/// Creates an event. +template +Evt MakeEvent(llvm::StringRef, std::optional); + +/// Extracts the result value from a response. +template +llvm::Expected GetResult(const Resp &); + +/// Extracts the id from a response. +template Id GetId(const Resp &); + +/// Extracts the method from a request or event. +template llvm::StringRef GetMethod(const T &); + +/// Extracts the parameters from a request or event. +template llvm::json::Value GetParams(const T &); + +/// Binder collects a table of functions that handle calls. +/// +/// The wrapper takes care of parsing/serializing responses. +/// +/// This allows a JSONTransport to handle incoming and outgoing requests and +/// events. +/// +/// A simple example could be to a method to a lambda like: +/// +/// \code{cpp} +/// Binder binder{transport}; +/// // Binds an incoming request handler. +/// binder.bind>("adder", [](const vector ¶ms) { +/// int sum = 0; +/// for (int v : params) +/// sum += v; +/// return sum; +/// }); +/// // Binds an outgoing request handler. +/// OutgoingRequest> call_add = +/// binder.bind>("add"); +/// call_add({1,2,3}, [](Expected result) { +/// cout << *result << "\n"; +/// }); +/// \endcode +template +class JSONTransport::Binder + : public JSONTransport::MessageHandler { + using Transport = JSONTransport; + using MessageHandler = typename Transport::MessageHandler; + +public: + explicit Binder(Transport &transport) : m_transport(transport), m_seq(0) {} + + Binder(const Binder &) = delete; + Binder &operator=(const Binder &) = delete; + + /// Bind a handler on transport disconnect. + template + void OnDisconnect(Fn &&fn, Args &&...args); + + /// Bind a handler on error when communicating with the transport. + template + void OnError(Fn &&fn, Args &&...args); + + /// Bind a handler for an incoming request. + /// e.g. `bind("peek", &ThisModule::peek, this);`. + /// Handler should be e.g. `Expected peek(const PeekParams&);` + /// PeekParams must be JSON parsable and PeekResult must be serializable. + template + void Bind(llvm::StringLiteral method, Fn &&fn, Args &&...args); + + /// Bind a handler for an incoming event. + /// e.g. `bind("peek", &ThisModule::peek, this);` + /// Handler should be e.g. `void peek(const PeekParams&);` + /// PeekParams must be JSON parsable. + template + void Bind(llvm::StringLiteral method, Fn &&fn, Args &&...args); + + /// Bind a function object to be used for outgoing requests. + /// e.g. `OutgoingRequest Edit = bind("edit");` + /// Params must be JSON-serializable, Result must be parsable. + template + OutgoingRequest Bind(llvm::StringLiteral method); + + /// Bind a function object to be used for outgoing events. + /// e.g. `OutgoingEvent Log = bind("log");` + /// LogParams must be JSON-serializable. + template + OutgoingEvent Bind(llvm::StringLiteral method); + + void Received(const Evt &evt) override { + std::scoped_lock guard(m_mutex); + auto it = m_event_handlers.find(GetMethod(evt)); + if (it == m_event_handlers.end()) { + OnError(llvm::createStringError( + llvm::formatv("no handled for event {0}", toJSON(evt)))); + return; + } + it->second(evt); + } + + void Received(const Req &req) override { + ReplyOnce reply(req, &m_transport, this); + + std::scoped_lock guard(m_mutex); + auto it = m_request_handlers.find(GetMethod(req)); + if (it == m_request_handlers.end()) { + reply(MakeResponse( + req, llvm::createStringError("method not found"))); + return; + } + + it->second(req, std::move(reply)); + } + + void Received(const Resp &resp) override { + std::scoped_lock guard(m_mutex); + + Id id = GetId(resp); + auto it = m_pending_responses.find(id); + if (it == m_pending_responses.end()) { + OnError(llvm::createStringError( + llvm::formatv("no pending request for {0}", toJSON(resp)))); + return; + } + + it->second(resp); + m_pending_responses.erase(it); + } + + void OnError(llvm::Error err) override { + std::scoped_lock guard(m_mutex); + if (m_error_handler) + m_error_handler(std::move(err)); + } + + void OnClosed() override { + std::scoped_lock guard(m_mutex); + if (m_disconnect_handler) + m_disconnect_handler(); + } + +private: + template + llvm::Expected static Parse(const llvm::json::Value &raw, + llvm::StringRef method); + + template using Callback = llvm::unique_function; + + std::recursive_mutex m_mutex; + Transport &m_transport; + Id m_seq; + std::map> m_pending_responses; + llvm::StringMap)>> + m_request_handlers; + llvm::StringMap> m_event_handlers; + Callback m_disconnect_handler; + Callback m_error_handler; + + /// Function object to reply to a call. + /// Each instance must be called exactly once, otherwise: + /// - the bug is logged, and (in debug mode) an assert will fire + /// - if there was no reply, an error reply is sent + /// - if there were multiple replies, only the first is sent + class ReplyOnce { + std::atomic replied = {false}; + const Req req; + Transport *transport; // Null when moved-from. + MessageHandler *handler; // Null when moved-from. + + public: + ReplyOnce(const Req req, Transport *transport, MessageHandler *handler) + : req(req), transport(transport), handler(handler) { + assert(handler); + } + ReplyOnce(ReplyOnce &&other) + : replied(other.replied.load()), req(other.req), + transport(other.transport), handler(other.handler) { + other.transport = nullptr; + other.handler = nullptr; + } + ReplyOnce &operator=(ReplyOnce &&) = delete; + ReplyOnce(const ReplyOnce &) = delete; + ReplyOnce &operator=(const ReplyOnce &) = delete; + + ~ReplyOnce() { + if (transport && handler && !replied) { + assert(false && "must reply to all calls!"); + (*this)(MakeResponse( + req, llvm::createStringError("failed to reply"))); + } + } + + void operator()(const Resp &resp) { + assert(transport && handler && "moved-from!"); + if (replied.exchange(true)) { + assert(false && "must reply to each call only once!"); + return; + } + + if (llvm::Error error = transport->Send(resp)) + handler->OnError(std::move(error)); + } + }; +}; + +template +template +void JSONTransport::Binder::OnDisconnect(Fn &&fn, + Args &&...args) { + m_disconnect_handler = [fn, args...]() mutable { + std::invoke(std::forward(fn), std::forward(args)...); + }; +} + +template +template +void JSONTransport::Binder::OnError(Fn &&fn, + Args &&...args) { + m_error_handler = [fn, args...](llvm::Error error) mutable { + std::invoke(std::forward(fn), std::forward(args)..., + std::move(error)); + }; +} + +template +template +void JSONTransport::Binder::Bind(llvm::StringLiteral method, + Fn &&fn, Args &&...args) { + assert(m_request_handlers.find(method) == m_request_handlers.end() && + "request already bound"); + if constexpr (std::is_void_v && std::is_void_v) { + m_request_handlers[method] = + [fn, args...](const Req &req, + llvm::unique_function reply) mutable { + llvm::Error result = + std::invoke(std::forward(fn), std::forward(args)...); + reply(MakeResponse(req, std::move(result))); + }; + } else if constexpr (std::is_void_v) { + m_request_handlers[method] = + [fn, args...](const Req &req, + llvm::unique_function reply) mutable { + llvm::Expected result = + std::invoke(std::forward(fn), std::forward(args)...); + if (!result) + return reply(MakeResponse(req, result.takeError())); + reply(MakeResponse(req, toJSON(*result))); + }; + } else if constexpr (std::is_void_v) { + m_request_handlers[method] = + [method, fn, + args...](const Req &req, + llvm::unique_function reply) mutable { + llvm::Expected params = + Parse(GetParams(req), method); + if (!params) + return reply(MakeResponse(req, params.takeError())); + + llvm::Error result = std::invoke( + std::forward(fn), std::forward(args)..., *params); + reply(MakeResponse(req, std::move(result))); + }; + } else { + m_request_handlers[method] = + [method, fn, + args...](const Req &req, + llvm::unique_function reply) mutable { + llvm::Expected params = + Parse(GetParams(req), method); + if (!params) + return reply(MakeResponse(req, params.takeError())); + + llvm::Expected result = std::invoke( + std::forward(fn), std::forward(args)..., *params); + if (!result) + return reply(MakeResponse(req, result.takeError())); + + reply(MakeResponse(req, toJSON(*result))); + }; + } +} + +template +template +void JSONTransport::Binder::Bind(llvm::StringLiteral method, + Fn &&fn, Args &&...args) { + assert(m_event_handlers.find(method) == m_event_handlers.end() && + "event already bound"); + if constexpr (std::is_void_v) { + m_event_handlers[method] = [fn, args...](const Evt &) mutable { + std::invoke(std::forward(fn), std::forward(args)...); + }; + } else { + m_event_handlers[method] = [this, method, fn, + args...](const Evt &evt) mutable { + llvm::Expected params = + Parse(GetParams(evt), method); + if (!params) + return OnError(params.takeError()); + std::invoke(std::forward(fn), std::forward(args)..., *params); + }; + } +} + +template +template +OutgoingRequest +JSONTransport::Binder::Bind(llvm::StringLiteral method) { + if constexpr (std::is_void_v && std::is_void_v) { + return [this, method](Reply fn) { + std::scoped_lock guard(m_mutex); + Id id = ++m_seq; + Req req = MakeRequest(id, method, std::nullopt); + m_pending_responses[id] = [fn = std::move(fn)](const Resp &resp) mutable { + llvm::Expected result = GetResult(resp); + if (!result) + return fn(result.takeError()); + fn(llvm::Error::success()); + }; + if (llvm::Error error = m_transport.Send(req)) + OnError(std::move(error)); + }; + } else if constexpr (std::is_void_v) { + return [this, method](Reply fn) { + std::scoped_lock guard(m_mutex); + Id id = ++m_seq; + Req req = MakeRequest(id, method, std::nullopt); + m_pending_responses[id] = [fn = std::move(fn), + method](const Resp &resp) mutable { + llvm::Expected result = GetResult(resp); + if (!result) + return fn(result.takeError()); + fn(Parse(*result, method)); + }; + if (llvm::Error error = m_transport.Send(req)) + OnError(std::move(error)); + }; + } else if constexpr (std::is_void_v) { + return [this, method](const Params ¶ms, Reply fn) { + std::scoped_lock guard(m_mutex); + Id id = ++m_seq; + Req req = MakeRequest(id, method, llvm::json::Value(params)); + m_pending_responses[id] = [fn = std::move(fn)](const Resp &resp) mutable { + llvm::Expected result = GetResult(resp); + if (!result) + return fn(result.takeError()); + fn(llvm::Error::success()); + }; + if (llvm::Error error = m_transport.Send(req)) + OnError(std::move(error)); + }; + } else { + return [this, method](const Params ¶ms, Reply fn) { + std::scoped_lock guard(m_mutex); + Id id = ++m_seq; + Req req = MakeRequest(id, method, llvm::json::Value(params)); + m_pending_responses[id] = [fn = std::move(fn), + method](const Resp &resp) mutable { + llvm::Expected result = GetResult(resp); + if (llvm::Error err = result.takeError()) + return fn(std::move(err)); + fn(Parse(*result, method)); + }; + if (llvm::Error error = m_transport.Send(req)) + OnError(std::move(error)); + }; + } +} + +template +template +OutgoingEvent +JSONTransport::Binder::Bind(llvm::StringLiteral method) { + if constexpr (std::is_void_v) { + return [this, method]() { + if (llvm::Error error = + m_transport.Send(MakeEvent(method, std::nullopt))) + OnError(std::move(error)); + }; + } else { + return [this, method](const Params ¶ms) { + if (llvm::Error error = + m_transport.Send(MakeEvent(method, toJSON(params)))) + OnError(std::move(error)); + }; + } +} + +template +template +llvm::Expected +JSONTransport::Binder::Parse(const llvm::json::Value &raw, + llvm::StringRef method) { + T result; + llvm::json::Path::Root root; + if (!fromJSON(raw, result, root)) { + // Dump the relevant parts of the broken message. + std::string context; + llvm::raw_string_ostream OS(context); + root.printErrorContext(raw, OS); + return llvm::make_error(method.str(), context); + } + return std::move(result); +} + +} // namespace lldb_private::transport #endif diff --git a/lldb/include/lldb/Protocol/MCP/MCPError.h b/lldb/include/lldb/Protocol/MCP/MCPError.h index 55dd40f124a15..609a1733197d4 100644 --- a/lldb/include/lldb/Protocol/MCP/MCPError.h +++ b/lldb/include/lldb/Protocol/MCP/MCPError.h @@ -9,7 +9,6 @@ #ifndef LLDB_PROTOCOL_MCP_MCPERROR_H #define LLDB_PROTOCOL_MCP_MCPERROR_H -#include "lldb/Protocol/MCP/Protocol.h" #include "llvm/Support/Error.h" #include @@ -26,14 +25,12 @@ class MCPError : public llvm::ErrorInfo { const std::string &getMessage() const { return m_message; } - lldb_protocol::mcp::Error toProtocolError() const; - static constexpr int64_t kResourceNotFound = -32002; static constexpr int64_t kInternalError = -32603; private: std::string m_message; - int64_t m_error_code; + int m_error_code; }; class UnsupportedURI : public llvm::ErrorInfo { diff --git a/lldb/include/lldb/Protocol/MCP/Protocol.h b/lldb/include/lldb/Protocol/MCP/Protocol.h index 1e0816110b80a..9c5df65a6bfcb 100644 --- a/lldb/include/lldb/Protocol/MCP/Protocol.h +++ b/lldb/include/lldb/Protocol/MCP/Protocol.h @@ -323,13 +323,17 @@ struct CallToolResult { llvm::json::Value toJSON(const CallToolResult &); bool fromJSON(const llvm::json::Value &, CallToolResult &, llvm::json::Path); +lldb_protocol::mcp::Request +MakeRequest(int64_t id, llvm::StringRef method, + std::optional params); + } // namespace lldb_protocol::mcp -namespace llvm::json { -inline Value toJSON(const lldb_protocol::mcp::Void &) { return Object(); } -inline bool fromJSON(const Value &, lldb_protocol::mcp::Void &, Path) { - return true; -} -} // namespace llvm::json +// namespace llvm::json { +// inline Value toJSON(const lldb_protocol::mcp::Void &) { return Object(); } +// inline bool fromJSON(const Value &, lldb_protocol::mcp::Void &, Path) { +// return true; +// } +// } // namespace llvm::json #endif diff --git a/lldb/include/lldb/Protocol/MCP/Transport.h b/lldb/include/lldb/Protocol/MCP/Transport.h index 2b275b0354fbe..28ddeec5275e2 100644 --- a/lldb/include/lldb/Protocol/MCP/Transport.h +++ b/lldb/include/lldb/Protocol/MCP/Transport.h @@ -17,7 +17,7 @@ #include "llvm/ADT/StringRef.h" #include "llvm/Support/Error.h" -namespace lldb_private { +namespace lldb_private::transport { /// Specializations of the JSONTransport protocol functions for MCP. /// @{ template <> @@ -30,15 +30,13 @@ template <> inline lldb_protocol::mcp::Response MakeResponse(const lldb_protocol::mcp::Request &req, llvm::Error error) { lldb_protocol::mcp::Error protocol_error; - llvm::handleAllErrors( - std::move(error), - [&](const lldb_protocol::mcp::MCPError &err) { - protocol_error = err.toProtocolError(); - }, - [&](const llvm::ErrorInfoBase &err) { - protocol_error.code = lldb_protocol::mcp::MCPError::kInternalError; - protocol_error.message = err.message(); - }); + llvm::handleAllErrors(std::move(error), [&](const llvm::ErrorInfoBase &err) { + std::error_code cerr = err.convertToErrorCode(); + protocol_error.code = cerr == llvm::inconvertibleErrorCode() + ? lldb_protocol::mcp::eErrorCodeInternalError + : cerr.value(); + protocol_error.message = err.message(); + }); return lldb_protocol::mcp::Response{req.id, std::move(protocol_error)}; } @@ -83,21 +81,22 @@ GetParams(const lldb_protocol::mcp::Notification &evt) { } /// @} -} // end namespace lldb_private +} // namespace lldb_private::transport namespace lldb_protocol::mcp { /// Generic transport that uses the MCP protocol. using MCPTransport = - lldb_private::JSONTransport; + lldb_private::transport::JSONTransport; /// Generic logging callback, to allow the MCP server / client / transport layer /// to be independent of the lldb log implementation. using LogCallback = llvm::unique_function; class Transport final - : public lldb_private::JSONRPCTransport { + : public lldb_private::transport::JSONRPCTransport { public: Transport(lldb::IOObjectSP in, lldb::IOObjectSP out, LogCallback log_callback = {}); diff --git a/lldb/source/Host/common/JSONTransport.cpp b/lldb/source/Host/common/JSONTransport.cpp index f809ef478c8f7..22de7fa8cbead 100644 --- a/lldb/source/Host/common/JSONTransport.cpp +++ b/lldb/source/Host/common/JSONTransport.cpp @@ -14,8 +14,7 @@ #include using namespace llvm; -using namespace lldb; -using namespace lldb_private; +using namespace lldb_private::transport; char TransportUnhandledContentsError::ID; @@ -23,7 +22,7 @@ TransportUnhandledContentsError::TransportUnhandledContentsError( std::string unhandled_contents) : m_unhandled_contents(unhandled_contents) {} -void TransportUnhandledContentsError::log(llvm::raw_ostream &OS) const { +void TransportUnhandledContentsError::log(raw_ostream &OS) const { OS << "transport EOF with unhandled contents: '" << m_unhandled_contents << "'"; } @@ -33,10 +32,21 @@ std::error_code TransportUnhandledContentsError::convertToErrorCode() const { char InvalidParams::ID; -void InvalidParams::log(llvm::raw_ostream &OS) const { +void InvalidParams::log(raw_ostream &OS) const { OS << "invalid parameters for method '" << m_method << "': '" << m_context << "'"; } std::error_code InvalidParams::convertToErrorCode() const { return std::make_error_code(std::errc::invalid_argument); } + +char MethodNotFound::ID; + +void MethodNotFound::log(raw_ostream &OS) const { + OS << "method not found: '" << m_method << "'"; +} + +std::error_code MethodNotFound::convertToErrorCode() const { + // JSON-RPC Method not found + return std::error_code(MethodNotFound::kErrorCode, std::generic_category()); +} diff --git a/lldb/source/Protocol/MCP/MCPError.cpp b/lldb/source/Protocol/MCP/MCPError.cpp index e140d11e12cfe..cfac055ba5f11 100644 --- a/lldb/source/Protocol/MCP/MCPError.cpp +++ b/lldb/source/Protocol/MCP/MCPError.cpp @@ -22,14 +22,7 @@ MCPError::MCPError(std::string message, int64_t error_code) void MCPError::log(llvm::raw_ostream &OS) const { OS << m_message; } std::error_code MCPError::convertToErrorCode() const { - return llvm::inconvertibleErrorCode(); -} - -lldb_protocol::mcp::Error MCPError::toProtocolError() const { - lldb_protocol::mcp::Error error; - error.code = m_error_code; - error.message = m_message; - return error; + return std::error_code(m_error_code, std::generic_category()); } UnsupportedURI::UnsupportedURI(std::string uri) : m_uri(uri) {} diff --git a/lldb/source/Protocol/MCP/Server.cpp b/lldb/source/Protocol/MCP/Server.cpp index 79a2925c6c320..70a9f0cad427b 100644 --- a/lldb/source/Protocol/MCP/Server.cpp +++ b/lldb/source/Protocol/MCP/Server.cpp @@ -129,17 +129,17 @@ void Server::AddResourceProvider( MCPTransport::BinderUP Server::Bind(MCPTransport &transport) { MCPTransport::BinderUP binder = std::make_unique(transport); - binder->bind( + binder->Bind( "initialize", &Server::InitializeHandler, this); - binder->bind("tools/list", &Server::ToolsListHandler, + binder->Bind("tools/list", &Server::ToolsListHandler, this); - binder->bind("tools/call", + binder->Bind("tools/call", &Server::ToolsCallHandler, this); - binder->bind("resources/list", + binder->Bind("resources/list", &Server::ResourcesListHandler, this); - binder->bind( + binder->Bind( "resources/read", &Server::ResourcesReadHandler, this); - binder->bind("notifications/initialized", + binder->Bind("notifications/initialized", [this]() { Log("MCP initialization complete"); }); return binder; } @@ -147,11 +147,14 @@ MCPTransport::BinderUP Server::Bind(MCPTransport &transport) { llvm::Error Server::Accept(MainLoop &loop, MCPTransportUP transport) { MCPTransport::BinderUP binder = Bind(*transport); MCPTransport *transport_ptr = transport.get(); - binder->disconnected([this, transport_ptr]() { + binder->OnDisconnect([this, transport_ptr]() { assert(m_instances.find(transport_ptr) != m_instances.end() && "Client not found in m_instances"); m_instances.erase(transport_ptr); }); + binder->OnError([this](llvm::Error err) { + Logv("Transport error: {0}", llvm::toString(std::move(err))); + }); auto handle = transport->RegisterMessageHandler(loop, *binder); if (!handle) diff --git a/lldb/tools/lldb-dap/DAP.h b/lldb/tools/lldb-dap/DAP.h index 0c921e5b72d74..fee638258310a 100644 --- a/lldb/tools/lldb-dap/DAP.h +++ b/lldb/tools/lldb-dap/DAP.h @@ -79,8 +79,8 @@ enum DAPBroadcasterBits { enum class ReplMode { Variable = 0, Command, Auto }; using DAPTransport = - lldb_private::JSONTransport; + lldb_private::transport::JSONTransport; struct DAP final : public DAPTransport::MessageHandler { /// Path to the lldb-dap binary itself. diff --git a/lldb/tools/lldb-dap/Transport.h b/lldb/tools/lldb-dap/Transport.h index 6462c155eb9af..dbcd65c6add64 100644 --- a/lldb/tools/lldb-dap/Transport.h +++ b/lldb/tools/lldb-dap/Transport.h @@ -24,9 +24,10 @@ namespace lldb_dap { /// A transport class that performs the Debug Adapter Protocol communication /// with the client. -class Transport final : public lldb_private::HTTPDelimitedJSONTransport< - protocol::Id, protocol::Request, protocol::Response, - protocol::Event> { +class Transport final + : public lldb_private::transport::HTTPDelimitedJSONTransport< + protocol::Id, protocol::Request, protocol::Response, + protocol::Event> { public: Transport(llvm::StringRef client_name, lldb_dap::Log *log, lldb::IOObjectSP input, lldb::IOObjectSP output); diff --git a/lldb/unittests/Host/JSONTransportTest.cpp b/lldb/unittests/Host/JSONTransportTest.cpp index 6477edd5eb666..c7654006a1397 100644 --- a/lldb/unittests/Host/JSONTransportTest.cpp +++ b/lldb/unittests/Host/JSONTransportTest.cpp @@ -31,6 +31,7 @@ using namespace llvm; using namespace lldb_private; +using namespace lldb_private::transport; using testing::_; using testing::HasSubstr; using testing::InSequence; @@ -292,7 +293,7 @@ class JSONRPCTransportTest : public JSONTransportTest { } }; -class TestTransportBinder : public testing::Test { +class TransportBinderTest : public testing::Test { protected: SubsystemRAII subsystems; @@ -321,41 +322,36 @@ class TestTransportBinder : public testing::Test { } // namespace -namespace lldb_private { +namespace lldb_private::transport { using namespace test_protocol; template <> -inline test_protocol::Req MakeRequest(int id, llvm::StringRef method, - std::optional params) { +test_protocol::Req MakeRequest(int id, llvm::StringRef method, + std::optional params) { return test_protocol::Req{id, method.str(), params}; } -template <> inline Resp MakeResponse(const Req &req, llvm::Error error) { +template <> Resp MakeResponse(const Req &req, llvm::Error error) { llvm::consumeError(std::move(error)); return Resp{req.id, std::nullopt}; } -template <> inline Resp MakeResponse(const Req &req, json::Value result) { +template <> Resp MakeResponse(const Req &req, json::Value result) { return Resp{req.id, std::move(result)}; } template <> -inline Evt MakeEvent(llvm::StringRef method, - std::optional params) { +Evt MakeEvent(llvm::StringRef method, std::optional params) { return Evt{method.str(), params}; } -template <> inline llvm::Expected GetResult(const Resp &resp) { +template <> llvm::Expected GetResult(const Resp &resp) { return resp.result; } -template <> inline int GetId(const Resp &resp) { return resp.id; } -template <> inline llvm::StringRef GetMethod(const Req &req) { - return req.name; -} -template <> inline llvm::StringRef GetMethod(const Evt &evt) { - return evt.name; -} -template <> inline json::Value GetParams(const Req &req) { return req.params; } -template <> inline json::Value GetParams(const Evt &evt) { return evt.params; } +template <> int GetId(const Resp &resp) { return resp.id; } +template <> llvm::StringRef GetMethod(const Req &req) { return req.name; } +template <> llvm::StringRef GetMethod(const Evt &evt) { return evt.name; } +template <> json::Value GetParams(const Req &req) { return req.params; } +template <> json::Value GetParams(const Evt &evt) { return evt.params; } -} // namespace lldb_private +} // namespace lldb_private::transport // Failing on Windows, see https://github.com/llvm/llvm-project/issues/153446. #ifndef _WIN32 @@ -603,46 +599,156 @@ TEST_F(JSONRPCTransportTest, InvalidTransport) { } // Out-bound binding request handler. -TEST_F(TestTransportBinder, OutBoundRequests) { - auto addFn = binder->bind("add"); - addFn(MyFnParams{1, 2}, [](Expected result) { +TEST_F(TransportBinderTest, OutBoundRequests) { + OutgoingRequest addFn = + binder->Bind("add"); + bool replied = false; + addFn(MyFnParams{1, 2}, [&](Expected result) { EXPECT_THAT_EXPECTED(result, Succeeded()); EXPECT_EQ(result->c, 3); + replied = true; }); EXPECT_CALL(remote, Received(Req{1, "add", MyFnParams{1, 2}})); - // Queue a reply that will be sent during 'Run'. EXPECT_THAT_ERROR(from_remote->Send(Resp{1, toJSON(MyFnResult{3})}), Succeeded()); Run(); + EXPECT_TRUE(replied); +} + +TEST_F(TransportBinderTest, OutBoundRequestsVoidParams) { + OutgoingRequest voidParamFn = + binder->Bind("voidParam"); + bool replied = false; + voidParamFn([&](Expected result) { + EXPECT_THAT_EXPECTED(result, Succeeded()); + EXPECT_EQ(result->c, 3); + replied = true; + }); + EXPECT_CALL(remote, Received(Req{1, "voidParam", std::nullopt})); + EXPECT_THAT_ERROR(from_remote->Send(Resp{1, toJSON(MyFnResult{3})}), + Succeeded()); + Run(); + EXPECT_TRUE(replied); +} + +TEST_F(TransportBinderTest, OutBoundRequestsVoidResult) { + OutgoingRequest voidResultFn = + binder->Bind("voidResult"); + bool replied = false; + voidResultFn(MyFnParams{4, 5}, [&](llvm::Error error) { + EXPECT_THAT_ERROR(std::move(error), Succeeded()); + replied = true; + }); + EXPECT_CALL(remote, Received(Req{1, "voidResult", MyFnParams{4, 5}})); + EXPECT_THAT_ERROR(from_remote->Send(Resp{1, std::nullopt}), Succeeded()); + Run(); + EXPECT_TRUE(replied); +} + +TEST_F(TransportBinderTest, OutBoundRequestsVoidParamsAndVoidResult) { + OutgoingRequest voidParamAndResultFn = + binder->Bind("voidParamAndResult"); + bool replied = false; + voidParamAndResultFn([&](llvm::Error error) { + EXPECT_THAT_ERROR(std::move(error), Succeeded()); + replied = true; + }); + EXPECT_CALL(remote, Received(Req{1, "voidParamAndResult", std::nullopt})); + EXPECT_THAT_ERROR(from_remote->Send(Resp{1, std::nullopt}), Succeeded()); + Run(); + EXPECT_TRUE(replied); } // In-bound binding request handler. -TEST_F(TestTransportBinder, InBoundRequests) { - binder->bind( +TEST_F(TransportBinderTest, InBoundRequests) { + bool called = false; + binder->Bind( "add", - [](const int captured_param, - const MyFnParams ¶ms) -> Expected { + [&](const int captured_param, + const MyFnParams ¶ms) -> Expected { + called = true; return MyFnResult{params.a + params.b + captured_param}; }, 2); - EXPECT_THAT_ERROR(from_remote->Send(Req{2, "add", MyFnParams{3, 4}}), + EXPECT_THAT_ERROR(from_remote->Send(Req{1, "add", MyFnParams{3, 4}}), + Succeeded()); + + EXPECT_CALL(remote, Received(Resp{1, MyFnResult{9}})); + Run(); + EXPECT_TRUE(called); +} + +TEST_F(TransportBinderTest, InBoundRequestsVoidParams) { + bool called = false; + binder->Bind( + "voidParam", + [&](const int captured_param) -> Expected { + called = true; + return MyFnResult{captured_param}; + }, + 2); + EXPECT_THAT_ERROR(from_remote->Send(Req{2, "voidParam", std::nullopt}), Succeeded()); - EXPECT_CALL(remote, Received(Resp{2, MyFnResult{9}})); + EXPECT_CALL(remote, Received(Resp{2, MyFnResult{2}})); Run(); + EXPECT_TRUE(called); +} + +TEST_F(TransportBinderTest, InBoundRequestsVoidResult) { + bool called = false; + binder->Bind( + "voidResult", + [&](const int captured_param, const MyFnParams ¶ms) -> llvm::Error { + called = true; + EXPECT_EQ(captured_param, 2); + EXPECT_EQ(params.a, 3); + EXPECT_EQ(params.b, 4); + return llvm::Error::success(); + }, + 2); + EXPECT_THAT_ERROR(from_remote->Send(Req{3, "voidResult", MyFnParams{3, 4}}), + Succeeded()); + EXPECT_CALL(remote, Received(Resp{3, std::nullopt})); + Run(); + EXPECT_TRUE(called); +} +TEST_F(TransportBinderTest, InBoundRequestsVoidParamsAndResult) { + bool called = false; + binder->Bind( + "voidParamAndResult", + [&](const int captured_param) -> llvm::Error { + called = true; + EXPECT_EQ(captured_param, 2); + return llvm::Error::success(); + }, + 2); + EXPECT_THAT_ERROR( + from_remote->Send(Req{4, "voidParamAndResult", std::nullopt}), + Succeeded()); + EXPECT_CALL(remote, Received(Resp{4, std::nullopt})); + Run(); + EXPECT_TRUE(called); } // Out-bound binding event handler. -TEST_F(TestTransportBinder, OutBoundEvents) { - auto emitEvent = binder->bind("evt"); +TEST_F(TransportBinderTest, OutBoundEvents) { + OutgoingEvent emitEvent = binder->Bind("evt"); emitEvent(MyFnParams{1, 2}); EXPECT_CALL(remote, Received(Evt{"evt", MyFnParams{1, 2}})); Run(); } +TEST_F(TransportBinderTest, OutBoundEventsVoidParams) { + OutgoingEvent emitEvent = binder->Bind("evt"); + emitEvent(); + EXPECT_CALL(remote, Received(Evt{"evt", std::nullopt})); + Run(); +} + // In-bound binding event handler. -TEST_F(TestTransportBinder, InBoundEvents) { +TEST_F(TransportBinderTest, InBoundEvents) { bool called = false; - binder->bind( + binder->Bind( "evt", [&](const int captured_arg, const MyFnParams ¶ms) { EXPECT_EQ(captured_arg, 42); @@ -657,4 +763,18 @@ TEST_F(TestTransportBinder, InBoundEvents) { EXPECT_TRUE(called); } +TEST_F(TransportBinderTest, InBoundEventsVoidParams) { + bool called = false; + binder->Bind( + "evt", + [&](const int captured_arg) { + EXPECT_EQ(captured_arg, 42); + called = true; + }, + 42); + EXPECT_THAT_ERROR(from_remote->Send(Evt{"evt", std::nullopt}), Succeeded()); + Run(); + EXPECT_TRUE(called); +} + #endif diff --git a/lldb/unittests/Protocol/ProtocolMCPServerTest.cpp b/lldb/unittests/Protocol/ProtocolMCPServerTest.cpp index 5523b21552400..bc3f62a2a2db4 100644 --- a/lldb/unittests/Protocol/ProtocolMCPServerTest.cpp +++ b/lldb/unittests/Protocol/ProtocolMCPServerTest.cpp @@ -30,12 +30,13 @@ #include #include #include +#include using namespace llvm; using namespace lldb; using namespace lldb_private; +using namespace lldb_private::transport; using namespace lldb_protocol::mcp; -using testing::_; namespace { @@ -101,7 +102,9 @@ class ErrorTool : public Tool { using Tool::Tool; llvm::Expected Call(const ToolArguments &args) override { - return llvm::createStringError("error"); + return llvm::createStringError( + std::error_code(eErrorCodeInternalError, std::generic_category()), + "error"); } }; @@ -160,7 +163,7 @@ class ProtocolServerMCPTest : public testing::Test { binder = server_up->Bind(*to_client); auto server_handle = to_server->RegisterMessageHandler(loop, *binder); EXPECT_THAT_EXPECTED(server_handle, Succeeded()); - binder->error([](llvm::Error error) { + binder->OnError([](llvm::Error error) { llvm::errs() << formatv("Server transport error: {0}", error); }); handles[0] = std::move(*server_handle); diff --git a/lldb/unittests/TestingSupport/Host/JSONTransportTestUtilities.h b/lldb/unittests/TestingSupport/Host/JSONTransportTestUtilities.h index 4dbcd614e400b..73918ace2c521 100644 --- a/lldb/unittests/TestingSupport/Host/JSONTransportTestUtilities.h +++ b/lldb/unittests/TestingSupport/Host/JSONTransportTestUtilities.h @@ -23,10 +23,11 @@ template class TestTransport final - : public lldb_private::JSONTransport { + : public lldb_private::transport::JSONTransport { public: using MessageHandler = - typename lldb_private::JSONTransport::MessageHandler; + typename lldb_private::transport::JSONTransport::MessageHandler; static std::pair>, std::unique_ptr>> @@ -100,7 +101,8 @@ class TestTransport final template class MockMessageHandler final - : public lldb_private::JSONTransport::MessageHandler { + : public lldb_private::transport::JSONTransport::MessageHandler { public: MOCK_METHOD(void, Received, (const Req &), (override)); MOCK_METHOD(void, Received, (const Resp &), (override)); From 193826be9ceffb3b7ec50d474f0b4a6d81baf20c Mon Sep 17 00:00:00 2001 From: John Harrison Date: Thu, 18 Sep 2025 18:07:26 -0700 Subject: [PATCH 4/8] Fixing clang-format. --- lldb/include/lldb/Protocol/MCP/Server.h | 1 - 1 file changed, 1 deletion(-) diff --git a/lldb/include/lldb/Protocol/MCP/Server.h b/lldb/include/lldb/Protocol/MCP/Server.h index 94dbba8d5888e..df2a4810ce620 100644 --- a/lldb/include/lldb/Protocol/MCP/Server.h +++ b/lldb/include/lldb/Protocol/MCP/Server.h @@ -27,7 +27,6 @@ namespace lldb_protocol::mcp { - class Server { using MCPTransportUP = std::unique_ptr; From 04b8f028e23a1b85c4b5f41278d99f88fdb8c897 Mon Sep 17 00:00:00 2001 From: John Harrison Date: Tue, 23 Sep 2025 12:50:54 -0700 Subject: [PATCH 5/8] Consolidating protocol descriptions into a single 'concept'. The JSONTransport having to carry around 4 template parameters was a lot of repeated information. To try to simplify this, I created a 'concept' for `ProtocolDescriptor` that describes the single type definition used to describes a protocol. The `BindingBuilder` concept is an extension on `ProtocolDescriptor` that has additional helpers for using the Binder. This helps relate the concepts together into a single place, which I find helpful. Previously, there were a few free standing template functions that needed to be implemented for the Binder to work. --- lldb/include/lldb/Host/JSONTransport.h | 225 ++++++++++-------- lldb/include/lldb/Protocol/MCP/Server.h | 4 +- lldb/include/lldb/Protocol/MCP/Transport.h | 123 +++++----- .../Protocol/MCP/ProtocolServerMCP.cpp | 2 +- lldb/source/Protocol/MCP/Server.cpp | 7 +- lldb/tools/lldb-dap/DAP.h | 4 +- lldb/tools/lldb-dap/Transport.h | 10 +- lldb/unittests/DAP/TestBase.h | 9 +- lldb/unittests/Host/JSONTransportTest.cpp | 150 ++++++------ .../Protocol/ProtocolMCPServerTest.cpp | 12 +- .../Host/JSONTransportTestUtilities.h | 34 ++- 11 files changed, 302 insertions(+), 278 deletions(-) diff --git a/lldb/include/lldb/Host/JSONTransport.h b/lldb/include/lldb/Host/JSONTransport.h index ca207ceddf6f0..f04cf87954639 100644 --- a/lldb/include/lldb/Host/JSONTransport.h +++ b/lldb/include/lldb/Host/JSONTransport.h @@ -71,7 +71,11 @@ class InvalidParams : public llvm::ErrorInfo { std::error_code convertToErrorCode() const override; private: + /// The JSONRPC remote method call. std::string m_method; + + /// Additional context from the parsing failure, e.g. "missing value at + /// (root)[1].str". std::string m_context; }; @@ -91,19 +95,32 @@ class MethodNotFound : public llvm::ErrorInfo { std::string m_method; }; +// FIXME: Once we upgrade to c++20, use this concept for JSONTransport. +// template +// concept ProtocolDescriptor = requires { +// typename T::Id; +// typename T::Req; +// typename T::Resp; +// typename T::Evt; +// }; + /// A transport is responsible for maintaining the connection to a client /// application, and reading/writing structured messages to it. /// /// JSONTransport have limited thread safety requirements: /// - Messages will not be sent concurrently. /// - Messages MAY be sent while Run() is reading, or its callback is active. -template -class JSONTransport { +/// +/// FIXME: Once we upgrade to c++20, use template +template class JSONTransport { public: + using Req = typename Proto::Req; + using Resp = typename Proto::Resp; + using Evt = typename Proto::Evt; using Message = std::variant; - class Binder; - using BinderUP = std::unique_ptr; + // class Binder; + // using BinderUP = std::unique_ptr; virtual ~JSONTransport() = default; @@ -152,19 +169,17 @@ class JSONTransport { }; /// An IOTransport sends and receives messages using an IOObject. -template -class IOTransport : public JSONTransport { +template class IOTransport : public JSONTransport { public: - using Message = typename JSONTransport::Message; - using MessageHandler = - typename JSONTransport::MessageHandler; + using Message = typename JSONTransport::Message; + using MessageHandler = typename JSONTransport::MessageHandler; IOTransport(lldb::IOObjectSP in, lldb::IOObjectSP out) : m_in(in), m_out(out) {} - llvm::Error Send(const Evt &evt) override { return Write(evt); } - llvm::Error Send(const Req &req) override { return Write(req); } - llvm::Error Send(const Resp &resp) override { return Write(resp); } + llvm::Error Send(const Proto::Evt &evt) override { return Write(evt); } + llvm::Error Send(const Proto::Req &req) override { return Write(req); } + llvm::Error Send(const Proto::Resp &resp) override { return Write(resp); } llvm::Expected RegisterMessageHandler(MainLoop &loop, MessageHandler &handler) override { @@ -244,10 +259,10 @@ class IOTransport : public JSONTransport { }; /// A transport class for JSON with a HTTP header. -template -class HTTPDelimitedJSONTransport : public IOTransport { +template +class HTTPDelimitedJSONTransport : public IOTransport { public: - using IOTransport::IOTransport; + using IOTransport::IOTransport; protected: /// Encodes messages based on @@ -311,10 +326,9 @@ class HTTPDelimitedJSONTransport : public IOTransport { }; /// A transport class for JSON RPC. -template -class JSONRPCTransport : public IOTransport { +template class JSONRPCTransport : public IOTransport { public: - using IOTransport::IOTransport; + using IOTransport::IOTransport; protected: std::string Encode(const llvm::json::Value &message) override { @@ -368,34 +382,51 @@ using OutgoingRequest = typename detail::request_t::type; /// A function to send an outgoing event. template using OutgoingEvent = typename detail::event_t

::type; -/// Creates a request with the given id, method, and optional params. -template -Req MakeRequest(Id, llvm::StringRef, std::optional); - -/// Creates an error response for a given request. -template -Resp MakeResponse(const Req &, llvm::Error); - -/// Creates a success response for a given request. -template -Resp MakeResponse(const Req &, llvm::json::Value); - -/// Creates an event. -template -Evt MakeEvent(llvm::StringRef, std::optional); - -/// Extracts the result value from a response. -template -llvm::Expected GetResult(const Resp &); - -/// Extracts the id from a response. -template Id GetId(const Resp &); - -/// Extracts the method from a request or event. -template llvm::StringRef GetMethod(const T &); - -/// Extracts the parameters from a request or event. -template llvm::json::Value GetParams(const T &); +// FIXME: With c++20, we should use this concept: +// template +// concept BindingBuilder = +// ProtocolDescriptor && +// requires(T::Id id, T::Req req, T::Resp resp, T::Evt evt, +// llvm::StringRef method, std::optional params, +// std::optional result, llvm::Error err) { +// // For initializing the unique sequence identifier; +// { T::InitialId() } -> std::same_as; +// // Incrementing the sequence identifier. +// { id++ } -> std::same_as; +// +// // Constructing protocol types +// // @{ +// // Construct a new request. +// { T::Make(id, method, params) } -> std::same_as; +// // Construct a new error response. +// { T::Make(req, std::move(err)) } -> std::same_as; +// // Construct a new success response. +// { T::Make(req, result) } -> std::same_as; +// // Construct a new event. +// { T::Make(method, params) } -> std::same_as; +// // @} +// +// // Keys for associated types. +// // @{ +// // Looking up in flight responses. +// { T::KeyFor(resp) } -> std::same_as; +// // Extract method from request. +// { T::KeyFor(req) } -> std::same_as; +// // Extract method from event. +// { T::KeyFor(evt) } -> std::same_as; +// // @} +// +// // Extracting information from associated types. +// // @{ +// // Extract parameters from a request. +// { T::Extract(req) } -> std::same_as>; +// // Extract result from a response. +// { T::Extract(resp) } -> +// std::same_as>; +// // Extract parameters from an event. +// { T::Extract(evt) } -> std::same_as>; +// // @} +// }; /// Binder collects a table of functions that handle calls. /// @@ -422,10 +453,14 @@ template llvm::json::Value GetParams(const T &); /// cout << *result << "\n"; /// }); /// \endcode -template -class JSONTransport::Binder - : public JSONTransport::MessageHandler { - using Transport = JSONTransport; +// FIXME: In c++20 use: template +template +class Binder : public JSONTransport::MessageHandler { + using Req = Proto::Req; + using Resp = Proto::Resp; + using Evt = Proto::Evt; + using Id = Proto::Id; + using Transport = JSONTransport; using MessageHandler = typename Transport::MessageHandler; public: @@ -470,7 +505,7 @@ class JSONTransport::Binder void Received(const Evt &evt) override { std::scoped_lock guard(m_mutex); - auto it = m_event_handlers.find(GetMethod(evt)); + auto it = m_event_handlers.find(Proto::KeyFor(evt)); if (it == m_event_handlers.end()) { OnError(llvm::createStringError( llvm::formatv("no handled for event {0}", toJSON(evt)))); @@ -483,10 +518,9 @@ class JSONTransport::Binder ReplyOnce reply(req, &m_transport, this); std::scoped_lock guard(m_mutex); - auto it = m_request_handlers.find(GetMethod(req)); + auto it = m_request_handlers.find(Proto::KeyFor(req)); if (it == m_request_handlers.end()) { - reply(MakeResponse( - req, llvm::createStringError("method not found"))); + reply(Proto::Make(req, llvm::createStringError("method not found"))); return; } @@ -496,7 +530,7 @@ class JSONTransport::Binder void Received(const Resp &resp) override { std::scoped_lock guard(m_mutex); - Id id = GetId(resp); + Id id = Proto::KeyFor(resp); auto it = m_pending_responses.find(id); if (it == m_pending_responses.end()) { OnError(llvm::createStringError( @@ -566,8 +600,7 @@ class JSONTransport::Binder ~ReplyOnce() { if (transport && handler && !replied) { assert(false && "must reply to all calls!"); - (*this)(MakeResponse( - req, llvm::createStringError("failed to reply"))); + (*this)(Proto::Make(req, llvm::createStringError("failed to reply"))); } } @@ -584,29 +617,26 @@ class JSONTransport::Binder }; }; -template +template template -void JSONTransport::Binder::OnDisconnect(Fn &&fn, - Args &&...args) { +void Binder::OnDisconnect(Fn &&fn, Args &&...args) { m_disconnect_handler = [fn, args...]() mutable { std::invoke(std::forward(fn), std::forward(args)...); }; } -template +template template -void JSONTransport::Binder::OnError(Fn &&fn, - Args &&...args) { +void Binder::OnError(Fn &&fn, Args &&...args) { m_error_handler = [fn, args...](llvm::Error error) mutable { std::invoke(std::forward(fn), std::forward(args)..., std::move(error)); }; } -template +template template -void JSONTransport::Binder::Bind(llvm::StringLiteral method, - Fn &&fn, Args &&...args) { +void Binder::Bind(llvm::StringLiteral method, Fn &&fn, Args &&...args) { assert(m_request_handlers.find(method) == m_request_handlers.end() && "request already bound"); if constexpr (std::is_void_v && std::is_void_v) { @@ -615,7 +645,7 @@ void JSONTransport::Binder::Bind(llvm::StringLiteral method, llvm::unique_function reply) mutable { llvm::Error result = std::invoke(std::forward(fn), std::forward(args)...); - reply(MakeResponse(req, std::move(result))); + reply(Proto::Make(req, std::move(result))); }; } else if constexpr (std::is_void_v) { m_request_handlers[method] = @@ -624,8 +654,8 @@ void JSONTransport::Binder::Bind(llvm::StringLiteral method, llvm::Expected result = std::invoke(std::forward(fn), std::forward(args)...); if (!result) - return reply(MakeResponse(req, result.takeError())); - reply(MakeResponse(req, toJSON(*result))); + return reply(Proto::Make(req, result.takeError())); + reply(Proto::Make(req, toJSON(*result))); }; } else if constexpr (std::is_void_v) { m_request_handlers[method] = @@ -633,13 +663,13 @@ void JSONTransport::Binder::Bind(llvm::StringLiteral method, args...](const Req &req, llvm::unique_function reply) mutable { llvm::Expected params = - Parse(GetParams(req), method); + Parse(Proto::Extract(req), method); if (!params) - return reply(MakeResponse(req, params.takeError())); + return reply(Proto::Make(req, params.takeError())); llvm::Error result = std::invoke( std::forward(fn), std::forward(args)..., *params); - reply(MakeResponse(req, std::move(result))); + reply(Proto::Make(req, std::move(result))); }; } else { m_request_handlers[method] = @@ -647,24 +677,23 @@ void JSONTransport::Binder::Bind(llvm::StringLiteral method, args...](const Req &req, llvm::unique_function reply) mutable { llvm::Expected params = - Parse(GetParams(req), method); + Parse(Proto::Extract(req), method); if (!params) - return reply(MakeResponse(req, params.takeError())); + return reply(Proto::Make(req, params.takeError())); llvm::Expected result = std::invoke( std::forward(fn), std::forward(args)..., *params); if (!result) - return reply(MakeResponse(req, result.takeError())); + return reply(Proto::Make(req, result.takeError())); - reply(MakeResponse(req, toJSON(*result))); + reply(Proto::Make(req, toJSON(*result))); }; } } -template +template template -void JSONTransport::Binder::Bind(llvm::StringLiteral method, - Fn &&fn, Args &&...args) { +void Binder::Bind(llvm::StringLiteral method, Fn &&fn, Args &&...args) { assert(m_event_handlers.find(method) == m_event_handlers.end() && "event already bound"); if constexpr (std::is_void_v) { @@ -675,7 +704,7 @@ void JSONTransport::Binder::Bind(llvm::StringLiteral method, m_event_handlers[method] = [this, method, fn, args...](const Evt &evt) mutable { llvm::Expected params = - Parse(GetParams(evt), method); + Parse(Proto::Extract(evt), method); if (!params) return OnError(params.takeError()); std::invoke(std::forward(fn), std::forward(args)..., *params); @@ -683,17 +712,17 @@ void JSONTransport::Binder::Bind(llvm::StringLiteral method, } } -template +template template OutgoingRequest -JSONTransport::Binder::Bind(llvm::StringLiteral method) { +Binder::Bind(llvm::StringLiteral method) { if constexpr (std::is_void_v && std::is_void_v) { return [this, method](Reply fn) { std::scoped_lock guard(m_mutex); Id id = ++m_seq; - Req req = MakeRequest(id, method, std::nullopt); + Req req = Proto::Make(id, method, std::nullopt); m_pending_responses[id] = [fn = std::move(fn)](const Resp &resp) mutable { - llvm::Expected result = GetResult(resp); + llvm::Expected result = Proto::Extract(resp); if (!result) return fn(result.takeError()); fn(llvm::Error::success()); @@ -705,10 +734,10 @@ JSONTransport::Binder::Bind(llvm::StringLiteral method) { return [this, method](Reply fn) { std::scoped_lock guard(m_mutex); Id id = ++m_seq; - Req req = MakeRequest(id, method, std::nullopt); + Req req = Proto::Make(id, method, std::nullopt); m_pending_responses[id] = [fn = std::move(fn), method](const Resp &resp) mutable { - llvm::Expected result = GetResult(resp); + llvm::Expected result = Proto::Extract(resp); if (!result) return fn(result.takeError()); fn(Parse(*result, method)); @@ -720,9 +749,9 @@ JSONTransport::Binder::Bind(llvm::StringLiteral method) { return [this, method](const Params ¶ms, Reply fn) { std::scoped_lock guard(m_mutex); Id id = ++m_seq; - Req req = MakeRequest(id, method, llvm::json::Value(params)); + Req req = Proto::Make(id, method, llvm::json::Value(params)); m_pending_responses[id] = [fn = std::move(fn)](const Resp &resp) mutable { - llvm::Expected result = GetResult(resp); + llvm::Expected result = Proto::Extract(resp); if (!result) return fn(result.takeError()); fn(llvm::Error::success()); @@ -734,10 +763,10 @@ JSONTransport::Binder::Bind(llvm::StringLiteral method) { return [this, method](const Params ¶ms, Reply fn) { std::scoped_lock guard(m_mutex); Id id = ++m_seq; - Req req = MakeRequest(id, method, llvm::json::Value(params)); + Req req = Proto::Make(id, method, llvm::json::Value(params)); m_pending_responses[id] = [fn = std::move(fn), method](const Resp &resp) mutable { - llvm::Expected result = GetResult(resp); + llvm::Expected result = Proto::Extract(resp); if (llvm::Error err = result.takeError()) return fn(std::move(err)); fn(Parse(*result, method)); @@ -748,30 +777,28 @@ JSONTransport::Binder::Bind(llvm::StringLiteral method) { } } -template +template template -OutgoingEvent -JSONTransport::Binder::Bind(llvm::StringLiteral method) { +OutgoingEvent Binder::Bind(llvm::StringLiteral method) { if constexpr (std::is_void_v) { return [this, method]() { if (llvm::Error error = - m_transport.Send(MakeEvent(method, std::nullopt))) + m_transport.Send(Proto::Make(method, std::nullopt))) OnError(std::move(error)); }; } else { return [this, method](const Params ¶ms) { if (llvm::Error error = - m_transport.Send(MakeEvent(method, toJSON(params)))) + m_transport.Send(Proto::Make(method, toJSON(params)))) OnError(std::move(error)); }; } } -template +template template -llvm::Expected -JSONTransport::Binder::Parse(const llvm::json::Value &raw, - llvm::StringRef method) { +llvm::Expected Binder::Parse(const llvm::json::Value &raw, + llvm::StringRef method) { T result; llvm::json::Path::Root root; if (!fromJSON(raw, result, root)) { diff --git a/lldb/include/lldb/Protocol/MCP/Server.h b/lldb/include/lldb/Protocol/MCP/Server.h index df2a4810ce620..f185d51f41192 100644 --- a/lldb/include/lldb/Protocol/MCP/Server.h +++ b/lldb/include/lldb/Protocol/MCP/Server.h @@ -43,7 +43,7 @@ class Server { llvm::Error Accept(lldb_private::MainLoop &, MCPTransportUP); protected: - MCPTransport::BinderUP Bind(MCPTransport &); + MCPBinderUP Bind(MCPTransport &); ServerCapabilities GetCapabilities(); @@ -72,7 +72,7 @@ class Server { struct Client { ReadHandleUP handle; MCPTransportUP transport; - MCPTransport::BinderUP binder; + MCPBinderUP binder; }; std::map m_instances; diff --git a/lldb/include/lldb/Protocol/MCP/Transport.h b/lldb/include/lldb/Protocol/MCP/Transport.h index 28ddeec5275e2..cf7cb100459de 100644 --- a/lldb/include/lldb/Protocol/MCP/Transport.h +++ b/lldb/include/lldb/Protocol/MCP/Transport.h @@ -16,87 +16,72 @@ #include "llvm/ADT/FunctionExtras.h" #include "llvm/ADT/StringRef.h" #include "llvm/Support/Error.h" +#include -namespace lldb_private::transport { -/// Specializations of the JSONTransport protocol functions for MCP. -/// @{ -template <> -inline lldb_protocol::mcp::Request -MakeRequest(int64_t id, llvm::StringRef method, - std::optional params) { - return lldb_protocol::mcp::Request{id, method.str(), params}; -} -template <> -inline lldb_protocol::mcp::Response -MakeResponse(const lldb_protocol::mcp::Request &req, llvm::Error error) { - lldb_protocol::mcp::Error protocol_error; - llvm::handleAllErrors(std::move(error), [&](const llvm::ErrorInfoBase &err) { - std::error_code cerr = err.convertToErrorCode(); - protocol_error.code = cerr == llvm::inconvertibleErrorCode() - ? lldb_protocol::mcp::eErrorCodeInternalError - : cerr.value(); - protocol_error.message = err.message(); - }); +namespace lldb_protocol::mcp { - return lldb_protocol::mcp::Response{req.id, std::move(protocol_error)}; -} -template <> -inline lldb_protocol::mcp::Response -MakeResponse(const lldb_protocol::mcp::Request &req, llvm::json::Value result) { - return lldb_protocol::mcp::Response{req.id, std::move(result)}; -} -template <> -inline lldb_protocol::mcp::Notification -MakeEvent(llvm::StringRef method, std::optional params) { - return lldb_protocol::mcp::Notification{method.str(), params}; -} -template <> -inline llvm::Expected -GetResult(const lldb_protocol::mcp::Response &resp) { - if (const lldb_protocol::mcp::Error *error = - std::get_if(&resp.result)) - return llvm::make_error(error->message, - error->code); - return std::get(resp.result); -} -template <> inline int64_t GetId(const lldb_protocol::mcp::Response &resp) { - return std::get(resp.id); -} -template <> -inline llvm::StringRef GetMethod(const lldb_protocol::mcp::Request &req) { - return req.method; -} -template <> -inline llvm::StringRef GetMethod(const lldb_protocol::mcp::Notification &evt) { - return evt.method; -} -template <> -inline llvm::json::Value GetParams(const lldb_protocol::mcp::Request &req) { - return req.params; -} -template <> -inline llvm::json::Value -GetParams(const lldb_protocol::mcp::Notification &evt) { - return evt.params; -} -/// @} +struct ProtocolDescriptor { + using Id = int64_t; + using Req = Request; + using Resp = Response; + using Evt = Notification; -} // namespace lldb_private::transport + static inline Id InitialId() { return 0; } + static inline Request Make(Id id, llvm::StringRef method, + std::optional params) { + return Request{id, method.str(), params}; + } + static inline Notification Make(llvm::StringRef method, + std::optional params) { + return Notification{method.str(), params}; + } + static inline Response Make(Req req, llvm::Error error) { + lldb_protocol::mcp::Error protocol_error; + llvm::handleAllErrors( + std::move(error), [&](const llvm::ErrorInfoBase &err) { + std::error_code cerr = err.convertToErrorCode(); + protocol_error.code = + cerr == llvm::inconvertibleErrorCode() + ? lldb_protocol::mcp::eErrorCodeInternalError + : cerr.value(); + protocol_error.message = err.message(); + }); -namespace lldb_protocol::mcp { + return Response{req.id, std::move(protocol_error)}; + } + static inline Response Make(Req req, + std::optional result) { + return Response{req.id, std::move(result)}; + } + static inline Id KeyFor(Response r) { return std::get(r.id); } + static inline llvm::StringRef KeyFor(Request r) { return r.method; } + static inline llvm::StringRef KeyFor(Notification n) { return n.method; } + static inline std::optional Extract(Request r) { + return r.params; + } + static inline llvm::Expected Extract(Response r) { + if (const lldb_protocol::mcp::Error *error = + std::get_if(&r.result)) + return llvm::make_error(error->message, + error->code); + return std::get(r.result); + } + static inline std::optional Extract(Notification n) { + return n.params; + } +}; /// Generic transport that uses the MCP protocol. -using MCPTransport = - lldb_private::transport::JSONTransport; +using MCPTransport = lldb_private::transport::JSONTransport; +using MCPBinder = lldb_private::transport::Binder; +using MCPBinderUP = std::unique_ptr; /// Generic logging callback, to allow the MCP server / client / transport layer /// to be independent of the lldb log implementation. using LogCallback = llvm::unique_function; class Transport final - : public lldb_private::transport::JSONRPCTransport { + : public lldb_private::transport::JSONRPCTransport { public: Transport(lldb::IOObjectSP in, lldb::IOObjectSP out, LogCallback log_callback = {}); diff --git a/lldb/source/Plugins/Protocol/MCP/ProtocolServerMCP.cpp b/lldb/source/Plugins/Protocol/MCP/ProtocolServerMCP.cpp index 7500bb1515e01..33bdd5eec3644 100644 --- a/lldb/source/Plugins/Protocol/MCP/ProtocolServerMCP.cpp +++ b/lldb/source/Plugins/Protocol/MCP/ProtocolServerMCP.cpp @@ -116,7 +116,7 @@ llvm::Error ProtocolServerMCP::Start(ProtocolServer::Connection connection) { m_running = true; m_server_info_handle = std::move(*server_info_handle); m_accept_handles = std::move(*handles); - m_loop_thread = std::thread([=] { + m_loop_thread = std::thread([this] { llvm::set_thread_name("protocol-server.mcp"); m_loop.Run(); }); diff --git a/lldb/source/Protocol/MCP/Server.cpp b/lldb/source/Protocol/MCP/Server.cpp index 70a9f0cad427b..cd0f151ba1466 100644 --- a/lldb/source/Protocol/MCP/Server.cpp +++ b/lldb/source/Protocol/MCP/Server.cpp @@ -126,9 +126,8 @@ void Server::AddResourceProvider( m_resource_providers.push_back(std::move(resource_provider)); } -MCPTransport::BinderUP Server::Bind(MCPTransport &transport) { - MCPTransport::BinderUP binder = - std::make_unique(transport); +MCPBinderUP Server::Bind(MCPTransport &transport) { + MCPBinderUP binder = std::make_unique(transport); binder->Bind( "initialize", &Server::InitializeHandler, this); binder->Bind("tools/list", &Server::ToolsListHandler, @@ -145,7 +144,7 @@ MCPTransport::BinderUP Server::Bind(MCPTransport &transport) { } llvm::Error Server::Accept(MainLoop &loop, MCPTransportUP transport) { - MCPTransport::BinderUP binder = Bind(*transport); + MCPBinderUP binder = Bind(*transport); MCPTransport *transport_ptr = transport.get(); binder->OnDisconnect([this, transport_ptr]() { assert(m_instances.find(transport_ptr) != m_instances.end() && diff --git a/lldb/tools/lldb-dap/DAP.h b/lldb/tools/lldb-dap/DAP.h index fee638258310a..a90ddf59671ee 100644 --- a/lldb/tools/lldb-dap/DAP.h +++ b/lldb/tools/lldb-dap/DAP.h @@ -78,9 +78,7 @@ enum DAPBroadcasterBits { enum class ReplMode { Variable = 0, Command, Auto }; -using DAPTransport = - lldb_private::transport::JSONTransport; +using DAPTransport = lldb_private::transport::JSONTransport; struct DAP final : public DAPTransport::MessageHandler { /// Path to the lldb-dap binary itself. diff --git a/lldb/tools/lldb-dap/Transport.h b/lldb/tools/lldb-dap/Transport.h index dbcd65c6add64..58c48c133f9cb 100644 --- a/lldb/tools/lldb-dap/Transport.h +++ b/lldb/tools/lldb-dap/Transport.h @@ -22,12 +22,18 @@ namespace lldb_dap { +struct ProtocolDescriptor { + using Id = protocol::Id; + using Req = protocol::Request; + using Resp = protocol::Response; + using Evt = protocol::Event; +}; + /// A transport class that performs the Debug Adapter Protocol communication /// with the client. class Transport final : public lldb_private::transport::HTTPDelimitedJSONTransport< - protocol::Id, protocol::Request, protocol::Response, - protocol::Event> { + ProtocolDescriptor> { public: Transport(llvm::StringRef client_name, lldb_dap::Log *log, lldb::IOObjectSP input, lldb::IOObjectSP output); diff --git a/lldb/unittests/DAP/TestBase.h b/lldb/unittests/DAP/TestBase.h index aaeab3b3d2cd9..c32f3a769c737 100644 --- a/lldb/unittests/DAP/TestBase.h +++ b/lldb/unittests/DAP/TestBase.h @@ -11,6 +11,7 @@ #include "Protocol/ProtocolBase.h" #include "TestingSupport/Host/JSONTransportTestUtilities.h" #include "TestingSupport/SubsystemRAII.h" +#include "Transport.h" #include "lldb/Host/FileSystem.h" #include "lldb/Host/HostInfo.h" #include "lldb/Host/MainLoop.h" @@ -47,9 +48,7 @@ inline void PrintTo(const Message &message, std::ostream *os) { namespace lldb_dap_tests { -using TestDAPTransport = - TestTransport; +using TestDAPTransport = TestTransport; /// A base class for tests that need transport configured for communicating DAP /// messages. @@ -63,9 +62,7 @@ class TransportBase : public testing::Test { std::unique_ptr log; std::unique_ptr to_client; - MockMessageHandler - client; + MockMessageHandler client; std::unique_ptr to_server; std::unique_ptr dap; diff --git a/lldb/unittests/Host/JSONTransportTest.cpp b/lldb/unittests/Host/JSONTransportTest.cpp index c7654006a1397..8450909038537 100644 --- a/lldb/unittests/Host/JSONTransportTest.cpp +++ b/lldb/unittests/Host/JSONTransportTest.cpp @@ -28,6 +28,7 @@ #include #include #include +#include using namespace llvm; using namespace lldb_private; @@ -77,17 +78,20 @@ void PrintTo(const Req &message, std::ostream *os) { struct Resp { int id = 0; + int errorCode = 0; std::optional result; }; json::Value toJSON(const Resp &T) { - return json::Object{{"id", T.id}, {"result", T.result}}; + return json::Object{ + {"id", T.id}, {"errorCode", T.errorCode}, {"result", T.result}}; } bool fromJSON(const json::Value &V, Resp &T, json::Path P) { json::ObjectMapper O(V, P); - return O && O.map("id", T.id) && O.map("result", T.result); + return O && O.map("id", T.id) && O.mapOptional("errorCode", T.errorCode) && + O.map("result", T.result); } bool operator==(const Resp &a, const Resp &b) { - return a.id == b.id && a.result == b.result; + return a.id == b.id && a.errorCode == b.errorCode && a.result == b.result; } inline llvm::raw_ostream &operator<<(llvm::raw_ostream &OS, const Resp &V) { OS << toJSON(V); @@ -181,8 +185,58 @@ bool fromJSON(const json::Value &V, MyFnResult &T, json::Path P) { return O && O.map("c", T.c); } -using Transport = TestTransport; -using MessageHandler = MockMessageHandler; +struct ProtoDesc { + using Id = int; + using Req = Req; + using Resp = Resp; + using Evt = Evt; + + static inline Id InitialId() { return 0; } + static inline Req Make(Id id, llvm::StringRef method, + std::optional params) { + return Req{id, method.str(), params}; + } + static inline Evt Make(llvm::StringRef method, + std::optional params) { + return Evt{method.str(), params}; + } + static inline Resp Make(Req req, llvm::Error error) { + Resp resp; + resp.id = req.id; + llvm::handleAllErrors( + std::move(error), [&](const llvm::ErrorInfoBase &err) { + std::error_code cerr = err.convertToErrorCode(); + resp.errorCode = + cerr == llvm::inconvertibleErrorCode() ? 1 : cerr.value(); + resp.result = err.message(); + }); + return resp; + } + static inline Resp Make(Req req, std::optional result) { + return Resp{req.id, 0, std::move(result)}; + } + static inline Id KeyFor(Resp r) { return r.id; } + static inline llvm::StringRef KeyFor(Req r) { return r.name; } + static inline llvm::StringRef KeyFor(Evt e) { return e.name; } + static inline std::optional Extract(Req r) { + return r.params; + } + static inline llvm::Expected Extract(Resp r) { + if (r.errorCode != 0) + return llvm::createStringError( + std::error_code(r.errorCode, std::generic_category()), + r.result && r.result->getAsString() ? *r.result->getAsString() + : "no-message"); + return r.result; + } + static inline std::optional Extract(Evt e) { + return e.params; + } +}; + +using Transport = TestTransport; +using Binder = lldb_private::transport::Binder; +using MessageHandler = MockMessageHandler; } // namespace test_protocol @@ -240,8 +294,7 @@ template class JSONTransportTest : public PipePairTest { }; class TestHTTPDelimitedJSONTransport final - : public HTTPDelimitedJSONTransport< - int, test_protocol::Req, test_protocol::Resp, test_protocol::Evt> { + : public HTTPDelimitedJSONTransport { public: using HTTPDelimitedJSONTransport::HTTPDelimitedJSONTransport; @@ -269,8 +322,7 @@ class HTTPDelimitedJSONTransportTest }; class TestJSONRPCTransport final - : public JSONRPCTransport { + : public JSONRPCTransport { public: using JSONRPCTransport::JSONRPCTransport; @@ -299,13 +351,13 @@ class TransportBinderTest : public testing::Test { std::unique_ptr to_remote; std::unique_ptr from_remote; - std::unique_ptr binder; + std::unique_ptr binder; test_protocol::MessageHandler remote; MainLoop loop; void SetUp() override { std::tie(to_remote, from_remote) = test_protocol::Transport::createPair(); - binder = std::make_unique(*to_remote); + binder = std::make_unique(*to_remote); auto binder_handle = to_remote->RegisterMessageHandler(loop, remote); EXPECT_THAT_EXPECTED(binder_handle, Succeeded()); @@ -322,37 +374,6 @@ class TransportBinderTest : public testing::Test { } // namespace -namespace lldb_private::transport { - -using namespace test_protocol; - -template <> -test_protocol::Req MakeRequest(int id, llvm::StringRef method, - std::optional params) { - return test_protocol::Req{id, method.str(), params}; -} -template <> Resp MakeResponse(const Req &req, llvm::Error error) { - llvm::consumeError(std::move(error)); - return Resp{req.id, std::nullopt}; -} -template <> Resp MakeResponse(const Req &req, json::Value result) { - return Resp{req.id, std::move(result)}; -} -template <> -Evt MakeEvent(llvm::StringRef method, std::optional params) { - return Evt{method.str(), params}; -} -template <> llvm::Expected GetResult(const Resp &resp) { - return resp.result; -} -template <> int GetId(const Resp &resp) { return resp.id; } -template <> llvm::StringRef GetMethod(const Req &req) { return req.name; } -template <> llvm::StringRef GetMethod(const Evt &evt) { return evt.name; } -template <> json::Value GetParams(const Req &req) { return req.params; } -template <> json::Value GetParams(const Evt &evt) { return evt.params; } - -} // namespace lldb_private::transport - // Failing on Windows, see https://github.com/llvm/llvm-project/issues/153446. #ifndef _WIN32 using namespace test_protocol; @@ -387,19 +408,17 @@ TEST_F(HTTPDelimitedJSONTransportTest, ReadMultipleMessagesInSingleWrite) { Evt{"two", std::nullopt}, }, Message{ - Resp{2, std::nullopt}, + Resp{2, 0, std::nullopt}, }); EXPECT_CALL(message_handler, Received(Req{6, "one", std::nullopt})); EXPECT_CALL(message_handler, Received(Evt{"two", std::nullopt})); - EXPECT_CALL(message_handler, Received(Resp{2, std::nullopt})); + EXPECT_CALL(message_handler, Received(Resp{2, 0, std::nullopt})); ASSERT_THAT_ERROR(Run(), Succeeded()); } TEST_F(HTTPDelimitedJSONTransportTest, ReadAcrossMultipleChunks) { std::string long_str = std::string( - HTTPDelimitedJSONTransport::kReadBufferSize * - 2, + HTTPDelimitedJSONTransport::kReadBufferSize * 2, 'x'); Write(Req{5, long_str, std::nullopt}); EXPECT_CALL(message_handler, Received(Req{5, long_str, std::nullopt})); @@ -482,7 +501,7 @@ TEST_F(HTTPDelimitedJSONTransportTest, InvalidTransport) { TEST_F(HTTPDelimitedJSONTransportTest, Write) { ASSERT_THAT_ERROR(transport->Send(Req{7, "foo", std::nullopt}), Succeeded()); - ASSERT_THAT_ERROR(transport->Send(Resp{5, "bar"}), Succeeded()); + ASSERT_THAT_ERROR(transport->Send(Resp{5, 0, "bar"}), Succeeded()); ASSERT_THAT_ERROR(transport->Send(Evt{"baz", std::nullopt}), Succeeded()); output.CloseWriteFileDescriptor(); char buf[1024]; @@ -492,8 +511,8 @@ TEST_F(HTTPDelimitedJSONTransportTest, Write) { ASSERT_EQ(StringRef(buf, *bytes_read), StringRef("Content-Length: 35\r\n\r\n" R"({"id":7,"name":"foo","params":null})" - "Content-Length: 23\r\n\r\n" - R"({"id":5,"result":"bar"})" + "Content-Length: 37\r\n\r\n" + R"({"errorCode":0,"id":5,"result":"bar"})" "Content-Length: 28\r\n\r\n" R"({"name":"baz","params":null})")); } @@ -519,21 +538,18 @@ TEST_F(JSONRPCTransportTest, Read) { TEST_F(JSONRPCTransportTest, ReadMultipleMessagesInSingleWrite) { InSequence seq; Write(Message{Req{1, "one", std::nullopt}}, Message{Evt{"two", std::nullopt}}, - Message{Resp{3, "three"}}); + Message{Resp{3, 0, "three"}}); EXPECT_CALL(message_handler, Received(Req{1, "one", std::nullopt})); EXPECT_CALL(message_handler, Received(Evt{"two", std::nullopt})); - EXPECT_CALL(message_handler, Received(Resp{3, "three"})); + EXPECT_CALL(message_handler, Received(Resp{3, 0, "three"})); ASSERT_THAT_ERROR(Run(), Succeeded()); } TEST_F(JSONRPCTransportTest, ReadAcrossMultipleChunks) { // Use a string longer than the chunk size to ensure we split the message // across the chunk boundary. - std::string long_str = - std::string(IOTransport::kReadBufferSize * - 2, - 'x'); + std::string long_str = std::string( + IOTransport::kReadBufferSize * 2, 'x'); Write(Req{42, long_str, std::nullopt}); EXPECT_CALL(message_handler, Received(Req{42, long_str, std::nullopt})); ASSERT_THAT_ERROR(Run(), Succeeded()); @@ -576,7 +592,7 @@ TEST_F(JSONRPCTransportTest, ReaderWithUnhandledData) { TEST_F(JSONRPCTransportTest, Write) { ASSERT_THAT_ERROR(transport->Send(Req{11, "foo", std::nullopt}), Succeeded()); - ASSERT_THAT_ERROR(transport->Send(Resp{14, "bar"}), Succeeded()); + ASSERT_THAT_ERROR(transport->Send(Resp{14, 0, "bar"}), Succeeded()); ASSERT_THAT_ERROR(transport->Send(Evt{"baz", std::nullopt}), Succeeded()); output.CloseWriteFileDescriptor(); char buf[1024]; @@ -586,7 +602,7 @@ TEST_F(JSONRPCTransportTest, Write) { ASSERT_EQ(StringRef(buf, *bytes_read), StringRef(R"({"id":11,"name":"foo","params":null})" "\n" - R"({"id":14,"result":"bar"})" + R"({"errorCode":0,"id":14,"result":"bar"})" "\n" R"({"name":"baz","params":null})" "\n")); @@ -609,7 +625,7 @@ TEST_F(TransportBinderTest, OutBoundRequests) { replied = true; }); EXPECT_CALL(remote, Received(Req{1, "add", MyFnParams{1, 2}})); - EXPECT_THAT_ERROR(from_remote->Send(Resp{1, toJSON(MyFnResult{3})}), + EXPECT_THAT_ERROR(from_remote->Send(Resp{1, 0, toJSON(MyFnResult{3})}), Succeeded()); Run(); EXPECT_TRUE(replied); @@ -625,7 +641,7 @@ TEST_F(TransportBinderTest, OutBoundRequestsVoidParams) { replied = true; }); EXPECT_CALL(remote, Received(Req{1, "voidParam", std::nullopt})); - EXPECT_THAT_ERROR(from_remote->Send(Resp{1, toJSON(MyFnResult{3})}), + EXPECT_THAT_ERROR(from_remote->Send(Resp{1, 0, toJSON(MyFnResult{3})}), Succeeded()); Run(); EXPECT_TRUE(replied); @@ -640,7 +656,7 @@ TEST_F(TransportBinderTest, OutBoundRequestsVoidResult) { replied = true; }); EXPECT_CALL(remote, Received(Req{1, "voidResult", MyFnParams{4, 5}})); - EXPECT_THAT_ERROR(from_remote->Send(Resp{1, std::nullopt}), Succeeded()); + EXPECT_THAT_ERROR(from_remote->Send(Resp{1, 0, std::nullopt}), Succeeded()); Run(); EXPECT_TRUE(replied); } @@ -654,7 +670,7 @@ TEST_F(TransportBinderTest, OutBoundRequestsVoidParamsAndVoidResult) { replied = true; }); EXPECT_CALL(remote, Received(Req{1, "voidParamAndResult", std::nullopt})); - EXPECT_THAT_ERROR(from_remote->Send(Resp{1, std::nullopt}), Succeeded()); + EXPECT_THAT_ERROR(from_remote->Send(Resp{1, 0, std::nullopt}), Succeeded()); Run(); EXPECT_TRUE(replied); } @@ -673,7 +689,7 @@ TEST_F(TransportBinderTest, InBoundRequests) { EXPECT_THAT_ERROR(from_remote->Send(Req{1, "add", MyFnParams{3, 4}}), Succeeded()); - EXPECT_CALL(remote, Received(Resp{1, MyFnResult{9}})); + EXPECT_CALL(remote, Received(Resp{1, 0, MyFnResult{9}})); Run(); EXPECT_TRUE(called); } @@ -689,7 +705,7 @@ TEST_F(TransportBinderTest, InBoundRequestsVoidParams) { 2); EXPECT_THAT_ERROR(from_remote->Send(Req{2, "voidParam", std::nullopt}), Succeeded()); - EXPECT_CALL(remote, Received(Resp{2, MyFnResult{2}})); + EXPECT_CALL(remote, Received(Resp{2, 0, MyFnResult{2}})); Run(); EXPECT_TRUE(called); } @@ -708,7 +724,7 @@ TEST_F(TransportBinderTest, InBoundRequestsVoidResult) { 2); EXPECT_THAT_ERROR(from_remote->Send(Req{3, "voidResult", MyFnParams{3, 4}}), Succeeded()); - EXPECT_CALL(remote, Received(Resp{3, std::nullopt})); + EXPECT_CALL(remote, Received(Resp{3, 0, std::nullopt})); Run(); EXPECT_TRUE(called); } @@ -725,7 +741,7 @@ TEST_F(TransportBinderTest, InBoundRequestsVoidParamsAndResult) { EXPECT_THAT_ERROR( from_remote->Send(Req{4, "voidParamAndResult", std::nullopt}), Succeeded()); - EXPECT_CALL(remote, Received(Resp{4, std::nullopt})); + EXPECT_CALL(remote, Received(Resp{4, 0, std::nullopt})); Run(); EXPECT_TRUE(called); } diff --git a/lldb/unittests/Protocol/ProtocolMCPServerTest.cpp b/lldb/unittests/Protocol/ProtocolMCPServerTest.cpp index bc3f62a2a2db4..9628cbd91e9ce 100644 --- a/lldb/unittests/Protocol/ProtocolMCPServerTest.cpp +++ b/lldb/unittests/Protocol/ProtocolMCPServerTest.cpp @@ -127,9 +127,7 @@ class TestServer : public Server { using Server::Server; }; -using Transport = TestTransport; +using Transport = TestTransport; class ProtocolServerMCPTest : public testing::Test { public: @@ -139,11 +137,11 @@ class ProtocolServerMCPTest : public testing::Test { lldb_private::MainLoop::ReadHandleUP handles[2]; std::unique_ptr to_server; - Transport::BinderUP binder; + MCPBinderUP binder; std::unique_ptr server_up; std::unique_ptr to_client; - MockMessageHandler client; + MockMessageHandler client; std::vector logged_messages; @@ -176,8 +174,8 @@ class ProtocolServerMCPTest : public testing::Test { template Expected Call(StringRef method, const Params ¶ms) { std::promise promised_result; - Request req = MakeRequest( - /*id=*/1, method, toJSON(params)); + Request req = + lldb_protocol::mcp::Request{/*id=*/1, method.str(), toJSON(params)}; EXPECT_THAT_ERROR(to_server->Send(req), Succeeded()); EXPECT_CALL(client, Received(testing::An())) .WillOnce( diff --git a/lldb/unittests/TestingSupport/Host/JSONTransportTestUtilities.h b/lldb/unittests/TestingSupport/Host/JSONTransportTestUtilities.h index 73918ace2c521..ca3e95911d863 100644 --- a/lldb/unittests/TestingSupport/Host/JSONTransportTestUtilities.h +++ b/lldb/unittests/TestingSupport/Host/JSONTransportTestUtilities.h @@ -21,20 +21,19 @@ #include #include -template +template class TestTransport final - : public lldb_private::transport::JSONTransport { + : public lldb_private::transport::JSONTransport { public: using MessageHandler = - typename lldb_private::transport::JSONTransport::MessageHandler; + typename lldb_private::transport::JSONTransport::MessageHandler; - static std::pair>, - std::unique_ptr>> + static std::pair>, + std::unique_ptr>> createPair() { - std::unique_ptr> transports[2] = { - std::make_unique>(), - std::make_unique>()}; + std::unique_ptr> transports[2] = { + std::make_unique>(), + std::make_unique>()}; return std::make_pair(std::move(transports[0]), std::move(transports[1])); } @@ -47,7 +46,7 @@ class TestTransport final m_dummy_file = std::move(*dummy_file); } - llvm::Error Send(const Evt &evt) override { + llvm::Error Send(const Proto::Evt &evt) override { EXPECT_TRUE(m_loop && m_handler) << "Send called before RegisterMessageHandler"; m_loop->AddPendingCallback([this, evt](lldb_private::MainLoopBase &) { @@ -56,7 +55,7 @@ class TestTransport final return llvm::Error::success(); } - llvm::Error Send(const Req &req) override { + llvm::Error Send(const Proto::Req &req) override { EXPECT_TRUE(m_loop && m_handler) << "Send called before RegisterMessageHandler"; m_loop->AddPendingCallback([this, req](lldb_private::MainLoopBase &) { @@ -65,7 +64,7 @@ class TestTransport final return llvm::Error::success(); } - llvm::Error Send(const Resp &resp) override { + llvm::Error Send(const Proto::Resp &resp) override { EXPECT_TRUE(m_loop && m_handler) << "Send called before RegisterMessageHandler"; m_loop->AddPendingCallback([this, resp](lldb_private::MainLoopBase &) { @@ -99,14 +98,13 @@ class TestTransport final lldb::FileSP m_dummy_file = nullptr; }; -template +template class MockMessageHandler final - : public lldb_private::transport::JSONTransport::MessageHandler { + : public lldb_private::transport::JSONTransport::MessageHandler { public: - MOCK_METHOD(void, Received, (const Req &), (override)); - MOCK_METHOD(void, Received, (const Resp &), (override)); - MOCK_METHOD(void, Received, (const Evt &), (override)); + MOCK_METHOD(void, Received, (const typename Proto::Req &), (override)); + MOCK_METHOD(void, Received, (const typename Proto::Resp &), (override)); + MOCK_METHOD(void, Received, (const typename Proto::Evt &), (override)); MOCK_METHOD(void, OnError, (llvm::Error), (override)); MOCK_METHOD(void, OnClosed, (), (override)); }; From 49c146204162b40f4d5780265d8d2266c70333be Mon Sep 17 00:00:00 2001 From: John Harrison Date: Tue, 23 Sep 2025 13:08:09 -0700 Subject: [PATCH 6/8] Improving comment formatting. --- lldb/include/lldb/Host/JSONTransport.h | 132 ++++++++++++---------- lldb/include/lldb/Protocol/MCP/Protocol.h | 7 -- lldb/source/Protocol/MCP/Server.cpp | 24 ++-- 3 files changed, 82 insertions(+), 81 deletions(-) diff --git a/lldb/include/lldb/Host/JSONTransport.h b/lldb/include/lldb/Host/JSONTransport.h index f04cf87954639..6efb0da4dcfbb 100644 --- a/lldb/include/lldb/Host/JSONTransport.h +++ b/lldb/include/lldb/Host/JSONTransport.h @@ -95,14 +95,19 @@ class MethodNotFound : public llvm::ErrorInfo { std::string m_method; }; -// FIXME: Once we upgrade to c++20, use this concept for JSONTransport. -// template -// concept ProtocolDescriptor = requires { -// typename T::Id; -// typename T::Req; -// typename T::Resp; -// typename T::Evt; -// }; +/* +FIXME: Once we upgrade to c++20, use this concept for JSONTransport. + +/// A ProtocolDescriptor details the types used in a JSONTransport for handling +/// transport communication. +template +concept ProtocolDescriptor = requires { + typename T::Id; + typename T::Req; + typename T::Resp; + typename T::Evt; +}; +*/ /// A transport is responsible for maintaining the connection to a client /// application, and reading/writing structured messages to it. @@ -111,7 +116,7 @@ class MethodNotFound : public llvm::ErrorInfo { /// - Messages will not be sent concurrently. /// - Messages MAY be sent while Run() is reading, or its callback is active. /// -/// FIXME: Once we upgrade to c++20, use template +/// FIXME: Once we upgrade to c++20, use `template ` template class JSONTransport { public: using Req = typename Proto::Req; @@ -119,9 +124,6 @@ template class JSONTransport { using Evt = typename Proto::Evt; using Message = std::variant; - // class Binder; - // using BinderUP = std::unique_ptr; - virtual ~JSONTransport() = default; /// Sends an event, a message that does not require a response. @@ -382,51 +384,56 @@ using OutgoingRequest = typename detail::request_t::type; /// A function to send an outgoing event. template using OutgoingEvent = typename detail::event_t

::type; -// FIXME: With c++20, we should use this concept: -// template -// concept BindingBuilder = -// ProtocolDescriptor && -// requires(T::Id id, T::Req req, T::Resp resp, T::Evt evt, -// llvm::StringRef method, std::optional params, -// std::optional result, llvm::Error err) { -// // For initializing the unique sequence identifier; -// { T::InitialId() } -> std::same_as; -// // Incrementing the sequence identifier. -// { id++ } -> std::same_as; -// -// // Constructing protocol types -// // @{ -// // Construct a new request. -// { T::Make(id, method, params) } -> std::same_as; -// // Construct a new error response. -// { T::Make(req, std::move(err)) } -> std::same_as; -// // Construct a new success response. -// { T::Make(req, result) } -> std::same_as; -// // Construct a new event. -// { T::Make(method, params) } -> std::same_as; -// // @} -// -// // Keys for associated types. -// // @{ -// // Looking up in flight responses. -// { T::KeyFor(resp) } -> std::same_as; -// // Extract method from request. -// { T::KeyFor(req) } -> std::same_as; -// // Extract method from event. -// { T::KeyFor(evt) } -> std::same_as; -// // @} -// -// // Extracting information from associated types. -// // @{ -// // Extract parameters from a request. -// { T::Extract(req) } -> std::same_as>; -// // Extract result from a response. -// { T::Extract(resp) } -> -// std::same_as>; -// // Extract parameters from an event. -// { T::Extract(evt) } -> std::same_as>; -// // @} -// }; +/* +FIXME: With c++20, we should use this concept: + +/// This represents a protocol description that includes additional helpers +/// for constructing requests, responses and events to work with `Binder`. +template +concept BindingBuilder = + ProtocolDescriptor && + requires(T::Id id, T::Req req, T::Resp resp, T::Evt evt, + llvm::StringRef method, std::optional params, + std::optional result, llvm::Error err) { + /// For initializing the unique sequence identifier; + { T::InitialId() } -> std::same_as; + /// Incrementing the sequence identifier. + { id++ } -> std::same_as; + + /// Constructing protocol types + /// @{ + /// Construct a new request. + { T::Make(id, method, params) } -> std::same_as; + /// Construct a new error response. + { T::Make(req, std::move(err)) } -> std::same_as; + /// Construct a new success response. + { T::Make(req, result) } -> std::same_as; + /// Construct a new event. + { T::Make(method, params) } -> std::same_as; + /// @} + + /// Keys for associated types. + /// @{ + /// Looking up in flight responses. + { T::KeyFor(resp) } -> std::same_as; + /// Extract method from request. + { T::KeyFor(req) } -> std::same_as; + /// Extract method from event. + { T::KeyFor(evt) } -> std::same_as; + /// @} + + /// Extracting information from associated types. + /// @{ + /// Extract parameters from a request. + { T::Extract(req) } -> std::same_as>; + /// Extract result from a response. + { T::Extract(resp) } -> + std::same_as>; + /// Extract parameters from an event. + { T::Extract(evt) } -> std::same_as>; + /// @} + }; +*/ /// Binder collects a table of functions that handle calls. /// @@ -435,25 +442,26 @@ template using OutgoingEvent = typename detail::event_t

::type; /// This allows a JSONTransport to handle incoming and outgoing requests and /// events. /// -/// A simple example could be to a method to a lambda like: -/// +/// A bind of an incoming request to a lambda. /// \code{cpp} /// Binder binder{transport}; -/// // Binds an incoming request handler. /// binder.bind>("adder", [](const vector ¶ms) { /// int sum = 0; /// for (int v : params) /// sum += v; /// return sum; /// }); -/// // Binds an outgoing request handler. +/// \endcode +/// +/// A bind of an outgoing request. +/// \code{cpp} /// OutgoingRequest> call_add = /// binder.bind>("add"); /// call_add({1,2,3}, [](Expected result) { /// cout << *result << "\n"; /// }); /// \endcode -// FIXME: In c++20 use: template +/// FIXME: In c++20 use `template `. template class Binder : public JSONTransport::MessageHandler { using Req = Proto::Req; diff --git a/lldb/include/lldb/Protocol/MCP/Protocol.h b/lldb/include/lldb/Protocol/MCP/Protocol.h index 9c5df65a6bfcb..a0ba8659ffe24 100644 --- a/lldb/include/lldb/Protocol/MCP/Protocol.h +++ b/lldb/include/lldb/Protocol/MCP/Protocol.h @@ -329,11 +329,4 @@ MakeRequest(int64_t id, llvm::StringRef method, } // namespace lldb_protocol::mcp -// namespace llvm::json { -// inline Value toJSON(const lldb_protocol::mcp::Void &) { return Object(); } -// inline bool fromJSON(const Value &, lldb_protocol::mcp::Void &, Path) { -// return true; -// } -// } // namespace llvm::json - #endif diff --git a/lldb/source/Protocol/MCP/Server.cpp b/lldb/source/Protocol/MCP/Server.cpp index cd0f151ba1466..71323adbac5f5 100644 --- a/lldb/source/Protocol/MCP/Server.cpp +++ b/lldb/source/Protocol/MCP/Server.cpp @@ -127,20 +127,20 @@ void Server::AddResourceProvider( } MCPBinderUP Server::Bind(MCPTransport &transport) { - MCPBinderUP binder = std::make_unique(transport); - binder->Bind( + MCPBinderUP binder_up = std::make_unique(transport); + binder_up->Bind( "initialize", &Server::InitializeHandler, this); - binder->Bind("tools/list", &Server::ToolsListHandler, - this); - binder->Bind("tools/call", - &Server::ToolsCallHandler, this); - binder->Bind("resources/list", - &Server::ResourcesListHandler, this); - binder->Bind( + binder_up->Bind("tools/list", + &Server::ToolsListHandler, this); + binder_up->Bind( + "tools/call", &Server::ToolsCallHandler, this); + binder_up->Bind( + "resources/list", &Server::ResourcesListHandler, this); + binder_up->Bind( "resources/read", &Server::ResourcesReadHandler, this); - binder->Bind("notifications/initialized", - [this]() { Log("MCP initialization complete"); }); - return binder; + binder_up->Bind("notifications/initialized", + [this]() { Log("MCP initialization complete"); }); + return binder_up; } llvm::Error Server::Accept(MainLoop &loop, MCPTransportUP transport) { From 1c1e98690cb10e82a5eaadc3bde595ea3e348636 Mon Sep 17 00:00:00 2001 From: John Harrison Date: Tue, 23 Sep 2025 14:01:26 -0700 Subject: [PATCH 7/8] Wrapping concepts in an #ifdef instead of commenting them out. --- lldb/include/lldb/Host/JSONTransport.h | 94 ++++++++++++++++------ lldb/include/lldb/Protocol/MCP/Transport.h | 4 +- lldb/unittests/Host/JSONTransportTest.cpp | 4 +- 3 files changed, 75 insertions(+), 27 deletions(-) diff --git a/lldb/include/lldb/Host/JSONTransport.h b/lldb/include/lldb/Host/JSONTransport.h index 6efb0da4dcfbb..1453316e96fb4 100644 --- a/lldb/include/lldb/Host/JSONTransport.h +++ b/lldb/include/lldb/Host/JSONTransport.h @@ -35,6 +35,9 @@ #include #include #include +#if __cplusplus >= 202002L +#include +#endif namespace lldb_private::transport { @@ -95,9 +98,7 @@ class MethodNotFound : public llvm::ErrorInfo { std::string m_method; }; -/* -FIXME: Once we upgrade to c++20, use this concept for JSONTransport. - +#if __cplusplus >= 202002L /// A ProtocolDescriptor details the types used in a JSONTransport for handling /// transport communication. template @@ -107,7 +108,7 @@ concept ProtocolDescriptor = requires { typename T::Resp; typename T::Evt; }; -*/ +#endif /// A transport is responsible for maintaining the connection to a client /// application, and reading/writing structured messages to it. @@ -116,8 +117,12 @@ concept ProtocolDescriptor = requires { /// - Messages will not be sent concurrently. /// - Messages MAY be sent while Run() is reading, or its callback is active. /// -/// FIXME: Once we upgrade to c++20, use `template ` -template class JSONTransport { +#if __cplusplus >= 202002L +template +#else +template +#endif +class JSONTransport { public: using Req = typename Proto::Req; using Resp = typename Proto::Resp; @@ -179,9 +184,15 @@ template class IOTransport : public JSONTransport { IOTransport(lldb::IOObjectSP in, lldb::IOObjectSP out) : m_in(in), m_out(out) {} - llvm::Error Send(const Proto::Evt &evt) override { return Write(evt); } - llvm::Error Send(const Proto::Req &req) override { return Write(req); } - llvm::Error Send(const Proto::Resp &resp) override { return Write(resp); } + llvm::Error Send(const typename Proto::Evt &evt) override { + return Write(evt); + } + llvm::Error Send(const typename Proto::Req &req) override { + return Write(req); + } + llvm::Error Send(const typename Proto::Resp &resp) override { + return Write(resp); + } llvm::Expected RegisterMessageHandler(MainLoop &loop, MessageHandler &handler) override { @@ -261,7 +272,11 @@ template class IOTransport : public JSONTransport { }; /// A transport class for JSON with a HTTP header. +#if __cplusplus >= 202002L +template +#else template +#endif class HTTPDelimitedJSONTransport : public IOTransport { public: using IOTransport::IOTransport; @@ -328,7 +343,12 @@ class HTTPDelimitedJSONTransport : public IOTransport { }; /// A transport class for JSON RPC. -template class JSONRPCTransport : public IOTransport { +#if __cplusplus >= 202002L +template +#else +template +#endif +class JSONRPCTransport : public IOTransport { public: using IOTransport::IOTransport; @@ -384,9 +404,7 @@ using OutgoingRequest = typename detail::request_t::type; /// A function to send an outgoing event. template using OutgoingEvent = typename detail::event_t

::type; -/* -FIXME: With c++20, we should use this concept: - +#if __cplusplus >= 202002L /// This represents a protocol description that includes additional helpers /// for constructing requests, responses and events to work with `Binder`. template @@ -417,9 +435,9 @@ concept BindingBuilder = /// Looking up in flight responses. { T::KeyFor(resp) } -> std::same_as; /// Extract method from request. - { T::KeyFor(req) } -> std::same_as; + { T::KeyFor(req) } -> std::same_as; /// Extract method from event. - { T::KeyFor(evt) } -> std::same_as; + { T::KeyFor(evt) } -> std::same_as; /// @} /// Extracting information from associated types. @@ -427,13 +445,12 @@ concept BindingBuilder = /// Extract parameters from a request. { T::Extract(req) } -> std::same_as>; /// Extract result from a response. - { T::Extract(resp) } -> - std::same_as>; + { T::Extract(resp) } -> std::same_as>; /// Extract parameters from an event. { T::Extract(evt) } -> std::same_as>; /// @} }; -*/ +#endif /// Binder collects a table of functions that handle calls. /// @@ -461,13 +478,16 @@ concept BindingBuilder = /// cout << *result << "\n"; /// }); /// \endcode -/// FIXME: In c++20 use `template `. +#if __cplusplus >= 202002L +template +#else template +#endif class Binder : public JSONTransport::MessageHandler { - using Req = Proto::Req; - using Resp = Proto::Resp; - using Evt = Proto::Evt; - using Id = Proto::Id; + using Req = typename Proto::Req; + using Resp = typename Proto::Resp; + using Evt = typename Proto::Evt; + using Id = typename Proto::Id; using Transport = JSONTransport; using MessageHandler = typename Transport::MessageHandler; @@ -625,7 +645,11 @@ class Binder : public JSONTransport::MessageHandler { }; }; +#if __cplusplus >= 202002L +template +#else template +#endif template void Binder::OnDisconnect(Fn &&fn, Args &&...args) { m_disconnect_handler = [fn, args...]() mutable { @@ -633,7 +657,11 @@ void Binder::OnDisconnect(Fn &&fn, Args &&...args) { }; } +#if __cplusplus >= 202002L +template +#else template +#endif template void Binder::OnError(Fn &&fn, Args &&...args) { m_error_handler = [fn, args...](llvm::Error error) mutable { @@ -642,7 +670,11 @@ void Binder::OnError(Fn &&fn, Args &&...args) { }; } +#if __cplusplus >= 202002L +template +#else template +#endif template void Binder::Bind(llvm::StringLiteral method, Fn &&fn, Args &&...args) { assert(m_request_handlers.find(method) == m_request_handlers.end() && @@ -699,7 +731,11 @@ void Binder::Bind(llvm::StringLiteral method, Fn &&fn, Args &&...args) { } } +#if __cplusplus >= 202002L +template +#else template +#endif template void Binder::Bind(llvm::StringLiteral method, Fn &&fn, Args &&...args) { assert(m_event_handlers.find(method) == m_event_handlers.end() && @@ -720,7 +756,11 @@ void Binder::Bind(llvm::StringLiteral method, Fn &&fn, Args &&...args) { } } +#if __cplusplus >= 202002L +template +#else template +#endif template OutgoingRequest Binder::Bind(llvm::StringLiteral method) { @@ -785,7 +825,11 @@ Binder::Bind(llvm::StringLiteral method) { } } +#if __cplusplus >= 202002L +template +#else template +#endif template OutgoingEvent Binder::Bind(llvm::StringLiteral method) { if constexpr (std::is_void_v) { @@ -803,7 +847,11 @@ OutgoingEvent Binder::Bind(llvm::StringLiteral method) { } } +#if __cplusplus >= 202002L +template +#else template +#endif template llvm::Expected Binder::Parse(const llvm::json::Value &raw, llvm::StringRef method) { diff --git a/lldb/include/lldb/Protocol/MCP/Transport.h b/lldb/include/lldb/Protocol/MCP/Transport.h index cf7cb100459de..b7a1eb778d660 100644 --- a/lldb/include/lldb/Protocol/MCP/Transport.h +++ b/lldb/include/lldb/Protocol/MCP/Transport.h @@ -54,8 +54,8 @@ struct ProtocolDescriptor { return Response{req.id, std::move(result)}; } static inline Id KeyFor(Response r) { return std::get(r.id); } - static inline llvm::StringRef KeyFor(Request r) { return r.method; } - static inline llvm::StringRef KeyFor(Notification n) { return n.method; } + static inline std::string KeyFor(Request r) { return r.method; } + static inline std::string KeyFor(Notification n) { return n.method; } static inline std::optional Extract(Request r) { return r.params; } diff --git a/lldb/unittests/Host/JSONTransportTest.cpp b/lldb/unittests/Host/JSONTransportTest.cpp index 8450909038537..7db6508aa4a3c 100644 --- a/lldb/unittests/Host/JSONTransportTest.cpp +++ b/lldb/unittests/Host/JSONTransportTest.cpp @@ -216,8 +216,8 @@ struct ProtoDesc { return Resp{req.id, 0, std::move(result)}; } static inline Id KeyFor(Resp r) { return r.id; } - static inline llvm::StringRef KeyFor(Req r) { return r.name; } - static inline llvm::StringRef KeyFor(Evt e) { return e.name; } + static inline std::string KeyFor(Req r) { return r.name; } + static inline std::string KeyFor(Evt e) { return e.name; } static inline std::optional Extract(Req r) { return r.params; } From b316f473f15d38b02e7d6698f1f1b2fb924e90f0 Mon Sep 17 00:00:00 2001 From: John Harrison Date: Tue, 23 Sep 2025 14:13:33 -0700 Subject: [PATCH 8/8] Fixing a c++17 compatibility. --- .../TestingSupport/Host/JSONTransportTestUtilities.h | 6 +++--- 1 file changed, 3 insertions(+), 3 deletions(-) diff --git a/lldb/unittests/TestingSupport/Host/JSONTransportTestUtilities.h b/lldb/unittests/TestingSupport/Host/JSONTransportTestUtilities.h index ca3e95911d863..bacf8ca36aa07 100644 --- a/lldb/unittests/TestingSupport/Host/JSONTransportTestUtilities.h +++ b/lldb/unittests/TestingSupport/Host/JSONTransportTestUtilities.h @@ -46,7 +46,7 @@ class TestTransport final m_dummy_file = std::move(*dummy_file); } - llvm::Error Send(const Proto::Evt &evt) override { + llvm::Error Send(const typename Proto::Evt &evt) override { EXPECT_TRUE(m_loop && m_handler) << "Send called before RegisterMessageHandler"; m_loop->AddPendingCallback([this, evt](lldb_private::MainLoopBase &) { @@ -55,7 +55,7 @@ class TestTransport final return llvm::Error::success(); } - llvm::Error Send(const Proto::Req &req) override { + llvm::Error Send(const typename Proto::Req &req) override { EXPECT_TRUE(m_loop && m_handler) << "Send called before RegisterMessageHandler"; m_loop->AddPendingCallback([this, req](lldb_private::MainLoopBase &) { @@ -64,7 +64,7 @@ class TestTransport final return llvm::Error::success(); } - llvm::Error Send(const Proto::Resp &resp) override { + llvm::Error Send(const typename Proto::Resp &resp) override { EXPECT_TRUE(m_loop && m_handler) << "Send called before RegisterMessageHandler"; m_loop->AddPendingCallback([this, resp](lldb_private::MainLoopBase &) {