Skip to content

Commit 7cc1142

Browse files
committed
Merge pull request #10298
ee9e4a4 p2p: connection patches (j-berman)
2 parents 027bbf9 + ee9e4a4 commit 7cc1142

File tree

3 files changed

+173
-62
lines changed

3 files changed

+173
-62
lines changed

contrib/epee/include/net/abstract_tcp_server2.h

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -128,6 +128,7 @@ namespace net_utils
128128

129129
void start_handshake();
130130
void start_read();
131+
void finish_read(size_t bytes_transferred);
131132
void start_write();
132133
void start_shutdown();
133134
void cancel_socket();
@@ -139,6 +140,7 @@ namespace net_utils
139140

140141
void terminate();
141142
void on_terminating();
143+
void terminate_async();
142144

143145
bool send(epee::byte_slice message);
144146
bool start_internal(
@@ -192,6 +194,7 @@ namespace net_utils
192194
bool wait_read;
193195
bool handle_read;
194196
bool cancel_read;
197+
bool shutdown_read;
195198

196199
bool wait_write;
197200
bool handle_write;

contrib/epee/include/net/abstract_tcp_server2.inl

Lines changed: 102 additions & 62 deletions
Original file line numberDiff line numberDiff line change
@@ -171,7 +171,7 @@ namespace net_utils
171171
return;
172172
m_state.timers.general.wait_expire = true;
173173
auto self = connection<T>::shared_from_this();
174-
m_timers.general.async_wait([this, self](const ec_t & ec){
174+
auto on_wait = [this, self] {
175175
std::lock_guard<std::mutex> guard(m_state.lock);
176176
m_state.timers.general.wait_expire = false;
177177
if (m_state.timers.general.cancel_expire) {
@@ -189,6 +189,9 @@ namespace net_utils
189189
interrupt();
190190
else if (m_state.status == status_t::INTERRUPTED)
191191
terminate();
192+
};
193+
m_timers.general.async_wait([this, self, on_wait](const ec_t & ec){
194+
boost::asio::post(m_strand, on_wait);
192195
});
193196
}
194197

@@ -242,27 +245,7 @@ namespace net_utils
242245
)
243246
) {
244247
m_state.ssl.enabled = false;
245-
m_state.socket.handle_read = true;
246-
boost::asio::post(
247-
connection_basic::strand_,
248-
[this, self, bytes_transferred]{
249-
bool success = m_handler.handle_recv(
250-
reinterpret_cast<char *>(m_state.data.read.buffer.data()),
251-
bytes_transferred
252-
);
253-
std::lock_guard<std::mutex> guard(m_state.lock);
254-
m_state.socket.handle_read = false;
255-
if (m_state.status == status_t::INTERRUPTED)
256-
on_interrupted();
257-
else if (m_state.status == status_t::TERMINATING)
258-
on_terminating();
259-
else if (!success)
260-
interrupt();
261-
else {
262-
start_read();
263-
}
264-
}
265-
);
248+
finish_read(bytes_transferred);
266249
}
267250
else {
268251
m_state.ssl.detected = true;
@@ -322,7 +305,7 @@ namespace net_utils
322305
void connection<T>::start_read()
323306
{
324307
if (m_state.timers.throttle.in.wait_expire || m_state.socket.wait_read ||
325-
m_state.socket.handle_read
308+
m_state.socket.handle_read || m_state.socket.shutdown_read
326309
) {
327310
return;
328311
}
@@ -346,7 +329,7 @@ namespace net_utils
346329
if (duration > duration_t{}) {
347330
m_timers.throttle.in.expires_after(duration);
348331
m_state.timers.throttle.in.wait_expire = true;
349-
m_timers.throttle.in.async_wait([this, self](const ec_t &ec){
332+
auto on_wait = [this, self](const ec_t &ec){
350333
std::lock_guard<std::mutex> guard(m_state.lock);
351334
m_state.timers.throttle.in.wait_expire = false;
352335
if (m_state.timers.throttle.in.cancel_expire) {
@@ -355,8 +338,16 @@ namespace net_utils
355338
}
356339
else if (ec.value())
357340
interrupt();
358-
else
341+
};
342+
m_timers.throttle.in.async_wait([this, self, on_wait](const ec_t &ec){
343+
std::lock_guard<std::mutex> guard(m_state.lock);
344+
const bool error_status = m_state.timers.throttle.in.cancel_expire || ec.value();
345+
if (error_status)
346+
boost::asio::post(m_strand, std::bind(on_wait, ec));
347+
else {
348+
m_state.timers.throttle.in.wait_expire = false;
359349
start_read();
350+
}
360351
});
361352
return;
362353
}
@@ -392,33 +383,7 @@ namespace net_utils
392383
m_conn_context.m_recv_cnt += bytes_transferred;
393384
start_timer(get_timeout_from_bytes_read(bytes_transferred), true);
394385
}
395-
396-
// Post handle_recv to a separate `strand_`, distinct from `m_strand`
397-
// which is listening for reads/writes. This avoids a circular dep.
398-
// handle_recv can queue many writes, and `m_strand` will process those
399-
// writes until the connection terminates without deadlocking waiting
400-
// for handle_recv.
401-
m_state.socket.handle_read = true;
402-
boost::asio::post(
403-
connection_basic::strand_,
404-
[this, self, bytes_transferred]{
405-
bool success = m_handler.handle_recv(
406-
reinterpret_cast<char *>(m_state.data.read.buffer.data()),
407-
bytes_transferred
408-
);
409-
std::lock_guard<std::mutex> guard(m_state.lock);
410-
m_state.socket.handle_read = false;
411-
if (m_state.status == status_t::INTERRUPTED)
412-
on_interrupted();
413-
else if (m_state.status == status_t::TERMINATING)
414-
on_terminating();
415-
else if (!success)
416-
interrupt();
417-
else {
418-
start_read();
419-
}
420-
}
421-
);
386+
finish_read(bytes_transferred);
422387
}
423388
};
424389
if (!m_state.ssl.enabled)
@@ -444,6 +409,62 @@ namespace net_utils
444409
);
445410
}
446411

412+
template<typename T>
413+
void connection<T>::finish_read(size_t bytes_transferred)
414+
{
415+
// Post handle_recv to a separate `strand_`, distinct from `m_strand`
416+
// which is listening for reads/writes. This avoids a circular dep.
417+
// handle_recv can queue many writes, and `m_strand` will process those
418+
// writes until the connection terminates without deadlocking waiting
419+
// for handle_recv.
420+
m_state.socket.handle_read = true;
421+
auto self = connection<T>::shared_from_this();
422+
boost::asio::post(
423+
connection_basic::strand_,
424+
[this, self, bytes_transferred]{
425+
bool success = m_handler.handle_recv(
426+
reinterpret_cast<char *>(m_state.data.read.buffer.data()),
427+
bytes_transferred
428+
);
429+
std::lock_guard<std::mutex> guard(m_state.lock);
430+
const bool error_status = m_state.status == status_t::INTERRUPTED
431+
|| m_state.status == status_t::TERMINATING
432+
|| !success;
433+
if (!error_status) {
434+
m_state.socket.handle_read = false;
435+
start_read();
436+
return;
437+
}
438+
boost::asio::post(
439+
m_strand,
440+
[this, self, success]{
441+
// expect error_status == true
442+
std::lock_guard<std::mutex> guard(m_state.lock);
443+
m_state.socket.handle_read = false;
444+
if (m_state.status == status_t::INTERRUPTED)
445+
on_interrupted();
446+
else if (m_state.status == status_t::TERMINATING)
447+
on_terminating();
448+
else if (!success) {
449+
ec_t ec;
450+
if (m_state.socket.wait_write) {
451+
// Allow the already queued writes time to finish, but no more new reads
452+
connection_basic::socket_.next_layer().shutdown(
453+
socket_t::shutdown_receive,
454+
ec
455+
);
456+
m_state.socket.shutdown_read = true;
457+
}
458+
if (!m_state.socket.wait_write || ec.value()) {
459+
interrupt();
460+
}
461+
}
462+
}
463+
);
464+
}
465+
);
466+
}
467+
447468
template<typename T>
448469
void connection<T>::start_write()
449470
{
@@ -475,7 +496,7 @@ namespace net_utils
475496
if (duration > duration_t{}) {
476497
m_timers.throttle.out.expires_after(duration);
477498
m_state.timers.throttle.out.wait_expire = true;
478-
m_timers.throttle.out.async_wait([this, self](const ec_t &ec){
499+
auto on_wait = [this, self](const ec_t &ec){
479500
std::lock_guard<std::mutex> guard(m_state.lock);
480501
m_state.timers.throttle.out.wait_expire = false;
481502
if (m_state.timers.throttle.out.cancel_expire) {
@@ -484,8 +505,16 @@ namespace net_utils
484505
}
485506
else if (ec.value())
486507
interrupt();
487-
else
508+
};
509+
m_timers.throttle.out.async_wait([this, self, on_wait](const ec_t &ec){
510+
std::lock_guard<std::mutex> guard(m_state.lock);
511+
const bool error_status = m_state.timers.throttle.out.cancel_expire || ec.value();
512+
if (error_status)
513+
boost::asio::post(m_strand, std::bind(on_wait, ec));
514+
else {
515+
m_state.timers.throttle.out.wait_expire = false;
488516
start_write();
517+
}
489518
});
490519
}
491520
}
@@ -533,7 +562,12 @@ namespace net_utils
533562
m_state.data.write.total_bytes -=
534563
std::min(m_state.data.write.total_bytes, byte_count);
535564
m_state.condition.notify_all();
536-
start_write();
565+
if (m_state.data.write.queue.empty() && m_state.socket.shutdown_read) {
566+
// All writes have been sent and reads shutdown already, connection can be closed
567+
interrupt();
568+
} else {
569+
start_write();
570+
}
537571
}
538572
};
539573
if (!m_state.ssl.enabled)
@@ -762,6 +796,17 @@ namespace net_utils
762796
m_state.status = status_t::WASTED;
763797
}
764798

799+
template<typename T>
800+
void connection<T>::terminate_async()
801+
{
802+
// synchronize with intermediate writes on `m_strand`
803+
auto self = connection<T>::shared_from_this();
804+
boost::asio::post(m_strand, [this, self] {
805+
std::lock_guard<std::mutex> guard(m_state.lock);
806+
terminate();
807+
});
808+
}
809+
765810
template<typename T>
766811
bool connection<T>::send(epee::byte_slice message)
767812
{
@@ -814,12 +859,7 @@ namespace net_utils
814859
);
815860
m_state.data.write.wait_consume = false;
816861
if (!success) {
817-
// synchronize with intermediate writes on `m_strand`
818-
auto self = connection<T>::shared_from_this();
819-
boost::asio::post(m_strand, [this, self] {
820-
std::lock_guard<std::mutex> guard(m_state.lock);
821-
terminate();
822-
});
862+
terminate_async();
823863
return false;
824864
}
825865
else
@@ -1093,7 +1133,7 @@ namespace net_utils
10931133
std::lock_guard<std::mutex> guard(m_state.lock);
10941134
if (m_state.status != status_t::RUNNING)
10951135
return false;
1096-
terminate();
1136+
terminate_async();
10971137
return true;
10981138
}
10991139

tests/unit_tests/epee_http_server.cpp

Lines changed: 68 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -198,3 +198,71 @@ TEST(http_server, private_ip_limit)
198198
failed |= bool(error);
199199
EXPECT_TRUE(failed);
200200
}
201+
202+
TEST(http_server, read_then_close)
203+
{
204+
namespace http = boost::beast::http;
205+
206+
http_server server{};
207+
server.dummy_size = 200000;
208+
server.init(nullptr, "8080");
209+
server.run(2, false); // need at least 2 threads to trigger issues
210+
211+
bool failed_read = false;
212+
bool closed_all_connections = true;
213+
for (std::size_t j = 0; j < 1000; ++j)
214+
{
215+
boost::system::error_code error{};
216+
boost::asio::io_context context{};
217+
boost::asio::ip::tcp::socket stream{context};
218+
stream.connect(
219+
boost::asio::ip::tcp::endpoint{
220+
boost::asio::ip::make_address("127.0.0.1"), 8080
221+
},
222+
error
223+
);
224+
EXPECT_FALSE(bool(error));
225+
226+
http::request<http::string_body> req{http::verb::get, "/dummy", 11};
227+
req.set(http::field::host, "127.0.0.1");
228+
req.set(http::field::user_agent, BOOST_BEAST_VERSION_STRING);
229+
req.set(http::field::connection, "close"); // tell server to close connection after sending all data to the client
230+
req.body() = make_payload();
231+
req.prepare_payload();
232+
233+
dummy::response payload{};
234+
boost::beast::flat_buffer buffer;
235+
http::response_parser<http::basic_string_body<char>> parser;
236+
parser.body_limit(server.dummy_size + 1024);
237+
238+
http::write(stream, req, error);
239+
EXPECT_FALSE(bool(error));
240+
241+
http::read(stream, buffer, parser, error);
242+
243+
// If the read fails, continue the loop still just to make sure the server can handle it
244+
failed_read |= bool(error);
245+
if (failed_read)
246+
continue;
247+
failed_read |= !(parser.is_done());
248+
if (failed_read)
249+
continue;
250+
const auto res = parser.release();
251+
failed_read |= res.result_int() != 200u
252+
|| !(epee::serialization::load_t_from_binary(payload, res.body()))
253+
|| (server.dummy_size != std::count(payload.payload.begin(), payload.payload.end(), 'f'));
254+
255+
// See if the server closes the connection after handling the resp
256+
char buf[1];
257+
stream.read_some(boost::asio::buffer(buf), error);
258+
closed_all_connections &= error == boost::asio::error::eof;
259+
}
260+
261+
// The client should have been able to read all data sent by the server across all requests
262+
EXPECT_FALSE(failed_read);
263+
264+
// The server should have closed all connections
265+
EXPECT_TRUE(closed_all_connections);
266+
267+
server.send_stop_signal();
268+
}

0 commit comments

Comments
 (0)