Skip to content

Commit 4da1969

Browse files
committed
FEAT: type hint
1 parent 662b90d commit 4da1969

File tree

7 files changed

+171
-106
lines changed

7 files changed

+171
-106
lines changed

channels_redis/core.py

Lines changed: 60 additions & 42 deletions
Original file line numberDiff line numberDiff line change
@@ -19,6 +19,13 @@
1919
create_pool,
2020
decode_hosts,
2121
)
22+
from typing import TYPE_CHECKING, Dict, List, Tuple, Union, Optional
23+
24+
if TYPE_CHECKING:
25+
from redis.asyncio.connection import ConnectionPool
26+
from redis.asyncio.client import Redis
27+
from .core import RedisChannelLayer
28+
from typing_extensions import Buffer
2229

2330
logger = logging.getLogger(__name__)
2431

@@ -32,23 +39,27 @@ class ChannelLock:
3239
"""
3340

3441
def __init__(self):
35-
self.locks = collections.defaultdict(asyncio.Lock)
36-
self.wait_counts = collections.defaultdict(int)
42+
self.locks: "collections.defaultdict[str, asyncio.Lock]" = (
43+
collections.defaultdict(asyncio.Lock)
44+
)
45+
self.wait_counts: "collections.defaultdict[str, int]" = collections.defaultdict(
46+
int
47+
)
3748

38-
async def acquire(self, channel):
49+
async def acquire(self, channel: str) -> bool:
3950
"""
4051
Acquire the lock for the given channel.
4152
"""
4253
self.wait_counts[channel] += 1
4354
return await self.locks[channel].acquire()
4455

45-
def locked(self, channel):
56+
def locked(self, channel: str) -> bool:
4657
"""
4758
Return ``True`` if the lock for the given channel is acquired.
4859
"""
4960
return self.locks[channel].locked()
5061

51-
def release(self, channel):
62+
def release(self, channel: str):
5263
"""
5364
Release the lock for the given channel.
5465
"""
@@ -73,12 +84,12 @@ def put_nowait(self, item):
7384

7485

7586
class RedisLoopLayer:
76-
def __init__(self, channel_layer):
87+
def __init__(self, channel_layer: "RedisChannelLayer"):
7788
self._lock = asyncio.Lock()
7889
self.channel_layer = channel_layer
79-
self._connections = {}
90+
self._connections: "Dict[int, Redis]" = {}
8091

81-
def get_connection(self, index):
92+
def get_connection(self, index: int) -> "Redis":
8293
if index not in self._connections:
8394
pool = self.channel_layer.create_pool(index)
8495
self._connections[index] = aioredis.Redis(connection_pool=pool)
@@ -134,7 +145,7 @@ def __init__(
134145
symmetric_encryption_keys=symmetric_encryption_keys,
135146
)
136147
# Cached redis connection pools and the event loop they are from
137-
self._layers = {}
148+
self._layers: "Dict[asyncio.AbstractEventLoop, RedisLoopLayer]" = {}
138149
# Normal channels choose a host index by cycling through the available hosts
139150
self._receive_index_generator = itertools.cycle(range(len(self.hosts)))
140151
self._send_index_generator = itertools.cycle(range(len(self.hosts)))
@@ -143,33 +154,33 @@ def __init__(
143154
# Number of coroutines trying to receive right now
144155
self.receive_count = 0
145156
# The receive lock
146-
self.receive_lock = None
157+
self.receive_lock: "Optional[asyncio.Lock]" = None
147158
# Event loop they are trying to receive on
148-
self.receive_event_loop = None
159+
self.receive_event_loop: "Optional[asyncio.AbstractEventLoop]" = None
149160
# Buffered messages by process-local channel name
150-
self.receive_buffer = collections.defaultdict(
151-
functools.partial(BoundedQueue, self.capacity)
161+
self.receive_buffer: "collections.defaultdict[str, BoundedQueue]" = (
162+
collections.defaultdict(functools.partial(BoundedQueue, self.capacity))
152163
)
153164
# Detached channel cleanup tasks
154-
self.receive_cleaners = []
165+
self.receive_cleaners: "List[asyncio.Task]" = []
155166
# Per-channel cleanup locks to prevent a receive starting and moving
156167
# a message back into the main queue before its cleanup has completed
157168
self.receive_clean_locks = ChannelLock()
158169

159-
def create_pool(self, index):
170+
def create_pool(self, index: int) -> "ConnectionPool":
160171
return create_pool(self.hosts[index])
161172

162173
### Channel layer API ###
163174

164175
extensions = ["groups", "flush"]
165176

166-
async def send(self, channel, message):
177+
async def send(self, channel: str, message):
167178
"""
168179
Send a message onto a (general or specific) channel.
169180
"""
170181
# Typecheck
171182
assert isinstance(message, dict), "message is not a dict"
172-
assert self.valid_channel_name(channel), "Channel name not valid"
183+
assert self.require_valid_channel_name(channel), "Channel name not valid"
173184
# Make sure the message does not contain reserved keys
174185
assert "__asgi_channel__" not in message
175186
# If it's a process-local channel, strip off local part and stick full name in message
@@ -203,13 +214,15 @@ async def send(self, channel, message):
203214
await connection.zadd(channel_key, {self.serialize(message): time.time()})
204215
await connection.expire(channel_key, int(self.expiry))
205216

206-
def _backup_channel_name(self, channel):
217+
def _backup_channel_name(self, channel: str) -> str:
207218
"""
208219
Construct the key used as a backup queue for the given channel.
209220
"""
210221
return channel + "$inflight"
211222

212-
async def _brpop_with_clean(self, index, channel, timeout):
223+
async def _brpop_with_clean(
224+
self, index: int, channel: str, timeout: "Union[int, float, bytes, str]"
225+
):
213226
"""
214227
Perform a Redis BRPOP and manage the backup processing queue.
215228
In case of cancellation, make sure the message is not lost.
@@ -240,23 +253,23 @@ async def _brpop_with_clean(self, index, channel, timeout):
240253

241254
return member
242255

243-
async def _clean_receive_backup(self, index, channel):
256+
async def _clean_receive_backup(self, index: int, channel: str):
244257
"""
245258
Pop the oldest message off the channel backup queue.
246259
The result isn't interesting as it was already processed.
247260
"""
248261
connection = self.connection(index)
249262
await connection.zpopmin(self._backup_channel_name(channel))
250263

251-
async def receive(self, channel):
264+
async def receive(self, channel: str):
252265
"""
253266
Receive the first message that arrives on the channel.
254267
If more than one coroutine waits on the same channel, the first waiter
255268
will be given the message when it arrives.
256269
"""
257270
# Make sure the channel name is valid then get the non-local part
258271
# and thus its index
259-
assert self.valid_channel_name(channel)
272+
assert self.require_valid_channel_name(channel)
260273
if "!" in channel:
261274
real_channel = self.non_local_name(channel)
262275
assert real_channel.endswith(
@@ -372,12 +385,14 @@ async def receive(self, channel):
372385
# Do a plain direct receive
373386
return (await self.receive_single(channel))[1]
374387

375-
async def receive_single(self, channel):
388+
async def receive_single(self, channel: str) -> "Tuple":
376389
"""
377390
Receives a single message off of the channel and returns it.
378391
"""
379392
# Check channel name
380-
assert self.valid_channel_name(channel, receive=True), "Channel name invalid"
393+
assert self.require_valid_channel_name(
394+
channel, receive=True
395+
), "Channel name invalid"
381396
# Work out the connection to use
382397
if "!" in channel:
383398
assert channel.endswith("!")
@@ -408,7 +423,7 @@ async def receive_single(self, channel):
408423
)
409424
self.receive_cleaners.append(cleaner)
410425

411-
def _cleanup_done(cleaner):
426+
def _cleanup_done(cleaner: "asyncio.Task"):
412427
self.receive_cleaners.remove(cleaner)
413428
self.receive_clean_locks.release(channel_key)
414429

@@ -427,7 +442,7 @@ def _cleanup_done(cleaner):
427442
del message["__asgi_channel__"]
428443
return channel, message
429444

430-
async def new_channel(self, prefix="specific"):
445+
async def new_channel(self, prefix: str = "specific") -> str:
431446
"""
432447
Returns a new channel name that can be used by something in our
433448
process as a specific channel.
@@ -477,13 +492,13 @@ async def wait_received(self):
477492

478493
### Groups extension ###
479494

480-
async def group_add(self, group, channel):
495+
async def group_add(self, group: str, channel: str):
481496
"""
482497
Adds the channel name to a group.
483498
"""
484499
# Check the inputs
485-
assert self.valid_group_name(group), "Group name not valid"
486-
assert self.valid_channel_name(channel), "Channel name not valid"
500+
assert self.require_valid_group_name(group), True
501+
assert self.require_valid_channel_name(channel), True
487502
# Get a connection to the right shard
488503
group_key = self._group_key(group)
489504
connection = self.connection(self.consistent_hash(group))
@@ -493,22 +508,22 @@ async def group_add(self, group, channel):
493508
# it at this point is guaranteed to expire before that
494509
await connection.expire(group_key, self.group_expiry)
495510

496-
async def group_discard(self, group, channel):
511+
async def group_discard(self, group: str, channel: str):
497512
"""
498513
Removes the channel from the named group if it is in the group;
499514
does nothing otherwise (does not error)
500515
"""
501-
assert self.valid_group_name(group), "Group name not valid"
502-
assert self.valid_channel_name(channel), "Channel name not valid"
516+
assert self.require_valid_group_name(group), "Group name not valid"
517+
assert self.require_valid_channel_name(channel), "Channel name not valid"
503518
key = self._group_key(group)
504519
connection = self.connection(self.consistent_hash(group))
505520
await connection.zrem(key, channel)
506521

507-
async def group_send(self, group, message):
522+
async def group_send(self, group: str, message):
508523
"""
509524
Sends a message to the entire group.
510525
"""
511-
assert self.valid_group_name(group), "Group name not valid"
526+
assert self.require_valid_group_name(group), "Group name not valid"
512527
# Retrieve list of all channel names
513528
key = self._group_key(group)
514529
connection = self.connection(self.consistent_hash(group))
@@ -573,7 +588,12 @@ async def group_send(self, group, message):
573588
channels_over_capacity = await connection.eval(
574589
group_send_lua, len(channel_redis_keys), *channel_redis_keys, *args
575590
)
576-
if channels_over_capacity > 0:
591+
_channels_over_capacity = -1
592+
try:
593+
_channels_over_capacity = float(channels_over_capacity)
594+
except Exception:
595+
pass
596+
if _channels_over_capacity > 0:
577597
logger.info(
578598
"%s of %s channels over capacity in group %s",
579599
channels_over_capacity,
@@ -631,37 +651,35 @@ def _map_channel_keys_to_connection(self, channel_names, message):
631651
channel_key_to_capacity,
632652
)
633653

634-
def _group_key(self, group):
654+
def _group_key(self, group: str) -> bytes:
635655
"""
636656
Common function to make the storage key for the group.
637657
"""
638658
return f"{self.prefix}:group:{group}".encode("utf8")
639659

640-
### Serialization ###
641-
642-
def serialize(self, message):
660+
def serialize(self, message) -> bytes:
643661
"""
644662
Serializes message to a byte string.
645663
"""
646664
return self._serializer.serialize(message)
647665

648-
def deserialize(self, message):
666+
def deserialize(self, message: bytes):
649667
"""
650668
Deserializes from a byte string.
651669
"""
652670
return self._serializer.deserialize(message)
653671

654672
### Internal functions ###
655673

656-
def consistent_hash(self, value):
674+
def consistent_hash(self, value: "Union[str, Buffer]") -> int:
657675
return _consistent_hash(value, self.ring_size)
658676

659677
def __str__(self):
660678
return f"{self.__class__.__name__}(hosts={self.hosts})"
661679

662680
### Connection handling ###
663681

664-
def connection(self, index):
682+
def connection(self, index: int) -> "Redis":
665683
"""
666684
Returns the correct connection for the index given.
667685
Lazily instantiates pools.

0 commit comments

Comments
 (0)