diff --git a/lldb/include/lldb/Host/JSONTransport.h b/lldb/include/lldb/Host/JSONTransport.h index c73021d204258..1453316e96fb4 100644 --- a/lldb/include/lldb/Host/JSONTransport.h +++ b/lldb/include/lldb/Host/JSONTransport.h @@ -18,6 +18,7 @@ #include "lldb/Utility/IOObject.h" #include "lldb/Utility/Status.h" #include "lldb/lldb-forward.h" +#include "llvm/ADT/FunctionExtras.h" #include "llvm/ADT/StringExtras.h" #include "llvm/ADT/StringRef.h" #include "llvm/Support/Error.h" @@ -25,13 +26,23 @@ #include "llvm/Support/FormatVariadic.h" #include "llvm/Support/JSON.h" #include "llvm/Support/raw_ostream.h" +#include +#include +#include #include #include +#include +#include #include #include +#if __cplusplus >= 202002L +#include +#endif -namespace lldb_private { +namespace lldb_private::transport { +/// An error to indicate that the transport reached EOF but there were still +/// unhandled contents in the read buffer. class TransportUnhandledContentsError : public llvm::ErrorInfo { public: @@ -50,17 +61,75 @@ class TransportUnhandledContentsError std::string m_unhandled_contents; }; +/// An error to indicate that the parameters of a Req, Resp or Evt could not be +/// deserialized. +class InvalidParams : public llvm::ErrorInfo { +public: + static char ID; + + explicit InvalidParams(std::string method, std::string context) + : m_method(std::move(method)), m_context(std::move(context)) {} + + void log(llvm::raw_ostream &OS) const override; + std::error_code convertToErrorCode() const override; + +private: + /// The JSONRPC remote method call. + std::string m_method; + + /// Additional context from the parsing failure, e.g. "missing value at + /// (root)[1].str". + std::string m_context; +}; + +/// An error to indicate that no handler was registered for a given method. +class MethodNotFound : public llvm::ErrorInfo { +public: + static char ID; + + static constexpr int kErrorCode = -32601; + + explicit MethodNotFound(std::string method) : m_method(std::move(method)) {} + + void log(llvm::raw_ostream &OS) const override; + std::error_code convertToErrorCode() const override; + +private: + std::string m_method; +}; + +#if __cplusplus >= 202002L +/// A ProtocolDescriptor details the types used in a JSONTransport for handling +/// transport communication. +template +concept ProtocolDescriptor = requires { + typename T::Id; + typename T::Req; + typename T::Resp; + typename T::Evt; +}; +#endif + /// A transport is responsible for maintaining the connection to a client /// application, and reading/writing structured messages to it. /// -/// Transports have limited thread safety requirements: +/// JSONTransport have limited thread safety requirements: /// - Messages will not be sent concurrently. /// - Messages MAY be sent while Run() is reading, or its callback is active. -template class Transport { +/// +#if __cplusplus >= 202002L +template +#else +template +#endif +class JSONTransport { public: + using Req = typename Proto::Req; + using Resp = typename Proto::Resp; + using Evt = typename Proto::Evt; using Message = std::variant; - virtual ~Transport() = default; + virtual ~JSONTransport() = default; /// Sends an event, a message that does not require a response. virtual llvm::Error Send(const Evt &) = 0; @@ -69,7 +138,8 @@ template class Transport { /// Sends a response to a specific request. virtual llvm::Error Send(const Resp &) = 0; - /// Implemented to handle incoming messages. (See Run() below). + /// Implemented to handle incoming messages. (See `RegisterMessageHandler()` + /// below). class MessageHandler { public: virtual ~MessageHandler() = default; @@ -90,8 +160,6 @@ template class Transport { virtual void OnClosed() = 0; }; - using MessageHandlerSP = std::shared_ptr; - /// RegisterMessageHandler registers the Transport with the given MainLoop and /// handles any incoming messages using the given MessageHandler. /// @@ -108,18 +176,23 @@ template class Transport { }; /// An IOTransport sends and receives messages using an IOObject. -template -class IOTransport : public Transport { +template class IOTransport : public JSONTransport { public: - using Transport::Transport; - using MessageHandler = typename Transport::MessageHandler; + using Message = typename JSONTransport::Message; + using MessageHandler = typename JSONTransport::MessageHandler; IOTransport(lldb::IOObjectSP in, lldb::IOObjectSP out) : m_in(in), m_out(out) {} - llvm::Error Send(const Evt &evt) override { return Write(evt); } - llvm::Error Send(const Req &req) override { return Write(req); } - llvm::Error Send(const Resp &resp) override { return Write(resp); } + llvm::Error Send(const typename Proto::Evt &evt) override { + return Write(evt); + } + llvm::Error Send(const typename Proto::Req &req) override { + return Write(req); + } + llvm::Error Send(const typename Proto::Resp &resp) override { + return Write(resp); + } llvm::Expected RegisterMessageHandler(MainLoop &loop, MessageHandler &handler) override { @@ -139,7 +212,7 @@ class IOTransport : public Transport { /// detail. static constexpr size_t kReadBufferSize = 1024; - // FIXME: Write should be protected. +protected: llvm::Error Write(const llvm::json::Value &message) { this->Logv("<-- {0}", message); std::string output = Encode(message); @@ -147,7 +220,6 @@ class IOTransport : public Transport { return m_out->Write(output.data(), bytes_written).takeError(); } -protected: virtual llvm::Expected> Parse() = 0; virtual std::string Encode(const llvm::json::Value &message) = 0; @@ -174,9 +246,8 @@ class IOTransport : public Transport { } for (const std::string &raw_message : *raw_messages) { - llvm::Expected::Message> message = - llvm::json::parse::Message>( - raw_message); + llvm::Expected message = + llvm::json::parse(raw_message); if (!message) { handler.OnError(message.takeError()); return; @@ -201,10 +272,14 @@ class IOTransport : public Transport { }; /// A transport class for JSON with a HTTP header. -template -class HTTPDelimitedJSONTransport : public IOTransport { +#if __cplusplus >= 202002L +template +#else +template +#endif +class HTTPDelimitedJSONTransport : public IOTransport { public: - using IOTransport::IOTransport; + using IOTransport::IOTransport; protected: /// Encodes messages based on @@ -230,8 +305,8 @@ class HTTPDelimitedJSONTransport : public IOTransport { for (const llvm::StringRef &header : llvm::split(headers, kHeaderSeparator)) { auto [key, value] = header.split(kHeaderFieldSeparator); - // 'Content-Length' is the only meaningful key at the moment. Others are - // ignored. + // 'Content-Length' is the only meaningful key at the moment. Others + // are ignored. if (!key.equals_insensitive(kHeaderContentLength)) continue; @@ -268,10 +343,14 @@ class HTTPDelimitedJSONTransport : public IOTransport { }; /// A transport class for JSON RPC. -template -class JSONRPCTransport : public IOTransport { +#if __cplusplus >= 202002L +template +#else +template +#endif +class JSONRPCTransport : public IOTransport { public: - using IOTransport::IOTransport; + using IOTransport::IOTransport; protected: std::string Encode(const llvm::json::Value &message) override { @@ -297,6 +376,497 @@ class JSONRPCTransport : public IOTransport { static constexpr llvm::StringLiteral kMessageSeparator = "\n"; }; -} // namespace lldb_private +/// A handler for the response to an outgoing request. +template +using Reply = + std::conditional_t, + llvm::unique_function, + llvm::unique_function)>>; + +namespace detail { +template struct request_t final { + using type = llvm::unique_function)>; +}; +template struct request_t final { + using type = llvm::unique_function)>; +}; +template struct event_t final { + using type = llvm::unique_function; +}; +template <> struct event_t final { + using type = llvm::unique_function; +}; +} // namespace detail + +template +using OutgoingRequest = typename detail::request_t::type; + +/// A function to send an outgoing event. +template using OutgoingEvent = typename detail::event_t

::type; + +#if __cplusplus >= 202002L +/// This represents a protocol description that includes additional helpers +/// for constructing requests, responses and events to work with `Binder`. +template +concept BindingBuilder = + ProtocolDescriptor && + requires(T::Id id, T::Req req, T::Resp resp, T::Evt evt, + llvm::StringRef method, std::optional params, + std::optional result, llvm::Error err) { + /// For initializing the unique sequence identifier; + { T::InitialId() } -> std::same_as; + /// Incrementing the sequence identifier. + { id++ } -> std::same_as; + + /// Constructing protocol types + /// @{ + /// Construct a new request. + { T::Make(id, method, params) } -> std::same_as; + /// Construct a new error response. + { T::Make(req, std::move(err)) } -> std::same_as; + /// Construct a new success response. + { T::Make(req, result) } -> std::same_as; + /// Construct a new event. + { T::Make(method, params) } -> std::same_as; + /// @} + + /// Keys for associated types. + /// @{ + /// Looking up in flight responses. + { T::KeyFor(resp) } -> std::same_as; + /// Extract method from request. + { T::KeyFor(req) } -> std::same_as; + /// Extract method from event. + { T::KeyFor(evt) } -> std::same_as; + /// @} + + /// Extracting information from associated types. + /// @{ + /// Extract parameters from a request. + { T::Extract(req) } -> std::same_as>; + /// Extract result from a response. + { T::Extract(resp) } -> std::same_as>; + /// Extract parameters from an event. + { T::Extract(evt) } -> std::same_as>; + /// @} + }; +#endif + +/// Binder collects a table of functions that handle calls. +/// +/// The wrapper takes care of parsing/serializing responses. +/// +/// This allows a JSONTransport to handle incoming and outgoing requests and +/// events. +/// +/// A bind of an incoming request to a lambda. +/// \code{cpp} +/// Binder binder{transport}; +/// binder.bind>("adder", [](const vector ¶ms) { +/// int sum = 0; +/// for (int v : params) +/// sum += v; +/// return sum; +/// }); +/// \endcode +/// +/// A bind of an outgoing request. +/// \code{cpp} +/// OutgoingRequest> call_add = +/// binder.bind>("add"); +/// call_add({1,2,3}, [](Expected result) { +/// cout << *result << "\n"; +/// }); +/// \endcode +#if __cplusplus >= 202002L +template +#else +template +#endif +class Binder : public JSONTransport::MessageHandler { + using Req = typename Proto::Req; + using Resp = typename Proto::Resp; + using Evt = typename Proto::Evt; + using Id = typename Proto::Id; + using Transport = JSONTransport; + using MessageHandler = typename Transport::MessageHandler; + +public: + explicit Binder(Transport &transport) : m_transport(transport), m_seq(0) {} + + Binder(const Binder &) = delete; + Binder &operator=(const Binder &) = delete; + + /// Bind a handler on transport disconnect. + template + void OnDisconnect(Fn &&fn, Args &&...args); + + /// Bind a handler on error when communicating with the transport. + template + void OnError(Fn &&fn, Args &&...args); + + /// Bind a handler for an incoming request. + /// e.g. `bind("peek", &ThisModule::peek, this);`. + /// Handler should be e.g. `Expected peek(const PeekParams&);` + /// PeekParams must be JSON parsable and PeekResult must be serializable. + template + void Bind(llvm::StringLiteral method, Fn &&fn, Args &&...args); + + /// Bind a handler for an incoming event. + /// e.g. `bind("peek", &ThisModule::peek, this);` + /// Handler should be e.g. `void peek(const PeekParams&);` + /// PeekParams must be JSON parsable. + template + void Bind(llvm::StringLiteral method, Fn &&fn, Args &&...args); + + /// Bind a function object to be used for outgoing requests. + /// e.g. `OutgoingRequest Edit = bind("edit");` + /// Params must be JSON-serializable, Result must be parsable. + template + OutgoingRequest Bind(llvm::StringLiteral method); + + /// Bind a function object to be used for outgoing events. + /// e.g. `OutgoingEvent Log = bind("log");` + /// LogParams must be JSON-serializable. + template + OutgoingEvent Bind(llvm::StringLiteral method); + + void Received(const Evt &evt) override { + std::scoped_lock guard(m_mutex); + auto it = m_event_handlers.find(Proto::KeyFor(evt)); + if (it == m_event_handlers.end()) { + OnError(llvm::createStringError( + llvm::formatv("no handled for event {0}", toJSON(evt)))); + return; + } + it->second(evt); + } + + void Received(const Req &req) override { + ReplyOnce reply(req, &m_transport, this); + + std::scoped_lock guard(m_mutex); + auto it = m_request_handlers.find(Proto::KeyFor(req)); + if (it == m_request_handlers.end()) { + reply(Proto::Make(req, llvm::createStringError("method not found"))); + return; + } + + it->second(req, std::move(reply)); + } + + void Received(const Resp &resp) override { + std::scoped_lock guard(m_mutex); + + Id id = Proto::KeyFor(resp); + auto it = m_pending_responses.find(id); + if (it == m_pending_responses.end()) { + OnError(llvm::createStringError( + llvm::formatv("no pending request for {0}", toJSON(resp)))); + return; + } + + it->second(resp); + m_pending_responses.erase(it); + } + + void OnError(llvm::Error err) override { + std::scoped_lock guard(m_mutex); + if (m_error_handler) + m_error_handler(std::move(err)); + } + + void OnClosed() override { + std::scoped_lock guard(m_mutex); + if (m_disconnect_handler) + m_disconnect_handler(); + } + +private: + template + llvm::Expected static Parse(const llvm::json::Value &raw, + llvm::StringRef method); + + template using Callback = llvm::unique_function; + + std::recursive_mutex m_mutex; + Transport &m_transport; + Id m_seq; + std::map> m_pending_responses; + llvm::StringMap)>> + m_request_handlers; + llvm::StringMap> m_event_handlers; + Callback m_disconnect_handler; + Callback m_error_handler; + + /// Function object to reply to a call. + /// Each instance must be called exactly once, otherwise: + /// - the bug is logged, and (in debug mode) an assert will fire + /// - if there was no reply, an error reply is sent + /// - if there were multiple replies, only the first is sent + class ReplyOnce { + std::atomic replied = {false}; + const Req req; + Transport *transport; // Null when moved-from. + MessageHandler *handler; // Null when moved-from. + + public: + ReplyOnce(const Req req, Transport *transport, MessageHandler *handler) + : req(req), transport(transport), handler(handler) { + assert(handler); + } + ReplyOnce(ReplyOnce &&other) + : replied(other.replied.load()), req(other.req), + transport(other.transport), handler(other.handler) { + other.transport = nullptr; + other.handler = nullptr; + } + ReplyOnce &operator=(ReplyOnce &&) = delete; + ReplyOnce(const ReplyOnce &) = delete; + ReplyOnce &operator=(const ReplyOnce &) = delete; + + ~ReplyOnce() { + if (transport && handler && !replied) { + assert(false && "must reply to all calls!"); + (*this)(Proto::Make(req, llvm::createStringError("failed to reply"))); + } + } + + void operator()(const Resp &resp) { + assert(transport && handler && "moved-from!"); + if (replied.exchange(true)) { + assert(false && "must reply to each call only once!"); + return; + } + + if (llvm::Error error = transport->Send(resp)) + handler->OnError(std::move(error)); + } + }; +}; + +#if __cplusplus >= 202002L +template +#else +template +#endif +template +void Binder::OnDisconnect(Fn &&fn, Args &&...args) { + m_disconnect_handler = [fn, args...]() mutable { + std::invoke(std::forward(fn), std::forward(args)...); + }; +} + +#if __cplusplus >= 202002L +template +#else +template +#endif +template +void Binder::OnError(Fn &&fn, Args &&...args) { + m_error_handler = [fn, args...](llvm::Error error) mutable { + std::invoke(std::forward(fn), std::forward(args)..., + std::move(error)); + }; +} + +#if __cplusplus >= 202002L +template +#else +template +#endif +template +void Binder::Bind(llvm::StringLiteral method, Fn &&fn, Args &&...args) { + assert(m_request_handlers.find(method) == m_request_handlers.end() && + "request already bound"); + if constexpr (std::is_void_v && std::is_void_v) { + m_request_handlers[method] = + [fn, args...](const Req &req, + llvm::unique_function reply) mutable { + llvm::Error result = + std::invoke(std::forward(fn), std::forward(args)...); + reply(Proto::Make(req, std::move(result))); + }; + } else if constexpr (std::is_void_v) { + m_request_handlers[method] = + [fn, args...](const Req &req, + llvm::unique_function reply) mutable { + llvm::Expected result = + std::invoke(std::forward(fn), std::forward(args)...); + if (!result) + return reply(Proto::Make(req, result.takeError())); + reply(Proto::Make(req, toJSON(*result))); + }; + } else if constexpr (std::is_void_v) { + m_request_handlers[method] = + [method, fn, + args...](const Req &req, + llvm::unique_function reply) mutable { + llvm::Expected params = + Parse(Proto::Extract(req), method); + if (!params) + return reply(Proto::Make(req, params.takeError())); + + llvm::Error result = std::invoke( + std::forward(fn), std::forward(args)..., *params); + reply(Proto::Make(req, std::move(result))); + }; + } else { + m_request_handlers[method] = + [method, fn, + args...](const Req &req, + llvm::unique_function reply) mutable { + llvm::Expected params = + Parse(Proto::Extract(req), method); + if (!params) + return reply(Proto::Make(req, params.takeError())); + + llvm::Expected result = std::invoke( + std::forward(fn), std::forward(args)..., *params); + if (!result) + return reply(Proto::Make(req, result.takeError())); + + reply(Proto::Make(req, toJSON(*result))); + }; + } +} + +#if __cplusplus >= 202002L +template +#else +template +#endif +template +void Binder::Bind(llvm::StringLiteral method, Fn &&fn, Args &&...args) { + assert(m_event_handlers.find(method) == m_event_handlers.end() && + "event already bound"); + if constexpr (std::is_void_v) { + m_event_handlers[method] = [fn, args...](const Evt &) mutable { + std::invoke(std::forward(fn), std::forward(args)...); + }; + } else { + m_event_handlers[method] = [this, method, fn, + args...](const Evt &evt) mutable { + llvm::Expected params = + Parse(Proto::Extract(evt), method); + if (!params) + return OnError(params.takeError()); + std::invoke(std::forward(fn), std::forward(args)..., *params); + }; + } +} + +#if __cplusplus >= 202002L +template +#else +template +#endif +template +OutgoingRequest +Binder::Bind(llvm::StringLiteral method) { + if constexpr (std::is_void_v && std::is_void_v) { + return [this, method](Reply fn) { + std::scoped_lock guard(m_mutex); + Id id = ++m_seq; + Req req = Proto::Make(id, method, std::nullopt); + m_pending_responses[id] = [fn = std::move(fn)](const Resp &resp) mutable { + llvm::Expected result = Proto::Extract(resp); + if (!result) + return fn(result.takeError()); + fn(llvm::Error::success()); + }; + if (llvm::Error error = m_transport.Send(req)) + OnError(std::move(error)); + }; + } else if constexpr (std::is_void_v) { + return [this, method](Reply fn) { + std::scoped_lock guard(m_mutex); + Id id = ++m_seq; + Req req = Proto::Make(id, method, std::nullopt); + m_pending_responses[id] = [fn = std::move(fn), + method](const Resp &resp) mutable { + llvm::Expected result = Proto::Extract(resp); + if (!result) + return fn(result.takeError()); + fn(Parse(*result, method)); + }; + if (llvm::Error error = m_transport.Send(req)) + OnError(std::move(error)); + }; + } else if constexpr (std::is_void_v) { + return [this, method](const Params ¶ms, Reply fn) { + std::scoped_lock guard(m_mutex); + Id id = ++m_seq; + Req req = Proto::Make(id, method, llvm::json::Value(params)); + m_pending_responses[id] = [fn = std::move(fn)](const Resp &resp) mutable { + llvm::Expected result = Proto::Extract(resp); + if (!result) + return fn(result.takeError()); + fn(llvm::Error::success()); + }; + if (llvm::Error error = m_transport.Send(req)) + OnError(std::move(error)); + }; + } else { + return [this, method](const Params ¶ms, Reply fn) { + std::scoped_lock guard(m_mutex); + Id id = ++m_seq; + Req req = Proto::Make(id, method, llvm::json::Value(params)); + m_pending_responses[id] = [fn = std::move(fn), + method](const Resp &resp) mutable { + llvm::Expected result = Proto::Extract(resp); + if (llvm::Error err = result.takeError()) + return fn(std::move(err)); + fn(Parse(*result, method)); + }; + if (llvm::Error error = m_transport.Send(req)) + OnError(std::move(error)); + }; + } +} + +#if __cplusplus >= 202002L +template +#else +template +#endif +template +OutgoingEvent Binder::Bind(llvm::StringLiteral method) { + if constexpr (std::is_void_v) { + return [this, method]() { + if (llvm::Error error = + m_transport.Send(Proto::Make(method, std::nullopt))) + OnError(std::move(error)); + }; + } else { + return [this, method](const Params ¶ms) { + if (llvm::Error error = + m_transport.Send(Proto::Make(method, toJSON(params)))) + OnError(std::move(error)); + }; + } +} + +#if __cplusplus >= 202002L +template +#else +template +#endif +template +llvm::Expected Binder::Parse(const llvm::json::Value &raw, + llvm::StringRef method) { + T result; + llvm::json::Path::Root root; + if (!fromJSON(raw, result, root)) { + // Dump the relevant parts of the broken message. + std::string context; + llvm::raw_string_ostream OS(context); + root.printErrorContext(raw, OS); + return llvm::make_error(method.str(), context); + } + return std::move(result); +} + +} // namespace lldb_private::transport #endif diff --git a/lldb/include/lldb/Protocol/MCP/MCPError.h b/lldb/include/lldb/Protocol/MCP/MCPError.h index 55dd40f124a15..609a1733197d4 100644 --- a/lldb/include/lldb/Protocol/MCP/MCPError.h +++ b/lldb/include/lldb/Protocol/MCP/MCPError.h @@ -9,7 +9,6 @@ #ifndef LLDB_PROTOCOL_MCP_MCPERROR_H #define LLDB_PROTOCOL_MCP_MCPERROR_H -#include "lldb/Protocol/MCP/Protocol.h" #include "llvm/Support/Error.h" #include @@ -26,14 +25,12 @@ class MCPError : public llvm::ErrorInfo { const std::string &getMessage() const { return m_message; } - lldb_protocol::mcp::Error toProtocolError() const; - static constexpr int64_t kResourceNotFound = -32002; static constexpr int64_t kInternalError = -32603; private: std::string m_message; - int64_t m_error_code; + int m_error_code; }; class UnsupportedURI : public llvm::ErrorInfo { diff --git a/lldb/include/lldb/Protocol/MCP/Protocol.h b/lldb/include/lldb/Protocol/MCP/Protocol.h index 6e1ffcbe1f3e3..a0ba8659ffe24 100644 --- a/lldb/include/lldb/Protocol/MCP/Protocol.h +++ b/lldb/include/lldb/Protocol/MCP/Protocol.h @@ -14,6 +14,7 @@ #ifndef LLDB_PROTOCOL_MCP_PROTOCOL_H #define LLDB_PROTOCOL_MCP_PROTOCOL_H +#include "llvm/ADT/StringRef.h" #include "llvm/Support/JSON.h" #include #include @@ -322,6 +323,10 @@ struct CallToolResult { llvm::json::Value toJSON(const CallToolResult &); bool fromJSON(const llvm::json::Value &, CallToolResult &, llvm::json::Path); +lldb_protocol::mcp::Request +MakeRequest(int64_t id, llvm::StringRef method, + std::optional params); + } // namespace lldb_protocol::mcp #endif diff --git a/lldb/include/lldb/Protocol/MCP/Server.h b/lldb/include/lldb/Protocol/MCP/Server.h index 970980d075ea6..f185d51f41192 100644 --- a/lldb/include/lldb/Protocol/MCP/Server.h +++ b/lldb/include/lldb/Protocol/MCP/Server.h @@ -9,7 +9,6 @@ #ifndef LLDB_PROTOCOL_MCP_SERVER_H #define LLDB_PROTOCOL_MCP_SERVER_H -#include "lldb/Host/JSONTransport.h" #include "lldb/Host/MainLoop.h" #include "lldb/Protocol/MCP/Protocol.h" #include "lldb/Protocol/MCP/Resource.h" @@ -19,75 +18,66 @@ #include "llvm/ADT/StringMap.h" #include "llvm/ADT/StringRef.h" #include "llvm/Support/Error.h" +#include "llvm/Support/FormatVariadic.h" #include "llvm/Support/JSON.h" #include "llvm/Support/Signals.h" -#include #include #include #include namespace lldb_protocol::mcp { -class Server : public MCPTransport::MessageHandler { - using ClosedCallback = llvm::unique_function; +class Server { + + using MCPTransportUP = std::unique_ptr; + + using ReadHandleUP = lldb_private::MainLoop::ReadHandleUP; public: - Server(std::string name, std::string version, MCPTransport &client, - LogCallback log_callback = {}, ClosedCallback closed_callback = {}); + Server(std::string name, std::string version, LogCallback log_callback = {}); ~Server() = default; - using NotificationHandler = std::function; - void AddTool(std::unique_ptr tool); void AddResourceProvider(std::unique_ptr resource_provider); - void AddNotificationHandler(llvm::StringRef method, - NotificationHandler handler); - -protected: - ServerCapabilities GetCapabilities(); - - using RequestHandler = - std::function(const Request &)>; - void AddRequestHandlers(); + llvm::Error Accept(lldb_private::MainLoop &, MCPTransportUP); - void AddRequestHandler(llvm::StringRef method, RequestHandler handler); - - llvm::Expected> HandleData(llvm::StringRef data); - - llvm::Expected Handle(const Request &request); - void Handle(const Notification ¬ification); +protected: + MCPBinderUP Bind(MCPTransport &); - llvm::Expected InitializeHandler(const Request &); + ServerCapabilities GetCapabilities(); - llvm::Expected ToolsListHandler(const Request &); - llvm::Expected ToolsCallHandler(const Request &); + llvm::Expected InitializeHandler(const InitializeParams &); - llvm::Expected ResourcesListHandler(const Request &); - llvm::Expected ResourcesReadHandler(const Request &); + llvm::Expected ToolsListHandler(); + llvm::Expected ToolsCallHandler(const CallToolParams &); - void Received(const Request &) override; - void Received(const Response &) override; - void Received(const Notification &) override; - void OnError(llvm::Error) override; - void OnClosed() override; + llvm::Expected ResourcesListHandler(); + llvm::Expected + ResourcesReadHandler(const ReadResourceParams &); -protected: - void Log(llvm::StringRef); + template inline auto Logv(const char *Fmt, Ts &&...Vals) { + Log(llvm::formatv(Fmt, std::forward(Vals)...).str()); + } + void Log(llvm::StringRef message) { + if (m_log_callback) + m_log_callback(message); + } private: const std::string m_name; const std::string m_version; - MCPTransport &m_client; LogCallback m_log_callback; - ClosedCallback m_closed_callback; + struct Client { + ReadHandleUP handle; + MCPTransportUP transport; + MCPBinderUP binder; + }; + std::map m_instances; llvm::StringMap> m_tools; std::vector> m_resource_providers; - - llvm::StringMap m_request_handlers; - llvm::StringMap m_notification_handlers; }; class ServerInfoHandle; @@ -121,7 +111,7 @@ class ServerInfoHandle { ServerInfoHandle &operator=(const ServerInfoHandle &) = delete; /// @} - /// Remove the file. + /// Remove the file on disk, if one is tracked. void Remove(); private: diff --git a/lldb/include/lldb/Protocol/MCP/Transport.h b/lldb/include/lldb/Protocol/MCP/Transport.h index 47c2ccfc44dfe..b7a1eb778d660 100644 --- a/lldb/include/lldb/Protocol/MCP/Transport.h +++ b/lldb/include/lldb/Protocol/MCP/Transport.h @@ -10,22 +10,78 @@ #define LLDB_PROTOCOL_MCP_TRANSPORT_H #include "lldb/Host/JSONTransport.h" +#include "lldb/Protocol/MCP/MCPError.h" #include "lldb/Protocol/MCP/Protocol.h" #include "lldb/lldb-forward.h" #include "llvm/ADT/FunctionExtras.h" #include "llvm/ADT/StringRef.h" +#include "llvm/Support/Error.h" +#include namespace lldb_protocol::mcp { +struct ProtocolDescriptor { + using Id = int64_t; + using Req = Request; + using Resp = Response; + using Evt = Notification; + + static inline Id InitialId() { return 0; } + static inline Request Make(Id id, llvm::StringRef method, + std::optional params) { + return Request{id, method.str(), params}; + } + static inline Notification Make(llvm::StringRef method, + std::optional params) { + return Notification{method.str(), params}; + } + static inline Response Make(Req req, llvm::Error error) { + lldb_protocol::mcp::Error protocol_error; + llvm::handleAllErrors( + std::move(error), [&](const llvm::ErrorInfoBase &err) { + std::error_code cerr = err.convertToErrorCode(); + protocol_error.code = + cerr == llvm::inconvertibleErrorCode() + ? lldb_protocol::mcp::eErrorCodeInternalError + : cerr.value(); + protocol_error.message = err.message(); + }); + + return Response{req.id, std::move(protocol_error)}; + } + static inline Response Make(Req req, + std::optional result) { + return Response{req.id, std::move(result)}; + } + static inline Id KeyFor(Response r) { return std::get(r.id); } + static inline std::string KeyFor(Request r) { return r.method; } + static inline std::string KeyFor(Notification n) { return n.method; } + static inline std::optional Extract(Request r) { + return r.params; + } + static inline llvm::Expected Extract(Response r) { + if (const lldb_protocol::mcp::Error *error = + std::get_if(&r.result)) + return llvm::make_error(error->message, + error->code); + return std::get(r.result); + } + static inline std::optional Extract(Notification n) { + return n.params; + } +}; + /// Generic transport that uses the MCP protocol. -using MCPTransport = lldb_private::Transport; +using MCPTransport = lldb_private::transport::JSONTransport; +using MCPBinder = lldb_private::transport::Binder; +using MCPBinderUP = std::unique_ptr; /// Generic logging callback, to allow the MCP server / client / transport layer /// to be independent of the lldb log implementation. using LogCallback = llvm::unique_function; class Transport final - : public lldb_private::JSONRPCTransport { + : public lldb_private::transport::JSONRPCTransport { public: Transport(lldb::IOObjectSP in, lldb::IOObjectSP out, LogCallback log_callback = {}); diff --git a/lldb/source/Host/common/JSONTransport.cpp b/lldb/source/Host/common/JSONTransport.cpp index c4b42eafc85d3..22de7fa8cbead 100644 --- a/lldb/source/Host/common/JSONTransport.cpp +++ b/lldb/source/Host/common/JSONTransport.cpp @@ -14,8 +14,7 @@ #include using namespace llvm; -using namespace lldb; -using namespace lldb_private; +using namespace lldb_private::transport; char TransportUnhandledContentsError::ID; @@ -23,10 +22,31 @@ TransportUnhandledContentsError::TransportUnhandledContentsError( std::string unhandled_contents) : m_unhandled_contents(unhandled_contents) {} -void TransportUnhandledContentsError::log(llvm::raw_ostream &OS) const { +void TransportUnhandledContentsError::log(raw_ostream &OS) const { OS << "transport EOF with unhandled contents: '" << m_unhandled_contents << "'"; } std::error_code TransportUnhandledContentsError::convertToErrorCode() const { return std::make_error_code(std::errc::bad_message); } + +char InvalidParams::ID; + +void InvalidParams::log(raw_ostream &OS) const { + OS << "invalid parameters for method '" << m_method << "': '" << m_context + << "'"; +} +std::error_code InvalidParams::convertToErrorCode() const { + return std::make_error_code(std::errc::invalid_argument); +} + +char MethodNotFound::ID; + +void MethodNotFound::log(raw_ostream &OS) const { + OS << "method not found: '" << m_method << "'"; +} + +std::error_code MethodNotFound::convertToErrorCode() const { + // JSON-RPC Method not found + return std::error_code(MethodNotFound::kErrorCode, std::generic_category()); +} diff --git a/lldb/source/Plugins/Protocol/MCP/ProtocolServerMCP.cpp b/lldb/source/Plugins/Protocol/MCP/ProtocolServerMCP.cpp index d7293fc28c524..33bdd5eec3644 100644 --- a/lldb/source/Plugins/Protocol/MCP/ProtocolServerMCP.cpp +++ b/lldb/source/Plugins/Protocol/MCP/ProtocolServerMCP.cpp @@ -52,11 +52,6 @@ llvm::StringRef ProtocolServerMCP::GetPluginDescriptionStatic() { } void ProtocolServerMCP::Extend(lldb_protocol::mcp::Server &server) const { - server.AddNotificationHandler("notifications/initialized", - [](const lldb_protocol::mcp::Notification &) { - LLDB_LOG(GetLog(LLDBLog::Host), - "MCP initialization complete"); - }); server.AddTool( std::make_unique("command", "Run an lldb command.")); server.AddTool(std::make_unique( @@ -74,26 +69,9 @@ void ProtocolServerMCP::AcceptCallback(std::unique_ptr socket) { io_sp, io_sp, [client_name](llvm::StringRef message) { LLDB_LOG(GetLog(LLDBLog::Host), "{0}: {1}", client_name, message); }); - MCPTransport *transport_ptr = transport_up.get(); - auto instance_up = std::make_unique( - std::string(kName), std::string(kVersion), *transport_up, - /*log_callback=*/ - [client_name](llvm::StringRef message) { - LLDB_LOG(GetLog(LLDBLog::Host), "{0} Server: {1}", client_name, - message); - }, - /*closed_callback=*/ - [this, transport_ptr]() { m_instances.erase(transport_ptr); }); - Extend(*instance_up); - llvm::Expected handle = - transport_up->RegisterMessageHandler(m_loop, *instance_up); - if (!handle) { - LLDB_LOG_ERROR(log, handle.takeError(), "Failed to run MCP server: {0}"); - return; - } - m_instances[transport_ptr] = - std::make_tuple( - std::move(instance_up), std::move(*handle), std::move(transport_up)); + + if (auto error = m_server->Accept(m_loop, std::move(transport_up))) + LLDB_LOG_ERROR(log, std::move(error), "{0}:"); } llvm::Error ProtocolServerMCP::Start(ProtocolServer::Connection connection) { @@ -124,14 +102,21 @@ llvm::Error ProtocolServerMCP::Start(ProtocolServer::Connection connection) { llvm::join(m_listener->GetListeningConnectionURI(), ", "); ServerInfo info{listening_uris[0]}; - llvm::Expected handle = ServerInfo::Write(info); - if (!handle) - return handle.takeError(); + llvm::Expected server_info_handle = ServerInfo::Write(info); + if (!server_info_handle) + return server_info_handle.takeError(); + + m_client_count = 0; + m_server = std::make_unique( + std::string(kName), std::string(kVersion), [](StringRef message) { + LLDB_LOG(GetLog(LLDBLog::Host), "MCP Server: {0}", message); + }); + Extend(*m_server); m_running = true; - m_server_info_handle = std::move(*handle); - m_listen_handlers = std::move(*handles); - m_loop_thread = std::thread([=] { + m_server_info_handle = std::move(*server_info_handle); + m_accept_handles = std::move(*handles); + m_loop_thread = std::thread([this] { llvm::set_thread_name("protocol-server.mcp"); m_loop.Run(); }); @@ -155,9 +140,10 @@ llvm::Error ProtocolServerMCP::Stop() { if (m_loop_thread.joinable()) m_loop_thread.join(); + m_accept_handles.clear(); + + m_server.reset(nullptr); m_server_info_handle.Remove(); - m_listen_handlers.clear(); - m_instances.clear(); return llvm::Error::success(); } diff --git a/lldb/source/Plugins/Protocol/MCP/ProtocolServerMCP.h b/lldb/source/Plugins/Protocol/MCP/ProtocolServerMCP.h index b325a3681bccb..e0f2a6ccea1f5 100644 --- a/lldb/source/Plugins/Protocol/MCP/ProtocolServerMCP.h +++ b/lldb/source/Plugins/Protocol/MCP/ProtocolServerMCP.h @@ -23,16 +23,17 @@ namespace lldb_private::mcp { class ProtocolServerMCP : public ProtocolServer { - using ReadHandleUP = MainLoopBase::ReadHandleUP; - using TransportUP = std::unique_ptr; + using ServerUP = std::unique_ptr; + using ReadHandleUP = MainLoop::ReadHandleUP; + public: ProtocolServerMCP(); - virtual ~ProtocolServerMCP() override; + ~ProtocolServerMCP() override; - virtual llvm::Error Start(ProtocolServer::Connection connection) override; - virtual llvm::Error Stop() override; + llvm::Error Start(ProtocolServer::Connection connection) override; + llvm::Error Stop() override; static void Initialize(); static void Terminate(); @@ -56,19 +57,18 @@ class ProtocolServerMCP : public ProtocolServer { bool m_running = false; - lldb_protocol::mcp::ServerInfoHandle m_server_info_handle; lldb_private::MainLoop m_loop; std::thread m_loop_thread; std::mutex m_mutex; size_t m_client_count = 0; std::unique_ptr m_listener; + std::vector m_accept_handles; - std::vector m_listen_handlers; - std::map> - m_instances; + ServerUP m_server; + lldb_protocol::mcp::ServerInfoHandle m_server_info_handle; }; + } // namespace lldb_private::mcp #endif diff --git a/lldb/source/Protocol/MCP/MCPError.cpp b/lldb/source/Protocol/MCP/MCPError.cpp index e140d11e12cfe..cfac055ba5f11 100644 --- a/lldb/source/Protocol/MCP/MCPError.cpp +++ b/lldb/source/Protocol/MCP/MCPError.cpp @@ -22,14 +22,7 @@ MCPError::MCPError(std::string message, int64_t error_code) void MCPError::log(llvm::raw_ostream &OS) const { OS << m_message; } std::error_code MCPError::convertToErrorCode() const { - return llvm::inconvertibleErrorCode(); -} - -lldb_protocol::mcp::Error MCPError::toProtocolError() const { - lldb_protocol::mcp::Error error; - error.code = m_error_code; - error.message = m_message; - return error; + return std::error_code(m_error_code, std::generic_category()); } UnsupportedURI::UnsupportedURI(std::string uri) : m_uri(uri) {} diff --git a/lldb/source/Protocol/MCP/Server.cpp b/lldb/source/Protocol/MCP/Server.cpp index 19030a3a4e5d6..71323adbac5f5 100644 --- a/lldb/source/Protocol/MCP/Server.cpp +++ b/lldb/source/Protocol/MCP/Server.cpp @@ -12,6 +12,7 @@ #include "lldb/Host/HostInfo.h" #include "lldb/Protocol/MCP/MCPError.h" #include "lldb/Protocol/MCP/Protocol.h" +#include "lldb/Protocol/MCP/Transport.h" #include "llvm/ADT/SmallString.h" #include "llvm/Support/FileSystem.h" #include "llvm/Support/JSON.h" @@ -108,48 +109,9 @@ Expected> ServerInfo::Load() { return infos; } -Server::Server(std::string name, std::string version, MCPTransport &client, - LogCallback log_callback, ClosedCallback closed_callback) - : m_name(std::move(name)), m_version(std::move(version)), m_client(client), - m_log_callback(std::move(log_callback)), - m_closed_callback(std::move(closed_callback)) { - AddRequestHandlers(); -} - -void Server::AddRequestHandlers() { - AddRequestHandler("initialize", std::bind(&Server::InitializeHandler, this, - std::placeholders::_1)); - AddRequestHandler("tools/list", std::bind(&Server::ToolsListHandler, this, - std::placeholders::_1)); - AddRequestHandler("tools/call", std::bind(&Server::ToolsCallHandler, this, - std::placeholders::_1)); - AddRequestHandler("resources/list", std::bind(&Server::ResourcesListHandler, - this, std::placeholders::_1)); - AddRequestHandler("resources/read", std::bind(&Server::ResourcesReadHandler, - this, std::placeholders::_1)); -} - -llvm::Expected Server::Handle(const Request &request) { - auto it = m_request_handlers.find(request.method); - if (it != m_request_handlers.end()) { - llvm::Expected response = it->second(request); - if (!response) - return response; - response->id = request.id; - return *response; - } - - return llvm::make_error( - llvm::formatv("no handler for request: {0}", request.method).str()); -} - -void Server::Handle(const Notification ¬ification) { - auto it = m_notification_handlers.find(notification.method); - if (it != m_notification_handlers.end()) { - it->second(notification); - return; - } -} +Server::Server(std::string name, std::string version, LogCallback log_callback) + : m_name(std::move(name)), m_version(std::move(version)), + m_log_callback(std::move(log_callback)) {} void Server::AddTool(std::unique_ptr tool) { if (!tool) @@ -164,48 +126,64 @@ void Server::AddResourceProvider( m_resource_providers.push_back(std::move(resource_provider)); } -void Server::AddRequestHandler(llvm::StringRef method, RequestHandler handler) { - m_request_handlers[method] = std::move(handler); -} - -void Server::AddNotificationHandler(llvm::StringRef method, - NotificationHandler handler) { - m_notification_handlers[method] = std::move(handler); -} - -llvm::Expected Server::InitializeHandler(const Request &request) { - Response response; +MCPBinderUP Server::Bind(MCPTransport &transport) { + MCPBinderUP binder_up = std::make_unique(transport); + binder_up->Bind( + "initialize", &Server::InitializeHandler, this); + binder_up->Bind("tools/list", + &Server::ToolsListHandler, this); + binder_up->Bind( + "tools/call", &Server::ToolsCallHandler, this); + binder_up->Bind( + "resources/list", &Server::ResourcesListHandler, this); + binder_up->Bind( + "resources/read", &Server::ResourcesReadHandler, this); + binder_up->Bind("notifications/initialized", + [this]() { Log("MCP initialization complete"); }); + return binder_up; +} + +llvm::Error Server::Accept(MainLoop &loop, MCPTransportUP transport) { + MCPBinderUP binder = Bind(*transport); + MCPTransport *transport_ptr = transport.get(); + binder->OnDisconnect([this, transport_ptr]() { + assert(m_instances.find(transport_ptr) != m_instances.end() && + "Client not found in m_instances"); + m_instances.erase(transport_ptr); + }); + binder->OnError([this](llvm::Error err) { + Logv("Transport error: {0}", llvm::toString(std::move(err))); + }); + + auto handle = transport->RegisterMessageHandler(loop, *binder); + if (!handle) + return handle.takeError(); + + m_instances[transport_ptr] = + Client{std::move(*handle), std::move(transport), std::move(binder)}; + return llvm::Error::success(); +} + +Expected +Server::InitializeHandler(const InitializeParams &request) { InitializeResult result; result.protocolVersion = mcp::kProtocolVersion; result.capabilities = GetCapabilities(); result.serverInfo.name = m_name; result.serverInfo.version = m_version; - response.result = std::move(result); - return response; + return result; } -llvm::Expected Server::ToolsListHandler(const Request &request) { - Response response; - +llvm::Expected Server::ToolsListHandler() { ListToolsResult result; for (const auto &tool : m_tools) result.tools.emplace_back(tool.second->GetDefinition()); - response.result = std::move(result); - - return response; + return result; } -llvm::Expected Server::ToolsCallHandler(const Request &request) { - Response response; - - if (!request.params) - return llvm::createStringError("no tool parameters"); - CallToolParams params; - json::Path::Root root("params"); - if (!fromJSON(request.params, params, root)) - return root.getError(); - +llvm::Expected +Server::ToolsCallHandler(const CallToolParams ¶ms) { llvm::StringRef tool_name = params.name; if (tool_name.empty()) return llvm::createStringError("no tool name"); @@ -222,113 +200,50 @@ llvm::Expected Server::ToolsCallHandler(const Request &request) { if (!text_result) return text_result.takeError(); - response.result = toJSON(*text_result); - - return response; + return text_result; } -llvm::Expected Server::ResourcesListHandler(const Request &request) { - Response response; - +llvm::Expected Server::ResourcesListHandler() { ListResourcesResult result; for (std::unique_ptr &resource_provider_up : m_resource_providers) for (const Resource &resource : resource_provider_up->GetResources()) result.resources.push_back(resource); - response.result = std::move(result); - - return response; + return result; } -llvm::Expected Server::ResourcesReadHandler(const Request &request) { - Response response; - - if (!request.params) - return llvm::createStringError("no resource parameters"); - - ReadResourceParams params; - json::Path::Root root("params"); - if (!fromJSON(request.params, params, root)) - return root.getError(); - - llvm::StringRef uri_str = params.uri; +Expected +Server::ResourcesReadHandler(const ReadResourceParams ¶ms) { + StringRef uri_str = params.uri; if (uri_str.empty()) - return llvm::createStringError("no resource uri"); + return createStringError("no resource uri"); for (std::unique_ptr &resource_provider_up : m_resource_providers) { - llvm::Expected result = + Expected result = resource_provider_up->ReadResource(uri_str); if (result.errorIsA()) { - llvm::consumeError(result.takeError()); + consumeError(result.takeError()); continue; } if (!result) return result.takeError(); - Response response; - response.result = std::move(*result); - return response; + return *result; } return make_error( - llvm::formatv("no resource handler for uri: {0}", uri_str).str(), + formatv("no resource handler for uri: {0}", uri_str).str(), MCPError::kResourceNotFound); } ServerCapabilities Server::GetCapabilities() { lldb_protocol::mcp::ServerCapabilities capabilities; capabilities.supportsToolsList = true; + capabilities.supportsResourcesList = true; // FIXME: Support sending notifications when a debugger/target are // added/removed. - capabilities.supportsResourcesList = false; + capabilities.supportsResourcesSubscribe = false; return capabilities; } - -void Server::Log(llvm::StringRef message) { - if (m_log_callback) - m_log_callback(message); -} - -void Server::Received(const Request &request) { - auto SendResponse = [this](const Response &response) { - if (llvm::Error error = m_client.Send(response)) - Log(llvm::toString(std::move(error))); - }; - - llvm::Expected response = Handle(request); - if (response) - return SendResponse(*response); - - lldb_protocol::mcp::Error protocol_error; - llvm::handleAllErrors( - response.takeError(), - [&](const MCPError &err) { protocol_error = err.toProtocolError(); }, - [&](const llvm::ErrorInfoBase &err) { - protocol_error.code = MCPError::kInternalError; - protocol_error.message = err.message(); - }); - Response error_response; - error_response.id = request.id; - error_response.result = std::move(protocol_error); - SendResponse(error_response); -} - -void Server::Received(const Response &response) { - Log("unexpected MCP message: response"); -} - -void Server::Received(const Notification ¬ification) { - Handle(notification); -} - -void Server::OnError(llvm::Error error) { - Log(llvm::toString(std::move(error))); -} - -void Server::OnClosed() { - Log("EOF"); - if (m_closed_callback) - m_closed_callback(); -} diff --git a/lldb/tools/lldb-dap/DAP.h b/lldb/tools/lldb-dap/DAP.h index 71681fd4b51ed..a90ddf59671ee 100644 --- a/lldb/tools/lldb-dap/DAP.h +++ b/lldb/tools/lldb-dap/DAP.h @@ -78,11 +78,9 @@ enum DAPBroadcasterBits { enum class ReplMode { Variable = 0, Command, Auto }; -using DAPTransport = - lldb_private::Transport; +using DAPTransport = lldb_private::transport::JSONTransport; -struct DAP final : private DAPTransport::MessageHandler { +struct DAP final : public DAPTransport::MessageHandler { /// Path to the lldb-dap binary itself. static llvm::StringRef debug_adapter_path; diff --git a/lldb/tools/lldb-dap/Protocol/ProtocolBase.h b/lldb/tools/lldb-dap/Protocol/ProtocolBase.h index 0a9ef538a7398..92e41b1dbf595 100644 --- a/lldb/tools/lldb-dap/Protocol/ProtocolBase.h +++ b/lldb/tools/lldb-dap/Protocol/ProtocolBase.h @@ -30,6 +30,8 @@ namespace lldb_dap::protocol { // MARK: Base Protocol +using Id = int64_t; + /// A client or debug adapter initiated request. struct Request { /// Sequence number of the message (also known as message ID). The `seq` for @@ -39,7 +41,7 @@ struct Request { /// associate requests with their corresponding responses. For protocol /// messages of type `request` the sequence number can be used to cancel the /// request. - int64_t seq; + Id seq; /// The command to execute. std::string command; @@ -76,7 +78,7 @@ enum ResponseMessage : unsigned { /// Response for a request. struct Response { /// Sequence number of the corresponding request. - int64_t request_seq; + Id request_seq; /// The command requested. std::string command; diff --git a/lldb/tools/lldb-dap/Transport.h b/lldb/tools/lldb-dap/Transport.h index 4a9dd76c2303e..58c48c133f9cb 100644 --- a/lldb/tools/lldb-dap/Transport.h +++ b/lldb/tools/lldb-dap/Transport.h @@ -22,11 +22,18 @@ namespace lldb_dap { +struct ProtocolDescriptor { + using Id = protocol::Id; + using Req = protocol::Request; + using Resp = protocol::Response; + using Evt = protocol::Event; +}; + /// A transport class that performs the Debug Adapter Protocol communication /// with the client. class Transport final - : public lldb_private::HTTPDelimitedJSONTransport< - protocol::Request, protocol::Response, protocol::Event> { + : public lldb_private::transport::HTTPDelimitedJSONTransport< + ProtocolDescriptor> { public: Transport(llvm::StringRef client_name, lldb_dap::Log *log, lldb::IOObjectSP input, lldb::IOObjectSP output); diff --git a/lldb/unittests/DAP/DAPTest.cpp b/lldb/unittests/DAP/DAPTest.cpp index 2090fe6896d6b..4fd6cd546e6fa 100644 --- a/lldb/unittests/DAP/DAPTest.cpp +++ b/lldb/unittests/DAP/DAPTest.cpp @@ -9,13 +9,10 @@ #include "DAP.h" #include "Protocol/ProtocolBase.h" #include "TestBase.h" -#include "llvm/Testing/Support/Error.h" #include "gmock/gmock.h" #include "gtest/gtest.h" #include -using namespace llvm; -using namespace lldb; using namespace lldb_dap; using namespace lldb_dap_tests; using namespace lldb_dap::protocol; @@ -24,18 +21,7 @@ using namespace testing; class DAPTest : public TransportBase {}; TEST_F(DAPTest, SendProtocolMessages) { - DAP dap{ - /*log=*/nullptr, - /*default_repl_mode=*/ReplMode::Auto, - /*pre_init_commands=*/{}, - /*no_lldbinit=*/false, - /*client_name=*/"test_client", - /*transport=*/*transport, - /*loop=*/loop, - }; - dap.Send(Event{/*event=*/"my-event", /*body=*/std::nullopt}); - loop.AddPendingCallback( - [](lldb_private::MainLoopBase &loop) { loop.RequestTermination(); }); - EXPECT_CALL(client, Received(IsEvent("my-event", std::nullopt))); - ASSERT_THAT_ERROR(dap.Loop(), llvm::Succeeded()); + dap->Send(Event{/*event=*/"my-event", /*body=*/std::nullopt}); + EXPECT_CALL(client, Received(IsEvent("my-event"))); + Run(); } diff --git a/lldb/unittests/DAP/Handler/DisconnectTest.cpp b/lldb/unittests/DAP/Handler/DisconnectTest.cpp index c6ff1f90b01d5..88d6e9a69eca3 100644 --- a/lldb/unittests/DAP/Handler/DisconnectTest.cpp +++ b/lldb/unittests/DAP/Handler/DisconnectTest.cpp @@ -31,7 +31,7 @@ TEST_F(DisconnectRequestHandlerTest, DisconnectTriggersTerminated) { DisconnectRequestHandler handler(*dap); ASSERT_THAT_ERROR(handler.Run(std::nullopt), Succeeded()); EXPECT_CALL(client, Received(IsEvent("terminated", _))); - RunOnce(); + Run(); } TEST_F(DisconnectRequestHandlerTest, DisconnectTriggersTerminateCommands) { @@ -53,5 +53,5 @@ TEST_F(DisconnectRequestHandlerTest, DisconnectTriggersTerminateCommands) { EXPECT_CALL(client, Received(Output("(lldb) script print(2)\n"))); EXPECT_CALL(client, Received(Output("Running terminateCommands:\n"))); EXPECT_CALL(client, Received(IsEvent("terminated", _))); - RunOnce(); + Run(); } diff --git a/lldb/unittests/DAP/TestBase.cpp b/lldb/unittests/DAP/TestBase.cpp index ba7baf2103799..3721e09d8b699 100644 --- a/lldb/unittests/DAP/TestBase.cpp +++ b/lldb/unittests/DAP/TestBase.cpp @@ -32,23 +32,9 @@ using lldb_private::FileSystem; using lldb_private::MainLoop; using lldb_private::Pipe; -Expected -TestTransport::RegisterMessageHandler(MainLoop &loop, MessageHandler &handler) { - Expected dummy_file = FileSystem::Instance().Open( - FileSpec(FileSystem::DEV_NULL), File::eOpenOptionReadWrite); - if (!dummy_file) - return dummy_file.takeError(); - m_dummy_file = std::move(*dummy_file); - lldb_private::Status status; - auto handle = loop.RegisterReadObject( - m_dummy_file, [](lldb_private::MainLoopBase &) {}, status); - if (status.Fail()) - return status.takeError(); - return handle; -} +void TransportBase::SetUp() { + std::tie(to_client, to_server) = TestDAPTransport::createPair(); -void DAPTestBase::SetUp() { - TransportBase::SetUp(); std::error_code EC; log = std::make_unique("-", EC); dap = std::make_unique( @@ -57,16 +43,30 @@ void DAPTestBase::SetUp() { /*pre_init_commands=*/std::vector(), /*no_lldbinit=*/false, /*client_name=*/"test_client", - /*transport=*/*transport, /*loop=*/loop); + /*transport=*/*to_client, /*loop=*/loop); + + auto server_handle = to_server->RegisterMessageHandler(loop, *dap.get()); + EXPECT_THAT_EXPECTED(server_handle, Succeeded()); + handles[0] = std::move(*server_handle); + + auto client_handle = to_client->RegisterMessageHandler(loop, client); + EXPECT_THAT_EXPECTED(client_handle, Succeeded()); + handles[1] = std::move(*client_handle); } +void TransportBase::Run() { + loop.AddPendingCallback( + [](lldb_private::MainLoopBase &loop) { loop.RequestTermination(); }); + EXPECT_THAT_ERROR(loop.Run().takeError(), llvm::Succeeded()); +} + +void DAPTestBase::SetUp() { TransportBase::SetUp(); } + void DAPTestBase::TearDown() { - if (core) { + if (core) ASSERT_THAT_ERROR(core->discard(), Succeeded()); - } - if (binary) { + if (binary) ASSERT_THAT_ERROR(binary->discard(), Succeeded()); - } } void DAPTestBase::SetUpTestSuite() { diff --git a/lldb/unittests/DAP/TestBase.h b/lldb/unittests/DAP/TestBase.h index c19eead4e37e7..c32f3a769c737 100644 --- a/lldb/unittests/DAP/TestBase.h +++ b/lldb/unittests/DAP/TestBase.h @@ -7,73 +7,48 @@ //===----------------------------------------------------------------------===// #include "DAP.h" +#include "DAPLog.h" #include "Protocol/ProtocolBase.h" #include "TestingSupport/Host/JSONTransportTestUtilities.h" #include "TestingSupport/SubsystemRAII.h" +#include "Transport.h" #include "lldb/Host/FileSystem.h" #include "lldb/Host/HostInfo.h" #include "lldb/Host/MainLoop.h" #include "lldb/Host/MainLoopBase.h" -#include "lldb/lldb-forward.h" #include "llvm/ADT/StringRef.h" -#include "llvm/Support/Error.h" #include "llvm/Support/FileSystem.h" #include "llvm/Support/JSON.h" -#include "llvm/Testing/Support/Error.h" #include "gmock/gmock.h" #include "gtest/gtest.h" #include +#include + +/// Helpers for gtest printing. +namespace lldb_dap::protocol { + +inline void PrintTo(const Request &req, std::ostream *os) { + *os << llvm::formatv("{0}", toJSON(req)).str(); +} + +inline void PrintTo(const Response &resp, std::ostream *os) { + *os << llvm::formatv("{0}", toJSON(resp)).str(); +} + +inline void PrintTo(const Event &evt, std::ostream *os) { + *os << llvm::formatv("{0}", toJSON(evt)).str(); +} + +inline void PrintTo(const Message &message, std::ostream *os) { + return std::visit([os](auto &&message) { return PrintTo(message, os); }, + message); +} + +} // namespace lldb_dap::protocol namespace lldb_dap_tests { -class TestTransport final - : public lldb_private::Transport { -public: - using Message = lldb_private::Transport::Message; - - TestTransport(lldb_private::MainLoop &loop, MessageHandler &handler) - : m_loop(loop), m_handler(handler) {} - - llvm::Error Send(const lldb_dap::protocol::Event &e) override { - m_loop.AddPendingCallback([this, e](lldb_private::MainLoopBase &) { - this->m_handler.Received(e); - }); - return llvm::Error::success(); - } - - llvm::Error Send(const lldb_dap::protocol::Request &r) override { - m_loop.AddPendingCallback([this, r](lldb_private::MainLoopBase &) { - this->m_handler.Received(r); - }); - return llvm::Error::success(); - } - - llvm::Error Send(const lldb_dap::protocol::Response &r) override { - m_loop.AddPendingCallback([this, r](lldb_private::MainLoopBase &) { - this->m_handler.Received(r); - }); - return llvm::Error::success(); - } - - llvm::Expected - RegisterMessageHandler(lldb_private::MainLoop &loop, - MessageHandler &handler) override; - - void Log(llvm::StringRef message) override { - log_messages.emplace_back(message); - } - - std::vector log_messages; - -private: - lldb_private::MainLoop &m_loop; - MessageHandler &m_handler; - lldb::FileSP m_dummy_file; -}; +using TestDAPTransport = TestTransport; /// A base class for tests that need transport configured for communicating DAP /// messages. @@ -82,22 +57,36 @@ class TransportBase : public testing::Test { lldb_private::SubsystemRAII subsystems; lldb_private::MainLoop loop; - std::unique_ptr transport; - MockMessageHandler - client; - - void SetUp() override { - transport = std::make_unique(loop, client); - } + lldb_private::MainLoop::ReadHandleUP handles[2]; + + std::unique_ptr log; + + std::unique_ptr to_client; + MockMessageHandler client; + + std::unique_ptr to_server; + std::unique_ptr dap; + + void SetUp() override; + + void Run(); }; /// A matcher for a DAP event. -template +template inline testing::Matcher -IsEvent(const M1 &m1, const M2 &m2) { - return testing::AllOf(testing::Field(&lldb_dap::protocol::Event::event, m1), - testing::Field(&lldb_dap::protocol::Event::body, m2)); +IsEvent(const EventMatcher &event_matcher, const BodyMatcher &body_matcher) { + return testing::AllOf( + testing::Field(&lldb_dap::protocol::Event::event, event_matcher), + testing::Field(&lldb_dap::protocol::Event::body, body_matcher)); +} + +template +inline testing::Matcher +IsEvent(const EventMatcher &event_matcher) { + return testing::AllOf( + testing::Field(&lldb_dap::protocol::Event::event, event_matcher), + testing::Field(&lldb_dap::protocol::Event::body, std::nullopt)); } /// Matches an "output" event. @@ -110,8 +99,6 @@ inline auto Output(llvm::StringRef o, llvm::StringRef cat = "console") { /// A base class for tests that interact with a `lldb_dap::DAP` instance. class DAPTestBase : public TransportBase { protected: - std::unique_ptr log; - std::unique_ptr dap; std::optional core; std::optional binary; @@ -126,12 +113,6 @@ class DAPTestBase : public TransportBase { bool GetDebuggerSupportsTarget(llvm::StringRef platform); void CreateDebugger(); void LoadCore(); - - void RunOnce() { - loop.AddPendingCallback( - [](lldb_private::MainLoopBase &loop) { loop.RequestTermination(); }); - ASSERT_THAT_ERROR(dap->Loop(), llvm::Succeeded()); - } }; } // namespace lldb_dap_tests diff --git a/lldb/unittests/Host/JSONTransportTest.cpp b/lldb/unittests/Host/JSONTransportTest.cpp index 3a36bf21f07ff..7db6508aa4a3c 100644 --- a/lldb/unittests/Host/JSONTransportTest.cpp +++ b/lldb/unittests/Host/JSONTransportTest.cpp @@ -9,6 +9,7 @@ #include "lldb/Host/JSONTransport.h" #include "TestingSupport/Host/JSONTransportTestUtilities.h" #include "TestingSupport/Host/PipeTestUtilities.h" +#include "TestingSupport/SubsystemRAII.h" #include "lldb/Host/File.h" #include "lldb/Host/MainLoop.h" #include "lldb/Host/MainLoopBase.h" @@ -25,27 +26,45 @@ #include #include #include +#include #include +#include using namespace llvm; using namespace lldb_private; +using namespace lldb_private::transport; using testing::_; using testing::HasSubstr; using testing::InSequence; +using testing::Ref; + +namespace llvm::json { +static bool fromJSON(const Value &V, Value &T, Path P) { + T = V; + return true; +} +} // namespace llvm::json namespace { namespace test_protocol { struct Req { + int id = 0; std::string name; + std::optional params; }; -json::Value toJSON(const Req &T) { return json::Object{{"req", T.name}}; } +json::Value toJSON(const Req &T) { + return json::Object{{"name", T.name}, {"id", T.id}, {"params", T.params}}; +} bool fromJSON(const json::Value &V, Req &T, json::Path P) { json::ObjectMapper O(V, P); - return O && O.map("req", T.name); + return O && O.map("name", T.name) && O.map("id", T.id) && + O.map("params", T.params); +} +bool operator==(const Req &a, const Req &b) { + return a.name == b.name && a.id == b.id && a.params == b.params; } -bool operator==(const Req &a, const Req &b) { return a.name == b.name; } inline llvm::raw_ostream &operator<<(llvm::raw_ostream &OS, const Req &V) { OS << toJSON(V); return OS; @@ -58,14 +77,22 @@ void PrintTo(const Req &message, std::ostream *os) { } struct Resp { - std::string name; + int id = 0; + int errorCode = 0; + std::optional result; }; -json::Value toJSON(const Resp &T) { return json::Object{{"resp", T.name}}; } +json::Value toJSON(const Resp &T) { + return json::Object{ + {"id", T.id}, {"errorCode", T.errorCode}, {"result", T.result}}; +} bool fromJSON(const json::Value &V, Resp &T, json::Path P) { json::ObjectMapper O(V, P); - return O && O.map("resp", T.name); + return O && O.map("id", T.id) && O.mapOptional("errorCode", T.errorCode) && + O.map("result", T.result); +} +bool operator==(const Resp &a, const Resp &b) { + return a.id == b.id && a.errorCode == b.errorCode && a.result == b.result; } -bool operator==(const Resp &a, const Resp &b) { return a.name == b.name; } inline llvm::raw_ostream &operator<<(llvm::raw_ostream &OS, const Resp &V) { OS << toJSON(V); return OS; @@ -79,11 +106,14 @@ void PrintTo(const Resp &message, std::ostream *os) { struct Evt { std::string name; + std::optional params; }; -json::Value toJSON(const Evt &T) { return json::Object{{"evt", T.name}}; } +json::Value toJSON(const Evt &T) { + return json::Object{{"name", T.name}, {"params", T.params}}; +} bool fromJSON(const json::Value &V, Evt &T, json::Path P) { json::ObjectMapper O(V, P); - return O && O.map("evt", T.name); + return O && O.map("name", T.name) && O.map("params", T.params); } bool operator==(const Evt &a, const Evt &b) { return a.name == b.name; } inline llvm::raw_ostream &operator<<(llvm::raw_ostream &OS, const Evt &V) { @@ -107,41 +137,114 @@ bool fromJSON(const json::Value &V, Message &msg, json::Path P) { P.report("expected object"); return false; } - if (O->get("req")) { - Req R; - if (!fromJSON(V, R, P)) + + if (O->find("id") == O->end()) { + Evt E; + if (!fromJSON(V, E, P)) return false; - msg = std::move(R); + msg = std::move(E); return true; } - if (O->get("resp")) { - Resp R; + + if (O->get("name")) { + Req R; if (!fromJSON(V, R, P)) return false; msg = std::move(R); return true; } - if (O->get("evt")) { - Evt E; - if (!fromJSON(V, E, P)) - return false; - msg = std::move(E); - return true; - } - P.report("unknown message type"); - return false; + Resp R; + if (!fromJSON(V, R, P)) + return false; + + msg = std::move(R); + return true; } -} // namespace test_protocol +struct MyFnParams { + int a = 0; + int b = 0; +}; +json::Value toJSON(const MyFnParams &T) { + return json::Object{{"a", T.a}, {"b", T.b}}; +} +bool fromJSON(const json::Value &V, MyFnParams &T, json::Path P) { + json::ObjectMapper O(V, P); + return O && O.map("a", T.a) && O.map("b", T.b); +} + +struct MyFnResult { + int c = 0; +}; +json::Value toJSON(const MyFnResult &T) { return json::Object{{"c", T.c}}; } +bool fromJSON(const json::Value &V, MyFnResult &T, json::Path P) { + json::ObjectMapper O(V, P); + return O && O.map("c", T.c); +} -template -class JSONTransportTest : public PipePairTest { +struct ProtoDesc { + using Id = int; + using Req = Req; + using Resp = Resp; + using Evt = Evt; + static inline Id InitialId() { return 0; } + static inline Req Make(Id id, llvm::StringRef method, + std::optional params) { + return Req{id, method.str(), params}; + } + static inline Evt Make(llvm::StringRef method, + std::optional params) { + return Evt{method.str(), params}; + } + static inline Resp Make(Req req, llvm::Error error) { + Resp resp; + resp.id = req.id; + llvm::handleAllErrors( + std::move(error), [&](const llvm::ErrorInfoBase &err) { + std::error_code cerr = err.convertToErrorCode(); + resp.errorCode = + cerr == llvm::inconvertibleErrorCode() ? 1 : cerr.value(); + resp.result = err.message(); + }); + return resp; + } + static inline Resp Make(Req req, std::optional result) { + return Resp{req.id, 0, std::move(result)}; + } + static inline Id KeyFor(Resp r) { return r.id; } + static inline std::string KeyFor(Req r) { return r.name; } + static inline std::string KeyFor(Evt e) { return e.name; } + static inline std::optional Extract(Req r) { + return r.params; + } + static inline llvm::Expected Extract(Resp r) { + if (r.errorCode != 0) + return llvm::createStringError( + std::error_code(r.errorCode, std::generic_category()), + r.result && r.result->getAsString() ? *r.result->getAsString() + : "no-message"); + return r.result; + } + static inline std::optional Extract(Evt e) { + return e.params; + } +}; + +using Transport = TestTransport; +using Binder = lldb_private::transport::Binder; +using MessageHandler = MockMessageHandler; + +} // namespace test_protocol + +template class JSONTransportTest : public PipePairTest { protected: - MockMessageHandler message_handler; + SubsystemRAII subsystems; + + test_protocol::MessageHandler message_handler; std::unique_ptr transport; MainLoop loop; @@ -191,8 +294,7 @@ class JSONTransportTest : public PipePairTest { }; class TestHTTPDelimitedJSONTransport final - : public HTTPDelimitedJSONTransport { + : public HTTPDelimitedJSONTransport { public: using HTTPDelimitedJSONTransport::HTTPDelimitedJSONTransport; @@ -204,9 +306,7 @@ class TestHTTPDelimitedJSONTransport final }; class HTTPDelimitedJSONTransportTest - : public JSONTransportTest { + : public JSONTransportTest { public: using JSONTransportTest::JSONTransportTest; @@ -222,8 +322,7 @@ class HTTPDelimitedJSONTransportTest }; class TestJSONRPCTransport final - : public JSONRPCTransport { + : public JSONRPCTransport { public: using JSONRPCTransport::JSONRPCTransport; @@ -234,9 +333,7 @@ class TestJSONRPCTransport final std::vector log_messages; }; -class JSONRPCTransportTest - : public JSONTransportTest { +class JSONRPCTransportTest : public JSONTransportTest { public: using JSONTransportTest::JSONTransportTest; @@ -248,6 +345,33 @@ class JSONRPCTransportTest } }; +class TransportBinderTest : public testing::Test { +protected: + SubsystemRAII subsystems; + + std::unique_ptr to_remote; + std::unique_ptr from_remote; + std::unique_ptr binder; + test_protocol::MessageHandler remote; + MainLoop loop; + + void SetUp() override { + std::tie(to_remote, from_remote) = test_protocol::Transport::createPair(); + binder = std::make_unique(*to_remote); + + auto binder_handle = to_remote->RegisterMessageHandler(loop, remote); + EXPECT_THAT_EXPECTED(binder_handle, Succeeded()); + + auto remote_handle = from_remote->RegisterMessageHandler(loop, *binder); + EXPECT_THAT_EXPECTED(remote_handle, Succeeded()); + } + + void Run() { + loop.AddPendingCallback([](auto &loop) { loop.RequestTermination(); }); + EXPECT_THAT_ERROR(loop.Run().takeError(), Succeeded()); + } +}; + } // namespace // Failing on Windows, see https://github.com/llvm/llvm-project/issues/153446. @@ -269,35 +393,45 @@ TEST_F(HTTPDelimitedJSONTransportTest, MalformedRequests) { } TEST_F(HTTPDelimitedJSONTransportTest, Read) { - Write(Req{"foo"}); - EXPECT_CALL(message_handler, Received(Req{"foo"})); + Write(Req{6, "foo", std::nullopt}); + EXPECT_CALL(message_handler, Received(Req{6, "foo", std::nullopt})); ASSERT_THAT_ERROR(Run(), Succeeded()); } TEST_F(HTTPDelimitedJSONTransportTest, ReadMultipleMessagesInSingleWrite) { InSequence seq; - Write(Message{Req{"one"}}, Message{Evt{"two"}}, Message{Resp{"three"}}); - EXPECT_CALL(message_handler, Received(Req{"one"})); - EXPECT_CALL(message_handler, Received(Evt{"two"})); - EXPECT_CALL(message_handler, Received(Resp{"three"})); + Write( + Message{ + Req{6, "one", std::nullopt}, + }, + Message{ + Evt{"two", std::nullopt}, + }, + Message{ + Resp{2, 0, std::nullopt}, + }); + EXPECT_CALL(message_handler, Received(Req{6, "one", std::nullopt})); + EXPECT_CALL(message_handler, Received(Evt{"two", std::nullopt})); + EXPECT_CALL(message_handler, Received(Resp{2, 0, std::nullopt})); ASSERT_THAT_ERROR(Run(), Succeeded()); } TEST_F(HTTPDelimitedJSONTransportTest, ReadAcrossMultipleChunks) { std::string long_str = std::string( - HTTPDelimitedJSONTransport::kReadBufferSize * 2, 'x'); - Write(Req{long_str}); - EXPECT_CALL(message_handler, Received(Req{long_str})); + HTTPDelimitedJSONTransport::kReadBufferSize * 2, + 'x'); + Write(Req{5, long_str, std::nullopt}); + EXPECT_CALL(message_handler, Received(Req{5, long_str, std::nullopt})); ASSERT_THAT_ERROR(Run(), Succeeded()); } TEST_F(HTTPDelimitedJSONTransportTest, ReadPartialMessage) { - std::string message = Encode(Req{"foo"}); + std::string message = Encode(Req{5, "foo", std::nullopt}); auto split_at = message.size() / 2; std::string part1 = message.substr(0, split_at); std::string part2 = message.substr(split_at); - EXPECT_CALL(message_handler, Received(Req{"foo"})); + EXPECT_CALL(message_handler, Received(Req{5, "foo", std::nullopt})); ASSERT_THAT_EXPECTED(input.Write(part1.data(), part1.size()), Succeeded()); loop.AddPendingCallback( @@ -309,12 +443,12 @@ TEST_F(HTTPDelimitedJSONTransportTest, ReadPartialMessage) { } TEST_F(HTTPDelimitedJSONTransportTest, ReadWithZeroByteWrites) { - std::string message = Encode(Req{"foo"}); + std::string message = Encode(Req{6, "foo", std::nullopt}); auto split_at = message.size() / 2; std::string part1 = message.substr(0, split_at); std::string part2 = message.substr(split_at); - EXPECT_CALL(message_handler, Received(Req{"foo"})); + EXPECT_CALL(message_handler, Received(Req{6, "foo", std::nullopt})); ASSERT_THAT_EXPECTED(input.Write(part1.data(), part1.size()), Succeeded()); @@ -366,20 +500,21 @@ TEST_F(HTTPDelimitedJSONTransportTest, InvalidTransport) { } TEST_F(HTTPDelimitedJSONTransportTest, Write) { - ASSERT_THAT_ERROR(transport->Send(Req{"foo"}), Succeeded()); - ASSERT_THAT_ERROR(transport->Send(Resp{"bar"}), Succeeded()); - ASSERT_THAT_ERROR(transport->Send(Evt{"baz"}), Succeeded()); + ASSERT_THAT_ERROR(transport->Send(Req{7, "foo", std::nullopt}), Succeeded()); + ASSERT_THAT_ERROR(transport->Send(Resp{5, 0, "bar"}), Succeeded()); + ASSERT_THAT_ERROR(transport->Send(Evt{"baz", std::nullopt}), Succeeded()); output.CloseWriteFileDescriptor(); char buf[1024]; Expected bytes_read = output.Read(buf, sizeof(buf), std::chrono::milliseconds(1)); ASSERT_THAT_EXPECTED(bytes_read, Succeeded()); - ASSERT_EQ(StringRef(buf, *bytes_read), StringRef("Content-Length: 13\r\n\r\n" - R"({"req":"foo"})" - "Content-Length: 14\r\n\r\n" - R"({"resp":"bar"})" - "Content-Length: 13\r\n\r\n" - R"({"evt":"baz"})")); + ASSERT_EQ(StringRef(buf, *bytes_read), + StringRef("Content-Length: 35\r\n\r\n" + R"({"id":7,"name":"foo","params":null})" + "Content-Length: 37\r\n\r\n" + R"({"errorCode":0,"id":5,"result":"bar"})" + "Content-Length: 28\r\n\r\n" + R"({"name":"baz","params":null})")); } TEST_F(JSONRPCTransportTest, MalformedRequests) { @@ -395,37 +530,38 @@ TEST_F(JSONRPCTransportTest, MalformedRequests) { } TEST_F(JSONRPCTransportTest, Read) { - Write(Message{Req{"foo"}}); - EXPECT_CALL(message_handler, Received(Req{"foo"})); + Write(Message{Req{1, "foo", std::nullopt}}); + EXPECT_CALL(message_handler, Received(Req{1, "foo", std::nullopt})); ASSERT_THAT_ERROR(Run(), Succeeded()); } TEST_F(JSONRPCTransportTest, ReadMultipleMessagesInSingleWrite) { InSequence seq; - Write(Message{Req{"one"}}, Message{Evt{"two"}}, Message{Resp{"three"}}); - EXPECT_CALL(message_handler, Received(Req{"one"})); - EXPECT_CALL(message_handler, Received(Evt{"two"})); - EXPECT_CALL(message_handler, Received(Resp{"three"})); + Write(Message{Req{1, "one", std::nullopt}}, Message{Evt{"two", std::nullopt}}, + Message{Resp{3, 0, "three"}}); + EXPECT_CALL(message_handler, Received(Req{1, "one", std::nullopt})); + EXPECT_CALL(message_handler, Received(Evt{"two", std::nullopt})); + EXPECT_CALL(message_handler, Received(Resp{3, 0, "three"})); ASSERT_THAT_ERROR(Run(), Succeeded()); } TEST_F(JSONRPCTransportTest, ReadAcrossMultipleChunks) { // Use a string longer than the chunk size to ensure we split the message // across the chunk boundary. - std::string long_str = - std::string(IOTransport::kReadBufferSize * 2, 'x'); - Write(Req{long_str}); - EXPECT_CALL(message_handler, Received(Req{long_str})); + std::string long_str = std::string( + IOTransport::kReadBufferSize * 2, 'x'); + Write(Req{42, long_str, std::nullopt}); + EXPECT_CALL(message_handler, Received(Req{42, long_str, std::nullopt})); ASSERT_THAT_ERROR(Run(), Succeeded()); } TEST_F(JSONRPCTransportTest, ReadPartialMessage) { - std::string message = R"({"req": "foo"})" + std::string message = R"({"id":42,"name":"foo","params":null})" "\n"; std::string part1 = message.substr(0, 7); std::string part2 = message.substr(7); - EXPECT_CALL(message_handler, Received(Req{"foo"})); + EXPECT_CALL(message_handler, Received(Req{42, "foo", std::nullopt})); ASSERT_THAT_EXPECTED(input.Write(part1.data(), part1.size()), Succeeded()); loop.AddPendingCallback( @@ -455,20 +591,21 @@ TEST_F(JSONRPCTransportTest, ReaderWithUnhandledData) { } TEST_F(JSONRPCTransportTest, Write) { - ASSERT_THAT_ERROR(transport->Send(Req{"foo"}), Succeeded()); - ASSERT_THAT_ERROR(transport->Send(Resp{"bar"}), Succeeded()); - ASSERT_THAT_ERROR(transport->Send(Evt{"baz"}), Succeeded()); + ASSERT_THAT_ERROR(transport->Send(Req{11, "foo", std::nullopt}), Succeeded()); + ASSERT_THAT_ERROR(transport->Send(Resp{14, 0, "bar"}), Succeeded()); + ASSERT_THAT_ERROR(transport->Send(Evt{"baz", std::nullopt}), Succeeded()); output.CloseWriteFileDescriptor(); char buf[1024]; Expected bytes_read = output.Read(buf, sizeof(buf), std::chrono::milliseconds(1)); ASSERT_THAT_EXPECTED(bytes_read, Succeeded()); - ASSERT_EQ(StringRef(buf, *bytes_read), StringRef(R"({"req":"foo"})" - "\n" - R"({"resp":"bar"})" - "\n" - R"({"evt":"baz"})" - "\n")); + ASSERT_EQ(StringRef(buf, *bytes_read), + StringRef(R"({"id":11,"name":"foo","params":null})" + "\n" + R"({"errorCode":0,"id":14,"result":"bar"})" + "\n" + R"({"name":"baz","params":null})" + "\n")); } TEST_F(JSONRPCTransportTest, InvalidTransport) { @@ -477,4 +614,183 @@ TEST_F(JSONRPCTransportTest, InvalidTransport) { FailedWithMessage("IO object is not valid.")); } +// Out-bound binding request handler. +TEST_F(TransportBinderTest, OutBoundRequests) { + OutgoingRequest addFn = + binder->Bind("add"); + bool replied = false; + addFn(MyFnParams{1, 2}, [&](Expected result) { + EXPECT_THAT_EXPECTED(result, Succeeded()); + EXPECT_EQ(result->c, 3); + replied = true; + }); + EXPECT_CALL(remote, Received(Req{1, "add", MyFnParams{1, 2}})); + EXPECT_THAT_ERROR(from_remote->Send(Resp{1, 0, toJSON(MyFnResult{3})}), + Succeeded()); + Run(); + EXPECT_TRUE(replied); +} + +TEST_F(TransportBinderTest, OutBoundRequestsVoidParams) { + OutgoingRequest voidParamFn = + binder->Bind("voidParam"); + bool replied = false; + voidParamFn([&](Expected result) { + EXPECT_THAT_EXPECTED(result, Succeeded()); + EXPECT_EQ(result->c, 3); + replied = true; + }); + EXPECT_CALL(remote, Received(Req{1, "voidParam", std::nullopt})); + EXPECT_THAT_ERROR(from_remote->Send(Resp{1, 0, toJSON(MyFnResult{3})}), + Succeeded()); + Run(); + EXPECT_TRUE(replied); +} + +TEST_F(TransportBinderTest, OutBoundRequestsVoidResult) { + OutgoingRequest voidResultFn = + binder->Bind("voidResult"); + bool replied = false; + voidResultFn(MyFnParams{4, 5}, [&](llvm::Error error) { + EXPECT_THAT_ERROR(std::move(error), Succeeded()); + replied = true; + }); + EXPECT_CALL(remote, Received(Req{1, "voidResult", MyFnParams{4, 5}})); + EXPECT_THAT_ERROR(from_remote->Send(Resp{1, 0, std::nullopt}), Succeeded()); + Run(); + EXPECT_TRUE(replied); +} + +TEST_F(TransportBinderTest, OutBoundRequestsVoidParamsAndVoidResult) { + OutgoingRequest voidParamAndResultFn = + binder->Bind("voidParamAndResult"); + bool replied = false; + voidParamAndResultFn([&](llvm::Error error) { + EXPECT_THAT_ERROR(std::move(error), Succeeded()); + replied = true; + }); + EXPECT_CALL(remote, Received(Req{1, "voidParamAndResult", std::nullopt})); + EXPECT_THAT_ERROR(from_remote->Send(Resp{1, 0, std::nullopt}), Succeeded()); + Run(); + EXPECT_TRUE(replied); +} + +// In-bound binding request handler. +TEST_F(TransportBinderTest, InBoundRequests) { + bool called = false; + binder->Bind( + "add", + [&](const int captured_param, + const MyFnParams ¶ms) -> Expected { + called = true; + return MyFnResult{params.a + params.b + captured_param}; + }, + 2); + EXPECT_THAT_ERROR(from_remote->Send(Req{1, "add", MyFnParams{3, 4}}), + Succeeded()); + + EXPECT_CALL(remote, Received(Resp{1, 0, MyFnResult{9}})); + Run(); + EXPECT_TRUE(called); +} + +TEST_F(TransportBinderTest, InBoundRequestsVoidParams) { + bool called = false; + binder->Bind( + "voidParam", + [&](const int captured_param) -> Expected { + called = true; + return MyFnResult{captured_param}; + }, + 2); + EXPECT_THAT_ERROR(from_remote->Send(Req{2, "voidParam", std::nullopt}), + Succeeded()); + EXPECT_CALL(remote, Received(Resp{2, 0, MyFnResult{2}})); + Run(); + EXPECT_TRUE(called); +} + +TEST_F(TransportBinderTest, InBoundRequestsVoidResult) { + bool called = false; + binder->Bind( + "voidResult", + [&](const int captured_param, const MyFnParams ¶ms) -> llvm::Error { + called = true; + EXPECT_EQ(captured_param, 2); + EXPECT_EQ(params.a, 3); + EXPECT_EQ(params.b, 4); + return llvm::Error::success(); + }, + 2); + EXPECT_THAT_ERROR(from_remote->Send(Req{3, "voidResult", MyFnParams{3, 4}}), + Succeeded()); + EXPECT_CALL(remote, Received(Resp{3, 0, std::nullopt})); + Run(); + EXPECT_TRUE(called); +} +TEST_F(TransportBinderTest, InBoundRequestsVoidParamsAndResult) { + bool called = false; + binder->Bind( + "voidParamAndResult", + [&](const int captured_param) -> llvm::Error { + called = true; + EXPECT_EQ(captured_param, 2); + return llvm::Error::success(); + }, + 2); + EXPECT_THAT_ERROR( + from_remote->Send(Req{4, "voidParamAndResult", std::nullopt}), + Succeeded()); + EXPECT_CALL(remote, Received(Resp{4, 0, std::nullopt})); + Run(); + EXPECT_TRUE(called); +} + +// Out-bound binding event handler. +TEST_F(TransportBinderTest, OutBoundEvents) { + OutgoingEvent emitEvent = binder->Bind("evt"); + emitEvent(MyFnParams{1, 2}); + EXPECT_CALL(remote, Received(Evt{"evt", MyFnParams{1, 2}})); + Run(); +} + +TEST_F(TransportBinderTest, OutBoundEventsVoidParams) { + OutgoingEvent emitEvent = binder->Bind("evt"); + emitEvent(); + EXPECT_CALL(remote, Received(Evt{"evt", std::nullopt})); + Run(); +} + +// In-bound binding event handler. +TEST_F(TransportBinderTest, InBoundEvents) { + bool called = false; + binder->Bind( + "evt", + [&](const int captured_arg, const MyFnParams ¶ms) { + EXPECT_EQ(captured_arg, 42); + EXPECT_EQ(params.a, 3); + EXPECT_EQ(params.b, 4); + called = true; + }, + 42); + EXPECT_THAT_ERROR(from_remote->Send(Evt{"evt", MyFnParams{3, 4}}), + Succeeded()); + Run(); + EXPECT_TRUE(called); +} + +TEST_F(TransportBinderTest, InBoundEventsVoidParams) { + bool called = false; + binder->Bind( + "evt", + [&](const int captured_arg) { + EXPECT_EQ(captured_arg, 42); + called = true; + }, + 42); + EXPECT_THAT_ERROR(from_remote->Send(Evt{"evt", std::nullopt}), Succeeded()); + Run(); + EXPECT_TRUE(called); +} + #endif diff --git a/lldb/unittests/Protocol/ProtocolMCPServerTest.cpp b/lldb/unittests/Protocol/ProtocolMCPServerTest.cpp index f3ca4cfc01788..9628cbd91e9ce 100644 --- a/lldb/unittests/Protocol/ProtocolMCPServerTest.cpp +++ b/lldb/unittests/Protocol/ProtocolMCPServerTest.cpp @@ -6,9 +6,8 @@ // //===----------------------------------------------------------------------===// -#include "ProtocolMCPTestUtilities.h" +#include "ProtocolMCPTestUtilities.h" // IWYU pragma: keep #include "TestingSupport/Host/JSONTransportTestUtilities.h" -#include "TestingSupport/Host/PipeTestUtilities.h" #include "TestingSupport/SubsystemRAII.h" #include "lldb/Host/FileSystem.h" #include "lldb/Host/HostInfo.h" @@ -28,20 +27,22 @@ #include "llvm/Testing/Support/Error.h" #include "gmock/gmock.h" #include "gtest/gtest.h" -#include -#include +#include +#include +#include +#include using namespace llvm; using namespace lldb; using namespace lldb_private; +using namespace lldb_private::transport; using namespace lldb_protocol::mcp; namespace { -class TestServer : public Server { -public: - using Server::Server; -}; +template Response make_response(T &&result, Id id = 1) { + return Response{id, std::forward(result)}; +} /// Test tool that returns it argument as text. class TestTool : public Tool { @@ -101,7 +102,9 @@ class ErrorTool : public Tool { using Tool::Tool; llvm::Expected Call(const ToolArguments &args) override { - return llvm::createStringError("error"); + return llvm::createStringError( + std::error_code(eErrorCodeInternalError, std::generic_category()), + "error"); } }; @@ -118,195 +121,207 @@ class FailTool : public Tool { } }; -class ProtocolServerMCPTest : public PipePairTest { +class TestServer : public Server { +public: + using Server::Bind; + using Server::Server; +}; + +using Transport = TestTransport; + +class ProtocolServerMCPTest : public testing::Test { public: SubsystemRAII subsystems; MainLoop loop; + lldb_private::MainLoop::ReadHandleUP handles[2]; - std::unique_ptr from_client; - std::unique_ptr to_client; - MainLoopBase::ReadHandleUP handles[2]; - + std::unique_ptr to_server; + MCPBinderUP binder; std::unique_ptr server_up; - MockMessageHandler message_handler; - llvm::Error Write(llvm::StringRef message) { - llvm::Expected value = json::parse(message); - if (!value) - return value.takeError(); - return from_client->Write(*value); - } + std::unique_ptr to_client; + MockMessageHandler client; - llvm::Error Write(json::Value value) { return from_client->Write(value); } + std::vector logged_messages; - /// Run the transport MainLoop and return any messages received. - llvm::Error Run() { - loop.AddCallback([](MainLoopBase &loop) { loop.RequestTermination(); }, - std::chrono::milliseconds(10)); - return loop.Run().takeError(); + /// Runs the MainLoop a single time, executing any pending callbacks. + void Run() { + loop.AddPendingCallback( + [](MainLoopBase &loop) { loop.RequestTermination(); }); + EXPECT_THAT_ERROR(loop.Run().takeError(), Succeeded()); } void SetUp() override { - PipePairTest::SetUp(); - - from_client = std::make_unique( - std::make_shared(input.GetReadFileDescriptor(), - File::eOpenOptionReadOnly, - NativeFile::Unowned), - std::make_shared(output.GetWriteFileDescriptor(), - File::eOpenOptionWriteOnly, - NativeFile::Unowned), - [](StringRef message) { - // Uncomment for debugging - // llvm::errs() << "from_client: " << message << '\n'; - }); - to_client = std::make_unique( - std::make_shared(output.GetReadFileDescriptor(), - File::eOpenOptionReadOnly, - NativeFile::Unowned), - std::make_shared(input.GetWriteFileDescriptor(), - File::eOpenOptionWriteOnly, - NativeFile::Unowned), - [](StringRef message) { - // Uncomment for debugging - // llvm::errs() << "to_client: " << message << '\n'; - }); - - server_up = std::make_unique("lldb-mcp", "0.1.0", *to_client, - [](StringRef message) { - // Uncomment for debugging - // llvm::errs() << "server: " << - // message << '\n'; - }); - - auto maybe_from_client_handle = - from_client->RegisterMessageHandler(loop, message_handler); - EXPECT_THAT_EXPECTED(maybe_from_client_handle, Succeeded()); - handles[0] = std::move(*maybe_from_client_handle); - - auto maybe_to_client_handle = - to_client->RegisterMessageHandler(loop, *server_up); - EXPECT_THAT_EXPECTED(maybe_to_client_handle, Succeeded()); - handles[1] = std::move(*maybe_to_client_handle); + std::tie(to_client, to_server) = Transport::createPair(); + + server_up = std::make_unique( + "lldb-mcp", "0.1.0", + [this](StringRef msg) { logged_messages.push_back(msg.str()); }); + binder = server_up->Bind(*to_client); + auto server_handle = to_server->RegisterMessageHandler(loop, *binder); + EXPECT_THAT_EXPECTED(server_handle, Succeeded()); + binder->OnError([](llvm::Error error) { + llvm::errs() << formatv("Server transport error: {0}", error); + }); + handles[0] = std::move(*server_handle); + + auto client_handle = to_client->RegisterMessageHandler(loop, client); + EXPECT_THAT_EXPECTED(client_handle, Succeeded()); + handles[1] = std::move(*client_handle); + } + + template + Expected Call(StringRef method, const Params ¶ms) { + std::promise promised_result; + Request req = + lldb_protocol::mcp::Request{/*id=*/1, method.str(), toJSON(params)}; + EXPECT_THAT_ERROR(to_server->Send(req), Succeeded()); + EXPECT_CALL(client, Received(testing::An())) + .WillOnce( + [&](const Response &resp) { promised_result.set_value(resp); }); + Run(); + Response resp = promised_result.get_future().get(); + return toJSON(resp); + } + + template + Expected + Capture(llvm::unique_function)> &fn) { + std::promise> promised_result; + fn([&promised_result](llvm::Expected result) { + promised_result.set_value(std::move(result)); + }); + Run(); + llvm::Expected result = promised_result.get_future().get(); + if (!result) + return result.takeError(); + return toJSON(*result); + } + + template + Expected + Capture(llvm::unique_function)> &fn, + const Params ¶ms) { + std::promise> promised_result; + fn(params, [&promised_result](llvm::Expected result) { + promised_result.set_value(std::move(result)); + }); + Run(); + llvm::Expected result = promised_result.get_future().get(); + if (!result) + return result.takeError(); + return toJSON(*result); } }; template -Request make_request(StringLiteral method, T &¶ms, Id id = 1) { - return Request{id, method.str(), toJSON(std::forward(params))}; -} - -template Response make_response(T &&result, Id id = 1) { - return Response{id, std::forward(result)}; +inline testing::internal::EqMatcher HasJSON(T x) { + return testing::internal::EqMatcher(toJSON(x)); } } // namespace TEST_F(ProtocolServerMCPTest, Initialization) { - Request request = make_request( - "initialize", InitializeParams{/*protocolVersion=*/"2024-11-05", - /*capabilities=*/{}, - /*clientInfo=*/{"lldb-unit", "0.1.0"}}); - Response response = make_response( - InitializeResult{/*protocolVersion=*/"2024-11-05", - /*capabilities=*/{/*supportsToolsList=*/true}, - /*serverInfo=*/{"lldb-mcp", "0.1.0"}}); - - ASSERT_THAT_ERROR(Write(request), Succeeded()); - EXPECT_CALL(message_handler, Received(response)); - EXPECT_THAT_ERROR(Run(), Succeeded()); + EXPECT_THAT_EXPECTED( + (Call( + "initialize", + InitializeParams{/*protocolVersion=*/"2024-11-05", + /*capabilities=*/{}, + /*clientInfo=*/{"lldb-unit", "0.1.0"}})), + HasValue(make_response( + InitializeResult{/*protocolVersion=*/"2024-11-05", + /*capabilities=*/ + { + /*supportsToolsList=*/true, + /*supportsResourcesList=*/true, + }, + /*serverInfo=*/{"lldb-mcp", "0.1.0"}}))); } TEST_F(ProtocolServerMCPTest, ToolsList) { server_up->AddTool(std::make_unique("test", "test tool")); - Request request = make_request("tools/list", Void{}, /*id=*/"one"); - ToolDefinition test_tool; test_tool.name = "test"; test_tool.description = "test tool"; test_tool.inputSchema = json::Object{{"type", "object"}}; - Response response = make_response(ListToolsResult{{test_tool}}, /*id=*/"one"); - - ASSERT_THAT_ERROR(Write(request), llvm::Succeeded()); - EXPECT_CALL(message_handler, Received(response)); - EXPECT_THAT_ERROR(Run(), Succeeded()); + EXPECT_THAT_EXPECTED(Call("tools/list", Void{}), + HasValue(make_response(ListToolsResult{{test_tool}}))); } TEST_F(ProtocolServerMCPTest, ResourcesList) { server_up->AddResourceProvider(std::make_unique()); - Request request = make_request("resources/list", Void{}); - Response response = make_response(ListResourcesResult{ - {{/*uri=*/"lldb://foo/bar", /*name=*/"name", - /*description=*/"description", /*mimeType=*/"application/json"}}}); - - ASSERT_THAT_ERROR(Write(request), llvm::Succeeded()); - EXPECT_CALL(message_handler, Received(response)); - EXPECT_THAT_ERROR(Run(), Succeeded()); + EXPECT_THAT_EXPECTED(Call("resources/list", Void{}), + HasValue(make_response(ListResourcesResult{{ + { + /*uri=*/"lldb://foo/bar", + /*name=*/"name", + /*description=*/"description", + /*mimeType=*/"application/json", + }, + }}))); } TEST_F(ProtocolServerMCPTest, ToolsCall) { server_up->AddTool(std::make_unique("test", "test tool")); - Request request = make_request( - "tools/call", CallToolParams{/*name=*/"test", /*arguments=*/json::Object{ - {"arguments", "foo"}, - {"debugger_id", 0}, - }}); - Response response = make_response(CallToolResult{{{/*text=*/"foo"}}}); - - ASSERT_THAT_ERROR(Write(request), llvm::Succeeded()); - EXPECT_CALL(message_handler, Received(response)); - EXPECT_THAT_ERROR(Run(), Succeeded()); + EXPECT_THAT_EXPECTED( + (Call("tools/call", + CallToolParams{ + /*name=*/"test", + /*arguments=*/ + json::Object{ + {"arguments", "foo"}, + {"debugger_id", 0}, + }, + })), + HasValue(make_response(CallToolResult{{{/*text=*/"foo"}}}))); } TEST_F(ProtocolServerMCPTest, ToolsCallError) { server_up->AddTool(std::make_unique("error", "error tool")); - Request request = make_request( - "tools/call", CallToolParams{/*name=*/"error", /*arguments=*/json::Object{ - {"arguments", "foo"}, - {"debugger_id", 0}, - }}); - Response response = - make_response(lldb_protocol::mcp::Error{eErrorCodeInternalError, - /*message=*/"error"}); - - ASSERT_THAT_ERROR(Write(request), llvm::Succeeded()); - EXPECT_CALL(message_handler, Received(response)); - EXPECT_THAT_ERROR(Run(), Succeeded()); + EXPECT_THAT_EXPECTED((Call( + "tools/call", CallToolParams{ + /*name=*/"error", + /*arguments=*/ + json::Object{ + {"arguments", "foo"}, + {"debugger_id", 0}, + }, + })), + HasValue(make_response(lldb_protocol::mcp::Error{ + eErrorCodeInternalError, "error"}))); } TEST_F(ProtocolServerMCPTest, ToolsCallFail) { server_up->AddTool(std::make_unique("fail", "fail tool")); - Request request = make_request( - "tools/call", CallToolParams{/*name=*/"fail", /*arguments=*/json::Object{ - {"arguments", "foo"}, - {"debugger_id", 0}, - }}); - Response response = - make_response(CallToolResult{{{/*text=*/"failed"}}, /*isError=*/true}); - - ASSERT_THAT_ERROR(Write(request), llvm::Succeeded()); - EXPECT_CALL(message_handler, Received(response)); - EXPECT_THAT_ERROR(Run(), Succeeded()); + EXPECT_THAT_EXPECTED((Call( + "tools/call", CallToolParams{ + /*name=*/"fail", + /*arguments=*/ + json::Object{ + {"arguments", "foo"}, + {"debugger_id", 0}, + }, + })), + HasValue(make_response(CallToolResult{ + {{/*text=*/"failed"}}, + /*isError=*/true, + }))); } TEST_F(ProtocolServerMCPTest, NotificationInitialized) { - bool handler_called = false; - std::condition_variable cv; - - server_up->AddNotificationHandler( - "notifications/initialized", - [&](const Notification ¬ification) { handler_called = true; }); - llvm::StringLiteral request = - R"json({"method":"notifications/initialized","jsonrpc":"2.0"})json"; - - ASSERT_THAT_ERROR(Write(request), llvm::Succeeded()); - EXPECT_THAT_ERROR(Run(), Succeeded()); - EXPECT_TRUE(handler_called); + EXPECT_THAT_ERROR(to_server->Send(lldb_protocol::mcp::Notification{ + "notifications/initialized", + std::nullopt, + }), + Succeeded()); + Run(); + EXPECT_THAT(logged_messages, + testing::Contains("MCP initialization complete")); } diff --git a/lldb/unittests/TestingSupport/Host/JSONTransportTestUtilities.h b/lldb/unittests/TestingSupport/Host/JSONTransportTestUtilities.h index 5a9eb8e59f2b6..bacf8ca36aa07 100644 --- a/lldb/unittests/TestingSupport/Host/JSONTransportTestUtilities.h +++ b/lldb/unittests/TestingSupport/Host/JSONTransportTestUtilities.h @@ -6,19 +6,105 @@ // //===----------------------------------------------------------------------===// -#ifndef LLDB_UNITTESTS_TESTINGSUPPORT_HOST_NATIVEPROCESSTESTUTILS_H -#define LLDB_UNITTESTS_TESTINGSUPPORT_HOST_NATIVEPROCESSTESTUTILS_H +#ifndef LLDB_UNITTESTS_TESTINGSUPPORT_HOST_JSONTRANSPORTTESTUTILITIES_H +#define LLDB_UNITTESTS_TESTINGSUPPORT_HOST_JSONTRANSPORTTESTUTILITIES_H +#include "lldb/Host/FileSystem.h" #include "lldb/Host/JSONTransport.h" +#include "lldb/Host/MainLoop.h" +#include "lldb/Utility/FileSpec.h" +#include "llvm/Support/raw_ostream.h" +#include "llvm/Testing/Support/Error.h" #include "gmock/gmock.h" +#include "gtest/gtest.h" +#include +#include +#include -template +template +class TestTransport final + : public lldb_private::transport::JSONTransport { +public: + using MessageHandler = + typename lldb_private::transport::JSONTransport::MessageHandler; + + static std::pair>, + std::unique_ptr>> + createPair() { + std::unique_ptr> transports[2] = { + std::make_unique>(), + std::make_unique>()}; + return std::make_pair(std::move(transports[0]), std::move(transports[1])); + } + + explicit TestTransport() { + llvm::Expected dummy_file = + lldb_private::FileSystem::Instance().Open( + lldb_private::FileSpec(lldb_private::FileSystem::DEV_NULL), + lldb_private::File::eOpenOptionReadWrite); + EXPECT_THAT_EXPECTED(dummy_file, llvm::Succeeded()); + m_dummy_file = std::move(*dummy_file); + } + + llvm::Error Send(const typename Proto::Evt &evt) override { + EXPECT_TRUE(m_loop && m_handler) + << "Send called before RegisterMessageHandler"; + m_loop->AddPendingCallback([this, evt](lldb_private::MainLoopBase &) { + m_handler->Received(evt); + }); + return llvm::Error::success(); + } + + llvm::Error Send(const typename Proto::Req &req) override { + EXPECT_TRUE(m_loop && m_handler) + << "Send called before RegisterMessageHandler"; + m_loop->AddPendingCallback([this, req](lldb_private::MainLoopBase &) { + m_handler->Received(req); + }); + return llvm::Error::success(); + } + + llvm::Error Send(const typename Proto::Resp &resp) override { + EXPECT_TRUE(m_loop && m_handler) + << "Send called before RegisterMessageHandler"; + m_loop->AddPendingCallback([this, resp](lldb_private::MainLoopBase &) { + m_handler->Received(resp); + }); + return llvm::Error::success(); + } + + llvm::Expected + RegisterMessageHandler(lldb_private::MainLoop &loop, + MessageHandler &handler) override { + if (!m_loop) + m_loop = &loop; + if (!m_handler) + m_handler = &handler; + lldb_private::Status status; + auto handle = loop.RegisterReadObject( + m_dummy_file, [](lldb_private::MainLoopBase &) {}, status); + if (status.Fail()) + return status.takeError(); + return handle; + } + +protected: + void Log(llvm::StringRef message) override {}; + +private: + lldb_private::MainLoop *m_loop = nullptr; + MessageHandler *m_handler = nullptr; + // Dummy file for registering with the MainLoop. + lldb::FileSP m_dummy_file = nullptr; +}; + +template class MockMessageHandler final - : public lldb_private::Transport::MessageHandler { + : public lldb_private::transport::JSONTransport::MessageHandler { public: - MOCK_METHOD(void, Received, (const Evt &), (override)); - MOCK_METHOD(void, Received, (const Req &), (override)); - MOCK_METHOD(void, Received, (const Resp &), (override)); + MOCK_METHOD(void, Received, (const typename Proto::Req &), (override)); + MOCK_METHOD(void, Received, (const typename Proto::Resp &), (override)); + MOCK_METHOD(void, Received, (const typename Proto::Evt &), (override)); MOCK_METHOD(void, OnError, (llvm::Error), (override)); MOCK_METHOD(void, OnClosed, (), (override)); };