9
9
10
10
from redis import asyncio as aioredis
11
11
12
- from channels .exceptions import ChannelFull
13
- from channels .layers import BaseChannelLayer
12
+ from channels .exceptions import ChannelFull # type: ignore[import-untyped]
13
+ from channels .layers import BaseChannelLayer # type: ignore[import-untyped]
14
14
15
15
from .serializers import registry
16
16
from .utils import (
@@ -37,13 +37,11 @@ class ChannelLock:
37
37
to mitigate multi-event loop problems.
38
38
"""
39
39
40
- def __init__ (self ):
41
- self .locks : collections .defaultdict [str , asyncio .Lock ] = (
42
- collections .defaultdict (asyncio .Lock )
43
- )
44
- self .wait_counts : collections .defaultdict [str , int ] = collections .defaultdict (
45
- int
40
+ def __init__ (self ) -> None :
41
+ self .locks : typing .DefaultDict [str , asyncio .Lock ] = collections .defaultdict (
42
+ asyncio .Lock
46
43
)
44
+ self .wait_counts : typing .DefaultDict [str , int ] = collections .defaultdict (int )
47
45
48
46
async def acquire (self , channel : str ) -> bool :
49
47
"""
@@ -58,7 +56,7 @@ def locked(self, channel: str) -> bool:
58
56
"""
59
57
return self .locks [channel ].locked ()
60
58
61
- def release (self , channel : str ):
59
+ def release (self , channel : str ) -> None :
62
60
"""
63
61
Release the lock for the given channel.
64
62
"""
@@ -70,7 +68,7 @@ def release(self, channel: str):
70
68
71
69
72
70
class BoundedQueue (asyncio .Queue ):
73
- def put_nowait (self , item ) :
71
+ def put_nowait (self , item : typing . Any ) -> None :
74
72
if self .full ():
75
73
# see: https://github.com/django/channels_redis/issues/212
76
74
# if we actually get into this code block, it likely means that
@@ -83,7 +81,7 @@ def put_nowait(self, item):
83
81
84
82
85
83
class RedisLoopLayer :
86
- def __init__ (self , channel_layer : "RedisChannelLayer" ):
84
+ def __init__ (self , channel_layer : "RedisChannelLayer" ) -> None :
87
85
self ._lock = asyncio .Lock ()
88
86
self .channel_layer = channel_layer
89
87
self ._connections : typing .Dict [int , "Redis" ] = {}
@@ -95,7 +93,7 @@ def get_connection(self, index: int) -> "Redis":
95
93
96
94
return self ._connections [index ]
97
95
98
- async def flush (self ):
96
+ async def flush (self ) -> None :
99
97
async with self ._lock :
100
98
for index in list (self ._connections ):
101
99
connection = self ._connections .pop (index )
@@ -116,15 +114,15 @@ class RedisChannelLayer(BaseChannelLayer):
116
114
def __init__ (
117
115
self ,
118
116
hosts = None ,
119
- prefix = "asgi" ,
117
+ prefix : str = "asgi" ,
120
118
expiry = 60 ,
121
- group_expiry = 86400 ,
119
+ group_expiry : int = 86400 ,
122
120
capacity = 100 ,
123
121
channel_capacity = None ,
124
122
symmetric_encryption_keys = None ,
125
123
random_prefix_length = 12 ,
126
124
serializer_format = "msgpack" ,
127
- ):
125
+ ) -> None :
128
126
# Store basic information
129
127
self .expiry = expiry
130
128
self .group_expiry = group_expiry
@@ -157,11 +155,11 @@ def __init__(
157
155
# Event loop they are trying to receive on
158
156
self .receive_event_loop : typing .Optional [asyncio .AbstractEventLoop ] = None
159
157
# Buffered messages by process-local channel name
160
- self .receive_buffer : collections . defaultdict [str , BoundedQueue ] = (
158
+ self .receive_buffer : typing . DefaultDict [str , BoundedQueue ] = (
161
159
collections .defaultdict (functools .partial (BoundedQueue , self .capacity ))
162
160
)
163
161
# Detached channel cleanup tasks
164
- self .receive_cleaners : typing .List [asyncio .Task ] = []
162
+ self .receive_cleaners : typing .List [" asyncio.Task[typing.Any]" ] = []
165
163
# Per-channel cleanup locks to prevent a receive starting and moving
166
164
# a message back into the main queue before its cleanup has completed
167
165
self .receive_clean_locks = ChannelLock ()
@@ -173,7 +171,7 @@ def create_pool(self, index: int) -> "ConnectionPool":
173
171
174
172
extensions = ["groups" , "flush" ]
175
173
176
- async def send (self , channel : str , message ) :
174
+ async def send (self , channel : str , message : typing . Any ) -> None :
177
175
"""
178
176
Send a message onto a (general or specific) channel.
179
177
"""
@@ -221,7 +219,7 @@ def _backup_channel_name(self, channel: str) -> str:
221
219
222
220
async def _brpop_with_clean (
223
221
self , index : int , channel : str , timeout : typing .Union [int , float , bytes , str ]
224
- ):
222
+ ) -> typing . Any :
225
223
"""
226
224
Perform a Redis BRPOP and manage the backup processing queue.
227
225
In case of cancellation, make sure the message is not lost.
@@ -240,7 +238,7 @@ async def _brpop_with_clean(
240
238
connection = self .connection (index )
241
239
# Cancellation here doesn't matter, we're not doing anything destructive
242
240
# and the script executes atomically...
243
- await connection .eval (cleanup_script , 0 , channel , backup_queue )
241
+ await connection .eval (cleanup_script , 0 , channel , backup_queue ) # type: ignore[misc]
244
242
# ...and it doesn't matter here either, the message will be safe in the backup.
245
243
result = await connection .bzpopmin (channel , timeout = timeout )
246
244
@@ -252,15 +250,15 @@ async def _brpop_with_clean(
252
250
253
251
return member
254
252
255
- async def _clean_receive_backup (self , index : int , channel : str ):
253
+ async def _clean_receive_backup (self , index : int , channel : str ) -> None :
256
254
"""
257
255
Pop the oldest message off the channel backup queue.
258
256
The result isn't interesting as it was already processed.
259
257
"""
260
258
connection = self .connection (index )
261
259
await connection .zpopmin (self ._backup_channel_name (channel ))
262
260
263
- async def receive (self , channel : str ):
261
+ async def receive (self , channel : str ) -> typing . Any :
264
262
"""
265
263
Receive the first message that arrives on the channel.
266
264
If more than one coroutine waits on the same channel, the first waiter
@@ -292,11 +290,11 @@ async def receive(self, channel: str):
292
290
# Wait for our message to appear
293
291
message = None
294
292
while self .receive_buffer [channel ].empty ():
295
- tasks = [
296
- self .receive_lock .acquire (),
293
+ _tasks = [
294
+ self .receive_lock .acquire (), # type: ignore[union-attr]
297
295
self .receive_buffer [channel ].get (),
298
296
]
299
- tasks = [asyncio .ensure_future (task ) for task in tasks ]
297
+ tasks = [asyncio .ensure_future (task ) for task in _tasks ]
300
298
try :
301
299
done , pending = await asyncio .wait (
302
300
tasks , return_when = asyncio .FIRST_COMPLETED
@@ -312,7 +310,7 @@ async def receive(self, channel: str):
312
310
if not task .cancel ():
313
311
assert task .done ()
314
312
if task .result () is True :
315
- self .receive_lock .release ()
313
+ self .receive_lock .release () # type: ignore[union-attr]
316
314
317
315
raise
318
316
@@ -335,7 +333,7 @@ async def receive(self, channel: str):
335
333
if message or exception :
336
334
if token :
337
335
# We will not be receving as we already have the message.
338
- self .receive_lock .release ()
336
+ self .receive_lock .release () # type: ignore[union-attr]
339
337
340
338
if exception :
341
339
raise exception
@@ -362,7 +360,7 @@ async def receive(self, channel: str):
362
360
del self .receive_buffer [channel ]
363
361
raise
364
362
finally :
365
- self .receive_lock .release ()
363
+ self .receive_lock .release () # type: ignore[union-attr]
366
364
367
365
# We know there's a message available, because there
368
366
# couldn't have been any interruption between empty() and here
@@ -377,14 +375,16 @@ async def receive(self, channel: str):
377
375
self .receive_count -= 1
378
376
# If we were the last out, drop the receive lock
379
377
if self .receive_count == 0 :
380
- assert not self .receive_lock .locked ()
378
+ assert not self .receive_lock .locked () # type: ignore[union-attr]
381
379
self .receive_lock = None
382
380
self .receive_event_loop = None
383
381
else :
384
382
# Do a plain direct receive
385
383
return (await self .receive_single (channel ))[1 ]
386
384
387
- async def receive_single (self , channel : str ) -> typing .Tuple :
385
+ async def receive_single (
386
+ self , channel : str
387
+ ) -> typing .Tuple [typing .Any , typing .Any ]:
388
388
"""
389
389
Receives a single message off of the channel and returns it.
390
390
"""
@@ -420,7 +420,7 @@ async def receive_single(self, channel: str) -> typing.Tuple:
420
420
)
421
421
self .receive_cleaners .append (cleaner )
422
422
423
- def _cleanup_done (cleaner : asyncio .Task ) :
423
+ def _cleanup_done (cleaner : " asyncio.Task" ) -> None :
424
424
self .receive_cleaners .remove (cleaner )
425
425
self .receive_clean_locks .release (channel_key )
426
426
@@ -448,7 +448,7 @@ async def new_channel(self, prefix: str = "specific") -> str:
448
448
449
449
### Flush extension ###
450
450
451
- async def flush (self ):
451
+ async def flush (self ) -> None :
452
452
"""
453
453
Deletes all messages and groups on all shards.
454
454
"""
@@ -466,11 +466,11 @@ async def flush(self):
466
466
# Go through each connection and remove all with prefix
467
467
for i in range (self .ring_size ):
468
468
connection = self .connection (i )
469
- await connection .eval (delete_prefix , 0 , self .prefix + "*" )
469
+ await connection .eval (delete_prefix , 0 , self .prefix + "*" ) # type: ignore[union-attr,misc]
470
470
# Now clear the pools as well
471
471
await self .close_pools ()
472
472
473
- async def close_pools (self ):
473
+ async def close_pools (self ) -> None :
474
474
"""
475
475
Close all connections in the event loop pools.
476
476
"""
@@ -480,7 +480,7 @@ async def close_pools(self):
480
480
for layer in self ._layers .values ():
481
481
await layer .flush ()
482
482
483
- async def wait_received (self ):
483
+ async def wait_received (self ) -> None :
484
484
"""
485
485
Wait for all channel cleanup functions to finish.
486
486
"""
@@ -489,13 +489,13 @@ async def wait_received(self):
489
489
490
490
### Groups extension ###
491
491
492
- async def group_add (self , group : str , channel : str ):
492
+ async def group_add (self , group : str , channel : str ) -> None :
493
493
"""
494
494
Adds the channel name to a group.
495
495
"""
496
496
# Check the inputs
497
- assert self .valid_group_name (group ), True
498
- assert self .valid_channel_name (channel ), True
497
+ assert self .valid_group_name (group ), "Group name not valid"
498
+ assert self .valid_channel_name (channel ), "Channel name not valid"
499
499
# Get a connection to the right shard
500
500
group_key = self ._group_key (group )
501
501
connection = self .connection (self .consistent_hash (group ))
@@ -505,7 +505,7 @@ async def group_add(self, group: str, channel: str):
505
505
# it at this point is guaranteed to expire before that
506
506
await connection .expire (group_key , self .group_expiry )
507
507
508
- async def group_discard (self , group : str , channel : str ):
508
+ async def group_discard (self , group : str , channel : str ) -> None :
509
509
"""
510
510
Removes the channel from the named group if it is in the group;
511
511
does nothing otherwise (does not error)
@@ -516,7 +516,7 @@ async def group_discard(self, group: str, channel: str):
516
516
connection = self .connection (self .consistent_hash (group ))
517
517
await connection .zrem (key , channel )
518
518
519
- async def group_send (self , group : str , message ) :
519
+ async def group_send (self , group : str , message : typing . Any ) -> None :
520
520
"""
521
521
Sends a message to the entire group.
522
522
"""
@@ -540,9 +540,9 @@ async def group_send(self, group: str, message):
540
540
for connection_index , channel_redis_keys in connection_to_channel_keys .items ():
541
541
# Discard old messages based on expiry
542
542
pipe = connection .pipeline ()
543
- for key in channel_redis_keys :
543
+ for _key in channel_redis_keys :
544
544
pipe .zremrangebyscore (
545
- key , min = 0 , max = int (time .time ()) - int (self .expiry )
545
+ _key , min = 0 , max = int (time .time ()) - int (self .expiry )
546
546
)
547
547
await pipe .execute ()
548
548
@@ -582,10 +582,10 @@ async def group_send(self, group: str, message):
582
582
583
583
# channel_keys does not contain a single redis key more than once
584
584
connection = self .connection (connection_index )
585
- channels_over_capacity = await connection .eval (
585
+ channels_over_capacity = await connection .eval ( # type: ignore[misc]
586
586
group_send_lua , len (channel_redis_keys ), * channel_redis_keys , * args
587
587
)
588
- _channels_over_capacity = - 1
588
+ _channels_over_capacity = - 1.0
589
589
try :
590
590
_channels_over_capacity = float (channels_over_capacity )
591
591
except Exception :
@@ -598,7 +598,13 @@ async def group_send(self, group: str, message):
598
598
group ,
599
599
)
600
600
601
- def _map_channel_keys_to_connection (self , channel_names , message ):
601
+ def _map_channel_keys_to_connection (
602
+ self , channel_names : typing .Iterable [str ], message : typing .Any
603
+ ) -> typing .Tuple [
604
+ typing .Dict [int , typing .List [str ]],
605
+ typing .Dict [str , typing .Any ],
606
+ typing .Dict [str , int ],
607
+ ]:
602
608
"""
603
609
For a list of channel names, GET
604
610
@@ -611,19 +617,21 @@ def _map_channel_keys_to_connection(self, channel_names, message):
611
617
"""
612
618
613
619
# Connection dict keyed by index to list of redis keys mapped on that index
614
- connection_to_channel_keys = collections .defaultdict (list )
620
+ connection_to_channel_keys : typing .Dict [int , typing .List [str ]] = (
621
+ collections .defaultdict (list )
622
+ )
615
623
# Message dict maps redis key to the message that needs to be send on that key
616
- channel_key_to_message = dict ()
624
+ channel_key_to_message : typing . Dict [ str , typing . Any ] = dict ()
617
625
# Channel key mapped to its capacity
618
- channel_key_to_capacity = dict ()
626
+ channel_key_to_capacity : typing . Dict [ str , int ] = dict ()
619
627
620
628
# For each channel
621
629
for channel in channel_names :
622
630
channel_non_local_name = channel
623
631
if "!" in channel :
624
632
channel_non_local_name = self .non_local_name (channel )
625
633
# Get its redis key
626
- channel_key = self .prefix + channel_non_local_name
634
+ channel_key : str = self .prefix + channel_non_local_name
627
635
# Have we come across the same redis key?
628
636
if channel_key not in channel_key_to_message :
629
637
# If not, fill the corresponding dicts
@@ -654,13 +662,15 @@ def _group_key(self, group: str) -> bytes:
654
662
"""
655
663
return f"{ self .prefix } :group:{ group } " .encode ("utf8" )
656
664
657
- def serialize (self , message ) -> bytes :
665
+ ### Serialization ###
666
+
667
+ def serialize (self , message : typing .Any ) -> bytes :
658
668
"""
659
669
Serializes message to a byte string.
660
670
"""
661
671
return self ._serializer .serialize (message )
662
672
663
- def deserialize (self , message : bytes ):
673
+ def deserialize (self , message : bytes ) -> typing . Any :
664
674
"""
665
675
Deserializes from a byte string.
666
676
"""
@@ -671,7 +681,7 @@ def deserialize(self, message: bytes):
671
681
def consistent_hash (self , value : typing .Union [str , "Buffer" ]) -> int :
672
682
return _consistent_hash (value , self .ring_size )
673
683
674
- def __str__ (self ):
684
+ def __str__ (self ) -> str :
675
685
return f"{ self .__class__ .__name__ } (hosts={ self .hosts } )"
676
686
677
687
### Connection handling ###
0 commit comments