Skip to content

Commit d9318a3

Browse files
committed
net: split ConnectToSocket() from ConnectDirectly() for unix sockets
1 parent ac2ecf3 commit d9318a3

File tree

2 files changed

+78
-30
lines changed

2 files changed

+78
-30
lines changed

src/netbase.cpp

Lines changed: 76 additions & 30 deletions
Original file line numberDiff line numberDiff line change
@@ -535,24 +535,10 @@ static void LogConnectFailure(bool manual_connection, const char* fmt, const Arg
535535
}
536536
}
537537

538-
std::unique_ptr<Sock> ConnectDirectly(const CService& dest, bool manual_connection)
538+
static bool ConnectToSocket(const Sock& sock, struct sockaddr* sockaddr, socklen_t len, const std::string& dest_str, bool manual_connection)
539539
{
540-
auto sock = CreateSock(dest.GetSAFamily());
541-
if (!sock) {
542-
LogPrintLevel(BCLog::NET, BCLog::Level::Error, "Cannot create a socket for connecting to %s\n", dest.ToStringAddrPort());
543-
return {};
544-
}
545-
546-
// Create a sockaddr from the specified service.
547-
struct sockaddr_storage sockaddr;
548-
socklen_t len = sizeof(sockaddr);
549-
if (!dest.GetSockAddr((struct sockaddr*)&sockaddr, &len)) {
550-
LogPrintf("Cannot connect to %s: unsupported network\n", dest.ToStringAddrPort());
551-
return {};
552-
}
553-
554-
// Connect to the dest service on the hSocket socket.
555-
if (sock->Connect(reinterpret_cast<struct sockaddr*>(&sockaddr), len) == SOCKET_ERROR) {
540+
// Connect to `sockaddr` using `sock`.
541+
if (sock.Connect(sockaddr, len) == SOCKET_ERROR) {
556542
int nErr = WSAGetLastError();
557543
// WSAEINVAL is here because some legacy version of winsock uses it
558544
if (nErr == WSAEINPROGRESS || nErr == WSAEWOULDBLOCK || nErr == WSAEINVAL)
@@ -562,14 +548,14 @@ std::unique_ptr<Sock> ConnectDirectly(const CService& dest, bool manual_connecti
562548
// synchronously to check for successful connection with a timeout.
563549
const Sock::Event requested = Sock::RECV | Sock::SEND;
564550
Sock::Event occurred;
565-
if (!sock->Wait(std::chrono::milliseconds{nConnectTimeout}, requested, &occurred)) {
551+
if (!sock.Wait(std::chrono::milliseconds{nConnectTimeout}, requested, &occurred)) {
566552
LogPrintf("wait for connect to %s failed: %s\n",
567-
dest.ToStringAddrPort(),
553+
dest_str,
568554
NetworkErrorString(WSAGetLastError()));
569-
return {};
555+
return false;
570556
} else if (occurred == 0) {
571-
LogPrint(BCLog::NET, "connection attempt to %s timed out\n", dest.ToStringAddrPort());
572-
return {};
557+
LogPrint(BCLog::NET, "connection attempt to %s timed out\n", dest_str);
558+
return false;
573559
}
574560

575561
// Even if the wait was successful, the connect might not
@@ -578,17 +564,17 @@ std::unique_ptr<Sock> ConnectDirectly(const CService& dest, bool manual_connecti
578564
// sockerr here.
579565
int sockerr;
580566
socklen_t sockerr_len = sizeof(sockerr);
581-
if (sock->GetSockOpt(SOL_SOCKET, SO_ERROR, (sockopt_arg_type)&sockerr, &sockerr_len) ==
567+
if (sock.GetSockOpt(SOL_SOCKET, SO_ERROR, (sockopt_arg_type)&sockerr, &sockerr_len) ==
582568
SOCKET_ERROR) {
583-
LogPrintf("getsockopt() for %s failed: %s\n", dest.ToStringAddrPort(), NetworkErrorString(WSAGetLastError()));
584-
return {};
569+
LogPrintf("getsockopt() for %s failed: %s\n", dest_str, NetworkErrorString(WSAGetLastError()));
570+
return false;
585571
}
586572
if (sockerr != 0) {
587573
LogConnectFailure(manual_connection,
588574
"connect() to %s failed after wait: %s",
589-
dest.ToStringAddrPort(),
575+
dest_str,
590576
NetworkErrorString(sockerr));
591-
return {};
577+
return false;
592578
}
593579
}
594580
#ifdef WIN32
@@ -597,11 +583,71 @@ std::unique_ptr<Sock> ConnectDirectly(const CService& dest, bool manual_connecti
597583
else
598584
#endif
599585
{
600-
LogConnectFailure(manual_connection, "connect() to %s failed: %s", dest.ToStringAddrPort(), NetworkErrorString(WSAGetLastError()));
601-
return {};
586+
LogConnectFailure(manual_connection, "connect() to %s failed: %s", dest_str, NetworkErrorString(WSAGetLastError()));
587+
return false;
602588
}
603589
}
590+
return true;
591+
}
592+
593+
std::unique_ptr<Sock> ConnectDirectly(const CService& dest, bool manual_connection)
594+
{
595+
auto sock = CreateSock(dest.GetSAFamily());
596+
if (!sock) {
597+
LogPrintLevel(BCLog::NET, BCLog::Level::Error, "Cannot create a socket for connecting to %s\n", dest.ToStringAddrPort());
598+
return {};
599+
}
600+
601+
// Create a sockaddr from the specified service.
602+
struct sockaddr_storage sockaddr;
603+
socklen_t len = sizeof(sockaddr);
604+
if (!dest.GetSockAddr((struct sockaddr*)&sockaddr, &len)) {
605+
LogPrintf("Cannot get sockaddr for %s: unsupported network\n", dest.ToStringAddrPort());
606+
return {};
607+
}
608+
609+
if (!ConnectToSocket(*sock, (struct sockaddr*)&sockaddr, len, dest.ToStringAddrPort(), manual_connection)) {
610+
LogPrintf("Cannot connect to socket for %s\n", dest.ToStringAddrPort());
611+
return {};
612+
}
613+
614+
return sock;
615+
}
616+
617+
std::unique_ptr<Sock> Proxy::Connect() const
618+
{
619+
if (!IsValid()) {
620+
LogPrintf("Cannot connect to invalid Proxy\n");
621+
return {};
622+
}
623+
624+
if (!m_is_unix_socket) return ConnectDirectly(proxy, /*manual_connection=*/true);
625+
626+
#if HAVE_SOCKADDR_UN
627+
auto sock = CreateSock(AF_UNIX);
628+
if (!sock) {
629+
LogPrintLevel(BCLog::NET, BCLog::Level::Error, "Cannot create a socket for connecting to %s\n", m_unix_socket_path);
630+
return {};
631+
}
632+
633+
const std::string path{m_unix_socket_path.substr(ADDR_PREFIX_UNIX.length())};
634+
635+
struct sockaddr_un addrun;
636+
memset(&addrun, 0, sizeof(addrun));
637+
addrun.sun_family = AF_UNIX;
638+
// leave the last char in addrun.sun_path[] to be always '\0'
639+
memcpy(addrun.sun_path, path.c_str(), std::min(sizeof(addrun.sun_path) - 1, path.length()));
640+
socklen_t len = sizeof(addrun);
641+
642+
if(!ConnectToSocket(*sock, (struct sockaddr*)&addrun, len, path, /*manual_connection=*/true)) {
643+
LogPrintf("Cannot connect to socket for %s\n", path);
644+
return {};
645+
}
646+
604647
return sock;
648+
#else
649+
return {};
650+
#endif
605651
}
606652

607653
bool SetProxy(enum Network net, const Proxy &addrProxy) {
@@ -658,7 +704,7 @@ std::unique_ptr<Sock> ConnectThroughProxy(const Proxy& proxy,
658704
bool& proxy_connection_failed)
659705
{
660706
// first connect to proxy server
661-
auto sock = ConnectDirectly(proxy.proxy, /*manual_connection=*/true);
707+
auto sock = proxy.Connect();
662708
if (!sock) {
663709
proxy_connection_failed = true;
664710
return {};

src/netbase.h

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -84,6 +84,8 @@ class Proxy
8484
if (m_is_unix_socket) return m_unix_socket_path;
8585
return proxy.ToStringAddrPort();
8686
}
87+
88+
std::unique_ptr<Sock> Connect() const;
8789
};
8890

8991
/** Credentials for proxy authentication */

0 commit comments

Comments
 (0)