Skip to content
Merged
Show file tree
Hide file tree
Changes from 4 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions lldb/include/lldb/Host/linux/AbstractSocket.h
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,7 @@ namespace lldb_private {
class AbstractSocket : public DomainSocket {
public:
AbstractSocket();
AbstractSocket(NativeSocket socket, bool should_close);

protected:
size_t GetNameOffset() const override;
Expand Down
4 changes: 4 additions & 0 deletions lldb/include/lldb/Host/posix/DomainSocket.h
Original file line number Diff line number Diff line change
Expand Up @@ -31,8 +31,12 @@ class DomainSocket : public Socket {

std::vector<std::string> GetListeningConnectionURI() const override;

static llvm::Expected<std::unique_ptr<DomainSocket>>
FromBoundNativeSocket(NativeSocket sockfd, bool should_close);

protected:
DomainSocket(SocketProtocol protocol);
DomainSocket(SocketProtocol protocol, NativeSocket socket, bool should_close);

virtual size_t GetNameOffset() const;
virtual void DeleteSocketFile(llvm::StringRef name);
Expand Down
3 changes: 3 additions & 0 deletions lldb/source/Host/linux/AbstractSocket.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,9 @@ using namespace lldb_private;

AbstractSocket::AbstractSocket() : DomainSocket(ProtocolUnixAbstract) {}

AbstractSocket::AbstractSocket(NativeSocket socket, bool should_close)
: DomainSocket(ProtocolUnixAbstract, socket, should_close) {}

size_t AbstractSocket::GetNameOffset() const { return 1; }

void AbstractSocket::DeleteSocketFile(llvm::StringRef name) {}
26 changes: 26 additions & 0 deletions lldb/source/Host/posix/DomainSocket.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,9 @@

#include "lldb/Host/posix/DomainSocket.h"
#include "lldb/Utility/LLDBLog.h"
#ifdef __linux__
#include <lldb/Host/linux/AbstractSocket.h>
#endif

#include "llvm/Support/Errno.h"
#include "llvm/Support/FileSystem.h"
Expand Down Expand Up @@ -67,6 +70,12 @@ DomainSocket::DomainSocket(NativeSocket socket,
m_socket = socket;
}

DomainSocket::DomainSocket(SocketProtocol protocol, NativeSocket socket,
bool should_close)
: Socket(protocol, should_close) {
m_socket = socket;
}

Status DomainSocket::Connect(llvm::StringRef name) {
sockaddr_un saddr_un;
socklen_t saddr_un_len;
Expand Down Expand Up @@ -182,3 +191,20 @@ std::vector<std::string> DomainSocket::GetListeningConnectionURI() const {

return {llvm::formatv("unix-connect://{0}", addr.sun_path)};
}

llvm::Expected<std::unique_ptr<DomainSocket>>
DomainSocket::FromBoundNativeSocket(NativeSocket sockfd, bool should_close) {
#ifdef __linux__
// Check if fd represents domain socket or abstract socket.
struct sockaddr_un addr;
socklen_t addr_len = sizeof(addr);
if (getsockname(sockfd, (struct sockaddr *)&addr, &addr_len) == -1)
return llvm::createStringError("not a socket or error occurred");
if (addr.sun_family != AF_UNIX)
return llvm::createStringError("Bad socket type");
if (addr_len > offsetof(struct sockaddr_un, sun_path) &&
addr.sun_path[0] == '\0')
return std::make_unique<AbstractSocket>(sockfd, should_close);
#endif
return std::make_unique<DomainSocket>(sockfd, should_close);
}
39 changes: 18 additions & 21 deletions lldb/tools/lldb-server/lldb-platform.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -455,37 +455,33 @@ int main_platform(int argc, char *argv[]) {
lldb_private::Args inferior_arguments;
inferior_arguments.SetArguments(argc, const_cast<const char **>(argv));

Socket::SocketProtocol protocol = Socket::ProtocolUnixDomain;

Log *log = GetLog(LLDBLog::Platform);
if (fd != SharedSocket::kInvalidFD) {
// Child process will handle the connection and exit.
if (gdbserver_port)
protocol = Socket::ProtocolTcp;

Log *log = GetLog(LLDBLog::Platform);

NativeSocket sockfd;
error = SharedSocket::GetNativeSocket(fd, sockfd);
if (error.Fail()) {
LLDB_LOGF(log, "lldb-platform child: %s", error.AsCString());
return socket_error;
}

GDBRemoteCommunicationServerPlatform platform(protocol, gdbserver_port);
Socket *socket;
if (protocol == Socket::ProtocolTcp)
socket = new TCPSocket(sockfd, /*should_close=*/true);
else {
#if LLDB_ENABLE_POSIX
socket = new DomainSocket(sockfd, /*should_close=*/true);
#else
WithColor::error() << "lldb-platform child: Unix domain sockets are not "
"supported on this platform.";
return socket_error;
#endif
std::unique_ptr<Socket> socket;
if (gdbserver_port) {
socket = std::make_unique<TCPSocket>(sockfd, /*should_close=*/true);
} else {
llvm::Expected<std::unique_ptr<DomainSocket>> domain_socket =
DomainSocket::FromBoundNativeSocket(sockfd, /*should_close=*/true);
if (!domain_socket) {
LLDB_LOGF(log, "Failed to create socket: %s\n", error.AsCString());
return socket_error;
}
socket = std::move(domain_socket.get());
}
platform.SetConnection(
std::unique_ptr<Connection>(new ConnectionFileDescriptor(socket)));

GDBRemoteCommunicationServerPlatform platform(socket->GetSocketProtocol(),
gdbserver_port);
platform.SetConnection(std::unique_ptr<Connection>(
new ConnectionFileDescriptor(socket.release())));
client_handle(platform, inferior_arguments);
return 0;
}
Expand All @@ -498,6 +494,7 @@ int main_platform(int argc, char *argv[]) {
return 1;
}

Socket::SocketProtocol protocol = Socket::ProtocolUnixDomain;
std::string address;
std::string gdb_address;
uint16_t platform_port = 0;
Expand Down
45 changes: 45 additions & 0 deletions lldb/unittests/Host/SocketTest.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,9 @@
#include "gmock/gmock.h"
#include "gtest/gtest.h"
#include <chrono>
#if __linux__
#include <lldb/Host/linux/AbstractSocket.h>
#endif

using namespace lldb_private;

Expand Down Expand Up @@ -339,6 +342,48 @@ TEST_F(SocketTest, DomainGetConnectURI) {
}
#endif

#if LLDB_ENABLE_POSIX
TEST_F(SocketTest, DomainSocketFromBoundNativeSocket) {
// Generate a name for the domain socket.
llvm::SmallString<64> name;
std::error_code EC = llvm::sys::fs::createUniqueDirectory(
"DomainSocketFromBoundNativeSocket", name);
ASSERT_FALSE(EC);
llvm::sys::path::append(name, "test");

DomainSocket socket(true);
Status error = socket.Listen(name, /*backlog=*/10);
ASSERT_FALSE(error.ToError());
NativeSocket native_socket = socket.GetNativeSocket();

llvm::Expected<std::unique_ptr<DomainSocket>> sock =
DomainSocket::FromBoundNativeSocket(native_socket, /*should_close=*/true);
ASSERT_THAT_EXPECTED(sock, llvm::Succeeded());
ASSERT_EQ(Socket::ProtocolUnixDomain, sock->get()->GetSocketProtocol());
}
#endif

#if __linux__
TEST_F(SocketTest, AbstractSocketFromBoundNativeSocket) {
// Generate a name for the abstract socket.
llvm::SmallString<64> name;
std::error_code EC = llvm::sys::fs::createUniqueDirectory(
"AbstractSocketFromBoundNativeSocket", name);
ASSERT_FALSE(EC);
llvm::sys::path::append(name, "test");

AbstractSocket socket;
Status error = socket.Listen(name, /*backlog=*/10);
ASSERT_FALSE(error.ToError());
NativeSocket native_socket = socket.GetNativeSocket();

llvm::Expected<std::unique_ptr<DomainSocket>> sock =
DomainSocket::FromBoundNativeSocket(native_socket, /*should_close=*/true);
ASSERT_THAT_EXPECTED(sock, llvm::Succeeded());
ASSERT_EQ(Socket::ProtocolUnixAbstract, sock->get()->GetSocketProtocol());
}
#endif

INSTANTIATE_TEST_SUITE_P(
SocketTests, SocketTest,
testing::Values(SocketTestParams{/*is_ipv6=*/false,
Expand Down
Loading