Skip to content

Commit c41a116

Browse files
committed
net: use Sock in CNode
Change `CNode` to use a pointer to `Sock` instead of a bare `SOCKET`. This will help mocking / testing / fuzzing more code.
1 parent c5dd72e commit c41a116

File tree

5 files changed

+166
-69
lines changed

5 files changed

+166
-69
lines changed

src/net.cpp

Lines changed: 23 additions & 25 deletions
Original file line numberDiff line numberDiff line change
@@ -505,7 +505,7 @@ CNode* CConnman::ConnectNode(CAddress addrConnect, const char *pszDest, bool fCo
505505
if (!addr_bind.IsValid()) {
506506
addr_bind = GetBindAddress(sock->Get());
507507
}
508-
CNode* pnode = new CNode(id, nLocalServices, sock->Release(), addrConnect, CalculateKeyedNetGroup(addrConnect), nonce, addr_bind, pszDest ? pszDest : "", conn_type, /* inbound_onion */ false);
508+
CNode* pnode = new CNode(id, nLocalServices, std::move(sock), addrConnect, CalculateKeyedNetGroup(addrConnect), nonce, addr_bind, pszDest ? pszDest : "", conn_type, /* inbound_onion */ false);
509509
pnode->AddRef();
510510

511511
// We're making a new connection, harvest entropy from the time (and our peer count)
@@ -518,10 +518,9 @@ void CNode::CloseSocketDisconnect()
518518
{
519519
fDisconnect = true;
520520
LOCK(cs_hSocket);
521-
if (hSocket != INVALID_SOCKET)
522-
{
521+
if (m_sock) {
523522
LogPrint(BCLog::NET, "disconnecting peer=%d\n", id);
524-
CloseSocket(hSocket);
523+
m_sock.reset();
525524
}
526525
}
527526

@@ -800,9 +799,10 @@ size_t CConnman::SocketSendData(CNode& node) const
800799
int nBytes = 0;
801800
{
802801
LOCK(node.cs_hSocket);
803-
if (node.hSocket == INVALID_SOCKET)
802+
if (!node.m_sock) {
804803
break;
805-
nBytes = send(node.hSocket, reinterpret_cast<const char*>(data.data()) + node.nSendOffset, data.size() - node.nSendOffset, MSG_NOSIGNAL | MSG_DONTWAIT);
804+
}
805+
nBytes = node.m_sock->Send(reinterpret_cast<const char*>(data.data()) + node.nSendOffset, data.size() - node.nSendOffset, MSG_NOSIGNAL | MSG_DONTWAIT);
806806
}
807807
if (nBytes > 0) {
808808
node.m_last_send = GetTime<std::chrono::seconds>();
@@ -1197,7 +1197,7 @@ void CConnman::CreateNodeFromAcceptedSocket(std::unique_ptr<Sock>&& sock,
11971197
}
11981198

11991199
const bool inbound_onion = std::find(m_onion_binds.begin(), m_onion_binds.end(), addr_bind) != m_onion_binds.end();
1200-
CNode* pnode = new CNode(id, nodeServices, sock->Release(), addr, CalculateKeyedNetGroup(addr), nonce, addr_bind, "", ConnectionType::INBOUND, inbound_onion);
1200+
CNode* pnode = new CNode(id, nodeServices, std::move(sock), addr, CalculateKeyedNetGroup(addr), nonce, addr_bind, "", ConnectionType::INBOUND, inbound_onion);
12011201
pnode->AddRef();
12021202
pnode->m_permissionFlags = permissionFlags;
12031203
pnode->m_prefer_evict = discouraged;
@@ -1382,16 +1382,17 @@ bool CConnman::GenerateSelectSet(const std::vector<CNode*>& nodes,
13821382
}
13831383

13841384
LOCK(pnode->cs_hSocket);
1385-
if (pnode->hSocket == INVALID_SOCKET)
1385+
if (!pnode->m_sock) {
13861386
continue;
1387+
}
13871388

1388-
error_set.insert(pnode->hSocket);
1389+
error_set.insert(pnode->m_sock->Get());
13891390
if (select_send) {
1390-
send_set.insert(pnode->hSocket);
1391+
send_set.insert(pnode->m_sock->Get());
13911392
continue;
13921393
}
13931394
if (select_recv) {
1394-
recv_set.insert(pnode->hSocket);
1395+
recv_set.insert(pnode->m_sock->Get());
13951396
}
13961397
}
13971398

@@ -1562,11 +1563,12 @@ void CConnman::SocketHandlerConnected(const std::vector<CNode*>& nodes,
15621563
bool errorSet = false;
15631564
{
15641565
LOCK(pnode->cs_hSocket);
1565-
if (pnode->hSocket == INVALID_SOCKET)
1566+
if (!pnode->m_sock) {
15661567
continue;
1567-
recvSet = recv_set.count(pnode->hSocket) > 0;
1568-
sendSet = send_set.count(pnode->hSocket) > 0;
1569-
errorSet = error_set.count(pnode->hSocket) > 0;
1568+
}
1569+
recvSet = recv_set.count(pnode->m_sock->Get()) > 0;
1570+
sendSet = send_set.count(pnode->m_sock->Get()) > 0;
1571+
errorSet = error_set.count(pnode->m_sock->Get()) > 0;
15701572
}
15711573
if (recvSet || errorSet)
15721574
{
@@ -1575,9 +1577,10 @@ void CConnman::SocketHandlerConnected(const std::vector<CNode*>& nodes,
15751577
int nBytes = 0;
15761578
{
15771579
LOCK(pnode->cs_hSocket);
1578-
if (pnode->hSocket == INVALID_SOCKET)
1580+
if (!pnode->m_sock) {
15791581
continue;
1580-
nBytes = recv(pnode->hSocket, (char*)pchBuf, sizeof(pchBuf), MSG_DONTWAIT);
1582+
}
1583+
nBytes = pnode->m_sock->Recv(pchBuf, sizeof(pchBuf), MSG_DONTWAIT);
15811584
}
15821585
if (nBytes > 0)
15831586
{
@@ -2962,8 +2965,9 @@ ServiceFlags CConnman::GetLocalServices() const
29622965

29632966
unsigned int CConnman::GetReceiveFloodSize() const { return nReceiveFloodSize; }
29642967

2965-
CNode::CNode(NodeId idIn, ServiceFlags nLocalServicesIn, SOCKET hSocketIn, const CAddress& addrIn, uint64_t nKeyedNetGroupIn, uint64_t nLocalHostNonceIn, const CAddress& addrBindIn, const std::string& addrNameIn, ConnectionType conn_type_in, bool inbound_onion)
2966-
: m_connected{GetTime<std::chrono::seconds>()},
2968+
CNode::CNode(NodeId idIn, ServiceFlags nLocalServicesIn, std::shared_ptr<Sock> sock, const CAddress& addrIn, uint64_t nKeyedNetGroupIn, uint64_t nLocalHostNonceIn, const CAddress& addrBindIn, const std::string& addrNameIn, ConnectionType conn_type_in, bool inbound_onion)
2969+
: m_sock{sock},
2970+
m_connected{GetTime<std::chrono::seconds>()},
29672971
addr(addrIn),
29682972
addrBind(addrBindIn),
29692973
m_addr_name{addrNameIn.empty() ? addr.ToStringIPPort() : addrNameIn},
@@ -2975,7 +2979,6 @@ CNode::CNode(NodeId idIn, ServiceFlags nLocalServicesIn, SOCKET hSocketIn, const
29752979
nLocalServices(nLocalServicesIn)
29762980
{
29772981
if (inbound_onion) assert(conn_type_in == ConnectionType::INBOUND);
2978-
hSocket = hSocketIn;
29792982
if (conn_type_in != ConnectionType::BLOCK_RELAY) {
29802983
m_tx_relay = std::make_unique<TxRelay>();
29812984
}
@@ -2994,11 +2997,6 @@ CNode::CNode(NodeId idIn, ServiceFlags nLocalServicesIn, SOCKET hSocketIn, const
29942997
m_serializer = std::make_unique<V1TransportSerializer>(V1TransportSerializer());
29952998
}
29962999

2997-
CNode::~CNode()
2998-
{
2999-
CloseSocket(hSocket);
3000-
}
3001-
30023000
bool CConnman::NodeFullyConnected(const CNode* pnode)
30033001
{
30043002
return pnode && pnode->fSuccessfullyConnected && !pnode->fDisconnect;

src/net.h

Lines changed: 12 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -402,7 +402,17 @@ class CNode
402402

403403
NetPermissionFlags m_permissionFlags{NetPermissionFlags::None};
404404
std::atomic<ServiceFlags> nServices{NODE_NONE};
405-
SOCKET hSocket GUARDED_BY(cs_hSocket);
405+
406+
/**
407+
* Socket used for communication with the node.
408+
* May not own a Sock object (after `CloseSocketDisconnect()` or during tests).
409+
* `shared_ptr` (instead of `unique_ptr`) is used to avoid premature close of
410+
* the underlying file descriptor by one thread while another thread is
411+
* poll(2)-ing it for activity.
412+
* @see https://github.com/bitcoin/bitcoin/issues/21744 for details.
413+
*/
414+
std::shared_ptr<Sock> m_sock GUARDED_BY(cs_hSocket);
415+
406416
/** Total size of all vSendMsg entries */
407417
size_t nSendSize GUARDED_BY(cs_vSend){0};
408418
/** Offset inside the first vSendMsg already sent */
@@ -578,8 +588,7 @@ class CNode
578588
* criterium in CConnman::AttemptToEvictConnection. */
579589
std::atomic<std::chrono::microseconds> m_min_ping_time{std::chrono::microseconds::max()};
580590

581-
CNode(NodeId id, ServiceFlags nLocalServicesIn, SOCKET hSocketIn, const CAddress& addrIn, uint64_t nKeyedNetGroupIn, uint64_t nLocalHostNonceIn, const CAddress& addrBindIn, const std::string& addrNameIn, ConnectionType conn_type_in, bool inbound_onion);
582-
~CNode();
591+
CNode(NodeId id, ServiceFlags nLocalServicesIn, std::shared_ptr<Sock> sock, const CAddress& addrIn, uint64_t nKeyedNetGroupIn, uint64_t nLocalHostNonceIn, const CAddress& addrBindIn, const std::string& addrNameIn, ConnectionType conn_type_in, bool inbound_onion);
583592
CNode(const CNode&) = delete;
584593
CNode& operator=(const CNode&) = delete;
585594

src/test/denialofservice_tests.cpp

Lines changed: 60 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -59,7 +59,16 @@ BOOST_AUTO_TEST_CASE(outbound_slow_chain_eviction)
5959

6060
// Mock an outbound peer
6161
CAddress addr1(ip(0xa0b0c001), NODE_NONE);
62-
CNode dummyNode1(id++, ServiceFlags(NODE_NETWORK | NODE_WITNESS), INVALID_SOCKET, addr1, /*nKeyedNetGroupIn=*/0, /*nLocalHostNonceIn=*/0, CAddress(), /*addrNameIn=*/"", ConnectionType::OUTBOUND_FULL_RELAY, /*inbound_onion=*/false);
62+
CNode dummyNode1{id++,
63+
ServiceFlags(NODE_NETWORK | NODE_WITNESS),
64+
/*sock=*/nullptr,
65+
addr1,
66+
/*nKeyedNetGroupIn=*/0,
67+
/*nLocalHostNonceIn=*/0,
68+
CAddress(),
69+
/*addrNameIn=*/"",
70+
ConnectionType::OUTBOUND_FULL_RELAY,
71+
/*inbound_onion=*/false};
6372
dummyNode1.SetCommonVersion(PROTOCOL_VERSION);
6473

6574
peerLogic->InitializeNode(&dummyNode1);
@@ -108,7 +117,16 @@ BOOST_AUTO_TEST_CASE(outbound_slow_chain_eviction)
108117
static void AddRandomOutboundPeer(std::vector<CNode*>& vNodes, PeerManager& peerLogic, ConnmanTestMsg& connman, ConnectionType connType)
109118
{
110119
CAddress addr(ip(g_insecure_rand_ctx.randbits(32)), NODE_NONE);
111-
vNodes.emplace_back(new CNode(id++, ServiceFlags(NODE_NETWORK | NODE_WITNESS), INVALID_SOCKET, addr, /*nKeyedNetGroupIn=*/0, /*nLocalHostNonceIn=*/0, CAddress(), /*addrNameIn=*/"", connType, /*inbound_onion=*/false));
120+
vNodes.emplace_back(new CNode{id++,
121+
ServiceFlags(NODE_NETWORK | NODE_WITNESS),
122+
/*sock=*/nullptr,
123+
addr,
124+
/*nKeyedNetGroupIn=*/0,
125+
/*nLocalHostNonceIn=*/0,
126+
CAddress(),
127+
/*addrNameIn=*/"",
128+
connType,
129+
/*inbound_onion=*/false});
112130
CNode &node = *vNodes.back();
113131
node.SetCommonVersion(PROTOCOL_VERSION);
114132

@@ -279,9 +297,16 @@ BOOST_AUTO_TEST_CASE(peer_discouragement)
279297
std::array<CNode*, 3> nodes;
280298

281299
banman->ClearBanned();
282-
nodes[0] = new CNode{id++, NODE_NETWORK, INVALID_SOCKET, addr[0], /*nKeyedNetGroupIn=*/0,
283-
/*nLocalHostNonceIn=*/0, CAddress(), /*addrNameIn=*/"",
284-
ConnectionType::INBOUND, /*inbound_onion=*/false};
300+
nodes[0] = new CNode{id++,
301+
NODE_NETWORK,
302+
/*sock=*/nullptr,
303+
addr[0],
304+
/*nKeyedNetGroupIn=*/0,
305+
/*nLocalHostNonceIn=*/0,
306+
CAddress(),
307+
/*addrNameIn=*/"",
308+
ConnectionType::INBOUND,
309+
/*inbound_onion=*/false};
285310
nodes[0]->SetCommonVersion(PROTOCOL_VERSION);
286311
peerLogic->InitializeNode(nodes[0]);
287312
nodes[0]->fSuccessfullyConnected = true;
@@ -295,9 +320,16 @@ BOOST_AUTO_TEST_CASE(peer_discouragement)
295320
BOOST_CHECK(nodes[0]->fDisconnect);
296321
BOOST_CHECK(!banman->IsDiscouraged(other_addr)); // Different address, not discouraged
297322

298-
nodes[1] = new CNode{id++, NODE_NETWORK, INVALID_SOCKET, addr[1], /*nKeyedNetGroupIn=*/1,
299-
/*nLocalHostNonceIn=*/1, CAddress(), /*addrNameIn=*/"",
300-
ConnectionType::INBOUND, /*inbound_onion=*/false};
323+
nodes[1] = new CNode{id++,
324+
NODE_NETWORK,
325+
/*sock=*/nullptr,
326+
addr[1],
327+
/*nKeyedNetGroupIn=*/1,
328+
/*nLocalHostNonceIn=*/1,
329+
CAddress(),
330+
/*addrNameIn=*/"",
331+
ConnectionType::INBOUND,
332+
/*inbound_onion=*/false};
301333
nodes[1]->SetCommonVersion(PROTOCOL_VERSION);
302334
peerLogic->InitializeNode(nodes[1]);
303335
nodes[1]->fSuccessfullyConnected = true;
@@ -326,9 +358,16 @@ BOOST_AUTO_TEST_CASE(peer_discouragement)
326358

327359
// Make sure non-IP peers are discouraged and disconnected properly.
328360

329-
nodes[2] = new CNode{id++, NODE_NETWORK, INVALID_SOCKET, addr[2], /*nKeyedNetGroupIn=*/1,
330-
/*nLocalHostNonceIn=*/1, CAddress(), /*addrNameIn=*/"",
331-
ConnectionType::OUTBOUND_FULL_RELAY, /*inbound_onion=*/false};
361+
nodes[2] = new CNode{id++,
362+
NODE_NETWORK,
363+
/*sock=*/nullptr,
364+
addr[2],
365+
/*nKeyedNetGroupIn=*/1,
366+
/*nLocalHostNonceIn=*/1,
367+
CAddress(),
368+
/*addrNameIn=*/"",
369+
ConnectionType::OUTBOUND_FULL_RELAY,
370+
/*inbound_onion=*/false};
332371
nodes[2]->SetCommonVersion(PROTOCOL_VERSION);
333372
peerLogic->InitializeNode(nodes[2]);
334373
nodes[2]->fSuccessfullyConnected = true;
@@ -364,7 +403,16 @@ BOOST_AUTO_TEST_CASE(DoS_bantime)
364403
SetMockTime(nStartTime); // Overrides future calls to GetTime()
365404

366405
CAddress addr(ip(0xa0b0c001), NODE_NONE);
367-
CNode dummyNode(id++, NODE_NETWORK, INVALID_SOCKET, addr, /*nKeyedNetGroupIn=*/4, /*nLocalHostNonceIn=*/4, CAddress(), /*addrNameIn=*/"", ConnectionType::INBOUND, /*inbound_onion=*/false);
406+
CNode dummyNode{id++,
407+
NODE_NETWORK,
408+
/*sock=*/nullptr,
409+
addr,
410+
/*nKeyedNetGroupIn=*/4,
411+
/*nLocalHostNonceIn=*/4,
412+
CAddress(),
413+
/*addrNameIn=*/"",
414+
ConnectionType::INBOUND,
415+
/*inbound_onion=*/false};
368416
dummyNode.SetCommonVersion(PROTOCOL_VERSION);
369417
peerLogic->InitializeNode(&dummyNode);
370418
dummyNode.fSuccessfullyConnected = true;

src/test/fuzz/util.h

Lines changed: 21 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -290,7 +290,7 @@ auto ConsumeNode(FuzzedDataProvider& fuzzed_data_provider, const std::optional<N
290290
{
291291
const NodeId node_id = node_id_in.value_or(fuzzed_data_provider.ConsumeIntegralInRange<NodeId>(0, std::numeric_limits<NodeId>::max()));
292292
const ServiceFlags local_services = ConsumeWeakEnum(fuzzed_data_provider, ALL_SERVICE_FLAGS);
293-
const SOCKET socket = INVALID_SOCKET;
293+
const auto sock = std::make_shared<FuzzedSock>(fuzzed_data_provider);
294294
const CAddress address = ConsumeAddress(fuzzed_data_provider);
295295
const uint64_t keyed_net_group = fuzzed_data_provider.ConsumeIntegral<uint64_t>();
296296
const uint64_t local_host_nonce = fuzzed_data_provider.ConsumeIntegral<uint64_t>();
@@ -299,9 +299,27 @@ auto ConsumeNode(FuzzedDataProvider& fuzzed_data_provider, const std::optional<N
299299
const ConnectionType conn_type = fuzzed_data_provider.PickValueInArray(ALL_CONNECTION_TYPES);
300300
const bool inbound_onion{conn_type == ConnectionType::INBOUND ? fuzzed_data_provider.ConsumeBool() : false};
301301
if constexpr (ReturnUniquePtr) {
302-
return std::make_unique<CNode>(node_id, local_services, socket, address, keyed_net_group, local_host_nonce, addr_bind, addr_name, conn_type, inbound_onion);
302+
return std::make_unique<CNode>(node_id,
303+
local_services,
304+
sock,
305+
address,
306+
keyed_net_group,
307+
local_host_nonce,
308+
addr_bind,
309+
addr_name,
310+
conn_type,
311+
inbound_onion);
303312
} else {
304-
return CNode{node_id, local_services, socket, address, keyed_net_group, local_host_nonce, addr_bind, addr_name, conn_type, inbound_onion};
313+
return CNode{node_id,
314+
local_services,
315+
sock,
316+
address,
317+
keyed_net_group,
318+
local_host_nonce,
319+
addr_bind,
320+
addr_name,
321+
conn_type,
322+
inbound_onion};
305323
}
306324
}
307325
inline std::unique_ptr<CNode> ConsumeNodeAsUniquePtr(FuzzedDataProvider& fdp, const std::optional<NodeId>& node_id_in = std::nullopt) { return ConsumeNode<true>(fdp, node_id_in); }

0 commit comments

Comments
 (0)