@@ -43,14 +43,15 @@ def _wrapper(self, *args, **kwargs):
43
43
class ConnectionPool :
44
44
"""
45
45
Connection pool manager for the channel layer.
46
-
47
46
It manages a set of connections for the given host specification and
48
47
taking into account asyncio event loops.
49
48
"""
50
49
51
50
def __init__ (self , host ):
52
- self .host = host
51
+ self .host = host .copy ()
52
+ self .master_name = self .host .pop ("master_name" , None )
53
53
self .conn_map = {}
54
+ self .sentinel_map = {}
54
55
self .in_use = {}
55
56
56
57
def _ensure_loop (self , loop ):
@@ -68,16 +69,27 @@ def _ensure_loop(self, loop):
68
69
69
70
return self .conn_map [loop ], loop
70
71
72
+ async def create_conn (self , loop ):
73
+ # One connection per pool since we are emulating a single connection
74
+ kwargs = {"minsize" : 1 , "maxsize" : 1 , ** self .host }
75
+ if not (sys .version_info >= (3 , 8 , 0 ) and AIOREDIS_VERSION >= (1 , 3 , 1 )):
76
+ kwargs ["loop" ] = loop
77
+ if self .master_name is None :
78
+ return await aioredis .create_redis_pool (** kwargs )
79
+ else :
80
+ kwargs = {"timeout" : 2 , ** kwargs } # aioredis default is way too low
81
+ sentinel = await aioredis .sentinel .create_sentinel (** kwargs )
82
+ conn = sentinel .master_for (self .master_name )
83
+ self .sentinel_map [conn ] = sentinel
84
+ return conn
85
+
71
86
async def pop (self , loop = None ):
72
87
"""
73
88
Get a connection for the given identifier and loop.
74
89
"""
75
90
conns , loop = self ._ensure_loop (loop )
76
91
if not conns :
77
- if sys .version_info >= (3 , 8 , 0 ) and AIOREDIS_VERSION >= (1 , 3 , 1 ):
78
- conn = await aioredis .create_redis (** self .host )
79
- else :
80
- conn = await aioredis .create_redis (** self .host , loop = loop )
92
+ conn = await self .create_conn (loop )
81
93
conns .append (conn )
82
94
conn = conns .pop ()
83
95
if conn .closed :
@@ -96,48 +108,58 @@ def push(self, conn):
96
108
conns , _ = self ._ensure_loop (loop )
97
109
conns .append (conn )
98
110
99
- def conn_error (self , conn ):
111
+ async def conn_error (self , conn ):
100
112
"""
101
113
Handle a connection that produced an error.
102
114
"""
103
- conn . close ( )
115
+ await self . _close_conn ( conn )
104
116
del self .in_use [conn ]
105
117
106
118
def reset (self ):
107
119
"""
108
120
Clear all connections from the pool.
109
121
"""
110
122
self .conn_map = {}
123
+ self .sentinel_map = {}
111
124
self .in_use = {}
112
125
126
+ async def _close_conn (self , conn , sentinel_map = None ):
127
+ if sentinel_map is None :
128
+ sentinel_map = self .sentinel_map
129
+ if conn in sentinel_map :
130
+ sentinel_map [conn ].close ()
131
+ await sentinel_map [conn ].wait_closed ()
132
+ del sentinel_map [conn ]
133
+ conn .close ()
134
+ await conn .wait_closed ()
135
+
113
136
async def close_loop (self , loop ):
114
137
"""
115
138
Close all connections owned by the pool on the given loop.
116
139
"""
117
140
if loop in self .conn_map :
118
141
for conn in self .conn_map [loop ]:
119
- conn .close ()
120
- await conn .wait_closed ()
142
+ await self ._close_conn (conn )
121
143
del self .conn_map [loop ]
122
144
123
145
for k , v in self .in_use .items ():
124
146
if v is loop :
147
+ await self ._close_conn (k )
125
148
self .in_use [k ] = None
126
149
127
150
async def close (self ):
128
151
"""
129
152
Close all connections owned by the pool.
130
153
"""
131
154
conn_map = self .conn_map
155
+ sentinel_map = self .sentinel_map
132
156
in_use = self .in_use
133
157
self .reset ()
134
158
for conns in conn_map .values ():
135
159
for conn in conns :
136
- conn .close ()
137
- await conn .wait_closed ()
160
+ await self ._close_conn (conn , sentinel_map )
138
161
for conn in in_use :
139
- conn .close ()
140
- await conn .wait_closed ()
162
+ await self ._close_conn (conn , sentinel_map )
141
163
142
164
143
165
class ChannelLock :
@@ -262,6 +284,7 @@ def decode_hosts(self, hosts):
262
284
raise ValueError (
263
285
"You must pass a list of Redis hosts, even if there is only one."
264
286
)
287
+
265
288
# Decode each hosts entry into a kwargs dict
266
289
result = []
267
290
for entry in hosts :
@@ -888,7 +911,7 @@ async def __aenter__(self):
888
911
889
912
async def __aexit__ (self , exc_type , exc , tb ):
890
913
if exc :
891
- self .pool .conn_error (self .conn )
914
+ await self .pool .conn_error (self .conn )
892
915
else :
893
916
self .pool .push (self .conn )
894
917
self .conn = None
0 commit comments