19
19
create_pool ,
20
20
decode_hosts ,
21
21
)
22
+ from typing import TYPE_CHECKING , Dict , List , Tuple , Union , Optional
23
+
24
+ if TYPE_CHECKING :
25
+ from redis .asyncio .connection import ConnectionPool
26
+ from redis .asyncio .client import Redis
27
+ from .core import RedisChannelLayer
28
+ from _typeshed import ReadableBuffer
29
+ from redis .typing import TimeoutSecT
22
30
23
31
logger = logging .getLogger (__name__ )
24
32
@@ -32,23 +40,27 @@ class ChannelLock:
32
40
"""
33
41
34
42
def __init__ (self ):
35
- self .locks = collections .defaultdict (asyncio .Lock )
36
- self .wait_counts = collections .defaultdict (int )
43
+ self .locks : "collections.defaultdict[str, asyncio.Lock]" = (
44
+ collections .defaultdict (asyncio .Lock )
45
+ )
46
+ self .wait_counts : "collections.defaultdict[str, int]" = collections .defaultdict (
47
+ int
48
+ )
37
49
38
- async def acquire (self , channel ) :
50
+ async def acquire (self , channel : str ) -> bool :
39
51
"""
40
52
Acquire the lock for the given channel.
41
53
"""
42
54
self .wait_counts [channel ] += 1
43
55
return await self .locks [channel ].acquire ()
44
56
45
- def locked (self , channel ) :
57
+ def locked (self , channel : str ) -> bool :
46
58
"""
47
59
Return ``True`` if the lock for the given channel is acquired.
48
60
"""
49
61
return self .locks [channel ].locked ()
50
62
51
- def release (self , channel ):
63
+ def release (self , channel : str ):
52
64
"""
53
65
Release the lock for the given channel.
54
66
"""
@@ -73,12 +85,12 @@ def put_nowait(self, item):
73
85
74
86
75
87
class RedisLoopLayer :
76
- def __init__ (self , channel_layer ):
88
+ def __init__ (self , channel_layer : "RedisChannelLayer" ):
77
89
self ._lock = asyncio .Lock ()
78
90
self .channel_layer = channel_layer
79
- self ._connections = {}
91
+ self ._connections : "Dict[int, Redis]" = {}
80
92
81
- def get_connection (self , index ) :
93
+ def get_connection (self , index : int ) -> "Redis" :
82
94
if index not in self ._connections :
83
95
pool = self .channel_layer .create_pool (index )
84
96
self ._connections [index ] = aioredis .Redis (connection_pool = pool )
@@ -134,7 +146,7 @@ def __init__(
134
146
symmetric_encryption_keys = symmetric_encryption_keys ,
135
147
)
136
148
# Cached redis connection pools and the event loop they are from
137
- self ._layers = {}
149
+ self ._layers : "Dict[asyncio.AbstractEventLoop, RedisLoopLayer]" = {}
138
150
# Normal channels choose a host index by cycling through the available hosts
139
151
self ._receive_index_generator = itertools .cycle (range (len (self .hosts )))
140
152
self ._send_index_generator = itertools .cycle (range (len (self .hosts )))
@@ -143,33 +155,33 @@ def __init__(
143
155
# Number of coroutines trying to receive right now
144
156
self .receive_count = 0
145
157
# The receive lock
146
- self .receive_lock = None
158
+ self .receive_lock : "Optional[asyncio.Lock]" = None
147
159
# Event loop they are trying to receive on
148
- self .receive_event_loop = None
160
+ self .receive_event_loop : "Optional[asyncio.AbstractEventLoop]" = None
149
161
# Buffered messages by process-local channel name
150
- self .receive_buffer = collections .defaultdict (
151
- functools .partial (BoundedQueue , self .capacity )
162
+ self .receive_buffer : " collections.defaultdict[str, BoundedQueue]" = (
163
+ collections . defaultdict ( functools .partial (BoundedQueue , self .capacity ) )
152
164
)
153
165
# Detached channel cleanup tasks
154
- self .receive_cleaners = []
166
+ self .receive_cleaners : "List[asyncio.Task]" = []
155
167
# Per-channel cleanup locks to prevent a receive starting and moving
156
168
# a message back into the main queue before its cleanup has completed
157
169
self .receive_clean_locks = ChannelLock ()
158
170
159
- def create_pool (self , index ) :
171
+ def create_pool (self , index : int ) -> "ConnectionPool" :
160
172
return create_pool (self .hosts [index ])
161
173
162
174
### Channel layer API ###
163
175
164
176
extensions = ["groups" , "flush" ]
165
177
166
- async def send (self , channel , message ):
178
+ async def send (self , channel : str , message ):
167
179
"""
168
180
Send a message onto a (general or specific) channel.
169
181
"""
170
182
# Typecheck
171
183
assert isinstance (message , dict ), "message is not a dict"
172
- assert self .valid_channel_name (channel ), "Channel name not valid"
184
+ assert self .require_valid_channel_name (channel ), "Channel name not valid"
173
185
# Make sure the message does not contain reserved keys
174
186
assert "__asgi_channel__" not in message
175
187
# If it's a process-local channel, strip off local part and stick full name in message
@@ -203,13 +215,13 @@ async def send(self, channel, message):
203
215
await connection .zadd (channel_key , {self .serialize (message ): time .time ()})
204
216
await connection .expire (channel_key , int (self .expiry ))
205
217
206
- def _backup_channel_name (self , channel ) :
218
+ def _backup_channel_name (self , channel : str ) -> str :
207
219
"""
208
220
Construct the key used as a backup queue for the given channel.
209
221
"""
210
222
return channel + "$inflight"
211
223
212
- async def _brpop_with_clean (self , index , channel , timeout ):
224
+ async def _brpop_with_clean (self , index : int , channel : str , timeout : "TimeoutSecT" ):
213
225
"""
214
226
Perform a Redis BRPOP and manage the backup processing queue.
215
227
In case of cancellation, make sure the message is not lost.
@@ -240,23 +252,23 @@ async def _brpop_with_clean(self, index, channel, timeout):
240
252
241
253
return member
242
254
243
- async def _clean_receive_backup (self , index , channel ):
255
+ async def _clean_receive_backup (self , index : int , channel : str ):
244
256
"""
245
257
Pop the oldest message off the channel backup queue.
246
258
The result isn't interesting as it was already processed.
247
259
"""
248
260
connection = self .connection (index )
249
261
await connection .zpopmin (self ._backup_channel_name (channel ))
250
262
251
- async def receive (self , channel ):
263
+ async def receive (self , channel : str ):
252
264
"""
253
265
Receive the first message that arrives on the channel.
254
266
If more than one coroutine waits on the same channel, the first waiter
255
267
will be given the message when it arrives.
256
268
"""
257
269
# Make sure the channel name is valid then get the non-local part
258
270
# and thus its index
259
- assert self .valid_channel_name (channel )
271
+ assert self .require_valid_channel_name (channel )
260
272
if "!" in channel :
261
273
real_channel = self .non_local_name (channel )
262
274
assert real_channel .endswith (
@@ -372,12 +384,14 @@ async def receive(self, channel):
372
384
# Do a plain direct receive
373
385
return (await self .receive_single (channel ))[1 ]
374
386
375
- async def receive_single (self , channel ) :
387
+ async def receive_single (self , channel : str ) -> "Tuple" :
376
388
"""
377
389
Receives a single message off of the channel and returns it.
378
390
"""
379
391
# Check channel name
380
- assert self .valid_channel_name (channel , receive = True ), "Channel name invalid"
392
+ assert self .require_valid_channel_name (
393
+ channel , receive = True
394
+ ), "Channel name invalid"
381
395
# Work out the connection to use
382
396
if "!" in channel :
383
397
assert channel .endswith ("!" )
@@ -408,7 +422,7 @@ async def receive_single(self, channel):
408
422
)
409
423
self .receive_cleaners .append (cleaner )
410
424
411
- def _cleanup_done (cleaner ):
425
+ def _cleanup_done (cleaner : "asyncio.Task" ):
412
426
self .receive_cleaners .remove (cleaner )
413
427
self .receive_clean_locks .release (channel_key )
414
428
@@ -427,7 +441,7 @@ def _cleanup_done(cleaner):
427
441
del message ["__asgi_channel__" ]
428
442
return channel , message
429
443
430
- async def new_channel (self , prefix = "specific" ):
444
+ async def new_channel (self , prefix : str = "specific" ) -> str :
431
445
"""
432
446
Returns a new channel name that can be used by something in our
433
447
process as a specific channel.
@@ -477,13 +491,13 @@ async def wait_received(self):
477
491
478
492
### Groups extension ###
479
493
480
- async def group_add (self , group , channel ):
494
+ async def group_add (self , group : str , channel : str ):
481
495
"""
482
496
Adds the channel name to a group.
483
497
"""
484
498
# Check the inputs
485
- assert self .valid_group_name (group ), "Group name not valid"
486
- assert self .valid_channel_name (channel ), "Channel name not valid"
499
+ assert self .require_valid_group_name (group ), True
500
+ assert self .require_valid_channel_name (channel ), True
487
501
# Get a connection to the right shard
488
502
group_key = self ._group_key (group )
489
503
connection = self .connection (self .consistent_hash (group ))
@@ -493,22 +507,22 @@ async def group_add(self, group, channel):
493
507
# it at this point is guaranteed to expire before that
494
508
await connection .expire (group_key , self .group_expiry )
495
509
496
- async def group_discard (self , group , channel ):
510
+ async def group_discard (self , group : str , channel : str ):
497
511
"""
498
512
Removes the channel from the named group if it is in the group;
499
513
does nothing otherwise (does not error)
500
514
"""
501
- assert self .valid_group_name (group ), "Group name not valid"
502
- assert self .valid_channel_name (channel ), "Channel name not valid"
515
+ assert self .require_valid_group_name (group ), "Group name not valid"
516
+ assert self .require_valid_channel_name (channel ), "Channel name not valid"
503
517
key = self ._group_key (group )
504
518
connection = self .connection (self .consistent_hash (group ))
505
519
await connection .zrem (key , channel )
506
520
507
- async def group_send (self , group , message ):
521
+ async def group_send (self , group : str , message ):
508
522
"""
509
523
Sends a message to the entire group.
510
524
"""
511
- assert self .valid_group_name (group ), "Group name not valid"
525
+ assert self .require_valid_group_name (group ), "Group name not valid"
512
526
# Retrieve list of all channel names
513
527
key = self ._group_key (group )
514
528
connection = self .connection (self .consistent_hash (group ))
@@ -573,7 +587,12 @@ async def group_send(self, group, message):
573
587
channels_over_capacity = await connection .eval (
574
588
group_send_lua , len (channel_redis_keys ), * channel_redis_keys , * args
575
589
)
576
- if channels_over_capacity > 0 :
590
+ _channels_over_capacity = - 1
591
+ try :
592
+ _channels_over_capacity = float (channels_over_capacity )
593
+ except Exception :
594
+ pass
595
+ if _channels_over_capacity > 0 :
577
596
logger .info (
578
597
"%s of %s channels over capacity in group %s" ,
579
598
channels_over_capacity ,
@@ -631,37 +650,35 @@ def _map_channel_keys_to_connection(self, channel_names, message):
631
650
channel_key_to_capacity ,
632
651
)
633
652
634
- def _group_key (self , group ) :
653
+ def _group_key (self , group : str ) -> bytes :
635
654
"""
636
655
Common function to make the storage key for the group.
637
656
"""
638
657
return f"{ self .prefix } :group:{ group } " .encode ("utf8" )
639
658
640
- ### Serialization ###
641
-
642
- def serialize (self , message ):
659
+ def serialize (self , message ) -> bytes :
643
660
"""
644
661
Serializes message to a byte string.
645
662
"""
646
663
return self ._serializer .serialize (message )
647
664
648
- def deserialize (self , message ):
665
+ def deserialize (self , message : bytes ):
649
666
"""
650
667
Deserializes from a byte string.
651
668
"""
652
669
return self ._serializer .deserialize (message )
653
670
654
671
### Internal functions ###
655
672
656
- def consistent_hash (self , value ) :
673
+ def consistent_hash (self , value : "Union[str, ReadableBuffer]" ) -> int :
657
674
return _consistent_hash (value , self .ring_size )
658
675
659
676
def __str__ (self ):
660
677
return f"{ self .__class__ .__name__ } (hosts={ self .hosts } )"
661
678
662
679
### Connection handling ###
663
680
664
- def connection (self , index ) :
681
+ def connection (self , index : int ) -> "Redis" :
665
682
"""
666
683
Returns the correct connection for the index given.
667
684
Lazily instantiates pools.
0 commit comments