Skip to content

Commit c5e95d6

Browse files
committed
FEAT: type hint
1 parent 662b90d commit c5e95d6

File tree

7 files changed

+169
-106
lines changed

7 files changed

+169
-106
lines changed

channels_redis/core.py

Lines changed: 59 additions & 42 deletions
Original file line numberDiff line numberDiff line change
@@ -19,6 +19,14 @@
1919
create_pool,
2020
decode_hosts,
2121
)
22+
from typing import TYPE_CHECKING, Dict, List, Tuple, Union, Optional
23+
24+
if TYPE_CHECKING:
25+
from redis.asyncio.connection import ConnectionPool
26+
from redis.asyncio.client import Redis
27+
from .core import RedisChannelLayer
28+
from _typeshed import ReadableBuffer
29+
from redis.typing import TimeoutSecT
2230

2331
logger = logging.getLogger(__name__)
2432

@@ -32,23 +40,27 @@ class ChannelLock:
3240
"""
3341

3442
def __init__(self):
35-
self.locks = collections.defaultdict(asyncio.Lock)
36-
self.wait_counts = collections.defaultdict(int)
43+
self.locks: "collections.defaultdict[str, asyncio.Lock]" = (
44+
collections.defaultdict(asyncio.Lock)
45+
)
46+
self.wait_counts: "collections.defaultdict[str, int]" = collections.defaultdict(
47+
int
48+
)
3749

38-
async def acquire(self, channel):
50+
async def acquire(self, channel: str) -> bool:
3951
"""
4052
Acquire the lock for the given channel.
4153
"""
4254
self.wait_counts[channel] += 1
4355
return await self.locks[channel].acquire()
4456

45-
def locked(self, channel):
57+
def locked(self, channel: str) -> bool:
4658
"""
4759
Return ``True`` if the lock for the given channel is acquired.
4860
"""
4961
return self.locks[channel].locked()
5062

51-
def release(self, channel):
63+
def release(self, channel: str):
5264
"""
5365
Release the lock for the given channel.
5466
"""
@@ -73,12 +85,12 @@ def put_nowait(self, item):
7385

7486

7587
class RedisLoopLayer:
76-
def __init__(self, channel_layer):
88+
def __init__(self, channel_layer: "RedisChannelLayer"):
7789
self._lock = asyncio.Lock()
7890
self.channel_layer = channel_layer
79-
self._connections = {}
91+
self._connections: "Dict[int, Redis]" = {}
8092

81-
def get_connection(self, index):
93+
def get_connection(self, index: int) -> "Redis":
8294
if index not in self._connections:
8395
pool = self.channel_layer.create_pool(index)
8496
self._connections[index] = aioredis.Redis(connection_pool=pool)
@@ -134,7 +146,7 @@ def __init__(
134146
symmetric_encryption_keys=symmetric_encryption_keys,
135147
)
136148
# Cached redis connection pools and the event loop they are from
137-
self._layers = {}
149+
self._layers: "Dict[asyncio.AbstractEventLoop, RedisLoopLayer]" = {}
138150
# Normal channels choose a host index by cycling through the available hosts
139151
self._receive_index_generator = itertools.cycle(range(len(self.hosts)))
140152
self._send_index_generator = itertools.cycle(range(len(self.hosts)))
@@ -143,33 +155,33 @@ def __init__(
143155
# Number of coroutines trying to receive right now
144156
self.receive_count = 0
145157
# The receive lock
146-
self.receive_lock = None
158+
self.receive_lock: "Optional[asyncio.Lock]" = None
147159
# Event loop they are trying to receive on
148-
self.receive_event_loop = None
160+
self.receive_event_loop: "Optional[asyncio.AbstractEventLoop]" = None
149161
# Buffered messages by process-local channel name
150-
self.receive_buffer = collections.defaultdict(
151-
functools.partial(BoundedQueue, self.capacity)
162+
self.receive_buffer: "collections.defaultdict[str, BoundedQueue]" = (
163+
collections.defaultdict(functools.partial(BoundedQueue, self.capacity))
152164
)
153165
# Detached channel cleanup tasks
154-
self.receive_cleaners = []
166+
self.receive_cleaners: "List[asyncio.Task]" = []
155167
# Per-channel cleanup locks to prevent a receive starting and moving
156168
# a message back into the main queue before its cleanup has completed
157169
self.receive_clean_locks = ChannelLock()
158170

159-
def create_pool(self, index):
171+
def create_pool(self, index: int) -> "ConnectionPool":
160172
return create_pool(self.hosts[index])
161173

162174
### Channel layer API ###
163175

164176
extensions = ["groups", "flush"]
165177

166-
async def send(self, channel, message):
178+
async def send(self, channel: str, message):
167179
"""
168180
Send a message onto a (general or specific) channel.
169181
"""
170182
# Typecheck
171183
assert isinstance(message, dict), "message is not a dict"
172-
assert self.valid_channel_name(channel), "Channel name not valid"
184+
assert self.require_valid_channel_name(channel), "Channel name not valid"
173185
# Make sure the message does not contain reserved keys
174186
assert "__asgi_channel__" not in message
175187
# If it's a process-local channel, strip off local part and stick full name in message
@@ -203,13 +215,13 @@ async def send(self, channel, message):
203215
await connection.zadd(channel_key, {self.serialize(message): time.time()})
204216
await connection.expire(channel_key, int(self.expiry))
205217

206-
def _backup_channel_name(self, channel):
218+
def _backup_channel_name(self, channel: str) -> str:
207219
"""
208220
Construct the key used as a backup queue for the given channel.
209221
"""
210222
return channel + "$inflight"
211223

212-
async def _brpop_with_clean(self, index, channel, timeout):
224+
async def _brpop_with_clean(self, index: int, channel: str, timeout: "TimeoutSecT"):
213225
"""
214226
Perform a Redis BRPOP and manage the backup processing queue.
215227
In case of cancellation, make sure the message is not lost.
@@ -240,23 +252,23 @@ async def _brpop_with_clean(self, index, channel, timeout):
240252

241253
return member
242254

243-
async def _clean_receive_backup(self, index, channel):
255+
async def _clean_receive_backup(self, index: int, channel: str):
244256
"""
245257
Pop the oldest message off the channel backup queue.
246258
The result isn't interesting as it was already processed.
247259
"""
248260
connection = self.connection(index)
249261
await connection.zpopmin(self._backup_channel_name(channel))
250262

251-
async def receive(self, channel):
263+
async def receive(self, channel: str):
252264
"""
253265
Receive the first message that arrives on the channel.
254266
If more than one coroutine waits on the same channel, the first waiter
255267
will be given the message when it arrives.
256268
"""
257269
# Make sure the channel name is valid then get the non-local part
258270
# and thus its index
259-
assert self.valid_channel_name(channel)
271+
assert self.require_valid_channel_name(channel)
260272
if "!" in channel:
261273
real_channel = self.non_local_name(channel)
262274
assert real_channel.endswith(
@@ -372,12 +384,14 @@ async def receive(self, channel):
372384
# Do a plain direct receive
373385
return (await self.receive_single(channel))[1]
374386

375-
async def receive_single(self, channel):
387+
async def receive_single(self, channel: str) -> "Tuple":
376388
"""
377389
Receives a single message off of the channel and returns it.
378390
"""
379391
# Check channel name
380-
assert self.valid_channel_name(channel, receive=True), "Channel name invalid"
392+
assert self.require_valid_channel_name(
393+
channel, receive=True
394+
), "Channel name invalid"
381395
# Work out the connection to use
382396
if "!" in channel:
383397
assert channel.endswith("!")
@@ -408,7 +422,7 @@ async def receive_single(self, channel):
408422
)
409423
self.receive_cleaners.append(cleaner)
410424

411-
def _cleanup_done(cleaner):
425+
def _cleanup_done(cleaner: "asyncio.Task"):
412426
self.receive_cleaners.remove(cleaner)
413427
self.receive_clean_locks.release(channel_key)
414428

@@ -427,7 +441,7 @@ def _cleanup_done(cleaner):
427441
del message["__asgi_channel__"]
428442
return channel, message
429443

430-
async def new_channel(self, prefix="specific"):
444+
async def new_channel(self, prefix: str = "specific") -> str:
431445
"""
432446
Returns a new channel name that can be used by something in our
433447
process as a specific channel.
@@ -477,13 +491,13 @@ async def wait_received(self):
477491

478492
### Groups extension ###
479493

480-
async def group_add(self, group, channel):
494+
async def group_add(self, group: str, channel: str):
481495
"""
482496
Adds the channel name to a group.
483497
"""
484498
# Check the inputs
485-
assert self.valid_group_name(group), "Group name not valid"
486-
assert self.valid_channel_name(channel), "Channel name not valid"
499+
assert self.require_valid_group_name(group), True
500+
assert self.require_valid_channel_name(channel), True
487501
# Get a connection to the right shard
488502
group_key = self._group_key(group)
489503
connection = self.connection(self.consistent_hash(group))
@@ -493,22 +507,22 @@ async def group_add(self, group, channel):
493507
# it at this point is guaranteed to expire before that
494508
await connection.expire(group_key, self.group_expiry)
495509

496-
async def group_discard(self, group, channel):
510+
async def group_discard(self, group: str, channel: str):
497511
"""
498512
Removes the channel from the named group if it is in the group;
499513
does nothing otherwise (does not error)
500514
"""
501-
assert self.valid_group_name(group), "Group name not valid"
502-
assert self.valid_channel_name(channel), "Channel name not valid"
515+
assert self.require_valid_group_name(group), "Group name not valid"
516+
assert self.require_valid_channel_name(channel), "Channel name not valid"
503517
key = self._group_key(group)
504518
connection = self.connection(self.consistent_hash(group))
505519
await connection.zrem(key, channel)
506520

507-
async def group_send(self, group, message):
521+
async def group_send(self, group: str, message):
508522
"""
509523
Sends a message to the entire group.
510524
"""
511-
assert self.valid_group_name(group), "Group name not valid"
525+
assert self.require_valid_group_name(group), "Group name not valid"
512526
# Retrieve list of all channel names
513527
key = self._group_key(group)
514528
connection = self.connection(self.consistent_hash(group))
@@ -573,7 +587,12 @@ async def group_send(self, group, message):
573587
channels_over_capacity = await connection.eval(
574588
group_send_lua, len(channel_redis_keys), *channel_redis_keys, *args
575589
)
576-
if channels_over_capacity > 0:
590+
_channels_over_capacity = -1
591+
try:
592+
_channels_over_capacity = float(channels_over_capacity)
593+
except Exception:
594+
pass
595+
if _channels_over_capacity > 0:
577596
logger.info(
578597
"%s of %s channels over capacity in group %s",
579598
channels_over_capacity,
@@ -631,37 +650,35 @@ def _map_channel_keys_to_connection(self, channel_names, message):
631650
channel_key_to_capacity,
632651
)
633652

634-
def _group_key(self, group):
653+
def _group_key(self, group: str) -> bytes:
635654
"""
636655
Common function to make the storage key for the group.
637656
"""
638657
return f"{self.prefix}:group:{group}".encode("utf8")
639658

640-
### Serialization ###
641-
642-
def serialize(self, message):
659+
def serialize(self, message) -> bytes:
643660
"""
644661
Serializes message to a byte string.
645662
"""
646663
return self._serializer.serialize(message)
647664

648-
def deserialize(self, message):
665+
def deserialize(self, message: bytes):
649666
"""
650667
Deserializes from a byte string.
651668
"""
652669
return self._serializer.deserialize(message)
653670

654671
### Internal functions ###
655672

656-
def consistent_hash(self, value):
673+
def consistent_hash(self, value: "Union[str, ReadableBuffer]") -> int:
657674
return _consistent_hash(value, self.ring_size)
658675

659676
def __str__(self):
660677
return f"{self.__class__.__name__}(hosts={self.hosts})"
661678

662679
### Connection handling ###
663680

664-
def connection(self, index):
681+
def connection(self, index: int) -> "Redis":
665682
"""
666683
Returns the correct connection for the index given.
667684
Lazily instantiates pools.

0 commit comments

Comments
 (0)