Skip to content
Merged
Show file tree
Hide file tree
Changes from 1 commit
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
6 changes: 0 additions & 6 deletions lldb/include/lldb/Core/Debugger.h
Original file line number Diff line number Diff line change
Expand Up @@ -602,10 +602,6 @@ class Debugger : public std::enable_shared_from_this<Debugger>,
void FlushProcessOutput(Process &process, bool flush_stdout,
bool flush_stderr);

void AddProtocolServer(lldb::ProtocolServerSP protocol_server_sp);
void RemoveProtocolServer(lldb::ProtocolServerSP protocol_server_sp);
lldb::ProtocolServerSP GetProtocolServer(llvm::StringRef protocol) const;

SourceManager::SourceFileCache &GetSourceFileCache() {
return m_source_file_cache;
}
Expand Down Expand Up @@ -776,8 +772,6 @@ class Debugger : public std::enable_shared_from_this<Debugger>,
mutable std::mutex m_progress_reports_mutex;
/// @}

llvm::SmallVector<lldb::ProtocolServerSP> m_protocol_servers;

std::mutex m_destroy_callback_mutex;
lldb::callback_token_t m_destroy_callback_next_token = 0;
struct DestroyCallbackInfo {
Expand Down
5 changes: 3 additions & 2 deletions lldb/include/lldb/Core/ProtocolServer.h
Original file line number Diff line number Diff line change
Expand Up @@ -20,8 +20,9 @@ class ProtocolServer : public PluginInterface {
ProtocolServer() = default;
virtual ~ProtocolServer() = default;

static lldb::ProtocolServerSP Create(llvm::StringRef name,
Debugger &debugger);
static ProtocolServer *GetOrCreate(llvm::StringRef name);

static std::vector<llvm::StringRef> GetSupportedProtocols();

struct Connection {
Socket::SocketProtocol protocol;
Expand Down
2 changes: 1 addition & 1 deletion lldb/include/lldb/lldb-forward.h
Original file line number Diff line number Diff line change
Expand Up @@ -391,7 +391,7 @@ typedef std::shared_ptr<lldb_private::Platform> PlatformSP;
typedef std::shared_ptr<lldb_private::Process> ProcessSP;
typedef std::shared_ptr<lldb_private::ProcessAttachInfo> ProcessAttachInfoSP;
typedef std::shared_ptr<lldb_private::ProcessLaunchInfo> ProcessLaunchInfoSP;
typedef std::shared_ptr<lldb_private::ProtocolServer> ProtocolServerSP;
typedef std::unique_ptr<lldb_private::ProtocolServer> ProtocolServerUP;
typedef std::weak_ptr<lldb_private::Process> ProcessWP;
typedef std::shared_ptr<lldb_private::RegisterCheckpoint> RegisterCheckpointSP;
typedef std::shared_ptr<lldb_private::RegisterContext> RegisterContextSP;
Expand Down
3 changes: 1 addition & 2 deletions lldb/include/lldb/lldb-private-interfaces.h
Original file line number Diff line number Diff line change
Expand Up @@ -81,8 +81,7 @@ typedef lldb::PlatformSP (*PlatformCreateInstance)(bool force,
typedef lldb::ProcessSP (*ProcessCreateInstance)(
lldb::TargetSP target_sp, lldb::ListenerSP listener_sp,
const FileSpec *crash_file_path, bool can_connect);
typedef lldb::ProtocolServerSP (*ProtocolServerCreateInstance)(
Debugger &debugger);
typedef lldb::ProtocolServerUP (*ProtocolServerCreateInstance)();
typedef lldb::RegisterTypeBuilderSP (*RegisterTypeBuilderCreateInstance)(
Target &target);
typedef lldb::ScriptInterpreterSP (*ScriptInterpreterCreateInstance)(
Expand Down
51 changes: 9 additions & 42 deletions lldb/source/Commands/CommandObjectProtocolServer.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -23,20 +23,6 @@ using namespace lldb_private;
#define LLDB_OPTIONS_mcp
#include "CommandOptions.inc"

static std::vector<llvm::StringRef> GetSupportedProtocols() {
std::vector<llvm::StringRef> supported_protocols;
size_t i = 0;

for (llvm::StringRef protocol_name =
PluginManager::GetProtocolServerPluginNameAtIndex(i++);
!protocol_name.empty();
protocol_name = PluginManager::GetProtocolServerPluginNameAtIndex(i++)) {
supported_protocols.push_back(protocol_name);
}

return supported_protocols;
}

class CommandObjectProtocolServerStart : public CommandObjectParsed {
public:
CommandObjectProtocolServerStart(CommandInterpreter &interpreter)
Expand All @@ -57,12 +43,11 @@ class CommandObjectProtocolServerStart : public CommandObjectParsed {
}

llvm::StringRef protocol = args.GetArgumentAtIndex(0);
std::vector<llvm::StringRef> supported_protocols = GetSupportedProtocols();
if (llvm::find(supported_protocols, protocol) ==
supported_protocols.end()) {
ProtocolServer *server = ProtocolServer::GetOrCreate(protocol);
if (!server) {
result.AppendErrorWithFormatv(
"unsupported protocol: {0}. Supported protocols are: {1}", protocol,
llvm::join(GetSupportedProtocols(), ", "));
llvm::join(ProtocolServer::GetSupportedProtocols(), ", "));
return;
}

Expand All @@ -72,10 +57,6 @@ class CommandObjectProtocolServerStart : public CommandObjectParsed {
}
llvm::StringRef connection_uri = args.GetArgumentAtIndex(1);

ProtocolServerSP server_sp = GetDebugger().GetProtocolServer(protocol);
if (!server_sp)
server_sp = ProtocolServer::Create(protocol, GetDebugger());

const char *connection_error =
"unsupported connection specifier, expected 'accept:///path' or "
"'listen://[host]:port', got '{0}'.";
Expand All @@ -98,14 +79,12 @@ class CommandObjectProtocolServerStart : public CommandObjectParsed {
formatv("[{0}]:{1}", uri->hostname.empty() ? "0.0.0.0" : uri->hostname,
uri->port.value_or(0));

if (llvm::Error error = server_sp->Start(connection)) {
if (llvm::Error error = server->Start(connection)) {
result.AppendErrorWithFormatv("{0}", llvm::fmt_consume(std::move(error)));
return;
}

GetDebugger().AddProtocolServer(server_sp);

if (Socket *socket = server_sp->GetSocket()) {
if (Socket *socket = server->GetSocket()) {
std::string address =
llvm::join(socket->GetListeningConnectionURI(), ", ");
result.AppendMessageWithFormatv(
Expand Down Expand Up @@ -134,30 +113,18 @@ class CommandObjectProtocolServerStop : public CommandObjectParsed {
}

llvm::StringRef protocol = args.GetArgumentAtIndex(0);
std::vector<llvm::StringRef> supported_protocols = GetSupportedProtocols();
if (llvm::find(supported_protocols, protocol) ==
supported_protocols.end()) {
ProtocolServer *server = ProtocolServer::GetOrCreate(protocol);
if (!server) {
result.AppendErrorWithFormatv(
"unsupported protocol: {0}. Supported protocols are: {1}", protocol,
llvm::join(GetSupportedProtocols(), ", "));
llvm::join(ProtocolServer::GetSupportedProtocols(), ", "));
return;
}

Debugger &debugger = GetDebugger();

ProtocolServerSP server_sp = debugger.GetProtocolServer(protocol);
if (!server_sp) {
result.AppendError(
llvm::formatv("no {0} protocol server running", protocol).str());
return;
}

if (llvm::Error error = server_sp->Stop()) {
if (llvm::Error error = server->Stop()) {
result.AppendErrorWithFormatv("{0}", llvm::fmt_consume(std::move(error)));
return;
}

debugger.RemoveProtocolServer(server_sp);
}
};

Expand Down
23 changes: 0 additions & 23 deletions lldb/source/Core/Debugger.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -2376,26 +2376,3 @@ llvm::ThreadPoolInterface &Debugger::GetThreadPool() {
"Debugger::GetThreadPool called before Debugger::Initialize");
return *g_thread_pool;
}

void Debugger::AddProtocolServer(lldb::ProtocolServerSP protocol_server_sp) {
assert(protocol_server_sp &&
GetProtocolServer(protocol_server_sp->GetPluginName()) == nullptr);
m_protocol_servers.push_back(protocol_server_sp);
}

void Debugger::RemoveProtocolServer(lldb::ProtocolServerSP protocol_server_sp) {
auto it = llvm::find(m_protocol_servers, protocol_server_sp);
if (it != m_protocol_servers.end())
m_protocol_servers.erase(it);
}

lldb::ProtocolServerSP
Debugger::GetProtocolServer(llvm::StringRef protocol) const {
for (ProtocolServerSP protocol_server_sp : m_protocol_servers) {
if (!protocol_server_sp)
continue;
if (protocol_server_sp->GetPluginName() == protocol)
return protocol_server_sp;
}
return nullptr;
}
34 changes: 30 additions & 4 deletions lldb/source/Core/ProtocolServer.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -12,10 +12,36 @@
using namespace lldb_private;
using namespace lldb;

ProtocolServerSP ProtocolServer::Create(llvm::StringRef name,
Debugger &debugger) {
ProtocolServer *ProtocolServer::GetOrCreate(llvm::StringRef name) {
static std::mutex g_mutex;
static llvm::StringMap<ProtocolServerUP> g_protocol_server_instances;

std::lock_guard<std::mutex> guard(g_mutex);

auto it = g_protocol_server_instances.find(name);
if (it != g_protocol_server_instances.end())
return it->second.get();

if (ProtocolServerCreateInstance create_callback =
PluginManager::GetProtocolCreateCallbackForPluginName(name))
return create_callback(debugger);
PluginManager::GetProtocolCreateCallbackForPluginName(name)) {
auto pair =
g_protocol_server_instances.try_emplace(name, create_callback());
return pair.first->second.get();
}

return nullptr;
}

std::vector<llvm::StringRef> ProtocolServer::GetSupportedProtocols() {
std::vector<llvm::StringRef> supported_protocols;
size_t i = 0;

for (llvm::StringRef protocol_name =
PluginManager::GetProtocolServerPluginNameAtIndex(i++);
!protocol_name.empty();
protocol_name = PluginManager::GetProtocolServerPluginNameAtIndex(i++)) {
supported_protocols.push_back(protocol_name);
}

return supported_protocols;
}
26 changes: 13 additions & 13 deletions lldb/source/Plugins/Protocol/MCP/ProtocolServerMCP.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -24,8 +24,7 @@ LLDB_PLUGIN_DEFINE(ProtocolServerMCP)

static constexpr size_t kChunkSize = 1024;

ProtocolServerMCP::ProtocolServerMCP(Debugger &debugger)
: ProtocolServer(), m_debugger(debugger) {
ProtocolServerMCP::ProtocolServerMCP() : ProtocolServer() {
AddRequestHandler("initialize",
std::bind(&ProtocolServerMCP::InitializeHandler, this,
std::placeholders::_1));
Expand All @@ -39,8 +38,10 @@ ProtocolServerMCP::ProtocolServerMCP(Debugger &debugger)
"notifications/initialized", [](const protocol::Notification &) {
LLDB_LOG(GetLog(LLDBLog::Host), "MCP initialization complete");
});
AddTool(std::make_unique<LLDBCommandTool>(
"lldb_command", "Run an lldb command.", m_debugger));
AddTool(
std::make_unique<CommandTool>("lldb_command", "Run an lldb command."));
AddTool(std::make_unique<DebuggerListTool>(
"lldb_debugger_list", "List debugger instances with their debugger_id."));
}

ProtocolServerMCP::~ProtocolServerMCP() { llvm::consumeError(Stop()); }
Expand All @@ -54,8 +55,8 @@ void ProtocolServerMCP::Terminate() {
PluginManager::UnregisterPlugin(CreateInstance);
}

lldb::ProtocolServerSP ProtocolServerMCP::CreateInstance(Debugger &debugger) {
return std::make_shared<ProtocolServerMCP>(debugger);
lldb::ProtocolServerUP ProtocolServerMCP::CreateInstance() {
return std::make_unique<ProtocolServerMCP>();
}

llvm::StringRef ProtocolServerMCP::GetPluginDescriptionStatic() {
Expand Down Expand Up @@ -145,7 +146,7 @@ llvm::Error ProtocolServerMCP::Start(ProtocolServer::Connection connection) {
std::lock_guard<std::mutex> guard(m_server_mutex);

if (m_running)
return llvm::createStringError("server already running");
return llvm::createStringError("the MCP server is already running");

Status status;
m_listener = Socket::Create(connection.protocol, status);
Expand All @@ -162,10 +163,10 @@ llvm::Error ProtocolServerMCP::Start(ProtocolServer::Connection connection) {
if (llvm::Error error = handles.takeError())
return error;

m_running = true;
m_listen_handlers = std::move(*handles);
m_loop_thread = std::thread([=] {
llvm::set_thread_name(
llvm::formatv("debugger-{0}.mcp.runloop", m_debugger.GetID()));
llvm::set_thread_name("protocol-server.mcp");
m_loop.Run();
});

Expand All @@ -175,6 +176,8 @@ llvm::Error ProtocolServerMCP::Start(ProtocolServer::Connection connection) {
llvm::Error ProtocolServerMCP::Stop() {
{
std::lock_guard<std::mutex> guard(m_server_mutex);
if (!m_running)
return createStringError("the MCP sever is not running");
m_running = false;
}

Expand Down Expand Up @@ -312,10 +315,7 @@ ProtocolServerMCP::ToolsCallHandler(const protocol::Request &request) {
return llvm::createStringError(llvm::formatv("no tool \"{0}\"", tool_name));

const json::Value *args = param_obj->get("arguments");
if (!args)
return llvm::createStringError("no tool arguments");

llvm::Expected<protocol::TextResult> text_result = it->second->Call(*args);
llvm::Expected<protocol::TextResult> text_result = it->second->Call(args);
if (!text_result)
return text_result.takeError();

Expand Down
6 changes: 2 additions & 4 deletions lldb/source/Plugins/Protocol/MCP/ProtocolServerMCP.h
Original file line number Diff line number Diff line change
Expand Up @@ -21,7 +21,7 @@ namespace lldb_private::mcp {

class ProtocolServerMCP : public ProtocolServer {
public:
ProtocolServerMCP(Debugger &debugger);
ProtocolServerMCP();
virtual ~ProtocolServerMCP() override;

virtual llvm::Error Start(ProtocolServer::Connection connection) override;
Expand All @@ -33,7 +33,7 @@ class ProtocolServerMCP : public ProtocolServer {
static llvm::StringRef GetPluginNameStatic() { return "MCP"; }
static llvm::StringRef GetPluginDescriptionStatic();

static lldb::ProtocolServerSP CreateInstance(Debugger &debugger);
static lldb::ProtocolServerUP CreateInstance();

llvm::StringRef GetPluginName() override { return GetPluginNameStatic(); }

Expand Down Expand Up @@ -71,8 +71,6 @@ class ProtocolServerMCP : public ProtocolServer {
llvm::StringLiteral kName = "lldb-mcp";
llvm::StringLiteral kVersion = "0.1.0";

Debugger &m_debugger;

bool m_running = false;

MainLoop m_loop;
Expand Down
Loading
Loading