@@ -37,7 +37,7 @@ class ChannelLock:
37
37
to mitigate multi-event loop problems.
38
38
"""
39
39
40
- def __init__ (self ):
40
+ def __init__ (self ) -> None :
41
41
self .locks : collections .defaultdict [str , asyncio .Lock ] = (
42
42
collections .defaultdict (asyncio .Lock )
43
43
)
@@ -58,7 +58,7 @@ def locked(self, channel: str) -> bool:
58
58
"""
59
59
return self .locks [channel ].locked ()
60
60
61
- def release (self , channel : str ):
61
+ def release (self , channel : str ) -> None :
62
62
"""
63
63
Release the lock for the given channel.
64
64
"""
@@ -69,8 +69,8 @@ def release(self, channel: str):
69
69
del self .wait_counts [channel ]
70
70
71
71
72
- class BoundedQueue (asyncio .Queue ):
73
- def put_nowait (self , item ) :
72
+ class BoundedQueue (asyncio .Queue [ typing . Any ] ):
73
+ def put_nowait (self , item : typing . Any ) -> None :
74
74
if self .full ():
75
75
# see: https://github.com/django/channels_redis/issues/212
76
76
# if we actually get into this code block, it likely means that
@@ -83,7 +83,7 @@ def put_nowait(self, item):
83
83
84
84
85
85
class RedisLoopLayer :
86
- def __init__ (self , channel_layer : "RedisChannelLayer" ):
86
+ def __init__ (self , channel_layer : "RedisChannelLayer" ) -> None :
87
87
self ._lock = asyncio .Lock ()
88
88
self .channel_layer = channel_layer
89
89
self ._connections : typing .Dict [int , "Redis" ] = {}
@@ -95,7 +95,7 @@ def get_connection(self, index: int) -> "Redis":
95
95
96
96
return self ._connections [index ]
97
97
98
- async def flush (self ):
98
+ async def flush (self ) -> None :
99
99
async with self ._lock :
100
100
for index in list (self ._connections ):
101
101
connection = self ._connections .pop (index )
@@ -116,15 +116,15 @@ class RedisChannelLayer(BaseChannelLayer):
116
116
def __init__ (
117
117
self ,
118
118
hosts = None ,
119
- prefix = "asgi" ,
119
+ prefix : str = "asgi" ,
120
120
expiry = 60 ,
121
- group_expiry = 86400 ,
121
+ group_expiry : int = 86400 ,
122
122
capacity = 100 ,
123
123
channel_capacity = None ,
124
124
symmetric_encryption_keys = None ,
125
125
random_prefix_length = 12 ,
126
126
serializer_format = "msgpack" ,
127
- ):
127
+ ) -> None :
128
128
# Store basic information
129
129
self .expiry = expiry
130
130
self .group_expiry = group_expiry
@@ -161,7 +161,7 @@ def __init__(
161
161
collections .defaultdict (functools .partial (BoundedQueue , self .capacity ))
162
162
)
163
163
# Detached channel cleanup tasks
164
- self .receive_cleaners : typing .List [asyncio .Task ] = []
164
+ self .receive_cleaners : typing .List [asyncio .Task [ typing . Any ] ] = []
165
165
# Per-channel cleanup locks to prevent a receive starting and moving
166
166
# a message back into the main queue before its cleanup has completed
167
167
self .receive_clean_locks = ChannelLock ()
@@ -173,7 +173,7 @@ def create_pool(self, index: int) -> "ConnectionPool":
173
173
174
174
extensions = ["groups" , "flush" ]
175
175
176
- async def send (self , channel : str , message ) :
176
+ async def send (self , channel : str , message : typing . Any ) -> None :
177
177
"""
178
178
Send a message onto a (general or specific) channel.
179
179
"""
@@ -221,7 +221,7 @@ def _backup_channel_name(self, channel: str) -> str:
221
221
222
222
async def _brpop_with_clean (
223
223
self , index : int , channel : str , timeout : typing .Union [int , float , bytes , str ]
224
- ):
224
+ ) -> typing . Any :
225
225
"""
226
226
Perform a Redis BRPOP and manage the backup processing queue.
227
227
In case of cancellation, make sure the message is not lost.
@@ -252,15 +252,15 @@ async def _brpop_with_clean(
252
252
253
253
return member
254
254
255
- async def _clean_receive_backup (self , index : int , channel : str ):
255
+ async def _clean_receive_backup (self , index : int , channel : str ) -> None :
256
256
"""
257
257
Pop the oldest message off the channel backup queue.
258
258
The result isn't interesting as it was already processed.
259
259
"""
260
260
connection = self .connection (index )
261
261
await connection .zpopmin (self ._backup_channel_name (channel ))
262
262
263
- async def receive (self , channel : str ):
263
+ async def receive (self , channel : str ) -> typing . Any :
264
264
"""
265
265
Receive the first message that arrives on the channel.
266
266
If more than one coroutine waits on the same channel, the first waiter
@@ -271,9 +271,9 @@ async def receive(self, channel: str):
271
271
assert self .valid_channel_name (channel )
272
272
if "!" in channel :
273
273
real_channel = self .non_local_name (channel )
274
- assert real_channel .endswith (
275
- self . client_prefix + "! "
276
- ), "Wrong client prefix"
274
+ assert real_channel .endswith (self . client_prefix + "!" ), (
275
+ "Wrong client prefix "
276
+ )
277
277
# Enter receiving section
278
278
loop = asyncio .get_running_loop ()
279
279
self .receive_count += 1
@@ -292,11 +292,11 @@ async def receive(self, channel: str):
292
292
# Wait for our message to appear
293
293
message = None
294
294
while self .receive_buffer [channel ].empty ():
295
- tasks = [
295
+ _tasks = [
296
296
self .receive_lock .acquire (),
297
297
self .receive_buffer [channel ].get (),
298
298
]
299
- tasks = [asyncio .ensure_future (task ) for task in tasks ]
299
+ tasks = [asyncio .ensure_future (task ) for task in _tasks ]
300
300
try :
301
301
done , pending = await asyncio .wait (
302
302
tasks , return_when = asyncio .FIRST_COMPLETED
@@ -384,7 +384,9 @@ async def receive(self, channel: str):
384
384
# Do a plain direct receive
385
385
return (await self .receive_single (channel ))[1 ]
386
386
387
- async def receive_single (self , channel : str ) -> typing .Tuple :
387
+ async def receive_single (
388
+ self , channel : str
389
+ ) -> typing .Tuple [typing .Any , typing .Any ]:
388
390
"""
389
391
Receives a single message off of the channel and returns it.
390
392
"""
@@ -420,7 +422,7 @@ async def receive_single(self, channel: str) -> typing.Tuple:
420
422
)
421
423
self .receive_cleaners .append (cleaner )
422
424
423
- def _cleanup_done (cleaner : asyncio .Task ) :
425
+ def _cleanup_done (cleaner : asyncio .Task [ typing . Any ]) -> None :
424
426
self .receive_cleaners .remove (cleaner )
425
427
self .receive_clean_locks .release (channel_key )
426
428
@@ -448,7 +450,7 @@ async def new_channel(self, prefix: str = "specific") -> str:
448
450
449
451
### Flush extension ###
450
452
451
- async def flush (self ):
453
+ async def flush (self ) -> None :
452
454
"""
453
455
Deletes all messages and groups on all shards.
454
456
"""
@@ -470,7 +472,7 @@ async def flush(self):
470
472
# Now clear the pools as well
471
473
await self .close_pools ()
472
474
473
- async def close_pools (self ):
475
+ async def close_pools (self ) -> None :
474
476
"""
475
477
Close all connections in the event loop pools.
476
478
"""
@@ -480,7 +482,7 @@ async def close_pools(self):
480
482
for layer in self ._layers .values ():
481
483
await layer .flush ()
482
484
483
- async def wait_received (self ):
485
+ async def wait_received (self ) -> None :
484
486
"""
485
487
Wait for all channel cleanup functions to finish.
486
488
"""
@@ -489,13 +491,13 @@ async def wait_received(self):
489
491
490
492
### Groups extension ###
491
493
492
- async def group_add (self , group : str , channel : str ):
494
+ async def group_add (self , group : str , channel : str ) -> None :
493
495
"""
494
496
Adds the channel name to a group.
495
497
"""
496
498
# Check the inputs
497
- assert self .valid_group_name (group ), True
498
- assert self .valid_channel_name (channel ), True
499
+ assert self .valid_group_name (group ), "Group name not valid"
500
+ assert self .valid_channel_name (channel ), "Channel name not valid"
499
501
# Get a connection to the right shard
500
502
group_key = self ._group_key (group )
501
503
connection = self .connection (self .consistent_hash (group ))
@@ -505,7 +507,7 @@ async def group_add(self, group: str, channel: str):
505
507
# it at this point is guaranteed to expire before that
506
508
await connection .expire (group_key , self .group_expiry )
507
509
508
- async def group_discard (self , group : str , channel : str ):
510
+ async def group_discard (self , group : str , channel : str ) -> None :
509
511
"""
510
512
Removes the channel from the named group if it is in the group;
511
513
does nothing otherwise (does not error)
@@ -516,7 +518,7 @@ async def group_discard(self, group: str, channel: str):
516
518
connection = self .connection (self .consistent_hash (group ))
517
519
await connection .zrem (key , channel )
518
520
519
- async def group_send (self , group : str , message ) :
521
+ async def group_send (self , group : str , message : typing . Any ) -> None :
520
522
"""
521
523
Sends a message to the entire group.
522
524
"""
@@ -540,9 +542,9 @@ async def group_send(self, group: str, message):
540
542
for connection_index , channel_redis_keys in connection_to_channel_keys .items ():
541
543
# Discard old messages based on expiry
542
544
pipe = connection .pipeline ()
543
- for key in channel_redis_keys :
545
+ for _key in channel_redis_keys :
544
546
pipe .zremrangebyscore (
545
- key , min = 0 , max = int (time .time ()) - int (self .expiry )
547
+ _key , min = 0 , max = int (time .time ()) - int (self .expiry )
546
548
)
547
549
await pipe .execute ()
548
550
@@ -585,7 +587,7 @@ async def group_send(self, group: str, message):
585
587
channels_over_capacity = await connection .eval (
586
588
group_send_lua , len (channel_redis_keys ), * channel_redis_keys , * args
587
589
)
588
- _channels_over_capacity = - 1
590
+ _channels_over_capacity = - 1.0
589
591
try :
590
592
_channels_over_capacity = float (channels_over_capacity )
591
593
except Exception :
@@ -598,7 +600,13 @@ async def group_send(self, group: str, message):
598
600
group ,
599
601
)
600
602
601
- def _map_channel_keys_to_connection (self , channel_names , message ):
603
+ def _map_channel_keys_to_connection (
604
+ self , channel_names : typing .Iterable [str ], message : typing .Any
605
+ ) -> typing .Tuple [
606
+ typing .Dict [int , typing .List [str ]],
607
+ typing .Dict [str , typing .Any ],
608
+ typing .Dict [str , int ],
609
+ ]:
602
610
"""
603
611
For a list of channel names, GET
604
612
@@ -611,19 +619,21 @@ def _map_channel_keys_to_connection(self, channel_names, message):
611
619
"""
612
620
613
621
# Connection dict keyed by index to list of redis keys mapped on that index
614
- connection_to_channel_keys = collections .defaultdict (list )
622
+ connection_to_channel_keys : typing .Dict [int , typing .List [str ]] = (
623
+ collections .defaultdict (list )
624
+ )
615
625
# Message dict maps redis key to the message that needs to be send on that key
616
- channel_key_to_message = dict ()
626
+ channel_key_to_message : typing . Dict [ str , typing . Any ] = dict ()
617
627
# Channel key mapped to its capacity
618
- channel_key_to_capacity = dict ()
628
+ channel_key_to_capacity : typing . Dict [ str , int ] = dict ()
619
629
620
630
# For each channel
621
631
for channel in channel_names :
622
632
channel_non_local_name = channel
623
633
if "!" in channel :
624
634
channel_non_local_name = self .non_local_name (channel )
625
635
# Get its redis key
626
- channel_key = self .prefix + channel_non_local_name
636
+ channel_key : str = self .prefix + channel_non_local_name
627
637
# Have we come across the same redis key?
628
638
if channel_key not in channel_key_to_message :
629
639
# If not, fill the corresponding dicts
@@ -654,13 +664,15 @@ def _group_key(self, group: str) -> bytes:
654
664
"""
655
665
return f"{ self .prefix } :group:{ group } " .encode ("utf8" )
656
666
657
- def serialize (self , message ) -> bytes :
667
+ ### Serialization ###
668
+
669
+ def serialize (self , message : typing .Any ) -> bytes :
658
670
"""
659
671
Serializes message to a byte string.
660
672
"""
661
673
return self ._serializer .serialize (message )
662
674
663
- def deserialize (self , message : bytes ):
675
+ def deserialize (self , message : bytes ) -> typing . Any :
664
676
"""
665
677
Deserializes from a byte string.
666
678
"""
@@ -671,7 +683,7 @@ def deserialize(self, message: bytes):
671
683
def consistent_hash (self , value : typing .Union [str , "Buffer" ]) -> int :
672
684
return _consistent_hash (value , self .ring_size )
673
685
674
- def __str__ (self ):
686
+ def __str__ (self ) -> str :
675
687
return f"{ self .__class__ .__name__ } (hosts={ self .hosts } )"
676
688
677
689
### Connection handling ###
0 commit comments