2929from redis .asyncio .connection import Connection , DefaultParser , SSLConnection , parse_url
3030from redis .asyncio .lock import Lock
3131from redis .asyncio .retry import Retry
32+ from redis .auth .token import TokenInterface
3233from redis .backoff import default_backoff
3334from redis .client import EMPTY_RESPONSE , NEVER_DECODE , AbstractRedis
3435from redis .cluster import (
4546from redis .commands import READ_COMMANDS , AsyncRedisClusterCommands
4647from redis .crc import REDIS_CLUSTER_HASH_SLOTS , key_slot
4748from redis .credentials import CredentialProvider
49+ from redis .event import AfterAsyncClusterInstantiationEvent , EventDispatcher
4850from redis .exceptions import (
4951 AskError ,
5052 BusyLoadingError ,
5759 MaxConnectionsError ,
5860 MovedError ,
5961 RedisClusterException ,
62+ RedisError ,
6063 ResponseError ,
6164 SlotNotCoveredError ,
6265 TimeoutError ,
@@ -270,6 +273,7 @@ def __init__(
270273 ssl_ciphers : Optional [str ] = None ,
271274 protocol : Optional [int ] = 2 ,
272275 address_remap : Optional [Callable [[Tuple [str , int ]], Tuple [str , int ]]] = None ,
276+ event_dispatcher : Optional [EventDispatcher ] = None ,
273277 ) -> None :
274278 if db :
275279 raise RedisClusterException (
@@ -366,11 +370,17 @@ def __init__(
366370 if host and port :
367371 startup_nodes .append (ClusterNode (host , port , ** self .connection_kwargs ))
368372
373+ if event_dispatcher is None :
374+ self ._event_dispatcher = EventDispatcher ()
375+ else :
376+ self ._event_dispatcher = event_dispatcher
377+
369378 self .nodes_manager = NodesManager (
370379 startup_nodes ,
371380 require_full_coverage ,
372381 kwargs ,
373382 address_remap = address_remap ,
383+ event_dispatcher = self ._event_dispatcher ,
374384 )
375385 self .encoder = Encoder (encoding , encoding_errors , decode_responses )
376386 self .read_from_replicas = read_from_replicas
@@ -929,6 +939,8 @@ class ClusterNode:
929939 __slots__ = (
930940 "_connections" ,
931941 "_free" ,
942+ "_lock" ,
943+ "_event_dispatcher" ,
932944 "connection_class" ,
933945 "connection_kwargs" ,
934946 "host" ,
@@ -966,6 +978,9 @@ def __init__(
966978
967979 self ._connections : List [Connection ] = []
968980 self ._free : Deque [Connection ] = collections .deque (maxlen = self .max_connections )
981+ self ._event_dispatcher = self .connection_kwargs .get ("event_dispatcher" , None )
982+ if self ._event_dispatcher is None :
983+ self ._event_dispatcher = EventDispatcher ()
969984
970985 def __repr__ (self ) -> str :
971986 return (
@@ -1082,10 +1097,38 @@ async def execute_pipeline(self, commands: List["PipelineCommand"]) -> bool:
10821097
10831098 return ret
10841099
1100+ async def re_auth_callback (self , token : TokenInterface ):
1101+ tmp_queue = collections .deque ()
1102+ while self ._free :
1103+ conn = self ._free .popleft ()
1104+ await conn .retry .call_with_retry (
1105+ lambda : conn .send_command (
1106+ "AUTH" , token .try_get ("oid" ), token .get_value ()
1107+ ),
1108+ lambda error : self ._mock (error ),
1109+ )
1110+ await conn .retry .call_with_retry (
1111+ lambda : conn .read_response (), lambda error : self ._mock (error )
1112+ )
1113+ tmp_queue .append (conn )
1114+
1115+ while tmp_queue :
1116+ conn = tmp_queue .popleft ()
1117+ self ._free .append (conn )
1118+
1119+ async def _mock (self , error : RedisError ):
1120+ """
1121+ Dummy functions, needs to be passed as error callback to retry object.
1122+ :param error:
1123+ :return:
1124+ """
1125+ pass
1126+
10851127
10861128class NodesManager :
10871129 __slots__ = (
10881130 "_moved_exception" ,
1131+ "_event_dispatcher" ,
10891132 "connection_kwargs" ,
10901133 "default_node" ,
10911134 "nodes_cache" ,
@@ -1102,6 +1145,7 @@ def __init__(
11021145 require_full_coverage : bool ,
11031146 connection_kwargs : Dict [str , Any ],
11041147 address_remap : Optional [Callable [[Tuple [str , int ]], Tuple [str , int ]]] = None ,
1148+ event_dispatcher : Optional [EventDispatcher ] = None ,
11051149 ) -> None :
11061150 self .startup_nodes = {node .name : node for node in startup_nodes }
11071151 self .require_full_coverage = require_full_coverage
@@ -1113,6 +1157,10 @@ def __init__(
11131157 self .slots_cache : Dict [int , List ["ClusterNode" ]] = {}
11141158 self .read_load_balancer = LoadBalancer ()
11151159 self ._moved_exception : MovedError = None
1160+ if event_dispatcher is None :
1161+ self ._event_dispatcher = EventDispatcher ()
1162+ else :
1163+ self ._event_dispatcher = event_dispatcher
11161164
11171165 def get_node (
11181166 self ,
@@ -1230,6 +1278,12 @@ async def initialize(self) -> None:
12301278 try :
12311279 # Make sure cluster mode is enabled on this node
12321280 try :
1281+ self ._event_dispatcher .dispatch (
1282+ AfterAsyncClusterInstantiationEvent (
1283+ self .nodes_cache ,
1284+ self .connection_kwargs .get ("credential_provider" , None ),
1285+ )
1286+ )
12331287 cluster_slots = await startup_node .execute_command ("CLUSTER SLOTS" )
12341288 except ResponseError :
12351289 raise RedisClusterException (
0 commit comments