Skip to content

Commit 508044c

Browse files
committed
merge bitcoin#21879: wrap accept() and extend usage of Sock
1 parent 2f93ee4 commit 508044c

File tree

7 files changed

+102
-30
lines changed

7 files changed

+102
-30
lines changed

src/net.cpp

Lines changed: 21 additions & 26 deletions
Original file line numberDiff line numberDiff line change
@@ -1220,9 +1220,10 @@ bool CConnman::AttemptToEvictConnection()
12201220
void CConnman::AcceptConnection(const ListenSocket& hListenSocket, CMasternodeSync& mn_sync) {
12211221
struct sockaddr_storage sockaddr;
12221222
socklen_t len = sizeof(sockaddr);
1223-
SOCKET hSocket = accept(hListenSocket.socket, (struct sockaddr*)&sockaddr, &len);
1223+
auto sock = hListenSocket.sock->Accept((struct sockaddr*)&sockaddr, &len);
12241224
CAddress addr;
1225-
if (hSocket == INVALID_SOCKET) {
1225+
1226+
if (!sock) {
12261227
const int nErr = WSAGetLastError();
12271228
if (nErr != WSAEWOULDBLOCK) {
12281229
LogPrintf("socket error accept failed: %s\n", NetworkErrorString(nErr));
@@ -1236,15 +1237,15 @@ void CConnman::AcceptConnection(const ListenSocket& hListenSocket, CMasternodeSy
12361237
addr = CAddress{MaybeFlipIPv6toCJDNS(addr), NODE_NONE};
12371238
}
12381239

1239-
const CAddress addr_bind{MaybeFlipIPv6toCJDNS(GetBindAddress(hSocket)), NODE_NONE};
1240+
const CAddress addr_bind{MaybeFlipIPv6toCJDNS(GetBindAddress(sock->Get())), NODE_NONE};
12401241

12411242
NetPermissionFlags permissionFlags = NetPermissionFlags::None;
12421243
hListenSocket.AddSocketPermissionFlags(permissionFlags);
12431244

1244-
CreateNodeFromAcceptedSocket(hSocket, permissionFlags, addr_bind, addr, mn_sync);
1245+
CreateNodeFromAcceptedSocket(std::move(sock), permissionFlags, addr_bind, addr, mn_sync);
12451246
}
12461247

1247-
void CConnman::CreateNodeFromAcceptedSocket(SOCKET hSocket,
1248+
void CConnman::CreateNodeFromAcceptedSocket(std::unique_ptr<Sock>&& sock,
12481249
NetPermissionFlags permissionFlags,
12491250
const CAddress& addr_bind,
12501251
const CAddress& addr,
@@ -1287,27 +1288,24 @@ void CConnman::CreateNodeFromAcceptedSocket(SOCKET hSocket,
12871288

12881289
if (!fNetworkActive) {
12891290
LogPrint(BCLog::NET_NETCONN, "%s: not accepting new connections\n", strDropped);
1290-
CloseSocket(hSocket);
12911291
return;
12921292
}
12931293

1294-
if (!IsSelectableSocket(hSocket))
1294+
if (!IsSelectableSocket(sock->Get()))
12951295
{
12961296
LogPrintf("%s: non-selectable socket\n", strDropped);
1297-
CloseSocket(hSocket);
12981297
return;
12991298
}
13001299

13011300
// According to the internet TCP_NODELAY is not carried into accepted sockets
13021301
// on all platforms. Set it again here just to be sure.
1303-
SetSocketNoDelay(hSocket);
1302+
SetSocketNoDelay(sock->Get());
13041303

13051304
// Don't accept connections from banned peers.
13061305
bool banned = m_banman && m_banman->IsBanned(addr);
13071306
if (!NetPermissions::HasFlag(permissionFlags, NetPermissionFlags::NoBan) && banned)
13081307
{
13091308
LogPrint(BCLog::NET, "%s (banned)\n", strDropped);
1310-
CloseSocket(hSocket);
13111309
return;
13121310
}
13131311

@@ -1316,7 +1314,6 @@ void CConnman::CreateNodeFromAcceptedSocket(SOCKET hSocket,
13161314
if (!NetPermissions::HasFlag(permissionFlags, NetPermissionFlags::NoBan) && nInbound + 1 >= nMaxInbound && discouraged)
13171315
{
13181316
LogPrint(BCLog::NET, "connection from %s dropped (discouraged)\n", addr.ToString());
1319-
CloseSocket(hSocket);
13201317
return;
13211318
}
13221319

@@ -1330,7 +1327,6 @@ void CConnman::CreateNodeFromAcceptedSocket(SOCKET hSocket,
13301327
if (!AttemptToEvictConnection()) {
13311328
// No connection to evict, disconnect the new connection
13321329
LogPrint(BCLog::NET, "failed to find an eviction candidate - connection dropped (full)\n");
1333-
CloseSocket(hSocket);
13341330
return;
13351331
}
13361332
nInbound--;
@@ -1339,7 +1335,6 @@ void CConnman::CreateNodeFromAcceptedSocket(SOCKET hSocket,
13391335
// don't accept incoming connections until blockchain is synced
13401336
if (fMasternodeMode && !mn_sync.IsBlockchainSynced()) {
13411337
LogPrint(BCLog::NET, "AcceptConnection -- blockchain is not synced yet, skipping inbound connection attempt\n");
1342-
CloseSocket(hSocket);
13431338
return;
13441339
}
13451340

@@ -1352,7 +1347,7 @@ void CConnman::CreateNodeFromAcceptedSocket(SOCKET hSocket,
13521347
}
13531348

13541349
const bool inbound_onion = std::find(m_onion_binds.begin(), m_onion_binds.end(), addr_bind) != m_onion_binds.end();
1355-
CNode* pnode = new CNode(id, nodeServices, hSocket, addr, CalculateKeyedNetGroup(addr), nonce, addr_bind, "", ConnectionType::INBOUND, inbound_onion);
1350+
CNode* pnode = new CNode(id, nodeServices, sock->Release(), addr, CalculateKeyedNetGroup(addr), nonce, addr_bind, "", ConnectionType::INBOUND, inbound_onion);
13561351
pnode->AddRef();
13571352
pnode->m_permissionFlags = permissionFlags;
13581353
// If this flag is present, the user probably expect that RPC and QT report it as whitelisted (backward compatibility)
@@ -1361,17 +1356,19 @@ void CConnman::CreateNodeFromAcceptedSocket(SOCKET hSocket,
13611356
m_msgproc->InitializeNode(pnode);
13621357

13631358
if (fLogIPs) {
1364-
LogPrint(BCLog::NET_NETCONN, "connection from %s accepted, sock=%d, peer=%d\n", addr.ToString(), hSocket, pnode->GetId());
1359+
LogPrint(BCLog::NET_NETCONN, "connection from %s accepted, sock=%d, peer=%d\n", addr.ToString(), sock->Get(), pnode->GetId());
13651360
} else {
1366-
LogPrint(BCLog::NET_NETCONN, "connection accepted, sock=%d, peer=%d\n", hSocket, pnode->GetId());
1361+
LogPrint(BCLog::NET_NETCONN, "connection accepted, sock=%d, peer=%d\n", sock->Get(), pnode->GetId());
13671362
}
13681363

13691364
{
13701365
LOCK(m_nodes_mutex);
13711366
m_nodes.push_back(pnode);
1372-
WITH_LOCK(cs_mapSocketToNode, mapSocketToNode.emplace(hSocket, pnode));
1367+
}
1368+
{
1369+
LOCK(pnode->cs_hSocket);
1370+
WITH_LOCK(cs_mapSocketToNode, mapSocketToNode.emplace(pnode->hSocket, pnode));
13731371
if (m_edge_trig_events) {
1374-
LOCK(pnode->cs_hSocket);
13751372
if (!m_edge_trig_events->RegisterEvents(pnode->hSocket)) {
13761373
LogPrint(BCLog::NET, "EdgeTriggeredEvents::RegisterEvents() failed\n");
13771374
}
@@ -1656,7 +1653,7 @@ bool CConnman::GenerateSelectSet(const std::vector<CNode*>& nodes,
16561653
std::set<SOCKET>& error_set)
16571654
{
16581655
for (const ListenSocket& hListenSocket : vhListenSocket) {
1659-
recv_set.insert(hListenSocket.socket);
1656+
recv_set.insert(hListenSocket.sock->Get());
16601657
}
16611658

16621659
for (CNode* pnode : nodes)
@@ -2128,7 +2125,7 @@ void CConnman::SocketHandlerListening(const std::set<SOCKET>& recv_set, CMastern
21282125
if (interruptNet) {
21292126
return;
21302127
}
2131-
if (recv_set.count(listen_socket.socket) > 0) {
2128+
if (recv_set.count(listen_socket.sock->Get()) > 0) {
21322129
AcceptConnection(listen_socket, mn_sync);
21332130
}
21342131
}
@@ -3168,7 +3165,7 @@ void CConnman::ThreadI2PAcceptIncoming(CMasternodeSync& mn_sync)
31683165
continue;
31693166
}
31703167

3171-
CreateNodeFromAcceptedSocket(conn.sock->Release(), NetPermissionFlags::None,
3168+
CreateNodeFromAcceptedSocket(std::move(conn.sock), NetPermissionFlags::None,
31723169
CAddress{conn.me, NODE_NONE}, CAddress{conn.peer, NODE_NONE}, mn_sync);
31733170
}
31743171
}
@@ -3235,7 +3232,7 @@ bool CConnman::BindListenPort(const CService& addrBind, bilingual_str& strError,
32353232
return false;
32363233
}
32373234

3238-
vhListenSocket.push_back(ListenSocket(sock->Release(), permissions));
3235+
vhListenSocket.emplace_back(std::move(sock), permissions);
32393236

32403237
return true;
32413238
}
@@ -3582,12 +3579,10 @@ void CConnman::StopNodes()
35823579
pnode->CloseSocketDisconnect(this);
35833580
}
35843581
for (ListenSocket& hListenSocket : vhListenSocket) {
3585-
if (hListenSocket.socket != INVALID_SOCKET) {
3586-
if (m_edge_trig_events && !m_edge_trig_events->RemoveSocket(hListenSocket.socket)) {
3582+
if (hListenSocket.sock->Get() != INVALID_SOCKET) {
3583+
if (m_edge_trig_events && !m_edge_trig_events->RemoveSocket(hListenSocket.sock->Get())) {
35873584
LogPrintf("EdgeTriggeredEvents::RemoveSocket() failed\n");
35883585
}
3589-
if (!CloseSocket(hListenSocket.socket))
3590-
LogPrintf("CloseSocket(hListenSocket) failed with error %s\n", NetworkErrorString(WSAGetLastError()));
35913586
}
35923587
}
35933588

src/net.h

Lines changed: 9 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -29,6 +29,7 @@
2929
#include <uint256.h>
3030
#include <util/check.h>
3131
#include <util/edge.h>
32+
#include <util/sock.h>
3233
#include <util/system.h>
3334
#include <util/wpipe.h>
3435
#include <consensus/params.h>
@@ -1221,9 +1222,13 @@ friend class CNode;
12211222
private:
12221223
struct ListenSocket {
12231224
public:
1224-
SOCKET socket;
1225+
std::shared_ptr<Sock> sock;
12251226
inline void AddSocketPermissionFlags(NetPermissionFlags& flags) const { NetPermissions::AddFlag(flags, m_permissions); }
1226-
ListenSocket(SOCKET socket_, NetPermissionFlags permissions_) : socket(socket_), m_permissions(permissions_) {}
1227+
ListenSocket(std::shared_ptr<Sock> sock_, NetPermissionFlags permissions_)
1228+
: sock{sock_}, m_permissions{permissions_}
1229+
{
1230+
}
1231+
12271232
private:
12281233
NetPermissionFlags m_permissions;
12291234
};
@@ -1251,12 +1256,12 @@ friend class CNode;
12511256
/**
12521257
* Create a `CNode` object from a socket that has just been accepted and add the node to
12531258
* the `m_nodes` member.
1254-
* @param[in] hSocket Connected socket to communicate with the peer.
1259+
* @param[in] sock Connected socket to communicate with the peer.
12551260
* @param[in] permissionFlags The peer's permissions.
12561261
* @param[in] addr_bind The address and port at our side of the connection.
12571262
* @param[in] addr The address and port at the peer's side of the connection.
12581263
*/
1259-
void CreateNodeFromAcceptedSocket(SOCKET hSocket,
1264+
void CreateNodeFromAcceptedSocket(std::unique_ptr<Sock>&& sock,
12601265
NetPermissionFlags permissionFlags,
12611266
const CAddress& addr_bind,
12621267
const CAddress& addr,

src/test/fuzz/util.cpp

Lines changed: 16 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -10,6 +10,8 @@
1010
#include <util/time.h>
1111
#include <version.h>
1212

13+
#include <memory>
14+
1315
FuzzedSock::FuzzedSock(FuzzedDataProvider& fuzzed_data_provider)
1416
: m_fuzzed_data_provider{fuzzed_data_provider}
1517
{
@@ -155,6 +157,20 @@ int FuzzedSock::Connect(const sockaddr*, socklen_t) const
155157
return 0;
156158
}
157159

160+
std::unique_ptr<Sock> FuzzedSock::Accept(sockaddr* addr, socklen_t* addr_len) const
161+
{
162+
constexpr std::array accept_errnos{
163+
ECONNABORTED,
164+
EINTR,
165+
ENOMEM,
166+
};
167+
if (m_fuzzed_data_provider.ConsumeBool()) {
168+
SetFuzzedErrNo(m_fuzzed_data_provider, accept_errnos);
169+
return std::unique_ptr<FuzzedSock>();
170+
}
171+
return std::make_unique<FuzzedSock>(m_fuzzed_data_provider);
172+
}
173+
158174
int FuzzedSock::GetSockOpt(int level, int opt_name, void* opt_val, socklen_t* opt_len) const
159175
{
160176
constexpr std::array getsockopt_errnos{

src/test/fuzz/util.h

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -560,6 +560,8 @@ class FuzzedSock : public Sock
560560

561561
int Connect(const sockaddr*, socklen_t) const override;
562562

563+
std::unique_ptr<Sock> Accept(sockaddr* addr, socklen_t* addr_len) const override;
564+
563565
int GetSockOpt(int level, int opt_name, void* opt_val, socklen_t* opt_len) const override;
564566

565567
bool Wait(std::chrono::milliseconds timeout, Event requested, Event* occurred = nullptr) const override;

src/test/util/net.h

Lines changed: 18 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -13,6 +13,7 @@
1313
#include <array>
1414
#include <cassert>
1515
#include <cstring>
16+
#include <memory>
1617
#include <string>
1718

1819
struct ConnmanTestMsg : public CConnman {
@@ -126,6 +127,23 @@ class StaticContentsSock : public Sock
126127

127128
int Connect(const sockaddr*, socklen_t) const override { return 0; }
128129

130+
std::unique_ptr<Sock> Accept(sockaddr* addr, socklen_t* addr_len) const override
131+
{
132+
if (addr != nullptr) {
133+
// Pretend all connections come from 5.5.5.5:6789
134+
memset(addr, 0x00, *addr_len);
135+
const socklen_t write_len = static_cast<socklen_t>(sizeof(sockaddr_in));
136+
if (*addr_len >= write_len) {
137+
*addr_len = write_len;
138+
sockaddr_in* addr_in = reinterpret_cast<sockaddr_in*>(addr);
139+
addr_in->sin_family = AF_INET;
140+
memset(&addr_in->sin_addr, 0x05, sizeof(addr_in->sin_addr));
141+
addr_in->sin_port = htons(6789);
142+
}
143+
}
144+
return std::make_unique<StaticContentsSock>("");
145+
};
146+
129147
int GetSockOpt(int level, int opt_name, void* opt_val, socklen_t* opt_len) const override
130148
{
131149
std::memset(opt_val, 0x0, *opt_len);

src/util/sock.cpp

Lines changed: 27 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -10,6 +10,7 @@
1010
#include <util/system.h>
1111
#include <util/time.h>
1212

13+
#include <memory>
1314
#include <stdexcept>
1415
#include <string>
1516

@@ -73,6 +74,32 @@ int Sock::Connect(const sockaddr* addr, socklen_t addr_len) const
7374
return connect(m_socket, addr, addr_len);
7475
}
7576

77+
std::unique_ptr<Sock> Sock::Accept(sockaddr* addr, socklen_t* addr_len) const
78+
{
79+
#ifdef WIN32
80+
static constexpr auto ERR = INVALID_SOCKET;
81+
#else
82+
static constexpr auto ERR = SOCKET_ERROR;
83+
#endif
84+
85+
std::unique_ptr<Sock> sock;
86+
87+
const auto socket = accept(m_socket, addr, addr_len);
88+
if (socket != ERR) {
89+
try {
90+
sock = std::make_unique<Sock>(socket);
91+
} catch (const std::exception&) {
92+
#ifdef WIN32
93+
closesocket(socket);
94+
#else
95+
close(socket);
96+
#endif
97+
}
98+
}
99+
100+
return sock;
101+
}
102+
76103
int Sock::GetSockOpt(int level, int opt_name, void* opt_val, socklen_t* opt_len) const
77104
{
78105
return getsockopt(m_socket, level, opt_name, static_cast<char*>(opt_val), opt_len);

src/util/sock.h

Lines changed: 9 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -10,6 +10,7 @@
1010
#include <util/time.h>
1111

1212
#include <chrono>
13+
#include <memory>
1314
#include <string>
1415

1516
/**
@@ -144,6 +145,14 @@ class Sock
144145
*/
145146
[[nodiscard]] virtual int Connect(const sockaddr* addr, socklen_t addr_len) const;
146147

148+
/**
149+
* accept(2) wrapper. Equivalent to `std::make_unique<Sock>(accept(this->Get(), addr, addr_len))`.
150+
* Code that uses this wrapper can be unit tested if this method is overridden by a mock Sock
151+
* implementation.
152+
* The returned unique_ptr is empty if `accept()` failed in which case errno will be set.
153+
*/
154+
[[nodiscard]] virtual std::unique_ptr<Sock> Accept(sockaddr* addr, socklen_t* addr_len) const;
155+
147156
/**
148157
* getsockopt(2) wrapper. Equivalent to
149158
* `getsockopt(this->Get(), level, opt_name, opt_val, opt_len)`. Code that uses this

0 commit comments

Comments
 (0)