diff --git a/lldb/include/lldb/Protocol/MCP/Protocol.h b/lldb/include/lldb/Protocol/MCP/Protocol.h index c43b06809bd3f..6448416eee08f 100644 --- a/lldb/include/lldb/Protocol/MCP/Protocol.h +++ b/lldb/include/lldb/Protocol/MCP/Protocol.h @@ -21,7 +21,7 @@ namespace lldb_protocol::mcp { -static llvm::StringLiteral kVersion = "2024-11-05"; +static llvm::StringLiteral kProtocolVersion = "2024-11-05"; /// A request that expects a response. struct Request { diff --git a/lldb/include/lldb/Protocol/MCP/Server.h b/lldb/include/lldb/Protocol/MCP/Server.h new file mode 100644 index 0000000000000..2ac05880de86b --- /dev/null +++ b/lldb/include/lldb/Protocol/MCP/Server.h @@ -0,0 +1,70 @@ +//===----------------------------------------------------------------------===// +// +// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. +// See https://llvm.org/LICENSE.txt for license information. +// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception +// +//===----------------------------------------------------------------------===// + +#ifndef LLDB_PROTOCOL_MCP_SERVER_H +#define LLDB_PROTOCOL_MCP_SERVER_H + +#include "lldb/Protocol/MCP/Protocol.h" +#include "lldb/Protocol/MCP/Resource.h" +#include "lldb/Protocol/MCP/Tool.h" +#include "llvm/ADT/StringMap.h" +#include "llvm/Support/Error.h" +#include + +namespace lldb_protocol::mcp { + +class Server { +public: + Server(std::string name, std::string version); + virtual ~Server() = default; + + void AddTool(std::unique_ptr tool); + void AddResourceProvider(std::unique_ptr resource_provider); + +protected: + virtual Capabilities GetCapabilities() = 0; + + using RequestHandler = + std::function(const Request &)>; + using NotificationHandler = std::function; + + void AddRequestHandlers(); + + void AddRequestHandler(llvm::StringRef method, RequestHandler handler); + void AddNotificationHandler(llvm::StringRef method, + NotificationHandler handler); + + llvm::Expected> HandleData(llvm::StringRef data); + + llvm::Expected Handle(Request request); + void Handle(Notification notification); + + llvm::Expected InitializeHandler(const Request &); + + llvm::Expected ToolsListHandler(const Request &); + llvm::Expected ToolsCallHandler(const Request &); + + llvm::Expected ResourcesListHandler(const Request &); + llvm::Expected ResourcesReadHandler(const Request &); + + std::mutex m_mutex; + +private: + const std::string m_name; + const std::string m_version; + + llvm::StringMap> m_tools; + std::vector> m_resource_providers; + + llvm::StringMap m_request_handlers; + llvm::StringMap m_notification_handlers; +}; + +} // namespace lldb_protocol::mcp + +#endif diff --git a/lldb/source/Plugins/Protocol/MCP/ProtocolServerMCP.cpp b/lldb/source/Plugins/Protocol/MCP/ProtocolServerMCP.cpp index c9fe474d45c49..c359663239dcc 100644 --- a/lldb/source/Plugins/Protocol/MCP/ProtocolServerMCP.cpp +++ b/lldb/source/Plugins/Protocol/MCP/ProtocolServerMCP.cpp @@ -27,25 +27,12 @@ using namespace llvm; LLDB_PLUGIN_DEFINE(ProtocolServerMCP) static constexpr size_t kChunkSize = 1024; +static constexpr llvm::StringLiteral kName = "lldb-mcp"; +static constexpr llvm::StringLiteral kVersion = "0.1.0"; -ProtocolServerMCP::ProtocolServerMCP() : ProtocolServer() { - AddRequestHandler("initialize", - std::bind(&ProtocolServerMCP::InitializeHandler, this, - std::placeholders::_1)); - - AddRequestHandler("tools/list", - std::bind(&ProtocolServerMCP::ToolsListHandler, this, - std::placeholders::_1)); - AddRequestHandler("tools/call", - std::bind(&ProtocolServerMCP::ToolsCallHandler, this, - std::placeholders::_1)); - - AddRequestHandler("resources/list", - std::bind(&ProtocolServerMCP::ResourcesListHandler, this, - std::placeholders::_1)); - AddRequestHandler("resources/read", - std::bind(&ProtocolServerMCP::ResourcesReadHandler, this, - std::placeholders::_1)); +ProtocolServerMCP::ProtocolServerMCP() + : ProtocolServer(), + lldb_protocol::mcp::Server(std::string(kName), std::string(kVersion)) { AddNotificationHandler("notifications/initialized", [](const lldb_protocol::mcp::Notification &) { LLDB_LOG(GetLog(LLDBLog::Host), @@ -77,32 +64,6 @@ llvm::StringRef ProtocolServerMCP::GetPluginDescriptionStatic() { return "MCP Server."; } -llvm::Expected -ProtocolServerMCP::Handle(lldb_protocol::mcp::Request request) { - auto it = m_request_handlers.find(request.method); - if (it != m_request_handlers.end()) { - llvm::Expected response = it->second(request); - if (!response) - return response; - response->id = request.id; - return *response; - } - - return make_error( - llvm::formatv("no handler for request: {0}", request.method).str()); -} - -void ProtocolServerMCP::Handle(lldb_protocol::mcp::Notification notification) { - auto it = m_notification_handlers.find(notification.method); - if (it != m_notification_handlers.end()) { - it->second(notification); - return; - } - - LLDB_LOG(GetLog(LLDBLog::Host), "MPC notification: {0} ({1})", - notification.method, notification.params); -} - void ProtocolServerMCP::AcceptCallback(std::unique_ptr socket) { LLDB_LOG(GetLog(LLDBLog::Host), "New MCP client ({0}) connected", m_clients.size() + 1); @@ -157,7 +118,7 @@ llvm::Error ProtocolServerMCP::ReadCallback(Client &client) { } llvm::Error ProtocolServerMCP::Start(ProtocolServer::Connection connection) { - std::lock_guard guard(m_server_mutex); + std::lock_guard guard(m_mutex); if (m_running) return llvm::createStringError("the MCP server is already running"); @@ -189,7 +150,7 @@ llvm::Error ProtocolServerMCP::Start(ProtocolServer::Connection connection) { llvm::Error ProtocolServerMCP::Stop() { { - std::lock_guard guard(m_server_mutex); + std::lock_guard guard(m_mutex); if (!m_running) return createStringError("the MCP sever is not running"); m_running = false; @@ -204,7 +165,7 @@ llvm::Error ProtocolServerMCP::Stop() { m_loop_thread.join(); { - std::lock_guard guard(m_server_mutex); + std::lock_guard guard(m_mutex); m_listener.reset(); m_listen_handlers.clear(); m_clients.clear(); @@ -213,48 +174,6 @@ llvm::Error ProtocolServerMCP::Stop() { return llvm::Error::success(); } -llvm::Expected> -ProtocolServerMCP::HandleData(llvm::StringRef data) { - auto message = llvm::json::parse(/*JSON=*/data); - if (!message) - return message.takeError(); - - if (const lldb_protocol::mcp::Request *request = - std::get_if(&(*message))) { - llvm::Expected response = Handle(*request); - - // Handle failures by converting them into an Error message. - if (!response) { - lldb_protocol::mcp::Error protocol_error; - llvm::handleAllErrors( - response.takeError(), - [&](const MCPError &err) { protocol_error = err.toProtcolError(); }, - [&](const llvm::ErrorInfoBase &err) { - protocol_error.error.code = MCPError::kInternalError; - protocol_error.error.message = err.message(); - }); - protocol_error.id = request->id; - return protocol_error; - } - - return *response; - } - - if (const lldb_protocol::mcp::Notification *notification = - std::get_if(&(*message))) { - Handle(*notification); - return std::nullopt; - } - - if (std::get_if(&(*message))) - return llvm::createStringError("unexpected MCP message: error"); - - if (std::get_if(&(*message))) - return llvm::createStringError("unexpected MCP message: response"); - - llvm_unreachable("all message types handled"); -} - lldb_protocol::mcp::Capabilities ProtocolServerMCP::GetCapabilities() { lldb_protocol::mcp::Capabilities capabilities; capabilities.tools.listChanged = true; @@ -263,158 +182,3 @@ lldb_protocol::mcp::Capabilities ProtocolServerMCP::GetCapabilities() { capabilities.resources.listChanged = false; return capabilities; } - -void ProtocolServerMCP::AddTool(std::unique_ptr tool) { - std::lock_guard guard(m_server_mutex); - - if (!tool) - return; - m_tools[tool->GetName()] = std::move(tool); -} - -void ProtocolServerMCP::AddResourceProvider( - std::unique_ptr resource_provider) { - std::lock_guard guard(m_server_mutex); - - if (!resource_provider) - return; - m_resource_providers.push_back(std::move(resource_provider)); -} - -void ProtocolServerMCP::AddRequestHandler(llvm::StringRef method, - RequestHandler handler) { - std::lock_guard guard(m_server_mutex); - m_request_handlers[method] = std::move(handler); -} - -void ProtocolServerMCP::AddNotificationHandler(llvm::StringRef method, - NotificationHandler handler) { - std::lock_guard guard(m_server_mutex); - m_notification_handlers[method] = std::move(handler); -} - -llvm::Expected -ProtocolServerMCP::InitializeHandler( - const lldb_protocol::mcp::Request &request) { - lldb_protocol::mcp::Response response; - response.result.emplace(llvm::json::Object{ - {"protocolVersion", lldb_protocol::mcp::kVersion}, - {"capabilities", GetCapabilities()}, - {"serverInfo", - llvm::json::Object{{"name", kName}, {"version", kVersion}}}}); - return response; -} - -llvm::Expected -ProtocolServerMCP::ToolsListHandler( - const lldb_protocol::mcp::Request &request) { - lldb_protocol::mcp::Response response; - - llvm::json::Array tools; - for (const auto &tool : m_tools) - tools.emplace_back(toJSON(tool.second->GetDefinition())); - - response.result.emplace(llvm::json::Object{{"tools", std::move(tools)}}); - - return response; -} - -llvm::Expected -ProtocolServerMCP::ToolsCallHandler( - const lldb_protocol::mcp::Request &request) { - lldb_protocol::mcp::Response response; - - if (!request.params) - return llvm::createStringError("no tool parameters"); - - const json::Object *param_obj = request.params->getAsObject(); - if (!param_obj) - return llvm::createStringError("no tool parameters"); - - const json::Value *name = param_obj->get("name"); - if (!name) - return llvm::createStringError("no tool name"); - - llvm::StringRef tool_name = name->getAsString().value_or(""); - if (tool_name.empty()) - return llvm::createStringError("no tool name"); - - auto it = m_tools.find(tool_name); - if (it == m_tools.end()) - return llvm::createStringError(llvm::formatv("no tool \"{0}\"", tool_name)); - - lldb_protocol::mcp::ToolArguments tool_args; - if (const json::Value *args = param_obj->get("arguments")) - tool_args = *args; - - llvm::Expected text_result = - it->second->Call(tool_args); - if (!text_result) - return text_result.takeError(); - - response.result.emplace(toJSON(*text_result)); - - return response; -} - -llvm::Expected -ProtocolServerMCP::ResourcesListHandler( - const lldb_protocol::mcp::Request &request) { - lldb_protocol::mcp::Response response; - - llvm::json::Array resources; - - std::lock_guard guard(m_server_mutex); - for (std::unique_ptr &resource_provider_up : - m_resource_providers) { - for (const lldb_protocol::mcp::Resource &resource : - resource_provider_up->GetResources()) - resources.push_back(resource); - } - response.result.emplace( - llvm::json::Object{{"resources", std::move(resources)}}); - - return response; -} - -llvm::Expected -ProtocolServerMCP::ResourcesReadHandler( - const lldb_protocol::mcp::Request &request) { - lldb_protocol::mcp::Response response; - - if (!request.params) - return llvm::createStringError("no resource parameters"); - - const json::Object *param_obj = request.params->getAsObject(); - if (!param_obj) - return llvm::createStringError("no resource parameters"); - - const json::Value *uri = param_obj->get("uri"); - if (!uri) - return llvm::createStringError("no resource uri"); - - llvm::StringRef uri_str = uri->getAsString().value_or(""); - if (uri_str.empty()) - return llvm::createStringError("no resource uri"); - - std::lock_guard guard(m_server_mutex); - for (std::unique_ptr &resource_provider_up : - m_resource_providers) { - llvm::Expected result = - resource_provider_up->ReadResource(uri_str); - if (result.errorIsA()) { - llvm::consumeError(result.takeError()); - continue; - } - if (!result) - return result.takeError(); - - lldb_protocol::mcp::Response response; - response.result.emplace(std::move(*result)); - return response; - } - - return make_error( - llvm::formatv("no resource handler for uri: {0}", uri_str).str(), - MCPError::kResourceNotFound); -} diff --git a/lldb/source/Plugins/Protocol/MCP/ProtocolServerMCP.h b/lldb/source/Plugins/Protocol/MCP/ProtocolServerMCP.h index 2ea9585a2334b..7fe909a728b85 100644 --- a/lldb/source/Plugins/Protocol/MCP/ProtocolServerMCP.h +++ b/lldb/source/Plugins/Protocol/MCP/ProtocolServerMCP.h @@ -13,14 +13,13 @@ #include "lldb/Host/MainLoop.h" #include "lldb/Host/Socket.h" #include "lldb/Protocol/MCP/Protocol.h" -#include "lldb/Protocol/MCP/Resource.h" -#include "lldb/Protocol/MCP/Tool.h" -#include "llvm/ADT/StringMap.h" +#include "lldb/Protocol/MCP/Server.h" #include namespace lldb_private::mcp { -class ProtocolServerMCP : public ProtocolServer { +class ProtocolServerMCP : public ProtocolServer, + public lldb_protocol::mcp::Server { public: ProtocolServerMCP(); virtual ~ProtocolServerMCP() override; @@ -40,48 +39,10 @@ class ProtocolServerMCP : public ProtocolServer { Socket *GetSocket() const override { return m_listener.get(); } -protected: - using RequestHandler = - std::function( - const lldb_protocol::mcp::Request &)>; - using NotificationHandler = - std::function; - - void AddTool(std::unique_ptr tool); - void AddResourceProvider( - std::unique_ptr resource_provider); - - void AddRequestHandler(llvm::StringRef method, RequestHandler handler); - void AddNotificationHandler(llvm::StringRef method, - NotificationHandler handler); - private: void AcceptCallback(std::unique_ptr socket); - llvm::Expected> - HandleData(llvm::StringRef data); - - llvm::Expected - Handle(lldb_protocol::mcp::Request request); - void Handle(lldb_protocol::mcp::Notification notification); - - llvm::Expected - InitializeHandler(const lldb_protocol::mcp::Request &); - - llvm::Expected - ToolsListHandler(const lldb_protocol::mcp::Request &); - llvm::Expected - ToolsCallHandler(const lldb_protocol::mcp::Request &); - - llvm::Expected - ResourcesListHandler(const lldb_protocol::mcp::Request &); - llvm::Expected - ResourcesReadHandler(const lldb_protocol::mcp::Request &); - - lldb_protocol::mcp::Capabilities GetCapabilities(); - - llvm::StringLiteral kName = "lldb-mcp"; - llvm::StringLiteral kVersion = "0.1.0"; + lldb_protocol::mcp::Capabilities GetCapabilities() override; bool m_running = false; @@ -98,14 +59,6 @@ class ProtocolServerMCP : public ProtocolServer { }; llvm::Error ReadCallback(Client &client); std::vector> m_clients; - - std::mutex m_server_mutex; - llvm::StringMap> m_tools; - std::vector> - m_resource_providers; - - llvm::StringMap m_request_handlers; - llvm::StringMap m_notification_handlers; }; } // namespace lldb_private::mcp diff --git a/lldb/source/Protocol/MCP/CMakeLists.txt b/lldb/source/Protocol/MCP/CMakeLists.txt index f1b1098e064a5..a73e7e6a7cab1 100644 --- a/lldb/source/Protocol/MCP/CMakeLists.txt +++ b/lldb/source/Protocol/MCP/CMakeLists.txt @@ -1,6 +1,7 @@ add_lldb_library(lldbProtocolMCP NO_PLUGIN_DEPENDENCIES MCPError.cpp Protocol.cpp + Server.cpp Tool.cpp LINK_COMPONENTS diff --git a/lldb/source/Protocol/MCP/Server.cpp b/lldb/source/Protocol/MCP/Server.cpp new file mode 100644 index 0000000000000..4ec127fe75bdd --- /dev/null +++ b/lldb/source/Protocol/MCP/Server.cpp @@ -0,0 +1,236 @@ +//===----------------------------------------------------------------------===// +// +// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. +// See https://llvm.org/LICENSE.txt for license information. +// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception +// +//===----------------------------------------------------------------------===// + +#include "lldb/Protocol/MCP/Server.h" +#include "lldb/Protocol/MCP/MCPError.h" + +using namespace lldb_protocol::mcp; +using namespace llvm; + +Server::Server(std::string name, std::string version) + : m_name(std::move(name)), m_version(std::move(version)) { + AddRequestHandlers(); +} + +void Server::AddRequestHandlers() { + AddRequestHandler("initialize", std::bind(&Server::InitializeHandler, this, + std::placeholders::_1)); + AddRequestHandler("tools/list", std::bind(&Server::ToolsListHandler, this, + std::placeholders::_1)); + AddRequestHandler("tools/call", std::bind(&Server::ToolsCallHandler, this, + std::placeholders::_1)); + AddRequestHandler("resources/list", std::bind(&Server::ResourcesListHandler, + this, std::placeholders::_1)); + AddRequestHandler("resources/read", std::bind(&Server::ResourcesReadHandler, + this, std::placeholders::_1)); +} + +llvm::Expected Server::Handle(Request request) { + auto it = m_request_handlers.find(request.method); + if (it != m_request_handlers.end()) { + llvm::Expected response = it->second(request); + if (!response) + return response; + response->id = request.id; + return *response; + } + + return llvm::make_error( + llvm::formatv("no handler for request: {0}", request.method).str()); +} + +void Server::Handle(Notification notification) { + auto it = m_notification_handlers.find(notification.method); + if (it != m_notification_handlers.end()) { + it->second(notification); + return; + } +} + +llvm::Expected> +Server::HandleData(llvm::StringRef data) { + auto message = llvm::json::parse(/*JSON=*/data); + if (!message) + return message.takeError(); + + if (const Request *request = std::get_if(&(*message))) { + llvm::Expected response = Handle(*request); + + // Handle failures by converting them into an Error message. + if (!response) { + Error protocol_error; + llvm::handleAllErrors( + response.takeError(), + [&](const MCPError &err) { protocol_error = err.toProtcolError(); }, + [&](const llvm::ErrorInfoBase &err) { + protocol_error.error.code = MCPError::kInternalError; + protocol_error.error.message = err.message(); + }); + protocol_error.id = request->id; + return protocol_error; + } + + return *response; + } + + if (const Notification *notification = + std::get_if(&(*message))) { + Handle(*notification); + return std::nullopt; + } + + if (std::get_if(&(*message))) + return llvm::createStringError("unexpected MCP message: error"); + + if (std::get_if(&(*message))) + return llvm::createStringError("unexpected MCP message: response"); + + llvm_unreachable("all message types handled"); +} + +void Server::AddTool(std::unique_ptr tool) { + std::lock_guard guard(m_mutex); + + if (!tool) + return; + m_tools[tool->GetName()] = std::move(tool); +} + +void Server::AddResourceProvider( + std::unique_ptr resource_provider) { + std::lock_guard guard(m_mutex); + + if (!resource_provider) + return; + m_resource_providers.push_back(std::move(resource_provider)); +} + +void Server::AddRequestHandler(llvm::StringRef method, RequestHandler handler) { + std::lock_guard guard(m_mutex); + m_request_handlers[method] = std::move(handler); +} + +void Server::AddNotificationHandler(llvm::StringRef method, + NotificationHandler handler) { + std::lock_guard guard(m_mutex); + m_notification_handlers[method] = std::move(handler); +} + +llvm::Expected Server::InitializeHandler(const Request &request) { + Response response; + response.result.emplace(llvm::json::Object{ + {"protocolVersion", mcp::kProtocolVersion}, + {"capabilities", GetCapabilities()}, + {"serverInfo", + llvm::json::Object{{"name", m_name}, {"version", m_version}}}}); + return response; +} + +llvm::Expected Server::ToolsListHandler(const Request &request) { + Response response; + + llvm::json::Array tools; + for (const auto &tool : m_tools) + tools.emplace_back(toJSON(tool.second->GetDefinition())); + + response.result.emplace(llvm::json::Object{{"tools", std::move(tools)}}); + + return response; +} + +llvm::Expected Server::ToolsCallHandler(const Request &request) { + Response response; + + if (!request.params) + return llvm::createStringError("no tool parameters"); + + const json::Object *param_obj = request.params->getAsObject(); + if (!param_obj) + return llvm::createStringError("no tool parameters"); + + const json::Value *name = param_obj->get("name"); + if (!name) + return llvm::createStringError("no tool name"); + + llvm::StringRef tool_name = name->getAsString().value_or(""); + if (tool_name.empty()) + return llvm::createStringError("no tool name"); + + auto it = m_tools.find(tool_name); + if (it == m_tools.end()) + return llvm::createStringError(llvm::formatv("no tool \"{0}\"", tool_name)); + + ToolArguments tool_args; + if (const json::Value *args = param_obj->get("arguments")) + tool_args = *args; + + llvm::Expected text_result = it->second->Call(tool_args); + if (!text_result) + return text_result.takeError(); + + response.result.emplace(toJSON(*text_result)); + + return response; +} + +llvm::Expected Server::ResourcesListHandler(const Request &request) { + Response response; + + llvm::json::Array resources; + + std::lock_guard guard(m_mutex); + for (std::unique_ptr &resource_provider_up : + m_resource_providers) { + for (const Resource &resource : resource_provider_up->GetResources()) + resources.push_back(resource); + } + response.result.emplace( + llvm::json::Object{{"resources", std::move(resources)}}); + + return response; +} + +llvm::Expected Server::ResourcesReadHandler(const Request &request) { + Response response; + + if (!request.params) + return llvm::createStringError("no resource parameters"); + + const json::Object *param_obj = request.params->getAsObject(); + if (!param_obj) + return llvm::createStringError("no resource parameters"); + + const json::Value *uri = param_obj->get("uri"); + if (!uri) + return llvm::createStringError("no resource uri"); + + llvm::StringRef uri_str = uri->getAsString().value_or(""); + if (uri_str.empty()) + return llvm::createStringError("no resource uri"); + + std::lock_guard guard(m_mutex); + for (std::unique_ptr &resource_provider_up : + m_resource_providers) { + llvm::Expected result = + resource_provider_up->ReadResource(uri_str); + if (result.errorIsA()) { + llvm::consumeError(result.takeError()); + continue; + } + if (!result) + return result.takeError(); + + Response response; + response.result.emplace(std::move(*result)); + return response; + } + + return make_error( + llvm::formatv("no resource handler for uri: {0}", uri_str).str(), + MCPError::kResourceNotFound); +}