Skip to content

Commit a5a1b26

Browse files
committed
TEST
1 parent 10b4a42 commit a5a1b26

File tree

9 files changed

+153
-128
lines changed

9 files changed

+153
-128
lines changed

channels_redis/core.py

Lines changed: 52 additions & 40 deletions
Original file line numberDiff line numberDiff line change
@@ -37,7 +37,7 @@ class ChannelLock:
3737
to mitigate multi-event loop problems.
3838
"""
3939

40-
def __init__(self):
40+
def __init__(self) -> None:
4141
self.locks: collections.defaultdict[str, asyncio.Lock] = (
4242
collections.defaultdict(asyncio.Lock)
4343
)
@@ -58,7 +58,7 @@ def locked(self, channel: str) -> bool:
5858
"""
5959
return self.locks[channel].locked()
6060

61-
def release(self, channel: str):
61+
def release(self, channel: str) -> None:
6262
"""
6363
Release the lock for the given channel.
6464
"""
@@ -69,8 +69,8 @@ def release(self, channel: str):
6969
del self.wait_counts[channel]
7070

7171

72-
class BoundedQueue(asyncio.Queue):
73-
def put_nowait(self, item):
72+
class BoundedQueue(asyncio.Queue[typing.Any]):
73+
def put_nowait(self, item: typing.Any) -> None:
7474
if self.full():
7575
# see: https://github.com/django/channels_redis/issues/212
7676
# if we actually get into this code block, it likely means that
@@ -83,7 +83,7 @@ def put_nowait(self, item):
8383

8484

8585
class RedisLoopLayer:
86-
def __init__(self, channel_layer: "RedisChannelLayer"):
86+
def __init__(self, channel_layer: "RedisChannelLayer") -> None:
8787
self._lock = asyncio.Lock()
8888
self.channel_layer = channel_layer
8989
self._connections: typing.Dict[int, "Redis"] = {}
@@ -95,7 +95,7 @@ def get_connection(self, index: int) -> "Redis":
9595

9696
return self._connections[index]
9797

98-
async def flush(self):
98+
async def flush(self) -> None:
9999
async with self._lock:
100100
for index in list(self._connections):
101101
connection = self._connections.pop(index)
@@ -116,15 +116,15 @@ class RedisChannelLayer(BaseChannelLayer):
116116
def __init__(
117117
self,
118118
hosts=None,
119-
prefix="asgi",
119+
prefix: str = "asgi",
120120
expiry=60,
121-
group_expiry=86400,
121+
group_expiry: int = 86400,
122122
capacity=100,
123123
channel_capacity=None,
124124
symmetric_encryption_keys=None,
125125
random_prefix_length=12,
126126
serializer_format="msgpack",
127-
):
127+
) -> None:
128128
# Store basic information
129129
self.expiry = expiry
130130
self.group_expiry = group_expiry
@@ -161,7 +161,7 @@ def __init__(
161161
collections.defaultdict(functools.partial(BoundedQueue, self.capacity))
162162
)
163163
# Detached channel cleanup tasks
164-
self.receive_cleaners: typing.List[asyncio.Task] = []
164+
self.receive_cleaners: typing.List[asyncio.Task[typing.Any]] = []
165165
# Per-channel cleanup locks to prevent a receive starting and moving
166166
# a message back into the main queue before its cleanup has completed
167167
self.receive_clean_locks = ChannelLock()
@@ -173,7 +173,7 @@ def create_pool(self, index: int) -> "ConnectionPool":
173173

174174
extensions = ["groups", "flush"]
175175

176-
async def send(self, channel: str, message):
176+
async def send(self, channel: str, message: typing.Any) -> None:
177177
"""
178178
Send a message onto a (general or specific) channel.
179179
"""
@@ -221,7 +221,7 @@ def _backup_channel_name(self, channel: str) -> str:
221221

222222
async def _brpop_with_clean(
223223
self, index: int, channel: str, timeout: typing.Union[int, float, bytes, str]
224-
):
224+
) -> typing.Any:
225225
"""
226226
Perform a Redis BRPOP and manage the backup processing queue.
227227
In case of cancellation, make sure the message is not lost.
@@ -252,15 +252,15 @@ async def _brpop_with_clean(
252252

253253
return member
254254

255-
async def _clean_receive_backup(self, index: int, channel: str):
255+
async def _clean_receive_backup(self, index: int, channel: str) -> None:
256256
"""
257257
Pop the oldest message off the channel backup queue.
258258
The result isn't interesting as it was already processed.
259259
"""
260260
connection = self.connection(index)
261261
await connection.zpopmin(self._backup_channel_name(channel))
262262

263-
async def receive(self, channel: str):
263+
async def receive(self, channel: str) -> typing.Any:
264264
"""
265265
Receive the first message that arrives on the channel.
266266
If more than one coroutine waits on the same channel, the first waiter
@@ -271,9 +271,9 @@ async def receive(self, channel: str):
271271
assert self.valid_channel_name(channel)
272272
if "!" in channel:
273273
real_channel = self.non_local_name(channel)
274-
assert real_channel.endswith(
275-
self.client_prefix + "!"
276-
), "Wrong client prefix"
274+
assert real_channel.endswith(self.client_prefix + "!"), (
275+
"Wrong client prefix"
276+
)
277277
# Enter receiving section
278278
loop = asyncio.get_running_loop()
279279
self.receive_count += 1
@@ -292,11 +292,11 @@ async def receive(self, channel: str):
292292
# Wait for our message to appear
293293
message = None
294294
while self.receive_buffer[channel].empty():
295-
tasks = [
295+
_tasks = [
296296
self.receive_lock.acquire(),
297297
self.receive_buffer[channel].get(),
298298
]
299-
tasks = [asyncio.ensure_future(task) for task in tasks]
299+
tasks = [asyncio.ensure_future(task) for task in _tasks]
300300
try:
301301
done, pending = await asyncio.wait(
302302
tasks, return_when=asyncio.FIRST_COMPLETED
@@ -384,7 +384,9 @@ async def receive(self, channel: str):
384384
# Do a plain direct receive
385385
return (await self.receive_single(channel))[1]
386386

387-
async def receive_single(self, channel: str) -> typing.Tuple:
387+
async def receive_single(
388+
self, channel: str
389+
) -> typing.Tuple[typing.Any, typing.Any]:
388390
"""
389391
Receives a single message off of the channel and returns it.
390392
"""
@@ -420,7 +422,7 @@ async def receive_single(self, channel: str) -> typing.Tuple:
420422
)
421423
self.receive_cleaners.append(cleaner)
422424

423-
def _cleanup_done(cleaner: asyncio.Task):
425+
def _cleanup_done(cleaner: asyncio.Task[typing.Any]) -> None:
424426
self.receive_cleaners.remove(cleaner)
425427
self.receive_clean_locks.release(channel_key)
426428

@@ -448,7 +450,7 @@ async def new_channel(self, prefix: str = "specific") -> str:
448450

449451
### Flush extension ###
450452

451-
async def flush(self):
453+
async def flush(self) -> None:
452454
"""
453455
Deletes all messages and groups on all shards.
454456
"""
@@ -470,7 +472,7 @@ async def flush(self):
470472
# Now clear the pools as well
471473
await self.close_pools()
472474

473-
async def close_pools(self):
475+
async def close_pools(self) -> None:
474476
"""
475477
Close all connections in the event loop pools.
476478
"""
@@ -480,7 +482,7 @@ async def close_pools(self):
480482
for layer in self._layers.values():
481483
await layer.flush()
482484

483-
async def wait_received(self):
485+
async def wait_received(self) -> None:
484486
"""
485487
Wait for all channel cleanup functions to finish.
486488
"""
@@ -489,13 +491,13 @@ async def wait_received(self):
489491

490492
### Groups extension ###
491493

492-
async def group_add(self, group: str, channel: str):
494+
async def group_add(self, group: str, channel: str) -> None:
493495
"""
494496
Adds the channel name to a group.
495497
"""
496498
# Check the inputs
497-
assert self.valid_group_name(group), True
498-
assert self.valid_channel_name(channel), True
499+
assert self.valid_group_name(group), "Group name not valid"
500+
assert self.valid_channel_name(channel), "Channel name not valid"
499501
# Get a connection to the right shard
500502
group_key = self._group_key(group)
501503
connection = self.connection(self.consistent_hash(group))
@@ -505,7 +507,7 @@ async def group_add(self, group: str, channel: str):
505507
# it at this point is guaranteed to expire before that
506508
await connection.expire(group_key, self.group_expiry)
507509

508-
async def group_discard(self, group: str, channel: str):
510+
async def group_discard(self, group: str, channel: str) -> None:
509511
"""
510512
Removes the channel from the named group if it is in the group;
511513
does nothing otherwise (does not error)
@@ -516,7 +518,7 @@ async def group_discard(self, group: str, channel: str):
516518
connection = self.connection(self.consistent_hash(group))
517519
await connection.zrem(key, channel)
518520

519-
async def group_send(self, group: str, message):
521+
async def group_send(self, group: str, message: typing.Any) -> None:
520522
"""
521523
Sends a message to the entire group.
522524
"""
@@ -540,9 +542,9 @@ async def group_send(self, group: str, message):
540542
for connection_index, channel_redis_keys in connection_to_channel_keys.items():
541543
# Discard old messages based on expiry
542544
pipe = connection.pipeline()
543-
for key in channel_redis_keys:
545+
for _key in channel_redis_keys:
544546
pipe.zremrangebyscore(
545-
key, min=0, max=int(time.time()) - int(self.expiry)
547+
_key, min=0, max=int(time.time()) - int(self.expiry)
546548
)
547549
await pipe.execute()
548550

@@ -585,7 +587,7 @@ async def group_send(self, group: str, message):
585587
channels_over_capacity = await connection.eval(
586588
group_send_lua, len(channel_redis_keys), *channel_redis_keys, *args
587589
)
588-
_channels_over_capacity = -1
590+
_channels_over_capacity = -1.0
589591
try:
590592
_channels_over_capacity = float(channels_over_capacity)
591593
except Exception:
@@ -598,7 +600,13 @@ async def group_send(self, group: str, message):
598600
group,
599601
)
600602

601-
def _map_channel_keys_to_connection(self, channel_names, message):
603+
def _map_channel_keys_to_connection(
604+
self, channel_names: typing.Iterable[str], message: typing.Any
605+
) -> typing.Tuple[
606+
typing.Dict[int, typing.List[str]],
607+
typing.Dict[str, typing.Any],
608+
typing.Dict[str, int],
609+
]:
602610
"""
603611
For a list of channel names, GET
604612
@@ -611,19 +619,21 @@ def _map_channel_keys_to_connection(self, channel_names, message):
611619
"""
612620

613621
# Connection dict keyed by index to list of redis keys mapped on that index
614-
connection_to_channel_keys = collections.defaultdict(list)
622+
connection_to_channel_keys: typing.Dict[int, typing.List[str]] = (
623+
collections.defaultdict(list)
624+
)
615625
# Message dict maps redis key to the message that needs to be send on that key
616-
channel_key_to_message = dict()
626+
channel_key_to_message: typing.Dict[str, typing.Any] = dict()
617627
# Channel key mapped to its capacity
618-
channel_key_to_capacity = dict()
628+
channel_key_to_capacity: typing.Dict[str, int] = dict()
619629

620630
# For each channel
621631
for channel in channel_names:
622632
channel_non_local_name = channel
623633
if "!" in channel:
624634
channel_non_local_name = self.non_local_name(channel)
625635
# Get its redis key
626-
channel_key = self.prefix + channel_non_local_name
636+
channel_key: str = self.prefix + channel_non_local_name
627637
# Have we come across the same redis key?
628638
if channel_key not in channel_key_to_message:
629639
# If not, fill the corresponding dicts
@@ -654,13 +664,15 @@ def _group_key(self, group: str) -> bytes:
654664
"""
655665
return f"{self.prefix}:group:{group}".encode("utf8")
656666

657-
def serialize(self, message) -> bytes:
667+
### Serialization ###
668+
669+
def serialize(self, message: typing.Any) -> bytes:
658670
"""
659671
Serializes message to a byte string.
660672
"""
661673
return self._serializer.serialize(message)
662674

663-
def deserialize(self, message: bytes):
675+
def deserialize(self, message: bytes) -> typing.Any:
664676
"""
665677
Deserializes from a byte string.
666678
"""
@@ -671,7 +683,7 @@ def deserialize(self, message: bytes):
671683
def consistent_hash(self, value: typing.Union[str, "Buffer"]) -> int:
672684
return _consistent_hash(value, self.ring_size)
673685

674-
def __str__(self):
686+
def __str__(self) -> str:
675687
return f"{self.__class__.__name__}(hosts={self.hosts})"
676688

677689
### Connection handling ###

0 commit comments

Comments
 (0)