@@ -171,7 +171,7 @@ namespace net_utils
171171 return ;
172172 m_state.timers .general .wait_expire = true ;
173173 auto self = connection<T>::shared_from_this ();
174- m_timers. general . async_wait ( [this , self]( const ec_t & ec) {
174+ auto on_wait = [this , self] {
175175 std::lock_guard<std::mutex> guard (m_state.lock );
176176 m_state.timers .general .wait_expire = false ;
177177 if (m_state.timers .general .cancel_expire ) {
@@ -189,6 +189,9 @@ namespace net_utils
189189 interrupt ();
190190 else if (m_state.status == status_t ::INTERRUPTED)
191191 terminate ();
192+ };
193+ m_timers.general .async_wait ([this , self, on_wait](const ec_t & ec){
194+ boost::asio::post (m_strand, on_wait);
192195 });
193196 }
194197
@@ -242,27 +245,7 @@ namespace net_utils
242245 )
243246 ) {
244247 m_state.ssl .enabled = false ;
245- m_state.socket .handle_read = true ;
246- boost::asio::post (
247- connection_basic::strand_,
248- [this , self, bytes_transferred]{
249- bool success = m_handler.handle_recv (
250- reinterpret_cast <char *>(m_state.data .read .buffer .data ()),
251- bytes_transferred
252- );
253- std::lock_guard<std::mutex> guard (m_state.lock );
254- m_state.socket .handle_read = false ;
255- if (m_state.status == status_t ::INTERRUPTED)
256- on_interrupted ();
257- else if (m_state.status == status_t ::TERMINATING)
258- on_terminating ();
259- else if (!success)
260- interrupt ();
261- else {
262- start_read ();
263- }
264- }
265- );
248+ finish_read (bytes_transferred);
266249 }
267250 else {
268251 m_state.ssl .detected = true ;
@@ -322,7 +305,7 @@ namespace net_utils
322305 void connection<T>::start_read()
323306 {
324307 if (m_state.timers .throttle .in .wait_expire || m_state.socket .wait_read ||
325- m_state.socket .handle_read
308+ m_state.socket .handle_read || m_state. socket . shutdown_read
326309 ) {
327310 return ;
328311 }
@@ -346,7 +329,7 @@ namespace net_utils
346329 if (duration > duration_t {}) {
347330 m_timers.throttle .in .expires_after (duration);
348331 m_state.timers .throttle .in .wait_expire = true ;
349- m_timers. throttle . in . async_wait ( [this , self](const ec_t &ec){
332+ auto on_wait = [this , self](const ec_t &ec){
350333 std::lock_guard<std::mutex> guard (m_state.lock );
351334 m_state.timers .throttle .in .wait_expire = false ;
352335 if (m_state.timers .throttle .in .cancel_expire ) {
@@ -355,8 +338,16 @@ namespace net_utils
355338 }
356339 else if (ec.value ())
357340 interrupt ();
358- else
341+ };
342+ m_timers.throttle .in .async_wait ([this , self, on_wait](const ec_t &ec){
343+ std::lock_guard<std::mutex> guard (m_state.lock );
344+ const bool error_status = m_state.timers .throttle .in .cancel_expire || ec.value ();
345+ if (error_status)
346+ boost::asio::post (m_strand, std::bind (on_wait, ec));
347+ else {
348+ m_state.timers .throttle .in .wait_expire = false ;
359349 start_read ();
350+ }
360351 });
361352 return ;
362353 }
@@ -392,33 +383,7 @@ namespace net_utils
392383 m_conn_context.m_recv_cnt += bytes_transferred;
393384 start_timer (get_timeout_from_bytes_read (bytes_transferred), true );
394385 }
395-
396- // Post handle_recv to a separate `strand_`, distinct from `m_strand`
397- // which is listening for reads/writes. This avoids a circular dep.
398- // handle_recv can queue many writes, and `m_strand` will process those
399- // writes until the connection terminates without deadlocking waiting
400- // for handle_recv.
401- m_state.socket .handle_read = true ;
402- boost::asio::post (
403- connection_basic::strand_,
404- [this , self, bytes_transferred]{
405- bool success = m_handler.handle_recv (
406- reinterpret_cast <char *>(m_state.data .read .buffer .data ()),
407- bytes_transferred
408- );
409- std::lock_guard<std::mutex> guard (m_state.lock );
410- m_state.socket .handle_read = false ;
411- if (m_state.status == status_t ::INTERRUPTED)
412- on_interrupted ();
413- else if (m_state.status == status_t ::TERMINATING)
414- on_terminating ();
415- else if (!success)
416- interrupt ();
417- else {
418- start_read ();
419- }
420- }
421- );
386+ finish_read (bytes_transferred);
422387 }
423388 };
424389 if (!m_state.ssl .enabled )
@@ -444,6 +409,62 @@ namespace net_utils
444409 );
445410 }
446411
412+ template <typename T>
413+ void connection<T>::finish_read(size_t bytes_transferred)
414+ {
415+ // Post handle_recv to a separate `strand_`, distinct from `m_strand`
416+ // which is listening for reads/writes. This avoids a circular dep.
417+ // handle_recv can queue many writes, and `m_strand` will process those
418+ // writes until the connection terminates without deadlocking waiting
419+ // for handle_recv.
420+ m_state.socket .handle_read = true ;
421+ auto self = connection<T>::shared_from_this ();
422+ boost::asio::post (
423+ connection_basic::strand_,
424+ [this , self, bytes_transferred]{
425+ bool success = m_handler.handle_recv (
426+ reinterpret_cast <char *>(m_state.data .read .buffer .data ()),
427+ bytes_transferred
428+ );
429+ std::lock_guard<std::mutex> guard (m_state.lock );
430+ const bool error_status = m_state.status == status_t ::INTERRUPTED
431+ || m_state.status == status_t ::TERMINATING
432+ || !success;
433+ if (!error_status) {
434+ m_state.socket .handle_read = false ;
435+ start_read ();
436+ return ;
437+ }
438+ boost::asio::post (
439+ m_strand,
440+ [this , self, success]{
441+ // expect error_status == true
442+ std::lock_guard<std::mutex> guard (m_state.lock );
443+ m_state.socket .handle_read = false ;
444+ if (m_state.status == status_t ::INTERRUPTED)
445+ on_interrupted ();
446+ else if (m_state.status == status_t ::TERMINATING)
447+ on_terminating ();
448+ else if (!success) {
449+ ec_t ec;
450+ if (m_state.socket .wait_write ) {
451+ // Allow the already queued writes time to finish, but no more new reads
452+ connection_basic::socket_.next_layer ().shutdown (
453+ socket_t ::shutdown_receive,
454+ ec
455+ );
456+ m_state.socket .shutdown_read = true ;
457+ }
458+ if (!m_state.socket .wait_write || ec.value ()) {
459+ interrupt ();
460+ }
461+ }
462+ }
463+ );
464+ }
465+ );
466+ }
467+
447468 template <typename T>
448469 void connection<T>::start_write()
449470 {
@@ -475,7 +496,7 @@ namespace net_utils
475496 if (duration > duration_t {}) {
476497 m_timers.throttle .out .expires_after (duration);
477498 m_state.timers .throttle .out .wait_expire = true ;
478- m_timers. throttle . out . async_wait ( [this , self](const ec_t &ec){
499+ auto on_wait = [this , self](const ec_t &ec){
479500 std::lock_guard<std::mutex> guard (m_state.lock );
480501 m_state.timers .throttle .out .wait_expire = false ;
481502 if (m_state.timers .throttle .out .cancel_expire ) {
@@ -484,8 +505,16 @@ namespace net_utils
484505 }
485506 else if (ec.value ())
486507 interrupt ();
487- else
508+ };
509+ m_timers.throttle .out .async_wait ([this , self, on_wait](const ec_t &ec){
510+ std::lock_guard<std::mutex> guard (m_state.lock );
511+ const bool error_status = m_state.timers .throttle .out .cancel_expire || ec.value ();
512+ if (error_status)
513+ boost::asio::post (m_strand, std::bind (on_wait, ec));
514+ else {
515+ m_state.timers .throttle .out .wait_expire = false ;
488516 start_write ();
517+ }
489518 });
490519 }
491520 }
@@ -533,7 +562,12 @@ namespace net_utils
533562 m_state.data .write .total_bytes -=
534563 std::min (m_state.data .write .total_bytes , byte_count);
535564 m_state.condition .notify_all ();
536- start_write ();
565+ if (m_state.data .write .queue .empty () && m_state.socket .shutdown_read ) {
566+ // All writes have been sent and reads shutdown already, connection can be closed
567+ interrupt ();
568+ } else {
569+ start_write ();
570+ }
537571 }
538572 };
539573 if (!m_state.ssl .enabled )
@@ -762,6 +796,17 @@ namespace net_utils
762796 m_state.status = status_t ::WASTED;
763797 }
764798
799+ template <typename T>
800+ void connection<T>::terminate_async()
801+ {
802+ // synchronize with intermediate writes on `m_strand`
803+ auto self = connection<T>::shared_from_this ();
804+ boost::asio::post (m_strand, [this , self] {
805+ std::lock_guard<std::mutex> guard (m_state.lock );
806+ terminate ();
807+ });
808+ }
809+
765810 template <typename T>
766811 bool connection<T>::send(epee::byte_slice message)
767812 {
@@ -814,12 +859,7 @@ namespace net_utils
814859 );
815860 m_state.data .write .wait_consume = false ;
816861 if (!success) {
817- // synchronize with intermediate writes on `m_strand`
818- auto self = connection<T>::shared_from_this ();
819- boost::asio::post (m_strand, [this , self] {
820- std::lock_guard<std::mutex> guard (m_state.lock );
821- terminate ();
822- });
862+ terminate_async ();
823863 return false ;
824864 }
825865 else
@@ -1093,7 +1133,7 @@ namespace net_utils
10931133 std::lock_guard<std::mutex> guard (m_state.lock );
10941134 if (m_state.status != status_t ::RUNNING)
10951135 return false ;
1096- terminate ();
1136+ terminate_async ();
10971137 return true ;
10981138 }
10991139
0 commit comments