Skip to content
Merged
Show file tree
Hide file tree
Changes from 3 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
9 changes: 9 additions & 0 deletions lldb/tools/lldb-dap/Options.td
Original file line number Diff line number Diff line change
Expand Up @@ -61,3 +61,12 @@ def pre_init_command: S<"pre-init-command">,
def: Separate<["-"], "c">,
Alias<pre_init_command>,
HelpText<"Alias for --pre-init-command">;

def time_to_live: S<"time-to-live">,
MetaVarName<"<ttl>">,
HelpText<"When using --connection, the number of seconds to wait for "
"new connections after the server has started and after all clients "
"have disconnected. New connections will reset the timer. When the "
"timer is reached, the server will be closed and the process will "
"exit. Not specifying this argument or specifying non-positive values "
"will cause the server to wait for new connections indefinitely.">;
63 changes: 61 additions & 2 deletions lldb/tools/lldb-dap/tool/lldb-dap.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -223,6 +223,36 @@ static int DuplicateFileDescriptor(int fd) {
#endif
}

static void ResetTimeToLive(std::mutex &ttl_mutex,
MainLoopBase::TimePoint &ttl_time_point) {
std::scoped_lock<std::mutex> lock(ttl_mutex);
ttl_time_point = MainLoopBase::TimePoint();
}

static void TrackTimeToLive(MainLoop &loop, std::mutex &ttl_mutex,
MainLoopBase::TimePoint &ttl_time_point,
std::chrono::seconds ttl_seconds) {
MainLoopBase::TimePoint next_checkpoint =
std::chrono::steady_clock::now() + std::chrono::seconds(ttl_seconds);
{
std::scoped_lock<std::mutex> lock(ttl_mutex);
// We don't need to take the max of `ttl_time_point` and `next_checkpoint`,
// because `next_checkpoint` must be the latest.
ttl_time_point = next_checkpoint;
}
loop.AddCallback(
[&ttl_mutex, &ttl_time_point, next_checkpoint](MainLoopBase &loop) {
bool should_request_terimation;
{
std::scoped_lock<std::mutex> lock(ttl_mutex);
should_request_terimation = ttl_time_point == next_checkpoint;
}
if (should_request_terimation)
loop.RequestTermination();
},
next_checkpoint);
}

static llvm::Expected<std::pair<Socket::SocketProtocol, std::string>>
validateConnection(llvm::StringRef conn) {
auto uri = lldb_private::URI::Parse(conn);
Expand Down Expand Up @@ -258,7 +288,8 @@ validateConnection(llvm::StringRef conn) {
static llvm::Error
serveConnection(const Socket::SocketProtocol &protocol, const std::string &name,
Log *log, const ReplMode default_repl_mode,
const std::vector<std::string> &pre_init_commands) {
const std::vector<std::string> &pre_init_commands,
std::optional<std::chrono::seconds> ttl_seconds) {
Status status;
static std::unique_ptr<Socket> listener = Socket::Create(protocol, status);
if (status.Fail()) {
Expand All @@ -283,6 +314,10 @@ serveConnection(const Socket::SocketProtocol &protocol, const std::string &name,
g_loop.AddPendingCallback(
[](MainLoopBase &loop) { loop.RequestTermination(); });
});
static MainLoopBase::TimePoint g_ttl_time_point;
static std::mutex g_ttl_mutex;
if (ttl_seconds)
TrackTimeToLive(g_loop, g_ttl_mutex, g_ttl_time_point, ttl_seconds.value());
std::condition_variable dap_sessions_condition;
std::mutex dap_sessions_mutex;
std::map<MainLoop *, DAP *> dap_sessions;
Expand All @@ -291,6 +326,10 @@ serveConnection(const Socket::SocketProtocol &protocol, const std::string &name,
&dap_sessions_mutex, &dap_sessions,
&clientCount](
std::unique_ptr<Socket> sock) {
// Reset the keep alive timer, because we won't be killing the server
// while this connection is being served.
if (ttl_seconds)
ResetTimeToLive(g_ttl_mutex, g_ttl_time_point);
std::string client_name = llvm::formatv("client_{0}", clientCount++).str();
DAP_LOG(log, "({0}) client connected", client_name);

Expand Down Expand Up @@ -327,6 +366,11 @@ serveConnection(const Socket::SocketProtocol &protocol, const std::string &name,
std::unique_lock<std::mutex> lock(dap_sessions_mutex);
dap_sessions.erase(&loop);
std::notify_all_at_thread_exit(dap_sessions_condition, std::move(lock));

// Start the countdown to kill the server at the end of each connection.
if (ttl_seconds)
TrackTimeToLive(g_loop, g_ttl_mutex, g_ttl_time_point,
ttl_seconds.value());
});
client.detach();
});
Expand Down Expand Up @@ -509,6 +553,21 @@ int main(int argc, char *argv[]) {
}

if (!connection.empty()) {
std::optional<std::chrono::seconds> ttl_seconds;
llvm::opt::Arg *time_to_live = input_args.getLastArg(OPT_time_to_live);
if (time_to_live) {
llvm::StringRef time_to_live_string_value = time_to_live->getValue();
int time_to_live_int_value;
if (time_to_live_string_value.getAsInteger(10, time_to_live_int_value)) {
llvm::errs() << "'" << time_to_live_string_value
<< "' is not a valid time-to-live value\n";
return EXIT_FAILURE;
}
// Ignore non-positive values.
if (time_to_live_int_value > 0)
ttl_seconds = std::chrono::seconds(time_to_live_int_value);
}

auto maybeProtoclAndName = validateConnection(connection);
if (auto Err = maybeProtoclAndName.takeError()) {
llvm::logAllUnhandledErrors(std::move(Err), llvm::errs(),
Expand All @@ -520,7 +579,7 @@ int main(int argc, char *argv[]) {
std::string name;
std::tie(protocol, name) = *maybeProtoclAndName;
if (auto Err = serveConnection(protocol, name, log.get(), default_repl_mode,
pre_init_commands)) {
pre_init_commands, ttl_seconds)) {
llvm::logAllUnhandledErrors(std::move(Err), llvm::errs(),
"Connection failed: ");
return EXIT_FAILURE;
Expand Down