Skip to content

Commit 1e3e55d

Browse files
committed
TEST
1 parent 10b4a42 commit 1e3e55d

File tree

9 files changed

+183
-153
lines changed

9 files changed

+183
-153
lines changed

channels_redis/core.py

Lines changed: 62 additions & 52 deletions
Original file line numberDiff line numberDiff line change
@@ -9,8 +9,8 @@
99

1010
from redis import asyncio as aioredis
1111

12-
from channels.exceptions import ChannelFull
13-
from channels.layers import BaseChannelLayer
12+
from channels.exceptions import ChannelFull # type: ignore[import-untyped]
13+
from channels.layers import BaseChannelLayer # type: ignore[import-untyped]
1414

1515
from .serializers import registry
1616
from .utils import (
@@ -37,13 +37,11 @@ class ChannelLock:
3737
to mitigate multi-event loop problems.
3838
"""
3939

40-
def __init__(self):
41-
self.locks: collections.defaultdict[str, asyncio.Lock] = (
42-
collections.defaultdict(asyncio.Lock)
43-
)
44-
self.wait_counts: collections.defaultdict[str, int] = collections.defaultdict(
45-
int
40+
def __init__(self) -> None:
41+
self.locks: typing.DefaultDict[str, asyncio.Lock] = collections.defaultdict(
42+
asyncio.Lock
4643
)
44+
self.wait_counts: typing.DefaultDict[str, int] = collections.defaultdict(int)
4745

4846
async def acquire(self, channel: str) -> bool:
4947
"""
@@ -58,7 +56,7 @@ def locked(self, channel: str) -> bool:
5856
"""
5957
return self.locks[channel].locked()
6058

61-
def release(self, channel: str):
59+
def release(self, channel: str) -> None:
6260
"""
6361
Release the lock for the given channel.
6462
"""
@@ -70,7 +68,7 @@ def release(self, channel: str):
7068

7169

7270
class BoundedQueue(asyncio.Queue):
73-
def put_nowait(self, item):
71+
def put_nowait(self, item: typing.Any) -> None:
7472
if self.full():
7573
# see: https://github.com/django/channels_redis/issues/212
7674
# if we actually get into this code block, it likely means that
@@ -83,7 +81,7 @@ def put_nowait(self, item):
8381

8482

8583
class RedisLoopLayer:
86-
def __init__(self, channel_layer: "RedisChannelLayer"):
84+
def __init__(self, channel_layer: "RedisChannelLayer") -> None:
8785
self._lock = asyncio.Lock()
8886
self.channel_layer = channel_layer
8987
self._connections: typing.Dict[int, "Redis"] = {}
@@ -95,7 +93,7 @@ def get_connection(self, index: int) -> "Redis":
9593

9694
return self._connections[index]
9795

98-
async def flush(self):
96+
async def flush(self) -> None:
9997
async with self._lock:
10098
for index in list(self._connections):
10199
connection = self._connections.pop(index)
@@ -116,15 +114,15 @@ class RedisChannelLayer(BaseChannelLayer):
116114
def __init__(
117115
self,
118116
hosts=None,
119-
prefix="asgi",
117+
prefix: str = "asgi",
120118
expiry=60,
121-
group_expiry=86400,
119+
group_expiry: int = 86400,
122120
capacity=100,
123121
channel_capacity=None,
124122
symmetric_encryption_keys=None,
125123
random_prefix_length=12,
126124
serializer_format="msgpack",
127-
):
125+
) -> None:
128126
# Store basic information
129127
self.expiry = expiry
130128
self.group_expiry = group_expiry
@@ -157,11 +155,11 @@ def __init__(
157155
# Event loop they are trying to receive on
158156
self.receive_event_loop: typing.Optional[asyncio.AbstractEventLoop] = None
159157
# Buffered messages by process-local channel name
160-
self.receive_buffer: collections.defaultdict[str, BoundedQueue] = (
158+
self.receive_buffer: typing.DefaultDict[str, BoundedQueue] = (
161159
collections.defaultdict(functools.partial(BoundedQueue, self.capacity))
162160
)
163161
# Detached channel cleanup tasks
164-
self.receive_cleaners: typing.List[asyncio.Task] = []
162+
self.receive_cleaners: typing.List["asyncio.Task[typing.Any]"] = []
165163
# Per-channel cleanup locks to prevent a receive starting and moving
166164
# a message back into the main queue before its cleanup has completed
167165
self.receive_clean_locks = ChannelLock()
@@ -173,7 +171,7 @@ def create_pool(self, index: int) -> "ConnectionPool":
173171

174172
extensions = ["groups", "flush"]
175173

176-
async def send(self, channel: str, message):
174+
async def send(self, channel: str, message: typing.Any) -> None:
177175
"""
178176
Send a message onto a (general or specific) channel.
179177
"""
@@ -221,7 +219,7 @@ def _backup_channel_name(self, channel: str) -> str:
221219

222220
async def _brpop_with_clean(
223221
self, index: int, channel: str, timeout: typing.Union[int, float, bytes, str]
224-
):
222+
) -> typing.Any:
225223
"""
226224
Perform a Redis BRPOP and manage the backup processing queue.
227225
In case of cancellation, make sure the message is not lost.
@@ -240,7 +238,7 @@ async def _brpop_with_clean(
240238
connection = self.connection(index)
241239
# Cancellation here doesn't matter, we're not doing anything destructive
242240
# and the script executes atomically...
243-
await connection.eval(cleanup_script, 0, channel, backup_queue)
241+
await connection.eval(cleanup_script, 0, channel, backup_queue) # type: ignore[misc]
244242
# ...and it doesn't matter here either, the message will be safe in the backup.
245243
result = await connection.bzpopmin(channel, timeout=timeout)
246244

@@ -252,15 +250,15 @@ async def _brpop_with_clean(
252250

253251
return member
254252

255-
async def _clean_receive_backup(self, index: int, channel: str):
253+
async def _clean_receive_backup(self, index: int, channel: str) -> None:
256254
"""
257255
Pop the oldest message off the channel backup queue.
258256
The result isn't interesting as it was already processed.
259257
"""
260258
connection = self.connection(index)
261259
await connection.zpopmin(self._backup_channel_name(channel))
262260

263-
async def receive(self, channel: str):
261+
async def receive(self, channel: str) -> typing.Any:
264262
"""
265263
Receive the first message that arrives on the channel.
266264
If more than one coroutine waits on the same channel, the first waiter
@@ -292,11 +290,11 @@ async def receive(self, channel: str):
292290
# Wait for our message to appear
293291
message = None
294292
while self.receive_buffer[channel].empty():
295-
tasks = [
296-
self.receive_lock.acquire(),
293+
_tasks = [
294+
self.receive_lock.acquire(), # type: ignore[union-attr]
297295
self.receive_buffer[channel].get(),
298296
]
299-
tasks = [asyncio.ensure_future(task) for task in tasks]
297+
tasks = [asyncio.ensure_future(task) for task in _tasks]
300298
try:
301299
done, pending = await asyncio.wait(
302300
tasks, return_when=asyncio.FIRST_COMPLETED
@@ -312,7 +310,7 @@ async def receive(self, channel: str):
312310
if not task.cancel():
313311
assert task.done()
314312
if task.result() is True:
315-
self.receive_lock.release()
313+
self.receive_lock.release() # type: ignore[union-attr]
316314

317315
raise
318316

@@ -335,7 +333,7 @@ async def receive(self, channel: str):
335333
if message or exception:
336334
if token:
337335
# We will not be receving as we already have the message.
338-
self.receive_lock.release()
336+
self.receive_lock.release() # type: ignore[union-attr]
339337

340338
if exception:
341339
raise exception
@@ -362,7 +360,7 @@ async def receive(self, channel: str):
362360
del self.receive_buffer[channel]
363361
raise
364362
finally:
365-
self.receive_lock.release()
363+
self.receive_lock.release() # type: ignore[union-attr]
366364

367365
# We know there's a message available, because there
368366
# couldn't have been any interruption between empty() and here
@@ -377,14 +375,16 @@ async def receive(self, channel: str):
377375
self.receive_count -= 1
378376
# If we were the last out, drop the receive lock
379377
if self.receive_count == 0:
380-
assert not self.receive_lock.locked()
378+
assert not self.receive_lock.locked() # type: ignore[union-attr]
381379
self.receive_lock = None
382380
self.receive_event_loop = None
383381
else:
384382
# Do a plain direct receive
385383
return (await self.receive_single(channel))[1]
386384

387-
async def receive_single(self, channel: str) -> typing.Tuple:
385+
async def receive_single(
386+
self, channel: str
387+
) -> typing.Tuple[typing.Any, typing.Any]:
388388
"""
389389
Receives a single message off of the channel and returns it.
390390
"""
@@ -420,7 +420,7 @@ async def receive_single(self, channel: str) -> typing.Tuple:
420420
)
421421
self.receive_cleaners.append(cleaner)
422422

423-
def _cleanup_done(cleaner: asyncio.Task):
423+
def _cleanup_done(cleaner: "asyncio.Task") -> None:
424424
self.receive_cleaners.remove(cleaner)
425425
self.receive_clean_locks.release(channel_key)
426426

@@ -448,7 +448,7 @@ async def new_channel(self, prefix: str = "specific") -> str:
448448

449449
### Flush extension ###
450450

451-
async def flush(self):
451+
async def flush(self) -> None:
452452
"""
453453
Deletes all messages and groups on all shards.
454454
"""
@@ -466,11 +466,11 @@ async def flush(self):
466466
# Go through each connection and remove all with prefix
467467
for i in range(self.ring_size):
468468
connection = self.connection(i)
469-
await connection.eval(delete_prefix, 0, self.prefix + "*")
469+
await connection.eval(delete_prefix, 0, self.prefix + "*") # type: ignore[union-attr,misc]
470470
# Now clear the pools as well
471471
await self.close_pools()
472472

473-
async def close_pools(self):
473+
async def close_pools(self) -> None:
474474
"""
475475
Close all connections in the event loop pools.
476476
"""
@@ -480,7 +480,7 @@ async def close_pools(self):
480480
for layer in self._layers.values():
481481
await layer.flush()
482482

483-
async def wait_received(self):
483+
async def wait_received(self) -> None:
484484
"""
485485
Wait for all channel cleanup functions to finish.
486486
"""
@@ -489,13 +489,13 @@ async def wait_received(self):
489489

490490
### Groups extension ###
491491

492-
async def group_add(self, group: str, channel: str):
492+
async def group_add(self, group: str, channel: str) -> None:
493493
"""
494494
Adds the channel name to a group.
495495
"""
496496
# Check the inputs
497-
assert self.valid_group_name(group), True
498-
assert self.valid_channel_name(channel), True
497+
assert self.valid_group_name(group), "Group name not valid"
498+
assert self.valid_channel_name(channel), "Channel name not valid"
499499
# Get a connection to the right shard
500500
group_key = self._group_key(group)
501501
connection = self.connection(self.consistent_hash(group))
@@ -505,7 +505,7 @@ async def group_add(self, group: str, channel: str):
505505
# it at this point is guaranteed to expire before that
506506
await connection.expire(group_key, self.group_expiry)
507507

508-
async def group_discard(self, group: str, channel: str):
508+
async def group_discard(self, group: str, channel: str) -> None:
509509
"""
510510
Removes the channel from the named group if it is in the group;
511511
does nothing otherwise (does not error)
@@ -516,7 +516,7 @@ async def group_discard(self, group: str, channel: str):
516516
connection = self.connection(self.consistent_hash(group))
517517
await connection.zrem(key, channel)
518518

519-
async def group_send(self, group: str, message):
519+
async def group_send(self, group: str, message: typing.Any) -> None:
520520
"""
521521
Sends a message to the entire group.
522522
"""
@@ -540,9 +540,9 @@ async def group_send(self, group: str, message):
540540
for connection_index, channel_redis_keys in connection_to_channel_keys.items():
541541
# Discard old messages based on expiry
542542
pipe = connection.pipeline()
543-
for key in channel_redis_keys:
543+
for _key in channel_redis_keys:
544544
pipe.zremrangebyscore(
545-
key, min=0, max=int(time.time()) - int(self.expiry)
545+
_key, min=0, max=int(time.time()) - int(self.expiry)
546546
)
547547
await pipe.execute()
548548

@@ -582,10 +582,10 @@ async def group_send(self, group: str, message):
582582

583583
# channel_keys does not contain a single redis key more than once
584584
connection = self.connection(connection_index)
585-
channels_over_capacity = await connection.eval(
585+
channels_over_capacity = await connection.eval( # type: ignore[misc]
586586
group_send_lua, len(channel_redis_keys), *channel_redis_keys, *args
587587
)
588-
_channels_over_capacity = -1
588+
_channels_over_capacity = -1.0
589589
try:
590590
_channels_over_capacity = float(channels_over_capacity)
591591
except Exception:
@@ -598,7 +598,13 @@ async def group_send(self, group: str, message):
598598
group,
599599
)
600600

601-
def _map_channel_keys_to_connection(self, channel_names, message):
601+
def _map_channel_keys_to_connection(
602+
self, channel_names: typing.Iterable[str], message: typing.Any
603+
) -> typing.Tuple[
604+
typing.Dict[int, typing.List[str]],
605+
typing.Dict[str, typing.Any],
606+
typing.Dict[str, int],
607+
]:
602608
"""
603609
For a list of channel names, GET
604610
@@ -611,19 +617,21 @@ def _map_channel_keys_to_connection(self, channel_names, message):
611617
"""
612618

613619
# Connection dict keyed by index to list of redis keys mapped on that index
614-
connection_to_channel_keys = collections.defaultdict(list)
620+
connection_to_channel_keys: typing.Dict[int, typing.List[str]] = (
621+
collections.defaultdict(list)
622+
)
615623
# Message dict maps redis key to the message that needs to be send on that key
616-
channel_key_to_message = dict()
624+
channel_key_to_message: typing.Dict[str, typing.Any] = dict()
617625
# Channel key mapped to its capacity
618-
channel_key_to_capacity = dict()
626+
channel_key_to_capacity: typing.Dict[str, int] = dict()
619627

620628
# For each channel
621629
for channel in channel_names:
622630
channel_non_local_name = channel
623631
if "!" in channel:
624632
channel_non_local_name = self.non_local_name(channel)
625633
# Get its redis key
626-
channel_key = self.prefix + channel_non_local_name
634+
channel_key: str = self.prefix + channel_non_local_name
627635
# Have we come across the same redis key?
628636
if channel_key not in channel_key_to_message:
629637
# If not, fill the corresponding dicts
@@ -654,13 +662,15 @@ def _group_key(self, group: str) -> bytes:
654662
"""
655663
return f"{self.prefix}:group:{group}".encode("utf8")
656664

657-
def serialize(self, message) -> bytes:
665+
### Serialization ###
666+
667+
def serialize(self, message: typing.Any) -> bytes:
658668
"""
659669
Serializes message to a byte string.
660670
"""
661671
return self._serializer.serialize(message)
662672

663-
def deserialize(self, message: bytes):
673+
def deserialize(self, message: bytes) -> typing.Any:
664674
"""
665675
Deserializes from a byte string.
666676
"""
@@ -671,7 +681,7 @@ def deserialize(self, message: bytes):
671681
def consistent_hash(self, value: typing.Union[str, "Buffer"]) -> int:
672682
return _consistent_hash(value, self.ring_size)
673683

674-
def __str__(self):
684+
def __str__(self) -> str:
675685
return f"{self.__class__.__name__}(hosts={self.hosts})"
676686

677687
### Connection handling ###

0 commit comments

Comments
 (0)