@@ -91,6 +91,8 @@ use backoff_std_async::future::retry;
9191#[ cfg( all( not( feature = "tokio-comp" ) , feature = "async-std-comp" ) ) ]
9292use backoff_std_async:: { Error as BackoffError , ExponentialBackoff } ;
9393
94+ #[ cfg( feature = "tokio-comp" ) ]
95+ use async_trait:: async_trait;
9496#[ cfg( feature = "tokio-comp" ) ]
9597use backoff_tokio:: future:: retry;
9698#[ cfg( feature = "tokio-comp" ) ]
@@ -379,20 +381,37 @@ where
379381#[ cfg( feature = "tokio-comp" ) ]
380382#[ derive( Clone ) ]
381383struct TokioDisconnectNotifier {
382- pub disconnect_notifier : Arc < Notify > ,
384+ disconnect_notifier : Arc < Notify > ,
383385}
384386
385387#[ cfg( feature = "tokio-comp" ) ]
388+ #[ async_trait]
386389impl DisconnectNotifier for TokioDisconnectNotifier {
387390 fn notify_disconnect ( & mut self ) {
388391 self . disconnect_notifier . notify_one ( ) ;
389392 }
390393
394+ async fn wait_for_disconnect_with_timeout ( & self , max_wait : & Duration ) {
395+ let _ = timeout ( * max_wait, async {
396+ self . disconnect_notifier . notified ( ) . await ;
397+ } )
398+ . await ;
399+ }
400+
391401 fn clone_box ( & self ) -> Box < dyn DisconnectNotifier > {
392402 Box :: new ( self . clone ( ) )
393403 }
394404}
395405
406+ #[ cfg( feature = "tokio-comp" ) ]
407+ impl TokioDisconnectNotifier {
408+ fn new ( ) -> TokioDisconnectNotifier {
409+ TokioDisconnectNotifier {
410+ disconnect_notifier : Arc :: new ( Notify :: new ( ) ) ,
411+ }
412+ }
413+ }
414+
396415type ConnectionMap < C > = connections_container:: ConnectionsMap < ConnectionFuture < C > > ;
397416type ConnectionsContainer < C > =
398417 self :: connections_container:: ConnectionsContainer < ConnectionFuture < C > > ;
@@ -406,8 +425,6 @@ pub(crate) struct InnerCore<C> {
406425 subscriptions_by_address : RwLock < HashMap < String , PubSubSubscriptionInfo > > ,
407426 unassigned_subscriptions : RwLock < PubSubSubscriptionInfo > ,
408427 glide_connection_options : GlideConnectionOptions ,
409- #[ cfg( feature = "tokio-comp" ) ]
410- tokio_notify : Arc < Notify > ,
411428}
412429
413430pub ( crate ) type Core < C > = Arc < InnerCore < C > > ;
@@ -990,27 +1007,24 @@ where
9901007 cluster_params : ClusterParams ,
9911008 push_sender : Option < mpsc:: UnboundedSender < PushInfo > > ,
9921009 ) -> RedisResult < Disposable < Self > > {
993- #[ cfg( feature = "tokio-comp" ) ]
994- let tokio_notify = Arc :: new ( Notify :: new ( ) ) ;
995-
9961010 let disconnect_notifier = {
9971011 #[ cfg( feature = "tokio-comp" ) ]
9981012 {
999- Some :: < Box < dyn DisconnectNotifier > > ( Box :: new ( TokioDisconnectNotifier {
1000- disconnect_notifier : tokio_notify. clone ( ) ,
1001- } ) )
1013+ Some :: < Box < dyn DisconnectNotifier > > ( Box :: new ( TokioDisconnectNotifier :: new ( ) ) )
10021014 }
10031015 #[ cfg( not( feature = "tokio-comp" ) ) ]
10041016 None
10051017 } ;
10061018
1019+ let glide_connection_options = GlideConnectionOptions {
1020+ push_sender,
1021+ disconnect_notifier,
1022+ } ;
1023+
10071024 let connections = Self :: create_initial_connections (
10081025 initial_nodes,
10091026 & cluster_params,
1010- GlideConnectionOptions {
1011- push_sender : push_sender. clone ( ) ,
1012- disconnect_notifier : disconnect_notifier. clone ( ) ,
1013- } ,
1027+ glide_connection_options. clone ( ) ,
10141028 )
10151029 . await ?;
10161030
@@ -1035,12 +1049,7 @@ where
10351049 } ,
10361050 ) ,
10371051 subscriptions_by_address : RwLock :: new ( Default :: default ( ) ) ,
1038- glide_connection_options : GlideConnectionOptions {
1039- push_sender : push_sender. clone ( ) ,
1040- disconnect_notifier : disconnect_notifier. clone ( ) ,
1041- } ,
1042- #[ cfg( feature = "tokio-comp" ) ]
1043- tokio_notify,
1052+ glide_connection_options,
10441053 } ) ;
10451054 let mut connection = ClusterConnInner {
10461055 inner,
@@ -1227,40 +1236,40 @@ where
12271236 // In addition, the validation is done by peeking at the state of the underlying transport w/o overhead of additional commands to server.
12281237 async fn validate_all_user_connections ( inner : Arc < InnerCore < C > > ) {
12291238 let mut all_valid_conns = HashMap :: new ( ) ;
1230- let mut all_nodes_with_slots = HashSet :: new ( ) ;
12311239 // prep connections and clean out these w/o assigned slots, as we might have established connections to unwanted hosts
1232- {
1233- let mut nodes_to_delete = Vec :: new ( ) ;
1234- let connections_container = inner. conn_lock . read ( ) . await ;
1235-
1236- connections_container
1237- . slot_map
1238- . addresses_for_all_nodes ( )
1239- . iter ( )
1240- . for_each ( |addr| {
1241- all_nodes_with_slots. insert ( String :: from ( * addr) ) ;
1242- } ) ;
1240+ let mut nodes_to_delete = Vec :: new ( ) ;
1241+ let connections_container = inner. conn_lock . read ( ) . await ;
12431242
1244- connections_container
1245- . all_node_connections ( )
1246- . for_each ( |( addr, con) | {
1247- if all_nodes_with_slots. contains ( & addr) {
1248- all_valid_conns. insert ( addr. clone ( ) , con. clone ( ) ) ;
1249- } else {
1250- nodes_to_delete. push ( addr. clone ( ) ) ;
1251- }
1252- } ) ;
1243+ let all_nodes_with_slots: HashSet < String > = connections_container
1244+ . slot_map
1245+ . addresses_for_all_nodes ( )
1246+ . iter ( )
1247+ . map ( |addr| String :: from ( * addr) )
1248+ . collect ( ) ;
1249+
1250+ connections_container
1251+ . all_node_connections ( )
1252+ . for_each ( |( addr, con) | {
1253+ if all_nodes_with_slots. contains ( & addr) {
1254+ all_valid_conns. insert ( addr. clone ( ) , con. clone ( ) ) ;
1255+ } else {
1256+ nodes_to_delete. push ( addr. clone ( ) ) ;
1257+ }
1258+ } ) ;
12531259
1254- for addr in & nodes_to_delete {
1255- connections_container. remove_node ( addr) ;
1256- }
1260+ for addr in & nodes_to_delete {
1261+ connections_container. remove_node ( addr) ;
12571262 }
12581263
1264+ drop ( connections_container) ;
1265+
12591266 // identify nodes with closed connection
12601267 let mut addrs_to_refresh = Vec :: new ( ) ;
12611268 for ( addr, con_fut) in & all_valid_conns {
12621269 let con = con_fut. clone ( ) . await ;
1270+ // connection object might be present despite the transport being closed
12631271 if con. is_closed ( ) {
1272+ // transport is closed, need to refresh
12641273 addrs_to_refresh. push ( addr. clone ( ) ) ;
12651274 }
12661275 }
@@ -1289,7 +1298,7 @@ where
12891298 inner : Arc < InnerCore < C > > ,
12901299 addresses : Vec < String > ,
12911300 conn_type : RefreshConnectionType ,
1292- try_existing_node : bool ,
1301+ check_existing_conn : bool ,
12931302 ) {
12941303 info ! ( "Started refreshing connections to {:?}" , addresses) ;
12951304 let connections_container = inner. conn_lock . read ( ) . await ;
@@ -1301,10 +1310,10 @@ where
13011310 . fold (
13021311 & * connections_container,
13031312 |connections_container, address| async move {
1304- let node_option = if try_existing_node {
1313+ let node_option = if check_existing_conn {
13051314 connections_container. remove_node ( & address)
13061315 } else {
1307- Option :: None
1316+ None
13081317 } ;
13091318
13101319 // override subscriptions for this connection
@@ -1541,13 +1550,15 @@ where
15411550
15421551 async fn connections_validation_task ( inner : Arc < InnerCore < C > > , interval_duration : Duration ) {
15431552 loop {
1544- #[ cfg( feature = "tokio-comp" ) ]
1545- let _ = timeout ( interval_duration, async {
1546- inner. tokio_notify . notified ( ) . await ;
1547- } )
1548- . await ;
1549- #[ cfg( not( feature = "tokio-comp" ) ) ]
1550- let _ = boxed_sleep ( interval_duration) . await ;
1553+ if let Some ( disconnect_notifier) =
1554+ inner. glide_connection_options . disconnect_notifier . clone ( )
1555+ {
1556+ disconnect_notifier
1557+ . wait_for_disconnect_with_timeout ( & interval_duration)
1558+ . await ;
1559+ } else {
1560+ let _ = boxed_sleep ( interval_duration) . await ;
1561+ }
15511562
15521563 Self :: validate_all_user_connections ( inner. clone ( ) ) . await ;
15531564 }
0 commit comments