diff --git a/lldb/include/lldb/Host/Socket.h b/lldb/include/lldb/Host/Socket.h index 14468c98ac5a3..0e542b05a085c 100644 --- a/lldb/include/lldb/Host/Socket.h +++ b/lldb/include/lldb/Host/Socket.h @@ -13,6 +13,7 @@ #include #include "lldb/Host/MainLoopBase.h" +#include "lldb/Utility/Timeout.h" #include "lldb/lldb-private.h" #include "lldb/Host/SocketAddress.h" @@ -108,7 +109,7 @@ class Socket : public IOObject { // Accept a single connection and "return" it in the pointer argument. This // function blocks until the connection arrives. - virtual Status Accept(Socket *&socket); + virtual Status Accept(const Timeout &timeout, Socket *&socket); // Initialize a Tcp Socket object in listening mode. listen and accept are // implemented separately because the caller may wish to manipulate or query diff --git a/lldb/source/Host/common/Socket.cpp b/lldb/source/Host/common/Socket.cpp index d69eb60820403..63396f7b4abc9 100644 --- a/lldb/source/Host/common/Socket.cpp +++ b/lldb/source/Host/common/Socket.cpp @@ -460,7 +460,8 @@ NativeSocket Socket::CreateSocket(const int domain, const int type, return sock; } -Status Socket::Accept(Socket *&socket) { +Status Socket::Accept(const Timeout &timeout, Socket *&socket) { + socket = nullptr; MainLoop accept_loop; llvm::Expected> expected_handles = Accept(accept_loop, @@ -470,7 +471,15 @@ Status Socket::Accept(Socket *&socket) { }); if (!expected_handles) return Status::FromError(expected_handles.takeError()); - return accept_loop.Run(); + if (timeout) { + accept_loop.AddCallback( + [](MainLoopBase &loop) { loop.RequestTermination(); }, *timeout); + } + if (Status status = accept_loop.Run(); status.Fail()) + return status; + if (socket) + return Status(); + return Status(std::make_error_code(std::errc::timed_out)); } NativeSocket Socket::AcceptSocket(NativeSocket sockfd, struct sockaddr *addr, diff --git a/lldb/source/Host/posix/ConnectionFileDescriptorPosix.cpp b/lldb/source/Host/posix/ConnectionFileDescriptorPosix.cpp index 8a03e47ef3d90..903bfc50def3a 100644 --- a/lldb/source/Host/posix/ConnectionFileDescriptorPosix.cpp +++ b/lldb/source/Host/posix/ConnectionFileDescriptorPosix.cpp @@ -543,7 +543,7 @@ lldb::ConnectionStatus ConnectionFileDescriptor::AcceptSocket( if (!error.Fail()) { post_listen_callback(*listening_socket); - error = listening_socket->Accept(accepted_socket); + error = listening_socket->Accept(/*timeout=*/std::nullopt, accepted_socket); } if (!error.Fail()) { diff --git a/lldb/source/Plugins/Process/gdb-remote/GDBRemoteCommunication.cpp b/lldb/source/Plugins/Process/gdb-remote/GDBRemoteCommunication.cpp index 7eacd605362e7..67b41b1e77a53 100644 --- a/lldb/source/Plugins/Process/gdb-remote/GDBRemoteCommunication.cpp +++ b/lldb/source/Plugins/Process/gdb-remote/GDBRemoteCommunication.cpp @@ -1223,10 +1223,6 @@ GDBRemoteCommunication::ConnectLocally(GDBRemoteCommunication &client, listen_socket.Listen("localhost:0", backlog).ToError()) return error; - Socket *accept_socket = nullptr; - std::future accept_status = std::async( - std::launch::async, [&] { return listen_socket.Accept(accept_socket); }); - llvm::SmallString<32> remote_addr; llvm::raw_svector_ostream(remote_addr) << "connect://localhost:" << listen_socket.GetLocalPortNumber(); @@ -1238,10 +1234,15 @@ GDBRemoteCommunication::ConnectLocally(GDBRemoteCommunication &client, return llvm::createStringError(llvm::inconvertibleErrorCode(), "Unable to connect: %s", status.AsCString()); - client.SetConnection(std::move(conn_up)); - if (llvm::Error error = accept_status.get().ToError()) - return error; + // The connection was already established above, so a short timeout is + // sufficient. + Socket *accept_socket = nullptr; + if (Status accept_status = + listen_socket.Accept(std::chrono::seconds(1), accept_socket); + accept_status.Fail()) + return accept_status.takeError(); + client.SetConnection(std::move(conn_up)); server.SetConnection( std::make_unique(accept_socket)); return llvm::Error::success(); diff --git a/lldb/unittests/Host/MainLoopTest.cpp b/lldb/unittests/Host/MainLoopTest.cpp index e7425b737a6da..462aae5085043 100644 --- a/lldb/unittests/Host/MainLoopTest.cpp +++ b/lldb/unittests/Host/MainLoopTest.cpp @@ -42,7 +42,8 @@ class MainLoopTest : public testing::Test { llvm::formatv("localhost:{0}", listen_socket_up->GetLocalPortNumber()) .str()); ASSERT_TRUE(error.Success()); - ASSERT_TRUE(listen_socket_up->Accept(accept_socket).Success()); + ASSERT_TRUE(listen_socket_up->Accept(std::chrono::seconds(1), accept_socket) + .Success()); callback_count = 0; socketpair[0] = std::move(connect_socket_up); diff --git a/lldb/unittests/Host/SocketTest.cpp b/lldb/unittests/Host/SocketTest.cpp index 4befddc0e21ce..c020f1bff0479 100644 --- a/lldb/unittests/Host/SocketTest.cpp +++ b/lldb/unittests/Host/SocketTest.cpp @@ -12,7 +12,9 @@ #include "lldb/Host/MainLoop.h" #include "lldb/Utility/UriParser.h" #include "llvm/Testing/Support/Error.h" +#include "gmock/gmock.h" #include "gtest/gtest.h" +#include using namespace lldb_private; @@ -133,6 +135,27 @@ TEST_P(SocketTest, TCPListen0ConnectAccept) { &socket_b_up); } +TEST_P(SocketTest, TCPAcceptTimeout) { + if (!HostSupportsProtocol()) + return; + + const bool child_processes_inherit = false; + auto listen_socket_up = + std::make_unique(true, child_processes_inherit); + Status error = listen_socket_up->Listen( + llvm::formatv("[{0}]:0", GetParam().localhost_ip).str(), 5); + ASSERT_THAT_ERROR(error.ToError(), llvm::Succeeded()); + ASSERT_TRUE(listen_socket_up->IsValid()); + + Socket *socket; + ASSERT_THAT_ERROR( + listen_socket_up->Accept(std::chrono::milliseconds(10), socket) + .takeError(), + llvm::Failed( + testing::Property(&llvm::ErrorInfoBase::convertToErrorCode, + std::make_error_code(std::errc::timed_out)))); +} + TEST_P(SocketTest, TCPMainLoopAccept) { if (!HostSupportsProtocol()) return; diff --git a/lldb/unittests/TestingSupport/Host/SocketTestUtilities.cpp b/lldb/unittests/TestingSupport/Host/SocketTestUtilities.cpp index 2455a4f6f5d49..417900e7674dc 100644 --- a/lldb/unittests/TestingSupport/Host/SocketTestUtilities.cpp +++ b/lldb/unittests/TestingSupport/Host/SocketTestUtilities.cpp @@ -19,11 +19,6 @@ using namespace lldb_private; -static void AcceptThread(Socket *listen_socket, bool child_processes_inherit, - Socket **accept_socket, Status *error) { - *error = listen_socket->Accept(*accept_socket); -} - template void lldb_private::CreateConnectedSockets( llvm::StringRef listen_remote_address, @@ -38,12 +33,6 @@ void lldb_private::CreateConnectedSockets( ASSERT_THAT_ERROR(error.ToError(), llvm::Succeeded()); ASSERT_TRUE(listen_socket_up->IsValid()); - Status accept_error; - Socket *accept_socket; - std::thread accept_thread(AcceptThread, listen_socket_up.get(), - child_processes_inherit, &accept_socket, - &accept_error); - std::string connect_remote_address = get_connect_addr(*listen_socket_up); std::unique_ptr connect_socket_up( new SocketType(true, child_processes_inherit)); @@ -55,9 +44,13 @@ void lldb_private::CreateConnectedSockets( a_up->swap(connect_socket_up); ASSERT_TRUE((*a_up)->IsValid()); - accept_thread.join(); + Socket *accept_socket; + ASSERT_THAT_ERROR( + listen_socket_up->Accept(std::chrono::seconds(1), accept_socket) + .takeError(), + llvm::Succeeded()); + b_up->reset(static_cast(accept_socket)); - ASSERT_THAT_ERROR(accept_error.ToError(), llvm::Succeeded()); ASSERT_NE(nullptr, b_up->get()); ASSERT_TRUE((*b_up)->IsValid()); diff --git a/lldb/unittests/tools/lldb-server/tests/TestClient.cpp b/lldb/unittests/tools/lldb-server/tests/TestClient.cpp index a6f2dc32c6d0c..9cbdf5278816d 100644 --- a/lldb/unittests/tools/lldb-server/tests/TestClient.cpp +++ b/lldb/unittests/tools/lldb-server/tests/TestClient.cpp @@ -26,9 +26,13 @@ using namespace lldb_private; using namespace llvm; using namespace llgs_tests; +static std::chrono::seconds GetDefaultTimeout() { + return std::chrono::seconds{10}; +} + TestClient::TestClient(std::unique_ptr Conn) { SetConnection(std::move(Conn)); - SetPacketTimeout(std::chrono::seconds(10)); + SetPacketTimeout(GetDefaultTimeout()); } TestClient::~TestClient() { @@ -117,7 +121,10 @@ TestClient::launchCustom(StringRef Log, bool disable_stdio, return status.ToError(); Socket *accept_socket; - listen_socket.Accept(accept_socket); + if (llvm::Error E = + listen_socket.Accept(2 * GetDefaultTimeout(), accept_socket) + .takeError()) + return E; auto Conn = std::make_unique(accept_socket); auto Client = std::unique_ptr(new TestClient(std::move(Conn)));