19
19
create_pool ,
20
20
decode_hosts ,
21
21
)
22
+ from typing import TYPE_CHECKING , Dict , List , Tuple , Union , Optional
23
+
24
+ if TYPE_CHECKING :
25
+ from redis .asyncio .connection import ConnectionPool
26
+ from redis .asyncio .client import Redis
27
+ from .core import RedisChannelLayer
28
+ from typing_extensions import Buffer
22
29
23
30
logger = logging .getLogger (__name__ )
24
31
@@ -32,23 +39,27 @@ class ChannelLock:
32
39
"""
33
40
34
41
def __init__ (self ):
35
- self .locks = collections .defaultdict (asyncio .Lock )
36
- self .wait_counts = collections .defaultdict (int )
42
+ self .locks : "collections.defaultdict[str, asyncio.Lock]" = (
43
+ collections .defaultdict (asyncio .Lock )
44
+ )
45
+ self .wait_counts : "collections.defaultdict[str, int]" = collections .defaultdict (
46
+ int
47
+ )
37
48
38
- async def acquire (self , channel ) :
49
+ async def acquire (self , channel : str ) -> bool :
39
50
"""
40
51
Acquire the lock for the given channel.
41
52
"""
42
53
self .wait_counts [channel ] += 1
43
54
return await self .locks [channel ].acquire ()
44
55
45
- def locked (self , channel ) :
56
+ def locked (self , channel : str ) -> bool :
46
57
"""
47
58
Return ``True`` if the lock for the given channel is acquired.
48
59
"""
49
60
return self .locks [channel ].locked ()
50
61
51
- def release (self , channel ):
62
+ def release (self , channel : str ):
52
63
"""
53
64
Release the lock for the given channel.
54
65
"""
@@ -73,12 +84,12 @@ def put_nowait(self, item):
73
84
74
85
75
86
class RedisLoopLayer :
76
- def __init__ (self , channel_layer ):
87
+ def __init__ (self , channel_layer : "RedisChannelLayer" ):
77
88
self ._lock = asyncio .Lock ()
78
89
self .channel_layer = channel_layer
79
- self ._connections = {}
90
+ self ._connections : "Dict[int, Redis]" = {}
80
91
81
- def get_connection (self , index ) :
92
+ def get_connection (self , index : int ) -> "Redis" :
82
93
if index not in self ._connections :
83
94
pool = self .channel_layer .create_pool (index )
84
95
self ._connections [index ] = aioredis .Redis (connection_pool = pool )
@@ -134,7 +145,7 @@ def __init__(
134
145
symmetric_encryption_keys = symmetric_encryption_keys ,
135
146
)
136
147
# Cached redis connection pools and the event loop they are from
137
- self ._layers = {}
148
+ self ._layers : "Dict[asyncio.AbstractEventLoop, RedisLoopLayer]" = {}
138
149
# Normal channels choose a host index by cycling through the available hosts
139
150
self ._receive_index_generator = itertools .cycle (range (len (self .hosts )))
140
151
self ._send_index_generator = itertools .cycle (range (len (self .hosts )))
@@ -143,33 +154,33 @@ def __init__(
143
154
# Number of coroutines trying to receive right now
144
155
self .receive_count = 0
145
156
# The receive lock
146
- self .receive_lock = None
157
+ self .receive_lock : "Optional[asyncio.Lock]" = None
147
158
# Event loop they are trying to receive on
148
- self .receive_event_loop = None
159
+ self .receive_event_loop : "Optional[asyncio.AbstractEventLoop]" = None
149
160
# Buffered messages by process-local channel name
150
- self .receive_buffer = collections .defaultdict (
151
- functools .partial (BoundedQueue , self .capacity )
161
+ self .receive_buffer : " collections.defaultdict[str, BoundedQueue]" = (
162
+ collections . defaultdict ( functools .partial (BoundedQueue , self .capacity ) )
152
163
)
153
164
# Detached channel cleanup tasks
154
- self .receive_cleaners = []
165
+ self .receive_cleaners : "List[asyncio.Task]" = []
155
166
# Per-channel cleanup locks to prevent a receive starting and moving
156
167
# a message back into the main queue before its cleanup has completed
157
168
self .receive_clean_locks = ChannelLock ()
158
169
159
- def create_pool (self , index ) :
170
+ def create_pool (self , index : int ) -> "ConnectionPool" :
160
171
return create_pool (self .hosts [index ])
161
172
162
173
### Channel layer API ###
163
174
164
175
extensions = ["groups" , "flush" ]
165
176
166
- async def send (self , channel , message ):
177
+ async def send (self , channel : str , message ):
167
178
"""
168
179
Send a message onto a (general or specific) channel.
169
180
"""
170
181
# Typecheck
171
182
assert isinstance (message , dict ), "message is not a dict"
172
- assert self .valid_channel_name (channel ), "Channel name not valid"
183
+ assert self .require_valid_channel_name (channel ), "Channel name not valid"
173
184
# Make sure the message does not contain reserved keys
174
185
assert "__asgi_channel__" not in message
175
186
# If it's a process-local channel, strip off local part and stick full name in message
@@ -203,13 +214,15 @@ async def send(self, channel, message):
203
214
await connection .zadd (channel_key , {self .serialize (message ): time .time ()})
204
215
await connection .expire (channel_key , int (self .expiry ))
205
216
206
- def _backup_channel_name (self , channel ) :
217
+ def _backup_channel_name (self , channel : str ) -> str :
207
218
"""
208
219
Construct the key used as a backup queue for the given channel.
209
220
"""
210
221
return channel + "$inflight"
211
222
212
- async def _brpop_with_clean (self , index , channel , timeout ):
223
+ async def _brpop_with_clean (
224
+ self , index : int , channel : str , timeout : "Union[int, float, bytes, str]"
225
+ ):
213
226
"""
214
227
Perform a Redis BRPOP and manage the backup processing queue.
215
228
In case of cancellation, make sure the message is not lost.
@@ -240,23 +253,23 @@ async def _brpop_with_clean(self, index, channel, timeout):
240
253
241
254
return member
242
255
243
- async def _clean_receive_backup (self , index , channel ):
256
+ async def _clean_receive_backup (self , index : int , channel : str ):
244
257
"""
245
258
Pop the oldest message off the channel backup queue.
246
259
The result isn't interesting as it was already processed.
247
260
"""
248
261
connection = self .connection (index )
249
262
await connection .zpopmin (self ._backup_channel_name (channel ))
250
263
251
- async def receive (self , channel ):
264
+ async def receive (self , channel : str ):
252
265
"""
253
266
Receive the first message that arrives on the channel.
254
267
If more than one coroutine waits on the same channel, the first waiter
255
268
will be given the message when it arrives.
256
269
"""
257
270
# Make sure the channel name is valid then get the non-local part
258
271
# and thus its index
259
- assert self .valid_channel_name (channel )
272
+ assert self .require_valid_channel_name (channel )
260
273
if "!" in channel :
261
274
real_channel = self .non_local_name (channel )
262
275
assert real_channel .endswith (
@@ -372,12 +385,14 @@ async def receive(self, channel):
372
385
# Do a plain direct receive
373
386
return (await self .receive_single (channel ))[1 ]
374
387
375
- async def receive_single (self , channel ) :
388
+ async def receive_single (self , channel : str ) -> "Tuple" :
376
389
"""
377
390
Receives a single message off of the channel and returns it.
378
391
"""
379
392
# Check channel name
380
- assert self .valid_channel_name (channel , receive = True ), "Channel name invalid"
393
+ assert self .require_valid_channel_name (
394
+ channel , receive = True
395
+ ), "Channel name invalid"
381
396
# Work out the connection to use
382
397
if "!" in channel :
383
398
assert channel .endswith ("!" )
@@ -408,7 +423,7 @@ async def receive_single(self, channel):
408
423
)
409
424
self .receive_cleaners .append (cleaner )
410
425
411
- def _cleanup_done (cleaner ):
426
+ def _cleanup_done (cleaner : "asyncio.Task" ):
412
427
self .receive_cleaners .remove (cleaner )
413
428
self .receive_clean_locks .release (channel_key )
414
429
@@ -427,7 +442,7 @@ def _cleanup_done(cleaner):
427
442
del message ["__asgi_channel__" ]
428
443
return channel , message
429
444
430
- async def new_channel (self , prefix = "specific" ):
445
+ async def new_channel (self , prefix : str = "specific" ) -> str :
431
446
"""
432
447
Returns a new channel name that can be used by something in our
433
448
process as a specific channel.
@@ -477,13 +492,13 @@ async def wait_received(self):
477
492
478
493
### Groups extension ###
479
494
480
- async def group_add (self , group , channel ):
495
+ async def group_add (self , group : str , channel : str ):
481
496
"""
482
497
Adds the channel name to a group.
483
498
"""
484
499
# Check the inputs
485
- assert self .valid_group_name (group ), "Group name not valid"
486
- assert self .valid_channel_name (channel ), "Channel name not valid"
500
+ assert self .require_valid_group_name (group ), True
501
+ assert self .require_valid_channel_name (channel ), True
487
502
# Get a connection to the right shard
488
503
group_key = self ._group_key (group )
489
504
connection = self .connection (self .consistent_hash (group ))
@@ -493,22 +508,22 @@ async def group_add(self, group, channel):
493
508
# it at this point is guaranteed to expire before that
494
509
await connection .expire (group_key , self .group_expiry )
495
510
496
- async def group_discard (self , group , channel ):
511
+ async def group_discard (self , group : str , channel : str ):
497
512
"""
498
513
Removes the channel from the named group if it is in the group;
499
514
does nothing otherwise (does not error)
500
515
"""
501
- assert self .valid_group_name (group ), "Group name not valid"
502
- assert self .valid_channel_name (channel ), "Channel name not valid"
516
+ assert self .require_valid_group_name (group ), "Group name not valid"
517
+ assert self .require_valid_channel_name (channel ), "Channel name not valid"
503
518
key = self ._group_key (group )
504
519
connection = self .connection (self .consistent_hash (group ))
505
520
await connection .zrem (key , channel )
506
521
507
- async def group_send (self , group , message ):
522
+ async def group_send (self , group : str , message ):
508
523
"""
509
524
Sends a message to the entire group.
510
525
"""
511
- assert self .valid_group_name (group ), "Group name not valid"
526
+ assert self .require_valid_group_name (group ), "Group name not valid"
512
527
# Retrieve list of all channel names
513
528
key = self ._group_key (group )
514
529
connection = self .connection (self .consistent_hash (group ))
@@ -573,7 +588,12 @@ async def group_send(self, group, message):
573
588
channels_over_capacity = await connection .eval (
574
589
group_send_lua , len (channel_redis_keys ), * channel_redis_keys , * args
575
590
)
576
- if channels_over_capacity > 0 :
591
+ _channels_over_capacity = - 1
592
+ try :
593
+ _channels_over_capacity = float (channels_over_capacity )
594
+ except Exception :
595
+ pass
596
+ if _channels_over_capacity > 0 :
577
597
logger .info (
578
598
"%s of %s channels over capacity in group %s" ,
579
599
channels_over_capacity ,
@@ -631,37 +651,35 @@ def _map_channel_keys_to_connection(self, channel_names, message):
631
651
channel_key_to_capacity ,
632
652
)
633
653
634
- def _group_key (self , group ) :
654
+ def _group_key (self , group : str ) -> bytes :
635
655
"""
636
656
Common function to make the storage key for the group.
637
657
"""
638
658
return f"{ self .prefix } :group:{ group } " .encode ("utf8" )
639
659
640
- ### Serialization ###
641
-
642
- def serialize (self , message ):
660
+ def serialize (self , message ) -> bytes :
643
661
"""
644
662
Serializes message to a byte string.
645
663
"""
646
664
return self ._serializer .serialize (message )
647
665
648
- def deserialize (self , message ):
666
+ def deserialize (self , message : bytes ):
649
667
"""
650
668
Deserializes from a byte string.
651
669
"""
652
670
return self ._serializer .deserialize (message )
653
671
654
672
### Internal functions ###
655
673
656
- def consistent_hash (self , value ) :
674
+ def consistent_hash (self , value : "Union[str, Buffer]" ) -> int :
657
675
return _consistent_hash (value , self .ring_size )
658
676
659
677
def __str__ (self ):
660
678
return f"{ self .__class__ .__name__ } (hosts={ self .hosts } )"
661
679
662
680
### Connection handling ###
663
681
664
- def connection (self , index ) :
682
+ def connection (self , index : int ) -> "Redis" :
665
683
"""
666
684
Returns the correct connection for the index given.
667
685
Lazily instantiates pools.
0 commit comments