@@ -223,31 +223,30 @@ static int DuplicateFileDescriptor(int fd) {
223223#endif
224224}
225225
226- static void ResetTimeToLive (std::mutex &ttl_mutex,
227- MainLoopBase::TimePoint &ttl_time_point) {
228- std::scoped_lock<std::mutex> lock (ttl_mutex);
229- ttl_time_point = MainLoopBase::TimePoint ();
226+ static void
227+ ResetConnectionTimeout (std::mutex &connection_timeout_mutex,
228+ MainLoopBase::TimePoint &conncetion_timeout_time_point) {
229+ std::scoped_lock<std::mutex> lock (connection_timeout_mutex);
230+ conncetion_timeout_time_point = MainLoopBase::TimePoint ();
230231}
231232
232- static void TrackTimeToLive (MainLoop &loop, std::mutex &ttl_mutex,
233- MainLoopBase::TimePoint &ttl_time_point,
234- std::chrono::seconds ttl_seconds) {
233+ static void
234+ TrackConnectionTimeout (MainLoop &loop, std::mutex &connection_timeout_mutex,
235+ MainLoopBase::TimePoint &conncetion_timeout_time_point,
236+ std::chrono::seconds ttl_seconds) {
235237 MainLoopBase::TimePoint next_checkpoint =
236238 std::chrono::steady_clock::now () + std::chrono::seconds (ttl_seconds);
237239 {
238- std::scoped_lock<std::mutex> lock (ttl_mutex );
240+ std::scoped_lock<std::mutex> lock (connection_timeout_mutex );
239241 // We don't need to take the max of `ttl_time_point` and `next_checkpoint`,
240242 // because `next_checkpoint` must be the latest.
241- ttl_time_point = next_checkpoint;
243+ conncetion_timeout_time_point = next_checkpoint;
242244 }
243245 loop.AddCallback (
244- [&ttl_mutex, &ttl_time_point, next_checkpoint](MainLoopBase &loop) {
245- bool should_request_terimation;
246- {
247- std::scoped_lock<std::mutex> lock (ttl_mutex);
248- should_request_terimation = ttl_time_point == next_checkpoint;
249- }
250- if (should_request_terimation)
246+ [&connection_timeout_mutex, &conncetion_timeout_time_point,
247+ next_checkpoint](MainLoopBase &loop) {
248+ std::scoped_lock<std::mutex> lock (connection_timeout_mutex);
249+ if (conncetion_timeout_time_point == next_checkpoint)
251250 loop.RequestTermination ();
252251 },
253252 next_checkpoint);
@@ -285,11 +284,11 @@ validateConnection(llvm::StringRef conn) {
285284 return make_error ();
286285}
287286
288- static llvm::Error
289- serveConnection ( const Socket::SocketProtocol &protocol, const std::string &name,
290- Log *log, const ReplMode default_repl_mode,
291- const std::vector<std::string> &pre_init_commands,
292- std::optional<std::chrono::seconds> ttl_seconds ) {
287+ static llvm::Error serveConnection (
288+ const Socket::SocketProtocol &protocol, const std::string &name, Log *log ,
289+ const ReplMode default_repl_mode,
290+ const std::vector<std::string> &pre_init_commands,
291+ std::optional<std::chrono::seconds> connection_timeout_seconds ) {
293292 Status status;
294293 static std::unique_ptr<Socket> listener = Socket::Create (protocol, status);
295294 if (status.Fail ()) {
@@ -314,10 +313,12 @@ serveConnection(const Socket::SocketProtocol &protocol, const std::string &name,
314313 g_loop.AddPendingCallback (
315314 [](MainLoopBase &loop) { loop.RequestTermination (); });
316315 });
317- static MainLoopBase::TimePoint g_ttl_time_point;
318- static std::mutex g_ttl_mutex;
319- if (ttl_seconds)
320- TrackTimeToLive (g_loop, g_ttl_mutex, g_ttl_time_point, ttl_seconds.value ());
316+ static MainLoopBase::TimePoint g_connection_timeout_time_point;
317+ static std::mutex g_connection_timeout_mutex;
318+ if (connection_timeout_seconds)
319+ TrackConnectionTimeout (g_loop, g_connection_timeout_mutex,
320+ g_connection_timeout_time_point,
321+ connection_timeout_seconds.value ());
321322 std::condition_variable dap_sessions_condition;
322323 std::mutex dap_sessions_mutex;
323324 std::map<MainLoop *, DAP *> dap_sessions;
@@ -328,8 +329,9 @@ serveConnection(const Socket::SocketProtocol &protocol, const std::string &name,
328329 std::unique_ptr<Socket> sock) {
329330 // Reset the keep alive timer, because we won't be killing the server
330331 // while this connection is being served.
331- if (ttl_seconds)
332- ResetTimeToLive (g_ttl_mutex, g_ttl_time_point);
332+ if (connection_timeout_seconds)
333+ ResetConnectionTimeout (g_connection_timeout_mutex,
334+ g_connection_timeout_time_point);
333335 std::string client_name = llvm::formatv (" client_{0}" , clientCount++).str ();
334336 DAP_LOG (log, " ({0}) client connected" , client_name);
335337
@@ -368,9 +370,10 @@ serveConnection(const Socket::SocketProtocol &protocol, const std::string &name,
368370 std::notify_all_at_thread_exit (dap_sessions_condition, std::move (lock));
369371
370372 // Start the countdown to kill the server at the end of each connection.
371- if (ttl_seconds)
372- TrackTimeToLive (g_loop, g_ttl_mutex, g_ttl_time_point,
373- ttl_seconds.value ());
373+ if (connection_timeout_seconds)
374+ TrackConnectionTimeout (g_loop, g_connection_timeout_mutex,
375+ g_connection_timeout_time_point,
376+ connection_timeout_seconds.value ());
374377 });
375378 client.detach ();
376379 });
@@ -500,6 +503,31 @@ int main(int argc, char *argv[]) {
500503 connection.assign (path);
501504 }
502505
506+ std::optional<std::chrono::seconds> connection_timeout_seconds;
507+ if (llvm::opt::Arg *connection_timeout_arg =
508+ input_args.getLastArg (OPT_connection_timeout)) {
509+ if (!connection.empty ()) {
510+ llvm::StringRef connection_timeout_string_value =
511+ connection_timeout_arg->getValue ();
512+ int connection_timeout_int_value;
513+ if (connection_timeout_string_value.getAsInteger (
514+ 10 , connection_timeout_int_value)) {
515+ llvm::errs () << " '" << connection_timeout_string_value
516+ << " ' is not a valid connection timeout value\n " ;
517+ return EXIT_FAILURE;
518+ }
519+ // Ignore non-positive values.
520+ if (connection_timeout_int_value > 0 )
521+ connection_timeout_seconds =
522+ std::chrono::seconds (connection_timeout_int_value);
523+ } else {
524+ llvm::errs ()
525+ << " \" --connection-timeout\" requires \" --connection\" to be "
526+ " specified\n " ;
527+ return EXIT_FAILURE;
528+ }
529+ }
530+
503531#if !defined(_WIN32)
504532 if (input_args.hasArg (OPT_wait_for_debugger)) {
505533 printf (" Paused waiting for debugger to attach (pid = %i)...\n " , getpid ());
@@ -553,21 +581,6 @@ int main(int argc, char *argv[]) {
553581 }
554582
555583 if (!connection.empty ()) {
556- std::optional<std::chrono::seconds> ttl_seconds;
557- llvm::opt::Arg *time_to_live = input_args.getLastArg (OPT_time_to_live);
558- if (time_to_live) {
559- llvm::StringRef time_to_live_string_value = time_to_live->getValue ();
560- int time_to_live_int_value;
561- if (time_to_live_string_value.getAsInteger (10 , time_to_live_int_value)) {
562- llvm::errs () << " '" << time_to_live_string_value
563- << " ' is not a valid time-to-live value\n " ;
564- return EXIT_FAILURE;
565- }
566- // Ignore non-positive values.
567- if (time_to_live_int_value > 0 )
568- ttl_seconds = std::chrono::seconds (time_to_live_int_value);
569- }
570-
571584 auto maybeProtoclAndName = validateConnection (connection);
572585 if (auto Err = maybeProtoclAndName.takeError ()) {
573586 llvm::logAllUnhandledErrors (std::move (Err), llvm::errs (),
@@ -578,8 +591,9 @@ int main(int argc, char *argv[]) {
578591 Socket::SocketProtocol protocol;
579592 std::string name;
580593 std::tie (protocol, name) = *maybeProtoclAndName;
581- if (auto Err = serveConnection (protocol, name, log.get (), default_repl_mode,
582- pre_init_commands, ttl_seconds)) {
594+ if (auto Err =
595+ serveConnection (protocol, name, log.get (), default_repl_mode,
596+ pre_init_commands, connection_timeout_seconds)) {
583597 llvm::logAllUnhandledErrors (std::move (Err), llvm::errs (),
584598 " Connection failed: " );
585599 return EXIT_FAILURE;
0 commit comments