Skip to content

Commit 8f1c28a

Browse files
committed
Merge bitcoin/bitcoin#21879: refactor: wrap accept() and extend usage of Sock
6bf6e9f net: change CreateNodeFromAcceptedSocket() to take Sock (Vasil Dimov) 9e3cbfc net: use Sock in CConnman::ListenSocket (Vasil Dimov) f8bd13f net: add new method Sock::Accept() that wraps accept() (Vasil Dimov) Pull request description: _This is a piece of bitcoin/bitcoin#21878, chopped off to ease review._ Introduce an `accept(2)` wrapper `Sock::Accept()` and extend the usage of `Sock` in `CConnman::ListenSocket` and `CreateNodeFromAcceptedSocket()`. ACKs for top commit: laanwj: Code review ACK 6bf6e9f jamesob: ACK 6bf6e9f ([`jamesob/ackr/21879.2.vasild.wrap_accept_and_extend_u`](https://github.com/jamesob/bitcoin/tree/ackr/21879.2.vasild.wrap_accept_and_extend_u)) jonatack: ACK 6bf6e9f per `git range-diff ea989de 976f6e8 6bf6e9f` -- only change since my last review was `s/listen_socket.socket/listen_socket.sock->Get()/` in `src/net.cpp: CConnman::SocketHandlerListening()` -- re-read the code changes, rebase/debug build/ran units following my previous full review (bitcoin/bitcoin#21879 (review)) w0xlt: tACK 6bf6e9f Tree-SHA512: dc6d1acc4f255f1f7e8cf6dd74e97975cf3d5959e9fc2e689f74812ac3526d5ee8b6a32eca605925d10a4f7b6ff1ce5e900344311e587d19786b48c54d021b64
2 parents 847cf76 + 6bf6e9f commit 8f1c28a

File tree

7 files changed

+93
-30
lines changed

7 files changed

+93
-30
lines changed

src/net.cpp

Lines changed: 12 additions & 26 deletions
Original file line numberDiff line numberDiff line change
@@ -1099,10 +1099,10 @@ bool CConnman::AttemptToEvictConnection()
10991099
void CConnman::AcceptConnection(const ListenSocket& hListenSocket) {
11001100
struct sockaddr_storage sockaddr;
11011101
socklen_t len = sizeof(sockaddr);
1102-
SOCKET hSocket = accept(hListenSocket.socket, (struct sockaddr*)&sockaddr, &len);
1102+
auto sock = hListenSocket.sock->Accept((struct sockaddr*)&sockaddr, &len);
11031103
CAddress addr;
11041104

1105-
if (hSocket == INVALID_SOCKET) {
1105+
if (!sock) {
11061106
const int nErr = WSAGetLastError();
11071107
if (nErr != WSAEWOULDBLOCK) {
11081108
LogPrintf("socket error accept failed: %s\n", NetworkErrorString(nErr));
@@ -1116,15 +1116,15 @@ void CConnman::AcceptConnection(const ListenSocket& hListenSocket) {
11161116
addr = CAddress{MaybeFlipIPv6toCJDNS(addr), NODE_NONE};
11171117
}
11181118

1119-
const CAddress addr_bind{MaybeFlipIPv6toCJDNS(GetBindAddress(hSocket)), NODE_NONE};
1119+
const CAddress addr_bind{MaybeFlipIPv6toCJDNS(GetBindAddress(sock->Get())), NODE_NONE};
11201120

11211121
NetPermissionFlags permissionFlags = NetPermissionFlags::None;
11221122
hListenSocket.AddSocketPermissionFlags(permissionFlags);
11231123

1124-
CreateNodeFromAcceptedSocket(hSocket, permissionFlags, addr_bind, addr);
1124+
CreateNodeFromAcceptedSocket(std::move(sock), permissionFlags, addr_bind, addr);
11251125
}
11261126

1127-
void CConnman::CreateNodeFromAcceptedSocket(SOCKET hSocket,
1127+
void CConnman::CreateNodeFromAcceptedSocket(std::unique_ptr<Sock>&& sock,
11281128
NetPermissionFlags permissionFlags,
11291129
const CAddress& addr_bind,
11301130
const CAddress& addr)
@@ -1150,27 +1150,24 @@ void CConnman::CreateNodeFromAcceptedSocket(SOCKET hSocket,
11501150

11511151
if (!fNetworkActive) {
11521152
LogPrint(BCLog::NET, "connection from %s dropped: not accepting new connections\n", addr.ToString());
1153-
CloseSocket(hSocket);
11541153
return;
11551154
}
11561155

1157-
if (!IsSelectableSocket(hSocket))
1156+
if (!IsSelectableSocket(sock->Get()))
11581157
{
11591158
LogPrintf("connection from %s dropped: non-selectable socket\n", addr.ToString());
1160-
CloseSocket(hSocket);
11611159
return;
11621160
}
11631161

11641162
// According to the internet TCP_NODELAY is not carried into accepted sockets
11651163
// on all platforms. Set it again here just to be sure.
1166-
SetSocketNoDelay(hSocket);
1164+
SetSocketNoDelay(sock->Get());
11671165

11681166
// Don't accept connections from banned peers.
11691167
bool banned = m_banman && m_banman->IsBanned(addr);
11701168
if (!NetPermissions::HasFlag(permissionFlags, NetPermissionFlags::NoBan) && banned)
11711169
{
11721170
LogPrint(BCLog::NET, "connection from %s dropped (banned)\n", addr.ToString());
1173-
CloseSocket(hSocket);
11741171
return;
11751172
}
11761173

@@ -1179,7 +1176,6 @@ void CConnman::CreateNodeFromAcceptedSocket(SOCKET hSocket,
11791176
if (!NetPermissions::HasFlag(permissionFlags, NetPermissionFlags::NoBan) && nInbound + 1 >= nMaxInbound && discouraged)
11801177
{
11811178
LogPrint(BCLog::NET, "connection from %s dropped (discouraged)\n", addr.ToString());
1182-
CloseSocket(hSocket);
11831179
return;
11841180
}
11851181

@@ -1188,7 +1184,6 @@ void CConnman::CreateNodeFromAcceptedSocket(SOCKET hSocket,
11881184
if (!AttemptToEvictConnection()) {
11891185
// No connection to evict, disconnect the new connection
11901186
LogPrint(BCLog::NET, "failed to find an eviction candidate - connection dropped (full)\n");
1191-
CloseSocket(hSocket);
11921187
return;
11931188
}
11941189
}
@@ -1202,7 +1197,7 @@ void CConnman::CreateNodeFromAcceptedSocket(SOCKET hSocket,
12021197
}
12031198

12041199
const bool inbound_onion = std::find(m_onion_binds.begin(), m_onion_binds.end(), addr_bind) != m_onion_binds.end();
1205-
CNode* pnode = new CNode(id, nodeServices, hSocket, addr, CalculateKeyedNetGroup(addr), nonce, addr_bind, "", ConnectionType::INBOUND, inbound_onion);
1200+
CNode* pnode = new CNode(id, nodeServices, sock->Release(), addr, CalculateKeyedNetGroup(addr), nonce, addr_bind, "", ConnectionType::INBOUND, inbound_onion);
12061201
pnode->AddRef();
12071202
pnode->m_permissionFlags = permissionFlags;
12081203
pnode->m_prefer_evict = discouraged;
@@ -1364,7 +1359,7 @@ bool CConnman::GenerateSelectSet(const std::vector<CNode*>& nodes,
13641359
std::set<SOCKET>& error_set)
13651360
{
13661361
for (const ListenSocket& hListenSocket : vhListenSocket) {
1367-
recv_set.insert(hListenSocket.socket);
1362+
recv_set.insert(hListenSocket.sock->Get());
13681363
}
13691364

13701365
for (CNode* pnode : nodes) {
@@ -1646,7 +1641,7 @@ void CConnman::SocketHandlerListening(const std::set<SOCKET>& recv_set)
16461641
if (interruptNet) {
16471642
return;
16481643
}
1649-
if (recv_set.count(listen_socket.socket) > 0) {
1644+
if (recv_set.count(listen_socket.sock->Get()) > 0) {
16501645
AcceptConnection(listen_socket);
16511646
}
16521647
}
@@ -2335,7 +2330,7 @@ void CConnman::ThreadI2PAcceptIncoming()
23352330
continue;
23362331
}
23372332

2338-
CreateNodeFromAcceptedSocket(conn.sock->Release(), NetPermissionFlags::None,
2333+
CreateNodeFromAcceptedSocket(std::move(conn.sock), NetPermissionFlags::None,
23392334
CAddress{conn.me, NODE_NONE}, CAddress{conn.peer, NODE_NONE});
23402335
}
23412336
}
@@ -2397,7 +2392,7 @@ bool CConnman::BindListenPort(const CService& addrBind, bilingual_str& strError,
23972392
return false;
23982393
}
23992394

2400-
vhListenSocket.push_back(ListenSocket(sock->Release(), permissions));
2395+
vhListenSocket.emplace_back(std::move(sock), permissions);
24012396
return true;
24022397
}
24032398

@@ -2706,15 +2701,6 @@ void CConnman::StopNodes()
27062701
DeleteNode(pnode);
27072702
}
27082703

2709-
// Close listening sockets.
2710-
for (ListenSocket& hListenSocket : vhListenSocket) {
2711-
if (hListenSocket.socket != INVALID_SOCKET) {
2712-
if (!CloseSocket(hListenSocket.socket)) {
2713-
LogPrintf("CloseSocket(hListenSocket) failed with error %s\n", NetworkErrorString(WSAGetLastError()));
2714-
}
2715-
}
2716-
}
2717-
27182704
for (CNode* pnode : m_nodes_disconnected) {
27192705
DeleteNode(pnode);
27202706
}

src/net.h

Lines changed: 9 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -25,6 +25,7 @@
2525
#include <threadinterrupt.h>
2626
#include <uint256.h>
2727
#include <util/check.h>
28+
#include <util/sock.h>
2829

2930
#include <atomic>
3031
#include <condition_variable>
@@ -947,9 +948,13 @@ class CConnman
947948
private:
948949
struct ListenSocket {
949950
public:
950-
SOCKET socket;
951+
std::shared_ptr<Sock> sock;
951952
inline void AddSocketPermissionFlags(NetPermissionFlags& flags) const { NetPermissions::AddFlag(flags, m_permissions); }
952-
ListenSocket(SOCKET socket_, NetPermissionFlags permissions_) : socket(socket_), m_permissions(permissions_) {}
953+
ListenSocket(std::shared_ptr<Sock> sock_, NetPermissionFlags permissions_)
954+
: sock{sock_}, m_permissions{permissions_}
955+
{
956+
}
957+
953958
private:
954959
NetPermissionFlags m_permissions;
955960
};
@@ -969,12 +974,12 @@ class CConnman
969974
/**
970975
* Create a `CNode` object from a socket that has just been accepted and add the node to
971976
* the `m_nodes` member.
972-
* @param[in] hSocket Connected socket to communicate with the peer.
977+
* @param[in] sock Connected socket to communicate with the peer.
973978
* @param[in] permissionFlags The peer's permissions.
974979
* @param[in] addr_bind The address and port at our side of the connection.
975980
* @param[in] addr The address and port at the peer's side of the connection.
976981
*/
977-
void CreateNodeFromAcceptedSocket(SOCKET hSocket,
982+
void CreateNodeFromAcceptedSocket(std::unique_ptr<Sock>&& sock,
978983
NetPermissionFlags permissionFlags,
979984
const CAddress& addr_bind,
980985
const CAddress& addr);

src/test/fuzz/util.cpp

Lines changed: 16 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -13,6 +13,8 @@
1313
#include <util/time.h>
1414
#include <version.h>
1515

16+
#include <memory>
17+
1618
FuzzedSock::FuzzedSock(FuzzedDataProvider& fuzzed_data_provider)
1719
: m_fuzzed_data_provider{fuzzed_data_provider}
1820
{
@@ -158,6 +160,20 @@ int FuzzedSock::Connect(const sockaddr*, socklen_t) const
158160
return 0;
159161
}
160162

163+
std::unique_ptr<Sock> FuzzedSock::Accept(sockaddr* addr, socklen_t* addr_len) const
164+
{
165+
constexpr std::array accept_errnos{
166+
ECONNABORTED,
167+
EINTR,
168+
ENOMEM,
169+
};
170+
if (m_fuzzed_data_provider.ConsumeBool()) {
171+
SetFuzzedErrNo(m_fuzzed_data_provider, accept_errnos);
172+
return std::unique_ptr<FuzzedSock>();
173+
}
174+
return std::make_unique<FuzzedSock>(m_fuzzed_data_provider);
175+
}
176+
161177
int FuzzedSock::GetSockOpt(int level, int opt_name, void* opt_val, socklen_t* opt_len) const
162178
{
163179
constexpr std::array getsockopt_errnos{

src/test/fuzz/util.h

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -401,6 +401,8 @@ class FuzzedSock : public Sock
401401

402402
int Connect(const sockaddr*, socklen_t) const override;
403403

404+
std::unique_ptr<Sock> Accept(sockaddr* addr, socklen_t* addr_len) const override;
405+
404406
int GetSockOpt(int level, int opt_name, void* opt_val, socklen_t* opt_len) const override;
405407

406408
bool Wait(std::chrono::milliseconds timeout, Event requested, Event* occurred = nullptr) const override;

src/test/util/net.h

Lines changed: 18 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -13,6 +13,7 @@
1313
#include <array>
1414
#include <cassert>
1515
#include <cstring>
16+
#include <memory>
1617
#include <string>
1718

1819
struct ConnmanTestMsg : public CConnman {
@@ -126,6 +127,23 @@ class StaticContentsSock : public Sock
126127

127128
int Connect(const sockaddr*, socklen_t) const override { return 0; }
128129

130+
std::unique_ptr<Sock> Accept(sockaddr* addr, socklen_t* addr_len) const override
131+
{
132+
if (addr != nullptr) {
133+
// Pretend all connections come from 5.5.5.5:6789
134+
memset(addr, 0x00, *addr_len);
135+
const socklen_t write_len = static_cast<socklen_t>(sizeof(sockaddr_in));
136+
if (*addr_len >= write_len) {
137+
*addr_len = write_len;
138+
sockaddr_in* addr_in = reinterpret_cast<sockaddr_in*>(addr);
139+
addr_in->sin_family = AF_INET;
140+
memset(&addr_in->sin_addr, 0x05, sizeof(addr_in->sin_addr));
141+
addr_in->sin_port = htons(6789);
142+
}
143+
}
144+
return std::make_unique<StaticContentsSock>("");
145+
};
146+
129147
int GetSockOpt(int level, int opt_name, void* opt_val, socklen_t* opt_len) const override
130148
{
131149
std::memset(opt_val, 0x0, *opt_len);

src/util/sock.cpp

Lines changed: 27 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -10,6 +10,7 @@
1010
#include <util/system.h>
1111
#include <util/time.h>
1212

13+
#include <memory>
1314
#include <stdexcept>
1415
#include <string>
1516

@@ -73,6 +74,32 @@ int Sock::Connect(const sockaddr* addr, socklen_t addr_len) const
7374
return connect(m_socket, addr, addr_len);
7475
}
7576

77+
std::unique_ptr<Sock> Sock::Accept(sockaddr* addr, socklen_t* addr_len) const
78+
{
79+
#ifdef WIN32
80+
static constexpr auto ERR = INVALID_SOCKET;
81+
#else
82+
static constexpr auto ERR = SOCKET_ERROR;
83+
#endif
84+
85+
std::unique_ptr<Sock> sock;
86+
87+
const auto socket = accept(m_socket, addr, addr_len);
88+
if (socket != ERR) {
89+
try {
90+
sock = std::make_unique<Sock>(socket);
91+
} catch (const std::exception&) {
92+
#ifdef WIN32
93+
closesocket(socket);
94+
#else
95+
close(socket);
96+
#endif
97+
}
98+
}
99+
100+
return sock;
101+
}
102+
76103
int Sock::GetSockOpt(int level, int opt_name, void* opt_val, socklen_t* opt_len) const
77104
{
78105
return getsockopt(m_socket, level, opt_name, static_cast<char*>(opt_val), opt_len);

src/util/sock.h

Lines changed: 9 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -10,6 +10,7 @@
1010
#include <util/time.h>
1111

1212
#include <chrono>
13+
#include <memory>
1314
#include <string>
1415

1516
/**
@@ -96,6 +97,14 @@ class Sock
9697
*/
9798
[[nodiscard]] virtual int Connect(const sockaddr* addr, socklen_t addr_len) const;
9899

100+
/**
101+
* accept(2) wrapper. Equivalent to `std::make_unique<Sock>(accept(this->Get(), addr, addr_len))`.
102+
* Code that uses this wrapper can be unit tested if this method is overridden by a mock Sock
103+
* implementation.
104+
* The returned unique_ptr is empty if `accept()` failed in which case errno will be set.
105+
*/
106+
[[nodiscard]] virtual std::unique_ptr<Sock> Accept(sockaddr* addr, socklen_t* addr_len) const;
107+
99108
/**
100109
* getsockopt(2) wrapper. Equivalent to
101110
* `getsockopt(this->Get(), level, opt_name, opt_val, opt_len)`. Code that uses this

0 commit comments

Comments
 (0)