Skip to content

Commit ba9d732

Browse files
committed
net: add RAII socket and use it instead of bare SOCKET
Introduce a class to manage the lifetime of a socket - when the object that contains the socket goes out of scope, the underlying socket will be closed. In addition, the new `Sock` class has a `Send()`, `Recv()` and `Wait()` methods that can be overridden by unit tests to mock the socket operations. The `Wait()` method also hides the `#ifdef USE_POLL poll() #else select() #endif` technique from higher level code.
1 parent dec9b5e commit ba9d732

File tree

7 files changed

+250
-41
lines changed

7 files changed

+250
-41
lines changed

src/net.cpp

Lines changed: 23 additions & 24 deletions
Original file line numberDiff line numberDiff line change
@@ -429,51 +429,53 @@ CNode* CConnman::ConnectNode(CAddress addrConnect, const char *pszDest, bool fCo
429429

430430
// Connect
431431
bool connected = false;
432-
SOCKET hSocket = INVALID_SOCKET;
432+
std::unique_ptr<Sock> sock;
433433
proxyType proxy;
434434
if (addrConnect.IsValid()) {
435435
bool proxyConnectionFailed = false;
436436

437437
if (GetProxy(addrConnect.GetNetwork(), proxy)) {
438-
hSocket = CreateSocket(proxy.proxy);
439-
if (hSocket == INVALID_SOCKET) {
438+
sock = CreateSock(proxy.proxy);
439+
if (!sock) {
440440
return nullptr;
441441
}
442-
connected = ConnectThroughProxy(proxy, addrConnect.ToStringIP(), addrConnect.GetPort(), hSocket, nConnectTimeout, proxyConnectionFailed);
442+
connected = ConnectThroughProxy(proxy, addrConnect.ToStringIP(), addrConnect.GetPort(),
443+
sock->Get(), nConnectTimeout, proxyConnectionFailed);
443444
} else {
444445
// no proxy needed (none set for target network)
445-
hSocket = CreateSocket(addrConnect);
446-
if (hSocket == INVALID_SOCKET) {
446+
sock = CreateSock(addrConnect);
447+
if (!sock) {
447448
return nullptr;
448449
}
449-
connected = ConnectSocketDirectly(addrConnect, hSocket, nConnectTimeout, conn_type == ConnectionType::MANUAL);
450+
connected = ConnectSocketDirectly(addrConnect, sock->Get(), nConnectTimeout,
451+
conn_type == ConnectionType::MANUAL);
450452
}
451453
if (!proxyConnectionFailed) {
452454
// If a connection to the node was attempted, and failure (if any) is not caused by a problem connecting to
453455
// the proxy, mark this as an attempt.
454456
addrman.Attempt(addrConnect, fCountFailure);
455457
}
456458
} else if (pszDest && GetNameProxy(proxy)) {
457-
hSocket = CreateSocket(proxy.proxy);
458-
if (hSocket == INVALID_SOCKET) {
459+
sock = CreateSock(proxy.proxy);
460+
if (!sock) {
459461
return nullptr;
460462
}
461463
std::string host;
462464
int port = default_port;
463465
SplitHostPort(std::string(pszDest), port, host);
464466
bool proxyConnectionFailed;
465-
connected = ConnectThroughProxy(proxy, host, port, hSocket, nConnectTimeout, proxyConnectionFailed);
467+
connected = ConnectThroughProxy(proxy, host, port, sock->Get(), nConnectTimeout,
468+
proxyConnectionFailed);
466469
}
467470
if (!connected) {
468-
CloseSocket(hSocket);
469471
return nullptr;
470472
}
471473

472474
// Add node
473475
NodeId id = GetNewNodeId();
474476
uint64_t nonce = GetDeterministicRandomizer(RANDOMIZER_ID_LOCALHOSTNONCE).Write(id).Finalize();
475-
CAddress addr_bind = GetBindAddress(hSocket);
476-
CNode* pnode = new CNode(id, nLocalServices, hSocket, addrConnect, CalculateKeyedNetGroup(addrConnect), nonce, addr_bind, pszDest ? pszDest : "", conn_type);
477+
CAddress addr_bind = GetBindAddress(sock->Get());
478+
CNode* pnode = new CNode(id, nLocalServices, sock->Release(), addrConnect, CalculateKeyedNetGroup(addrConnect), nonce, addr_bind, pszDest ? pszDest : "", conn_type);
477479
pnode->AddRef();
478480

479481
// We're making a new connection, harvest entropy from the time (and our peer count)
@@ -2177,53 +2179,50 @@ bool CConnman::BindListenPort(const CService& addrBind, bilingual_str& strError,
21772179
return false;
21782180
}
21792181

2180-
SOCKET hListenSocket = CreateSocket(addrBind);
2181-
if (hListenSocket == INVALID_SOCKET)
2182-
{
2182+
std::unique_ptr<Sock> sock = CreateSock(addrBind);
2183+
if (!sock) {
21832184
strError = strprintf(Untranslated("Error: Couldn't open socket for incoming connections (socket returned error %s)"), NetworkErrorString(WSAGetLastError()));
21842185
LogPrintf("%s\n", strError.original);
21852186
return false;
21862187
}
21872188

21882189
// Allow binding if the port is still in TIME_WAIT state after
21892190
// the program was closed and restarted.
2190-
setsockopt(hListenSocket, SOL_SOCKET, SO_REUSEADDR, (sockopt_arg_type)&nOne, sizeof(int));
2191+
setsockopt(sock->Get(), SOL_SOCKET, SO_REUSEADDR, (sockopt_arg_type)&nOne, sizeof(int));
21912192

21922193
// some systems don't have IPV6_V6ONLY but are always v6only; others do have the option
21932194
// and enable it by default or not. Try to enable it, if possible.
21942195
if (addrBind.IsIPv6()) {
21952196
#ifdef IPV6_V6ONLY
2196-
setsockopt(hListenSocket, IPPROTO_IPV6, IPV6_V6ONLY, (sockopt_arg_type)&nOne, sizeof(int));
2197+
setsockopt(sock->Get(), IPPROTO_IPV6, IPV6_V6ONLY, (sockopt_arg_type)&nOne, sizeof(int));
21972198
#endif
21982199
#ifdef WIN32
21992200
int nProtLevel = PROTECTION_LEVEL_UNRESTRICTED;
2200-
setsockopt(hListenSocket, IPPROTO_IPV6, IPV6_PROTECTION_LEVEL, (const char*)&nProtLevel, sizeof(int));
2201+
setsockopt(sock->Get(), IPPROTO_IPV6, IPV6_PROTECTION_LEVEL, (const char*)&nProtLevel, sizeof(int));
22012202
#endif
22022203
}
22032204

2204-
if (::bind(hListenSocket, (struct sockaddr*)&sockaddr, len) == SOCKET_ERROR)
2205+
if (::bind(sock->Get(), (struct sockaddr*)&sockaddr, len) == SOCKET_ERROR)
22052206
{
22062207
int nErr = WSAGetLastError();
22072208
if (nErr == WSAEADDRINUSE)
22082209
strError = strprintf(_("Unable to bind to %s on this computer. %s is probably already running."), addrBind.ToString(), PACKAGE_NAME);
22092210
else
22102211
strError = strprintf(_("Unable to bind to %s on this computer (bind returned error %s)"), addrBind.ToString(), NetworkErrorString(nErr));
22112212
LogPrintf("%s\n", strError.original);
2212-
CloseSocket(hListenSocket);
22132213
return false;
22142214
}
22152215
LogPrintf("Bound to %s\n", addrBind.ToString());
22162216

22172217
// Listen for incoming connections
2218-
if (listen(hListenSocket, SOMAXCONN) == SOCKET_ERROR)
2218+
if (listen(sock->Get(), SOMAXCONN) == SOCKET_ERROR)
22192219
{
22202220
strError = strprintf(_("Error: Listening for incoming connections failed (listen returned error %s)"), NetworkErrorString(WSAGetLastError()));
22212221
LogPrintf("%s\n", strError.original);
2222-
CloseSocket(hListenSocket);
22232222
return false;
22242223
}
22252224

2226-
vhListenSocket.push_back(ListenSocket(hListenSocket, permissions));
2225+
vhListenSocket.push_back(ListenSocket(sock->Release(), permissions));
22272226
return true;
22282227
}
22292228

src/netbase.cpp

Lines changed: 15 additions & 16 deletions
Original file line numberDiff line numberDiff line change
@@ -15,7 +15,9 @@
1515

1616
#include <atomic>
1717
#include <cstdint>
18+
#include <functional>
1819
#include <limits>
20+
#include <memory>
1921

2022
#ifndef WIN32
2123
#include <fcntl.h>
@@ -559,34 +561,28 @@ static bool Socks5(const std::string& strDest, int port, const ProxyCredentials
559561
return true;
560562
}
561563

562-
/**
563-
* Try to create a socket file descriptor with specific properties in the
564-
* communications domain (address family) of the specified service.
565-
*
566-
* For details on the desired properties, see the inline comments in the source
567-
* code.
568-
*/
569-
SOCKET CreateSocket(const CService &addrConnect)
564+
std::unique_ptr<Sock> CreateSockTCP(const CService& address_family)
570565
{
571566
// Create a sockaddr from the specified service.
572567
struct sockaddr_storage sockaddr;
573568
socklen_t len = sizeof(sockaddr);
574-
if (!addrConnect.GetSockAddr((struct sockaddr*)&sockaddr, &len)) {
575-
LogPrintf("Cannot create socket for %s: unsupported network\n", addrConnect.ToString());
576-
return INVALID_SOCKET;
569+
if (!address_family.GetSockAddr((struct sockaddr*)&sockaddr, &len)) {
570+
LogPrintf("Cannot create socket for %s: unsupported network\n", address_family.ToString());
571+
return nullptr;
577572
}
578573

579574
// Create a TCP socket in the address family of the specified service.
580575
SOCKET hSocket = socket(((struct sockaddr*)&sockaddr)->sa_family, SOCK_STREAM, IPPROTO_TCP);
581-
if (hSocket == INVALID_SOCKET)
582-
return INVALID_SOCKET;
576+
if (hSocket == INVALID_SOCKET) {
577+
return nullptr;
578+
}
583579

584580
// Ensure that waiting for I/O on this socket won't result in undefined
585581
// behavior.
586582
if (!IsSelectableSocket(hSocket)) {
587583
CloseSocket(hSocket);
588584
LogPrintf("Cannot create connection: non-selectable socket created (fd >= FD_SETSIZE ?)\n");
589-
return INVALID_SOCKET;
585+
return nullptr;
590586
}
591587

592588
#ifdef SO_NOSIGPIPE
@@ -602,11 +598,14 @@ SOCKET CreateSocket(const CService &addrConnect)
602598
// Set the non-blocking option on the socket.
603599
if (!SetSocketNonBlocking(hSocket, true)) {
604600
CloseSocket(hSocket);
605-
LogPrintf("CreateSocket: Setting socket to non-blocking failed, error %s\n", NetworkErrorString(WSAGetLastError()));
601+
LogPrintf("Error setting socket to non-blocking: %s\n", NetworkErrorString(WSAGetLastError()));
602+
return nullptr;
606603
}
607-
return hSocket;
604+
return std::make_unique<Sock>(hSocket);
608605
}
609606

607+
std::function<std::unique_ptr<Sock>(const CService&)> CreateSock = CreateSockTCP;
608+
610609
template<typename... Args>
611610
static void LogConnectFailure(bool manual_connection, const char* fmt, const Args&... args) {
612611
std::string error_message = tfm::format(fmt, args...);

src/netbase.h

Lines changed: 16 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -12,7 +12,10 @@
1212
#include <compat.h>
1313
#include <netaddress.h>
1414
#include <serialize.h>
15+
#include <util/sock.h>
1516

17+
#include <functional>
18+
#include <memory>
1619
#include <stdint.h>
1720
#include <string>
1821
#include <vector>
@@ -51,7 +54,19 @@ bool Lookup(const std::string& name, CService& addr, int portDefault, bool fAllo
5154
bool Lookup(const std::string& name, std::vector<CService>& vAddr, int portDefault, bool fAllowLookup, unsigned int nMaxSolutions);
5255
CService LookupNumeric(const std::string& name, int portDefault = 0);
5356
bool LookupSubNet(const std::string& strSubnet, CSubNet& subnet);
54-
SOCKET CreateSocket(const CService &addrConnect);
57+
58+
/**
59+
* Create a TCP socket in the given address family.
60+
* @param[in] address_family The socket is created in the same address family as this address.
61+
* @return pointer to the created Sock object or unique_ptr that owns nothing in case of failure
62+
*/
63+
std::unique_ptr<Sock> CreateSockTCP(const CService& address_family);
64+
65+
/**
66+
* Socket factory. Defaults to `CreateSockTCP()`, but can be overridden by unit tests.
67+
*/
68+
extern std::function<std::unique_ptr<Sock>(const CService&)> CreateSock;
69+
5570
bool ConnectSocketDirectly(const CService &addrConnect, const SOCKET& hSocketRet, int nTimeout, bool manual_connection);
5671
bool ConnectThroughProxy(const proxyType &proxy, const std::string& strDest, int port, const SOCKET& hSocketRet, int nTimeout, bool& outProxyConnectionFailed);
5772
/** Disable or enable blocking-mode for a socket */

src/util/sock.cpp

Lines changed: 85 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -6,12 +6,97 @@
66
#include <logging.h>
77
#include <tinyformat.h>
88
#include <util/sock.h>
9+
#include <util/system.h>
10+
#include <util/time.h>
911

1012
#include <codecvt>
1113
#include <cwchar>
1214
#include <locale>
1315
#include <string>
1416

17+
#ifdef USE_POLL
18+
#include <poll.h>
19+
#endif
20+
21+
Sock::Sock() : m_socket(INVALID_SOCKET) {}
22+
23+
Sock::Sock(SOCKET s) : m_socket(s) {}
24+
25+
Sock::Sock(Sock&& other)
26+
{
27+
m_socket = other.m_socket;
28+
other.m_socket = INVALID_SOCKET;
29+
}
30+
31+
Sock::~Sock() { Reset(); }
32+
33+
Sock& Sock::operator=(Sock&& other)
34+
{
35+
Reset();
36+
m_socket = other.m_socket;
37+
other.m_socket = INVALID_SOCKET;
38+
return *this;
39+
}
40+
41+
SOCKET Sock::Get() const { return m_socket; }
42+
43+
SOCKET Sock::Release()
44+
{
45+
const SOCKET s = m_socket;
46+
m_socket = INVALID_SOCKET;
47+
return s;
48+
}
49+
50+
void Sock::Reset() { CloseSocket(m_socket); }
51+
52+
ssize_t Sock::Send(const void* data, size_t len, int flags) const
53+
{
54+
return send(m_socket, static_cast<const char*>(data), len, flags);
55+
}
56+
57+
ssize_t Sock::Recv(void* buf, size_t len, int flags) const
58+
{
59+
return recv(m_socket, static_cast<char*>(buf), len, flags);
60+
}
61+
62+
bool Sock::Wait(std::chrono::milliseconds timeout, Event requested) const
63+
{
64+
#ifdef USE_POLL
65+
pollfd fd;
66+
fd.fd = m_socket;
67+
fd.events = 0;
68+
if (requested & RECV) {
69+
fd.events |= POLLIN;
70+
}
71+
if (requested & SEND) {
72+
fd.events |= POLLOUT;
73+
}
74+
75+
return poll(&fd, 1, count_milliseconds(timeout)) != SOCKET_ERROR;
76+
#else
77+
if (!IsSelectableSocket(m_socket)) {
78+
return false;
79+
}
80+
81+
fd_set fdset_recv;
82+
fd_set fdset_send;
83+
FD_ZERO(&fdset_recv);
84+
FD_ZERO(&fdset_send);
85+
86+
if (requested & RECV) {
87+
FD_SET(m_socket, &fdset_recv);
88+
}
89+
90+
if (requested & SEND) {
91+
FD_SET(m_socket, &fdset_send);
92+
}
93+
94+
timeval timeout_struct = MillisToTimeval(timeout);
95+
96+
return select(m_socket + 1, &fdset_recv, &fdset_send, nullptr, &timeout_struct) != SOCKET_ERROR;
97+
#endif /* USE_POLL */
98+
}
99+
15100
#ifdef WIN32
16101
std::string NetworkErrorString(int err)
17102
{

0 commit comments

Comments
 (0)