Skip to content

Commit 1729c29

Browse files
committed
net: split socket creation out of connection
Also, check for the correct error during socket creation
1 parent 6f01dcf commit 1729c29

File tree

3 files changed

+44
-17
lines changed

3 files changed

+44
-17
lines changed

src/net.cpp

Lines changed: 16 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -417,16 +417,30 @@ CNode* CConnman::ConnectNode(CAddress addrConnect, const char *pszDest, bool fCo
417417
if (addrConnect.IsValid()) {
418418
bool proxyConnectionFailed = false;
419419

420-
if (GetProxy(addrConnect.GetNetwork(), proxy))
420+
if (GetProxy(addrConnect.GetNetwork(), proxy)) {
421+
hSocket = CreateSocket(proxy.proxy);
422+
if (hSocket == INVALID_SOCKET) {
423+
return nullptr;
424+
}
421425
connected = ConnectThroughProxy(proxy, addrConnect.ToStringIP(), addrConnect.GetPort(), hSocket, nConnectTimeout, &proxyConnectionFailed);
422-
else // no proxy needed (none set for target network)
426+
} else {
427+
// no proxy needed (none set for target network)
428+
hSocket = CreateSocket(addrConnect);
429+
if (hSocket == INVALID_SOCKET) {
430+
return nullptr;
431+
}
423432
connected = ConnectSocketDirectly(addrConnect, hSocket, nConnectTimeout);
433+
}
424434
if (!proxyConnectionFailed) {
425435
// If a connection to the node was attempted, and failure (if any) is not caused by a problem connecting to
426436
// the proxy, mark this as an attempt.
427437
addrman.Attempt(addrConnect, fCountFailure);
428438
}
429439
} else if (pszDest && GetNameProxy(proxy)) {
440+
hSocket = CreateSocket(proxy.proxy);
441+
if (hSocket == INVALID_SOCKET) {
442+
return nullptr;
443+
}
430444
std::string host;
431445
int port = default_port;
432446
SplitHostPort(std::string(pszDest), port, host);

src/netbase.cpp

Lines changed: 27 additions & 15 deletions
Original file line numberDiff line numberDiff line change
@@ -452,20 +452,18 @@ static bool Socks5(const std::string& strDest, int port, const ProxyCredentials
452452
return true;
453453
}
454454

455-
bool ConnectSocketDirectly(const CService &addrConnect, SOCKET& hSocketRet, int nTimeout)
455+
SOCKET CreateSocket(const CService &addrConnect)
456456
{
457-
hSocketRet = INVALID_SOCKET;
458-
459457
struct sockaddr_storage sockaddr;
460458
socklen_t len = sizeof(sockaddr);
461459
if (!addrConnect.GetSockAddr((struct sockaddr*)&sockaddr, &len)) {
462-
LogPrintf("Cannot connect to %s: unsupported network\n", addrConnect.ToString());
463-
return false;
460+
LogPrintf("Cannot create socket for %s: unsupported network\n", addrConnect.ToString());
461+
return INVALID_SOCKET;
464462
}
465463

466464
SOCKET hSocket = socket(((struct sockaddr*)&sockaddr)->sa_family, SOCK_STREAM, IPPROTO_TCP);
467465
if (hSocket == INVALID_SOCKET)
468-
return false;
466+
return INVALID_SOCKET;
469467

470468
#ifdef SO_NOSIGPIPE
471469
int set = 1;
@@ -479,9 +477,24 @@ bool ConnectSocketDirectly(const CService &addrConnect, SOCKET& hSocketRet, int
479477
// Set to non-blocking
480478
if (!SetSocketNonBlocking(hSocket, true)) {
481479
CloseSocket(hSocket);
482-
return error("ConnectSocketDirectly: Setting socket to non-blocking failed, error %s\n", NetworkErrorString(WSAGetLastError()));
480+
LogPrintf("ConnectSocketDirectly: Setting socket to non-blocking failed, error %s\n", NetworkErrorString(WSAGetLastError()));
483481
}
482+
return hSocket;
483+
}
484484

485+
bool ConnectSocketDirectly(const CService &addrConnect, SOCKET& hSocket, int nTimeout)
486+
{
487+
struct sockaddr_storage sockaddr;
488+
socklen_t len = sizeof(sockaddr);
489+
if (hSocket == INVALID_SOCKET) {
490+
LogPrintf("Cannot connect to %s: invalid socket\n", addrConnect.ToString());
491+
return false;
492+
}
493+
if (!addrConnect.GetSockAddr((struct sockaddr*)&sockaddr, &len)) {
494+
LogPrintf("Cannot connect to %s: unsupported network\n", addrConnect.ToString());
495+
CloseSocket(hSocket);
496+
return false;
497+
}
485498
if (connect(hSocket, (struct sockaddr*)&sockaddr, len) == SOCKET_ERROR)
486499
{
487500
int nErr = WSAGetLastError();
@@ -534,8 +547,6 @@ bool ConnectSocketDirectly(const CService &addrConnect, SOCKET& hSocketRet, int
534547
return false;
535548
}
536549
}
537-
538-
hSocketRet = hSocket;
539550
return true;
540551
}
541552

@@ -587,9 +598,8 @@ bool IsProxy(const CNetAddr &addr) {
587598
return false;
588599
}
589600

590-
bool ConnectThroughProxy(const proxyType &proxy, const std::string& strDest, int port, SOCKET& hSocketRet, int nTimeout, bool *outProxyConnectionFailed)
601+
bool ConnectThroughProxy(const proxyType &proxy, const std::string& strDest, int port, SOCKET& hSocket, int nTimeout, bool *outProxyConnectionFailed)
591602
{
592-
SOCKET hSocket = INVALID_SOCKET;
593603
// first connect to proxy server
594604
if (!ConnectSocketDirectly(proxy.proxy, hSocket, nTimeout)) {
595605
if (outProxyConnectionFailed)
@@ -601,14 +611,16 @@ bool ConnectThroughProxy(const proxyType &proxy, const std::string& strDest, int
601611
ProxyCredentials random_auth;
602612
static std::atomic_int counter(0);
603613
random_auth.username = random_auth.password = strprintf("%i", counter++);
604-
if (!Socks5(strDest, (unsigned short)port, &random_auth, hSocket))
614+
if (!Socks5(strDest, (unsigned short)port, &random_auth, hSocket)) {
615+
CloseSocket(hSocket);
605616
return false;
617+
}
606618
} else {
607-
if (!Socks5(strDest, (unsigned short)port, 0, hSocket))
619+
if (!Socks5(strDest, (unsigned short)port, 0, hSocket)) {
620+
CloseSocket(hSocket);
608621
return false;
622+
}
609623
}
610-
611-
hSocketRet = hSocket;
612624
return true;
613625
}
614626
bool LookupSubNet(const char* pszName, CSubNet& ret)

src/netbase.h

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -51,6 +51,7 @@ bool Lookup(const char *pszName, CService& addr, int portDefault, bool fAllowLoo
5151
bool Lookup(const char *pszName, std::vector<CService>& vAddr, int portDefault, bool fAllowLookup, unsigned int nMaxSolutions);
5252
CService LookupNumeric(const char *pszName, int portDefault = 0);
5353
bool LookupSubNet(const char *pszName, CSubNet& subnet);
54+
SOCKET CreateSocket(const CService &addrConnect);
5455
bool ConnectSocketDirectly(const CService &addrConnect, SOCKET& hSocketRet, int nTimeout);
5556
bool ConnectThroughProxy(const proxyType &proxy, const std::string& strDest, int port, SOCKET& hSocketRet, int nTimeout, bool *outProxyConnectionFailed);
5657
/** Return readable error string for a network error code */

0 commit comments

Comments
 (0)