19
19
create_pool ,
20
20
decode_hosts ,
21
21
)
22
- from typing import TYPE_CHECKING , Dict , List , Tuple , Union , Optional
23
22
24
- if TYPE_CHECKING :
23
+ import typing
24
+
25
+ if typing .TYPE_CHECKING :
25
26
from redis .asyncio .connection import ConnectionPool
26
27
from redis .asyncio .client import Redis
27
- from .core import RedisChannelLayer
28
28
from typing_extensions import Buffer
29
29
30
30
logger = logging .getLogger (__name__ )
@@ -39,10 +39,10 @@ class ChannelLock:
39
39
"""
40
40
41
41
def __init__ (self ):
42
- self .locks : " collections.defaultdict[str, asyncio.Lock]" = (
42
+ self .locks : collections .defaultdict [str , asyncio .Lock ] = (
43
43
collections .defaultdict (asyncio .Lock )
44
44
)
45
- self .wait_counts : " collections.defaultdict[str, int]" = collections .defaultdict (
45
+ self .wait_counts : collections .defaultdict [str , int ] = collections .defaultdict (
46
46
int
47
47
)
48
48
@@ -87,7 +87,7 @@ class RedisLoopLayer:
87
87
def __init__ (self , channel_layer : "RedisChannelLayer" ):
88
88
self ._lock = asyncio .Lock ()
89
89
self .channel_layer = channel_layer
90
- self ._connections : " Dict[int, Redis]" = {}
90
+ self ._connections : typing . Dict [int , " Redis" ] = {}
91
91
92
92
def get_connection (self , index : int ) -> "Redis" :
93
93
if index not in self ._connections :
@@ -145,7 +145,7 @@ def __init__(
145
145
symmetric_encryption_keys = symmetric_encryption_keys ,
146
146
)
147
147
# Cached redis connection pools and the event loop they are from
148
- self ._layers : " Dict[asyncio.AbstractEventLoop, RedisLoopLayer]" = {}
148
+ self ._layers : typing . Dict [asyncio .AbstractEventLoop , " RedisLoopLayer" ] = {}
149
149
# Normal channels choose a host index by cycling through the available hosts
150
150
self ._receive_index_generator = itertools .cycle (range (len (self .hosts )))
151
151
self ._send_index_generator = itertools .cycle (range (len (self .hosts )))
@@ -154,15 +154,15 @@ def __init__(
154
154
# Number of coroutines trying to receive right now
155
155
self .receive_count = 0
156
156
# The receive lock
157
- self .receive_lock : " Optional[asyncio.Lock]" = None
157
+ self .receive_lock : typing . Optional [asyncio .Lock ] = None
158
158
# Event loop they are trying to receive on
159
- self .receive_event_loop : " Optional[asyncio.AbstractEventLoop]" = None
159
+ self .receive_event_loop : typing . Optional [asyncio .AbstractEventLoop ] = None
160
160
# Buffered messages by process-local channel name
161
- self .receive_buffer : " collections.defaultdict[str, BoundedQueue]" = (
161
+ self .receive_buffer : collections .defaultdict [str , BoundedQueue ] = (
162
162
collections .defaultdict (functools .partial (BoundedQueue , self .capacity ))
163
163
)
164
164
# Detached channel cleanup tasks
165
- self .receive_cleaners : " List[asyncio.Task]" = []
165
+ self .receive_cleaners : typing . List [asyncio .Task ] = []
166
166
# Per-channel cleanup locks to prevent a receive starting and moving
167
167
# a message back into the main queue before its cleanup has completed
168
168
self .receive_clean_locks = ChannelLock ()
@@ -180,7 +180,7 @@ async def send(self, channel: str, message):
180
180
"""
181
181
# Typecheck
182
182
assert isinstance (message , dict ), "message is not a dict"
183
- assert self .require_valid_channel_name (channel ), "Channel name not valid"
183
+ assert self .valid_channel_name (channel ), "Channel name not valid"
184
184
# Make sure the message does not contain reserved keys
185
185
assert "__asgi_channel__" not in message
186
186
# If it's a process-local channel, strip off local part and stick full name in message
@@ -221,7 +221,7 @@ def _backup_channel_name(self, channel: str) -> str:
221
221
return channel + "$inflight"
222
222
223
223
async def _brpop_with_clean (
224
- self , index : int , channel : str , timeout : " Union[int, float, bytes, str]"
224
+ self , index : int , channel : str , timeout : typing . Union [int , float , bytes , str ]
225
225
):
226
226
"""
227
227
Perform a Redis BRPOP and manage the backup processing queue.
@@ -269,7 +269,7 @@ async def receive(self, channel: str):
269
269
"""
270
270
# Make sure the channel name is valid then get the non-local part
271
271
# and thus its index
272
- assert self .require_valid_channel_name (channel )
272
+ assert self .valid_channel_name (channel )
273
273
if "!" in channel :
274
274
real_channel = self .non_local_name (channel )
275
275
assert real_channel .endswith (
@@ -385,14 +385,12 @@ async def receive(self, channel: str):
385
385
# Do a plain direct receive
386
386
return (await self .receive_single (channel ))[1 ]
387
387
388
- async def receive_single (self , channel : str ) -> " Tuple" :
388
+ async def receive_single (self , channel : str ) -> typing . Tuple :
389
389
"""
390
390
Receives a single message off of the channel and returns it.
391
391
"""
392
392
# Check channel name
393
- assert self .require_valid_channel_name (
394
- channel , receive = True
395
- ), "Channel name invalid"
393
+ assert self .valid_channel_name (channel , receive = True ), "Channel name invalid"
396
394
# Work out the connection to use
397
395
if "!" in channel :
398
396
assert channel .endswith ("!" )
@@ -423,7 +421,7 @@ async def receive_single(self, channel: str) -> "Tuple":
423
421
)
424
422
self .receive_cleaners .append (cleaner )
425
423
426
- def _cleanup_done (cleaner : " asyncio.Task" ):
424
+ def _cleanup_done (cleaner : asyncio .Task ):
427
425
self .receive_cleaners .remove (cleaner )
428
426
self .receive_clean_locks .release (channel_key )
429
427
@@ -497,8 +495,8 @@ async def group_add(self, group: str, channel: str):
497
495
Adds the channel name to a group.
498
496
"""
499
497
# Check the inputs
500
- assert self .require_valid_group_name (group ), True
501
- assert self .require_valid_channel_name (channel ), True
498
+ assert self .valid_group_name (group ), True
499
+ assert self .valid_channel_name (channel ), True
502
500
# Get a connection to the right shard
503
501
group_key = self ._group_key (group )
504
502
connection = self .connection (self .consistent_hash (group ))
@@ -513,8 +511,8 @@ async def group_discard(self, group: str, channel: str):
513
511
Removes the channel from the named group if it is in the group;
514
512
does nothing otherwise (does not error)
515
513
"""
516
- assert self .require_valid_group_name (group ), "Group name not valid"
517
- assert self .require_valid_channel_name (channel ), "Channel name not valid"
514
+ assert self .valid_group_name (group ), "Group name not valid"
515
+ assert self .valid_channel_name (channel ), "Channel name not valid"
518
516
key = self ._group_key (group )
519
517
connection = self .connection (self .consistent_hash (group ))
520
518
await connection .zrem (key , channel )
@@ -523,7 +521,7 @@ async def group_send(self, group: str, message):
523
521
"""
524
522
Sends a message to the entire group.
525
523
"""
526
- assert self .require_valid_group_name (group ), "Group name not valid"
524
+ assert self .valid_group_name (group ), "Group name not valid"
527
525
# Retrieve list of all channel names
528
526
key = self ._group_key (group )
529
527
connection = self .connection (self .consistent_hash (group ))
@@ -671,7 +669,7 @@ def deserialize(self, message: bytes):
671
669
672
670
### Internal functions ###
673
671
674
- def consistent_hash (self , value : " Union[str, Buffer]" ) -> int :
672
+ def consistent_hash (self , value : typing . Union [str , " Buffer" ] ) -> int :
675
673
return _consistent_hash (value , self .ring_size )
676
674
677
675
def __str__ (self ):
0 commit comments