Skip to content

Commit 04ae846

Browse files
net: use Sock in InterruptibleRecv() and Socks5()
Use the `Sock` class instead of `SOCKET` for `InterruptibleRecv()` and `Socks5()`. This way the `Socks5()` function can be tested by giving it a mocked instance of a socket. Co-authored-by: practicalswift <[email protected]>
1 parent ba9d732 commit 04ae846

File tree

3 files changed

+12
-27
lines changed

3 files changed

+12
-27
lines changed

src/net.cpp

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -440,7 +440,7 @@ CNode* CConnman::ConnectNode(CAddress addrConnect, const char *pszDest, bool fCo
440440
return nullptr;
441441
}
442442
connected = ConnectThroughProxy(proxy, addrConnect.ToStringIP(), addrConnect.GetPort(),
443-
sock->Get(), nConnectTimeout, proxyConnectionFailed);
443+
*sock, nConnectTimeout, proxyConnectionFailed);
444444
} else {
445445
// no proxy needed (none set for target network)
446446
sock = CreateSock(addrConnect);
@@ -464,7 +464,7 @@ CNode* CConnman::ConnectNode(CAddress addrConnect, const char *pszDest, bool fCo
464464
int port = default_port;
465465
SplitHostPort(std::string(pszDest), port, host);
466466
bool proxyConnectionFailed;
467-
connected = ConnectThroughProxy(proxy, host, port, sock->Get(), nConnectTimeout,
467+
connected = ConnectThroughProxy(proxy, host, port, *sock, nConnectTimeout,
468468
proxyConnectionFailed);
469469
}
470470
if (!connected) {

src/netbase.cpp

Lines changed: 9 additions & 24 deletions
Original file line numberDiff line numberDiff line change
@@ -343,15 +343,15 @@ enum class IntrRecvError {
343343
* Sockets can be made non-blocking with SetSocketNonBlocking(const
344344
* SOCKET&, bool).
345345
*/
346-
static IntrRecvError InterruptibleRecv(uint8_t* data, size_t len, int timeout, const SOCKET& hSocket)
346+
static IntrRecvError InterruptibleRecv(uint8_t* data, size_t len, int timeout, const Sock& hSocket)
347347
{
348348
int64_t curTime = GetTimeMillis();
349349
int64_t endTime = curTime + timeout;
350350
// Maximum time to wait for I/O readiness. It will take up until this time
351351
// (in millis) to break off in case of an interruption.
352352
const int64_t maxWait = 1000;
353353
while (len > 0 && curTime < endTime) {
354-
ssize_t ret = recv(hSocket, (char*)data, len, 0); // Optimistically try the recv first
354+
ssize_t ret = hSocket.Recv(data, len, 0); // Optimistically try the recv first
355355
if (ret > 0) {
356356
len -= ret;
357357
data += ret;
@@ -360,25 +360,10 @@ static IntrRecvError InterruptibleRecv(uint8_t* data, size_t len, int timeout, c
360360
} else { // Other error or blocking
361361
int nErr = WSAGetLastError();
362362
if (nErr == WSAEINPROGRESS || nErr == WSAEWOULDBLOCK || nErr == WSAEINVAL) {
363-
if (!IsSelectableSocket(hSocket)) {
364-
return IntrRecvError::NetworkError;
365-
}
366363
// Only wait at most maxWait milliseconds at a time, unless
367364
// we're approaching the end of the specified total timeout
368365
int timeout_ms = std::min(endTime - curTime, maxWait);
369-
#ifdef USE_POLL
370-
struct pollfd pollfd = {};
371-
pollfd.fd = hSocket;
372-
pollfd.events = POLLIN;
373-
int nRet = poll(&pollfd, 1, timeout_ms);
374-
#else
375-
struct timeval tval = MillisToTimeval(timeout_ms);
376-
fd_set fdset;
377-
FD_ZERO(&fdset);
378-
FD_SET(hSocket, &fdset);
379-
int nRet = select(hSocket + 1, &fdset, nullptr, nullptr, &tval);
380-
#endif
381-
if (nRet == SOCKET_ERROR) {
366+
if (!hSocket.Wait(std::chrono::milliseconds{timeout_ms}, Sock::RECV)) {
382367
return IntrRecvError::NetworkError;
383368
}
384369
} else {
@@ -442,7 +427,7 @@ static std::string Socks5ErrorString(uint8_t err)
442427
* @see <a href="https://www.ietf.org/rfc/rfc1928.txt">RFC1928: SOCKS Protocol
443428
* Version 5</a>
444429
*/
445-
static bool Socks5(const std::string& strDest, int port, const ProxyCredentials *auth, const SOCKET& hSocket)
430+
static bool Socks5(const std::string& strDest, int port, const ProxyCredentials* auth, const Sock& hSocket)
446431
{
447432
IntrRecvError recvr;
448433
LogPrint(BCLog::NET, "SOCKS5 connecting %s\n", strDest);
@@ -460,7 +445,7 @@ static bool Socks5(const std::string& strDest, int port, const ProxyCredentials
460445
vSocks5Init.push_back(0x01); // 1 method identifier follows...
461446
vSocks5Init.push_back(SOCKS5Method::NOAUTH);
462447
}
463-
ssize_t ret = send(hSocket, (const char*)vSocks5Init.data(), vSocks5Init.size(), MSG_NOSIGNAL);
448+
ssize_t ret = hSocket.Send(vSocks5Init.data(), vSocks5Init.size(), MSG_NOSIGNAL);
464449
if (ret != (ssize_t)vSocks5Init.size()) {
465450
return error("Error sending to proxy");
466451
}
@@ -482,7 +467,7 @@ static bool Socks5(const std::string& strDest, int port, const ProxyCredentials
482467
vAuth.insert(vAuth.end(), auth->username.begin(), auth->username.end());
483468
vAuth.push_back(auth->password.size());
484469
vAuth.insert(vAuth.end(), auth->password.begin(), auth->password.end());
485-
ret = send(hSocket, (const char*)vAuth.data(), vAuth.size(), MSG_NOSIGNAL);
470+
ret = hSocket.Send(vAuth.data(), vAuth.size(), MSG_NOSIGNAL);
486471
if (ret != (ssize_t)vAuth.size()) {
487472
return error("Error sending authentication to proxy");
488473
}
@@ -508,7 +493,7 @@ static bool Socks5(const std::string& strDest, int port, const ProxyCredentials
508493
vSocks5.insert(vSocks5.end(), strDest.begin(), strDest.end());
509494
vSocks5.push_back((port >> 8) & 0xFF);
510495
vSocks5.push_back((port >> 0) & 0xFF);
511-
ret = send(hSocket, (const char*)vSocks5.data(), vSocks5.size(), MSG_NOSIGNAL);
496+
ret = hSocket.Send(vSocks5.data(), vSocks5.size(), MSG_NOSIGNAL);
512497
if (ret != (ssize_t)vSocks5.size()) {
513498
return error("Error sending to proxy");
514499
}
@@ -787,10 +772,10 @@ bool IsProxy(const CNetAddr &addr) {
787772
*
788773
* @returns Whether or not the operation succeeded.
789774
*/
790-
bool ConnectThroughProxy(const proxyType &proxy, const std::string& strDest, int port, const SOCKET& hSocket, int nTimeout, bool& outProxyConnectionFailed)
775+
bool ConnectThroughProxy(const proxyType& proxy, const std::string& strDest, int port, const Sock& hSocket, int nTimeout, bool& outProxyConnectionFailed)
791776
{
792777
// first connect to proxy server
793-
if (!ConnectSocketDirectly(proxy.proxy, hSocket, nTimeout, true)) {
778+
if (!ConnectSocketDirectly(proxy.proxy, hSocket.Get(), nTimeout, true)) {
794779
outProxyConnectionFailed = true;
795780
return false;
796781
}

src/netbase.h

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -68,7 +68,7 @@ std::unique_ptr<Sock> CreateSockTCP(const CService& address_family);
6868
extern std::function<std::unique_ptr<Sock>(const CService&)> CreateSock;
6969

7070
bool ConnectSocketDirectly(const CService &addrConnect, const SOCKET& hSocketRet, int nTimeout, bool manual_connection);
71-
bool ConnectThroughProxy(const proxyType &proxy, const std::string& strDest, int port, const SOCKET& hSocketRet, int nTimeout, bool& outProxyConnectionFailed);
71+
bool ConnectThroughProxy(const proxyType& proxy, const std::string& strDest, int port, const Sock& hSocketRet, int nTimeout, bool& outProxyConnectionFailed);
7272
/** Disable or enable blocking-mode for a socket */
7373
bool SetSocketNonBlocking(const SOCKET& hSocket, bool fNonBlocking);
7474
/** Set the TCP_NODELAY flag on a socket */

0 commit comments

Comments
 (0)