diff --git a/channels_redis/core.py b/channels_redis/core.py index 9721d7b..2dbcbf2 100644 --- a/channels_redis/core.py +++ b/channels_redis/core.py @@ -4,12 +4,13 @@ import itertools import logging import time +import typing import uuid from redis import asyncio as aioredis -from channels.exceptions import ChannelFull -from channels.layers import BaseChannelLayer +from channels.exceptions import ChannelFull # type: ignore[import-untyped] +from channels.layers import BaseChannelLayer # type: ignore[import-untyped] from .serializers import registry from .utils import ( @@ -20,6 +21,11 @@ decode_hosts, ) +if typing.TYPE_CHECKING: + from redis.asyncio.client import Redis + from redis.asyncio.connection import ConnectionPool + from typing_extensions import Buffer + logger = logging.getLogger(__name__) @@ -31,24 +37,26 @@ class ChannelLock: to mitigate multi-event loop problems. """ - def __init__(self): - self.locks = collections.defaultdict(asyncio.Lock) - self.wait_counts = collections.defaultdict(int) + def __init__(self) -> None: + self.locks: typing.DefaultDict[str, asyncio.Lock] = collections.defaultdict( + asyncio.Lock + ) + self.wait_counts: typing.DefaultDict[str, int] = collections.defaultdict(int) - async def acquire(self, channel): + async def acquire(self, channel: str) -> bool: """ Acquire the lock for the given channel. """ self.wait_counts[channel] += 1 return await self.locks[channel].acquire() - def locked(self, channel): + def locked(self, channel: str) -> bool: """ Return ``True`` if the lock for the given channel is acquired. """ return self.locks[channel].locked() - def release(self, channel): + def release(self, channel: str) -> None: """ Release the lock for the given channel. """ @@ -60,7 +68,7 @@ def release(self, channel): class BoundedQueue(asyncio.Queue): - def put_nowait(self, item): + def put_nowait(self, item: typing.Any) -> None: if self.full(): # see: https://github.com/django/channels_redis/issues/212 # if we actually get into this code block, it likely means that @@ -73,19 +81,19 @@ def put_nowait(self, item): class RedisLoopLayer: - def __init__(self, channel_layer): + def __init__(self, channel_layer: "RedisChannelLayer") -> None: self._lock = asyncio.Lock() self.channel_layer = channel_layer - self._connections = {} + self._connections: typing.Dict[int, "Redis"] = {} - def get_connection(self, index): + def get_connection(self, index: int) -> "Redis": if index not in self._connections: pool = self.channel_layer.create_pool(index) self._connections[index] = aioredis.Redis(connection_pool=pool) return self._connections[index] - async def flush(self): + async def flush(self) -> None: async with self._lock: for index in list(self._connections): connection = self._connections.pop(index) @@ -106,15 +114,15 @@ class RedisChannelLayer(BaseChannelLayer): def __init__( self, hosts=None, - prefix="asgi", + prefix: str = "asgi", expiry=60, - group_expiry=86400, + group_expiry: int = 86400, capacity=100, channel_capacity=None, symmetric_encryption_keys=None, random_prefix_length=12, serializer_format="msgpack", - ): + ) -> None: # Store basic information self.expiry = expiry self.group_expiry = group_expiry @@ -134,7 +142,7 @@ def __init__( symmetric_encryption_keys=symmetric_encryption_keys, ) # Cached redis connection pools and the event loop they are from - self._layers = {} + self._layers: typing.Dict[asyncio.AbstractEventLoop, "RedisLoopLayer"] = {} # Normal channels choose a host index by cycling through the available hosts self._receive_index_generator = itertools.cycle(range(len(self.hosts))) self._send_index_generator = itertools.cycle(range(len(self.hosts))) @@ -143,27 +151,27 @@ def __init__( # Number of coroutines trying to receive right now self.receive_count = 0 # The receive lock - self.receive_lock = None + self.receive_lock: typing.Optional[asyncio.Lock] = None # Event loop they are trying to receive on - self.receive_event_loop = None + self.receive_event_loop: typing.Optional[asyncio.AbstractEventLoop] = None # Buffered messages by process-local channel name - self.receive_buffer = collections.defaultdict( - functools.partial(BoundedQueue, self.capacity) + self.receive_buffer: typing.DefaultDict[str, BoundedQueue] = ( + collections.defaultdict(functools.partial(BoundedQueue, self.capacity)) ) # Detached channel cleanup tasks - self.receive_cleaners = [] + self.receive_cleaners: typing.List["asyncio.Task[typing.Any]"] = [] # Per-channel cleanup locks to prevent a receive starting and moving # a message back into the main queue before its cleanup has completed self.receive_clean_locks = ChannelLock() - def create_pool(self, index): + def create_pool(self, index: int) -> "ConnectionPool": return create_pool(self.hosts[index]) ### Channel layer API ### extensions = ["groups", "flush"] - async def send(self, channel, message): + async def send(self, channel: str, message: typing.Any) -> None: """ Send a message onto a (general or specific) channel. """ @@ -203,13 +211,15 @@ async def send(self, channel, message): await connection.zadd(channel_key, {self.serialize(message): time.time()}) await connection.expire(channel_key, int(self.expiry)) - def _backup_channel_name(self, channel): + def _backup_channel_name(self, channel: str) -> str: """ Construct the key used as a backup queue for the given channel. """ return channel + "$inflight" - async def _brpop_with_clean(self, index, channel, timeout): + async def _brpop_with_clean( + self, index: int, channel: str, timeout: typing.Union[int, float, bytes, str] + ) -> typing.Any: """ Perform a Redis BRPOP and manage the backup processing queue. In case of cancellation, make sure the message is not lost. @@ -228,7 +238,7 @@ async def _brpop_with_clean(self, index, channel, timeout): connection = self.connection(index) # Cancellation here doesn't matter, we're not doing anything destructive # and the script executes atomically... - await connection.eval(cleanup_script, 0, channel, backup_queue) + await connection.eval(cleanup_script, 0, channel, backup_queue) # type: ignore[misc] # ...and it doesn't matter here either, the message will be safe in the backup. result = await connection.bzpopmin(channel, timeout=timeout) @@ -240,7 +250,7 @@ async def _brpop_with_clean(self, index, channel, timeout): return member - async def _clean_receive_backup(self, index, channel): + async def _clean_receive_backup(self, index: int, channel: str) -> None: """ Pop the oldest message off the channel backup queue. The result isn't interesting as it was already processed. @@ -248,7 +258,7 @@ async def _clean_receive_backup(self, index, channel): connection = self.connection(index) await connection.zpopmin(self._backup_channel_name(channel)) - async def receive(self, channel): + async def receive(self, channel: str) -> typing.Any: """ Receive the first message that arrives on the channel. If more than one coroutine waits on the same channel, the first waiter @@ -280,11 +290,11 @@ async def receive(self, channel): # Wait for our message to appear message = None while self.receive_buffer[channel].empty(): - tasks = [ - self.receive_lock.acquire(), + _tasks = [ + self.receive_lock.acquire(), # type: ignore[union-attr] self.receive_buffer[channel].get(), ] - tasks = [asyncio.ensure_future(task) for task in tasks] + tasks = [asyncio.ensure_future(task) for task in _tasks] try: done, pending = await asyncio.wait( tasks, return_when=asyncio.FIRST_COMPLETED @@ -300,7 +310,7 @@ async def receive(self, channel): if not task.cancel(): assert task.done() if task.result() is True: - self.receive_lock.release() + self.receive_lock.release() # type: ignore[union-attr] raise @@ -323,7 +333,7 @@ async def receive(self, channel): if message or exception: if token: # We will not be receving as we already have the message. - self.receive_lock.release() + self.receive_lock.release() # type: ignore[union-attr] if exception: raise exception @@ -350,7 +360,7 @@ async def receive(self, channel): del self.receive_buffer[channel] raise finally: - self.receive_lock.release() + self.receive_lock.release() # type: ignore[union-attr] # We know there's a message available, because there # couldn't have been any interruption between empty() and here @@ -365,14 +375,16 @@ async def receive(self, channel): self.receive_count -= 1 # If we were the last out, drop the receive lock if self.receive_count == 0: - assert not self.receive_lock.locked() + assert not self.receive_lock.locked() # type: ignore[union-attr] self.receive_lock = None self.receive_event_loop = None else: # Do a plain direct receive return (await self.receive_single(channel))[1] - async def receive_single(self, channel): + async def receive_single( + self, channel: str + ) -> typing.Tuple[typing.Any, typing.Any]: """ Receives a single message off of the channel and returns it. """ @@ -410,7 +422,7 @@ async def receive_single(self, channel): ) self.receive_cleaners.append(cleaner) - def _cleanup_done(cleaner): + def _cleanup_done(cleaner: "asyncio.Task") -> None: self.receive_cleaners.remove(cleaner) self.receive_clean_locks.release(channel_key) @@ -429,7 +441,7 @@ def _cleanup_done(cleaner): del message["__asgi_channel__"] return channel, message - async def new_channel(self, prefix="specific"): + async def new_channel(self, prefix: str = "specific") -> str: """ Returns a new channel name that can be used by something in our process as a specific channel. @@ -438,7 +450,7 @@ async def new_channel(self, prefix="specific"): ### Flush extension ### - async def flush(self): + async def flush(self) -> None: """ Deletes all messages and groups on all shards. """ @@ -456,11 +468,11 @@ async def flush(self): # Go through each connection and remove all with prefix for i in range(self.ring_size): connection = self.connection(i) - await connection.eval(delete_prefix, 0, self.prefix + "*") + await connection.eval(delete_prefix, 0, self.prefix + "*") # type: ignore[union-attr,misc] # Now clear the pools as well await self.close_pools() - async def close_pools(self): + async def close_pools(self) -> None: """ Close all connections in the event loop pools. """ @@ -470,7 +482,7 @@ async def close_pools(self): for layer in self._layers.values(): await layer.flush() - async def wait_received(self): + async def wait_received(self) -> None: """ Wait for all channel cleanup functions to finish. """ @@ -479,7 +491,7 @@ async def wait_received(self): ### Groups extension ### - async def group_add(self, group, channel): + async def group_add(self, group: str, channel: str) -> None: """ Adds the channel name to a group. """ @@ -495,7 +507,7 @@ async def group_add(self, group, channel): # it at this point is guaranteed to expire before that await connection.expire(group_key, self.group_expiry) - async def group_discard(self, group, channel): + async def group_discard(self, group: str, channel: str) -> None: """ Removes the channel from the named group if it is in the group; does nothing otherwise (does not error) @@ -506,7 +518,7 @@ async def group_discard(self, group, channel): connection = self.connection(self.consistent_hash(group)) await connection.zrem(key, channel) - async def group_send(self, group, message): + async def group_send(self, group: str, message: typing.Any) -> None: """ Sends a message to the entire group. """ @@ -530,9 +542,9 @@ async def group_send(self, group, message): for connection_index, channel_redis_keys in connection_to_channel_keys.items(): # Discard old messages based on expiry pipe = connection.pipeline() - for key in channel_redis_keys: + for _key in channel_redis_keys: pipe.zremrangebyscore( - key, min=0, max=int(time.time()) - int(self.expiry) + _key, min=0, max=int(time.time()) - int(self.expiry) ) await pipe.execute() @@ -572,10 +584,15 @@ async def group_send(self, group, message): # channel_keys does not contain a single redis key more than once connection = self.connection(connection_index) - channels_over_capacity = await connection.eval( + channels_over_capacity = await connection.eval( # type: ignore[misc] group_send_lua, len(channel_redis_keys), *channel_redis_keys, *args ) - if channels_over_capacity > 0: + _channels_over_capacity = -1.0 + try: + _channels_over_capacity = float(channels_over_capacity) + except Exception: + pass + if _channels_over_capacity > 0: logger.info( "%s of %s channels over capacity in group %s", channels_over_capacity, @@ -583,7 +600,13 @@ async def group_send(self, group, message): group, ) - def _map_channel_keys_to_connection(self, channel_names, message): + def _map_channel_keys_to_connection( + self, channel_names: typing.Iterable[str], message: typing.Any + ) -> typing.Tuple[ + typing.Dict[int, typing.List[str]], + typing.Dict[str, typing.Any], + typing.Dict[str, int], + ]: """ For a list of channel names, GET @@ -596,11 +619,13 @@ def _map_channel_keys_to_connection(self, channel_names, message): """ # Connection dict keyed by index to list of redis keys mapped on that index - connection_to_channel_keys = collections.defaultdict(list) + connection_to_channel_keys: typing.Dict[int, typing.List[str]] = ( + collections.defaultdict(list) + ) # Message dict maps redis key to the message that needs to be send on that key - channel_key_to_message = dict() + channel_key_to_message: typing.Dict[str, typing.Any] = dict() # Channel key mapped to its capacity - channel_key_to_capacity = dict() + channel_key_to_capacity: typing.Dict[str, int] = dict() # For each channel for channel in channel_names: @@ -608,7 +633,7 @@ def _map_channel_keys_to_connection(self, channel_names, message): if "!" in channel: channel_non_local_name = self.non_local_name(channel) # Get its redis key - channel_key = self.prefix + channel_non_local_name + channel_key: str = self.prefix + channel_non_local_name # Have we come across the same redis key? if channel_key not in channel_key_to_message: # If not, fill the corresponding dicts @@ -633,7 +658,7 @@ def _map_channel_keys_to_connection(self, channel_names, message): channel_key_to_capacity, ) - def _group_key(self, group): + def _group_key(self, group: str) -> bytes: """ Common function to make the storage key for the group. """ @@ -641,13 +666,13 @@ def _group_key(self, group): ### Serialization ### - def serialize(self, message): + def serialize(self, message: typing.Any) -> bytes: """ Serializes message to a byte string. """ return self._serializer.serialize(message) - def deserialize(self, message): + def deserialize(self, message: bytes) -> typing.Any: """ Deserializes from a byte string. """ @@ -655,15 +680,15 @@ def deserialize(self, message): ### Internal functions ### - def consistent_hash(self, value): + def consistent_hash(self, value: typing.Union[str, "Buffer"]) -> int: return _consistent_hash(value, self.ring_size) - def __str__(self): + def __str__(self) -> str: return f"{self.__class__.__name__}(hosts={self.hosts})" ### Connection handling ### - def connection(self, index): + def connection(self, index: int) -> "Redis": """ Returns the correct connection for the index given. Lazily instantiates pools. diff --git a/channels_redis/pubsub.py b/channels_redis/pubsub.py index a80e12d..ef8d8ea 100644 --- a/channels_redis/pubsub.py +++ b/channels_redis/pubsub.py @@ -1,7 +1,9 @@ import asyncio import functools import logging +import typing import uuid +from copy import deepcopy from redis import asyncio as aioredis @@ -14,10 +16,15 @@ decode_hosts, ) +if typing.TYPE_CHECKING: + from redis.asyncio.client import PubSub, Redis + from typing_extensions import Buffer logger = logging.getLogger(__name__) -async def _async_proxy(obj, name, *args, **kwargs): +async def _async_proxy( + obj: "RedisPubSubChannelLayer", name: str, *args, **kwargs +) -> typing.Any: # Must be defined as a function and not a method due to # https://bugs.python.org/issue38364 layer = obj._get_layer() @@ -28,20 +35,22 @@ class RedisPubSubChannelLayer: def __init__( self, *args, - symmetric_encryption_keys=None, + symmetric_encryption_keys: typing.Optional[ + typing.Iterable[typing.Union[str, "Buffer"]] + ] = None, serializer_format="msgpack", **kwargs, ) -> None: self._args = args self._kwargs = kwargs - self._layers = {} + self._layers: typing.Dict[asyncio.AbstractEventLoop, RedisPubSubLoopLayer] = {} # serialization self._serializer = registry.get_serializer( serializer_format, symmetric_encryption_keys=symmetric_encryption_keys, ) - def __getattr__(self, name): + def __getattr__(self, name: str) -> typing.Any: if name in ( "new_channel", "send", @@ -55,28 +64,29 @@ def __getattr__(self, name): else: return getattr(self._get_layer(), name) - def serialize(self, message): + def serialize(self, message: typing.Any) -> bytes: """ Serializes message to a byte string. """ return self._serializer.serialize(message) - def deserialize(self, message): + def deserialize(self, message: bytes) -> typing.Any: """ Deserializes from a byte string. """ return self._serializer.deserialize(message) - def _get_layer(self): + def _get_layer(self) -> "RedisPubSubLoopLayer": loop = asyncio.get_running_loop() try: layer = self._layers[loop] except KeyError: + kwargs = deepcopy(self._kwargs) + kwargs["channel_layer"] = self layer = RedisPubSubLoopLayer( *self._args, - **self._kwargs, - channel_layer=self, + **kwargs, ) self._layers[loop] = layer _wrap_close(self, loop) @@ -91,13 +101,13 @@ class RedisPubSubLoopLayer: def __init__( self, - hosts=None, - prefix="asgi", + hosts: typing.Union[typing.Iterable[typing.Any], str, bytes, None] = None, + prefix: str = "asgi", on_disconnect=None, on_reconnect=None, - channel_layer=None, + channel_layer: typing.Optional[RedisPubSubChannelLayer] = None, **kwargs, - ): + ) -> None: self.prefix = prefix self.on_disconnect = on_disconnect @@ -106,24 +116,26 @@ def __init__( # Each consumer gets its own *specific* channel, created with the `new_channel()` method. # This dict maps `channel_name` to a queue of messages for that channel. - self.channels = {} + self.channels: typing.Dict[typing.Any, asyncio.Queue[typing.Any]] = {} # A channel can subscribe to zero or more groups. # This dict maps `group_name` to set of channel names who are subscribed to that group. - self.groups = {} + self.groups: typing.Dict[str, typing.Set[typing.Any]] = {} # For each host, we create a `RedisSingleShardConnection` to manage the connection to that host. self._shards = [ RedisSingleShardConnection(host, self) for host in decode_hosts(hosts) ] - def _get_shard(self, channel_or_group_name): + def _get_shard( + self, channel_or_group_name: typing.Union[str, "Buffer"] + ) -> "RedisSingleShardConnection": """ Return the shard that is used exclusively for this channel or group. """ return self._shards[_consistent_hash(channel_or_group_name, len(self._shards))] - def _get_group_channel_name(self, group): + def _get_group_channel_name(self, group: str) -> str: """ Return the channel name used by a group. Includes '__group__' in the returned @@ -135,7 +147,7 @@ def _get_group_channel_name(self, group): """ return f"{self.prefix}__group__{group}" - async def _subscribe_to_channel(self, channel): + async def _subscribe_to_channel(self, channel: typing.Union[str, bytes]) -> None: self.channels[channel] = asyncio.Queue() shard = self._get_shard(channel) await shard.subscribe(channel) @@ -146,14 +158,16 @@ async def _subscribe_to_channel(self, channel): # Channel layer API ################################################################################ - async def send(self, channel, message): + async def send( + self, channel: typing.Union[str, bytes], message: typing.Any + ) -> None: """ Send a message onto a (general or specific) channel. """ shard = self._get_shard(channel) - await shard.publish(channel, self.channel_layer.serialize(message)) + await shard.publish(channel, self.channel_layer.serialize(message)) # type: ignore[union-attr] - async def new_channel(self, prefix="specific."): + async def new_channel(self, prefix: str = "specific.") -> str: """ Returns a new channel name that can be used by a consumer in our process as a specific channel. @@ -162,7 +176,7 @@ async def new_channel(self, prefix="specific."): await self._subscribe_to_channel(channel) return channel - async def receive(self, channel): + async def receive(self, channel: typing.Union[str, bytes]) -> typing.Any: """ Receive the first message that arrives on the channel. If more than one coroutine waits on the same channel, a random one @@ -193,13 +207,13 @@ async def receive(self, channel): # We don't re-raise here because we want the CancelledError to be the one re-raised. raise - return self.channel_layer.deserialize(message) + return self.channel_layer.deserialize(message) # type: ignore[union-attr] ################################################################################ # Groups extension ################################################################################ - async def group_add(self, group, channel): + async def group_add(self, group: str, channel: typing.Union[str, bytes]) -> None: """ Adds the channel name to a group. """ @@ -218,7 +232,9 @@ async def group_add(self, group, channel): shard = self._get_shard(group_channel) await shard.subscribe(group_channel) - async def group_discard(self, group, channel): + async def group_discard( + self, group: str, channel: typing.Union[str, bytes] + ) -> None: """ Removes the channel from a group if it is in the group; does nothing otherwise (does not error) @@ -234,19 +250,19 @@ async def group_discard(self, group, channel): shard = self._get_shard(group_channel) await shard.unsubscribe(group_channel) - async def group_send(self, group, message): + async def group_send(self, group: str, message: typing.Any) -> None: """ Send the message to all subscribers of the group. """ group_channel = self._get_group_channel_name(group) shard = self._get_shard(group_channel) - await shard.publish(group_channel, self.channel_layer.serialize(message)) + await shard.publish(group_channel, self.channel_layer.serialize(message)) # type: ignore[union-attr] ################################################################################ # Flush extension ################################################################################ - async def flush(self): + async def flush(self) -> None: """ Flush the layer, making it like new. It can continue to be used as if it was just created. This also closes connections, serving as a clean-up @@ -259,37 +275,43 @@ async def flush(self): class RedisSingleShardConnection: - def __init__(self, host, channel_layer): + def __init__( + self, host: typing.Dict[str, typing.Any], channel_layer: RedisPubSubLoopLayer + ) -> None: self.host = host self.channel_layer = channel_layer - self._subscribed_to = set() + self._subscribed_to: typing.Set[typing.Union[str, bytes]] = set() self._lock = asyncio.Lock() - self._redis = None - self._pubsub = None - self._receive_task = None + self._redis: typing.Optional["Redis"] = None + self._pubsub: typing.Optional["PubSub"] = None + self._receive_task: typing.Optional["asyncio.Task[typing.Any]"] = None - async def publish(self, channel, message): + async def publish( + self, + channel: typing.Union[str, bytes], + message: typing.Union[str, bytes, int, float], + ) -> None: async with self._lock: self._ensure_redis() - await self._redis.publish(channel, message) + await self._redis.publish(channel, message) # type: ignore[union-attr] - async def subscribe(self, channel): + async def subscribe(self, channel: typing.Union[str, bytes]) -> None: async with self._lock: if channel not in self._subscribed_to: self._ensure_redis() self._ensure_receiver() - await self._pubsub.subscribe(channel) + await self._pubsub.subscribe(channel) # type: ignore[union-attr] self._subscribed_to.add(channel) - async def unsubscribe(self, channel): + async def unsubscribe(self, channel: typing.Union[str, bytes]) -> None: async with self._lock: if channel in self._subscribed_to: self._ensure_redis() self._ensure_receiver() - await self._pubsub.unsubscribe(channel) + await self._pubsub.unsubscribe(channel) # type: ignore[union-attr] self._subscribed_to.remove(channel) - async def flush(self): + async def flush(self) -> None: async with self._lock: if self._receive_task is not None: self._receive_task.cancel() @@ -307,7 +329,7 @@ async def flush(self): self._pubsub = None self._subscribed_to = set() - async def _do_receiving(self): + async def _do_receiving(self) -> None: while True: try: if self._pubsub and self._pubsub.subscribed: @@ -327,7 +349,9 @@ async def _do_receiving(self): logger.exception("Unexpected exception in receive task") await asyncio.sleep(1) - def _receive_message(self, message): + def _receive_message( + self, message: typing.Optional[typing.Dict[typing.Any, typing.Any]] + ) -> None: if message is not None: name = message["channel"] data = message["data"] @@ -340,12 +364,12 @@ def _receive_message(self, message): if channel_name in self.channel_layer.channels: self.channel_layer.channels[channel_name].put_nowait(data) - def _ensure_redis(self): + def _ensure_redis(self) -> None: if self._redis is None: pool = create_pool(self.host) self._redis = aioredis.Redis(connection_pool=pool) self._pubsub = self._redis.pubsub() - def _ensure_receiver(self): + def _ensure_receiver(self) -> None: if self._receive_task is None: self._receive_task = asyncio.ensure_future(self._do_receiving()) diff --git a/channels_redis/py.typed b/channels_redis/py.typed new file mode 100644 index 0000000..e69de29 diff --git a/channels_redis/serializers.py b/channels_redis/serializers.py index b981797..e7805bb 100644 --- a/channels_redis/serializers.py +++ b/channels_redis/serializers.py @@ -3,11 +3,19 @@ import hashlib import json import random +import typing + +if typing.TYPE_CHECKING: + from typing_extensions import Buffer try: from cryptography.fernet import Fernet, MultiFernet + + _MultiFernet: typing.Optional[typing.Type[MultiFernet]] = MultiFernet + _Fernet: typing.Optional[typing.Type[Fernet]] = Fernet except ImportError: - MultiFernet = Fernet = None + _MultiFernet = None + _Fernet = None class SerializerDoesNotExist(KeyError): @@ -17,36 +25,44 @@ class SerializerDoesNotExist(KeyError): class BaseMessageSerializer(abc.ABC): def __init__( self, - symmetric_encryption_keys=None, - random_prefix_length=0, - expiry=None, + symmetric_encryption_keys: typing.Optional[ + typing.Iterable[typing.Union[str, "Buffer"]] + ] = None, + random_prefix_length: int = 0, + expiry: typing.Optional[int] = None, ): self.random_prefix_length = random_prefix_length self.expiry = expiry + self.crypter: typing.Optional["MultiFernet"] = None # Set up any encryption objects self._setup_encryption(symmetric_encryption_keys) - def _setup_encryption(self, symmetric_encryption_keys): + def _setup_encryption( + self, + symmetric_encryption_keys: typing.Optional[ + typing.Union[typing.Iterable[typing.Union[str, "Buffer"]], str, bytes] + ], + ) -> None: # See if we can do encryption if they asked if symmetric_encryption_keys: if isinstance(symmetric_encryption_keys, (str, bytes)): raise ValueError( "symmetric_encryption_keys must be a list of possible keys" ) - if MultiFernet is None: + if _MultiFernet is None: raise ValueError( "Cannot run with encryption without 'cryptography' installed." ) sub_fernets = [self.make_fernet(key) for key in symmetric_encryption_keys] - self.crypter = MultiFernet(sub_fernets) + self.crypter = _MultiFernet(sub_fernets) else: self.crypter = None - def make_fernet(self, key): + def make_fernet(self, key: typing.Union[str, "Buffer"]) -> "Fernet": """ Given a single encryption key, returns a Fernet instance using it. """ - if Fernet is None: + if _Fernet is None: raise ValueError( "Cannot run with encryption without 'cryptography' installed." ) @@ -54,35 +70,35 @@ def make_fernet(self, key): if isinstance(key, str): key = key.encode("utf-8") formatted_key = base64.urlsafe_b64encode(hashlib.sha256(key).digest()) - return Fernet(formatted_key) + return _Fernet(formatted_key) @abc.abstractmethod - def as_bytes(self, message, *args, **kwargs): + def as_bytes(self, message: typing.Any, *args, **kwargs) -> bytes: raise NotImplementedError @abc.abstractmethod - def from_bytes(self, message, *args, **kwargs): + def from_bytes(self, message: bytes, *args, **kwargs) -> typing.Any: raise NotImplementedError - def serialize(self, message): + def serialize(self, message: typing.Any) -> bytes: """ Serializes message to a byte string. """ - message = self.as_bytes(message) + msg = self.as_bytes(message) if self.crypter: - message = self.crypter.encrypt(message) + msg = self.crypter.encrypt(msg) if self.random_prefix_length > 0: # provide random prefix - message = ( + msg = ( random.getrandbits(8 * self.random_prefix_length).to_bytes( self.random_prefix_length, "big" ) - + message + + msg ) - return message + return msg - def deserialize(self, message): + def deserialize(self, message: bytes) -> typing.Any: """ Deserializes from a byte string. """ @@ -97,10 +113,10 @@ def deserialize(self, message): class MissingSerializer(BaseMessageSerializer): - exception = None + exception: typing.Optional[Exception] = None - def __init__(self, *args, **kwargs): - raise self.exception + def __init__(self, *args, **kwargs) -> None: + raise self.exception if self.exception else NotImplementedError() class JSONSerializer(BaseMessageSerializer): @@ -108,26 +124,32 @@ class JSONSerializer(BaseMessageSerializer): # thus we must force bytes conversion # we use UTF-8 since it is the recommended encoding for interoperability # see https://docs.python.org/3/library/json.html#character-encodings - def as_bytes(self, message, *args, **kwargs): - message = json.dumps(message, *args, **kwargs) - return message.encode("utf-8") + def as_bytes(self, message: typing.Any, *args, **kwargs) -> bytes: + msg = json.dumps(message, *args, **kwargs) + return msg.encode("utf-8") - from_bytes = staticmethod(json.loads) + from_bytes = staticmethod(json.loads) # type: ignore[assignment] # code ready for a future in which msgpack may become an optional dependency +MsgPackSerializer: typing.Union[ + typing.Type[BaseMessageSerializer], typing.Type[MissingSerializer] +] try: - import msgpack + import msgpack # type: ignore[import-untyped] except ImportError as exc: - class MsgPackSerializer(MissingSerializer): + class _MsgPackSerializer(MissingSerializer): exception = exc + MsgPackSerializer = _MsgPackSerializer else: - class MsgPackSerializer(BaseMessageSerializer): - as_bytes = staticmethod(msgpack.packb) - from_bytes = staticmethod(msgpack.unpackb) + class __MsgPackSerializer(BaseMessageSerializer): + as_bytes = staticmethod(msgpack.packb) # type: ignore[assignment] + from_bytes = staticmethod(msgpack.unpackb) # type: ignore[assignment] + + MsgPackSerializer = __MsgPackSerializer class SerializersRegistry: @@ -135,10 +157,12 @@ class SerializersRegistry: Serializers registry inspired by that of ``django.core.serializers``. """ - def __init__(self): - self._registry = {} + def __init__(self) -> None: + self._registry: typing.Dict[typing.Any, typing.Type[BaseMessageSerializer]] = {} - def register_serializer(self, format, serializer_class): + def register_serializer( + self, format: typing.Any, serializer_class: typing.Type[BaseMessageSerializer] + ) -> None: """ Register a new serializer for given format """ @@ -155,7 +179,9 @@ def register_serializer(self, format, serializer_class): self._registry[format] = serializer_class - def get_serializer(self, format, *args, **kwargs): + def get_serializer( + self, format: typing.Any, *args, **kwargs + ) -> BaseMessageSerializer: try: serializer_class = self._registry[format] except KeyError: diff --git a/channels_redis/utils.py b/channels_redis/utils.py index 6f15050..89265f8 100644 --- a/channels_redis/utils.py +++ b/channels_redis/utils.py @@ -1,10 +1,22 @@ import binascii import types +import typing from redis import asyncio as aioredis +from redis.asyncio import sentinel as redis_sentinel +if typing.TYPE_CHECKING: + from asyncio import AbstractEventLoop -def _consistent_hash(value, ring_size): + from redis.asyncio.client import Redis + from redis.asyncio.connection import ConnectionPool + from typing_extensions import Buffer + + from .core import RedisChannelLayer + from .pubsub import RedisPubSubChannelLayer + + +def _consistent_hash(value: typing.Union[str, "Buffer"], ring_size: int) -> int: """ Maps the value to a node value between 0 and 4095 using CRC, then down to one of the ring nodes. @@ -20,10 +32,13 @@ def _consistent_hash(value, ring_size): return int(bigval / ring_divisor) -def _wrap_close(proxy, loop): +def _wrap_close( + proxy: typing.Union["RedisPubSubChannelLayer", "RedisChannelLayer"], + loop: "AbstractEventLoop", +) -> None: original_impl = loop.close - def _wrapper(self, *args, **kwargs): + def _wrapper(self, *args, **kwargs) -> typing.Any: if loop in proxy._layers: layer = proxy._layers[loop] del proxy._layers[loop] @@ -32,10 +47,10 @@ def _wrapper(self, *args, **kwargs): self.close = original_impl return self.close(*args, **kwargs) - loop.close = types.MethodType(_wrapper, loop) + loop.close = types.MethodType(_wrapper, loop) # type: ignore[method-assign] -async def _close_redis(connection): +async def _close_redis(connection: "Redis") -> None: """ Handle compatibility with redis-py 4.x and 5.x close methods """ @@ -45,7 +60,9 @@ async def _close_redis(connection): await connection.close(close_connection_pool=True) -def decode_hosts(hosts): +def decode_hosts( + hosts: typing.Union[typing.Iterable[typing.Any], str, bytes, None], +) -> typing.List[typing.Dict[typing.Any, typing.Any]]: """ Takes the value of the "hosts" argument and returns a list of kwargs to use for the Redis connection constructor. @@ -60,7 +77,7 @@ def decode_hosts(hosts): ) # Decode each hosts entry into a kwargs dict - result = [] + result: typing.List[typing.Dict[typing.Any, typing.Any]] = [] for entry in hosts: if isinstance(entry, dict): result.append(entry) @@ -71,7 +88,7 @@ def decode_hosts(hosts): return result -def create_pool(host): +def create_pool(host: typing.Dict[str, typing.Any]) -> "ConnectionPool": """ Takes the value of the "host" argument and returns a suited connection pool to the corresponding redis instance. @@ -86,10 +103,10 @@ def create_pool(host): if master_name is not None: sentinels = host.pop("sentinels") sentinel_kwargs = host.pop("sentinel_kwargs", None) - return aioredis.sentinel.SentinelConnectionPool( + return redis_sentinel.SentinelConnectionPool( master_name, - aioredis.sentinel.Sentinel(sentinels, sentinel_kwargs=sentinel_kwargs), - **host + redis_sentinel.Sentinel(sentinels, sentinel_kwargs=sentinel_kwargs), + **host, ) return aioredis.ConnectionPool(**host) diff --git a/setup.cfg b/setup.cfg index 3888deb..b5a4092 100644 --- a/setup.cfg +++ b/setup.cfg @@ -12,3 +12,10 @@ max-line-length = 119 [isort] profile = black known_first_party = channels, asgiref, channels_redis, daphne + +[mypy] +warn_unused_ignores = True +strict = True + +[mypy-tests.*] +ignore_errors = True diff --git a/tox.ini b/tox.ini index ffdfc99..c3632a3 100644 --- a/tox.ini +++ b/tox.ini @@ -10,13 +10,15 @@ usedevelop = true extras = tests commands = pytest -v {posargs} + mypy --config-file=tox.ini channels_redis deps = ch4: channels>=4.0,<5 chmain: https://github.com/django/channels/archive/main.tar.gz redis5: redis>=5.0,<6 redis6: redis>=6.0,<7 redismain: https://github.com/redis/redis-py/archive/master.tar.gz - + mypy>=1.9.0 + [testenv:qa] skip_install=true deps = @@ -27,3 +29,7 @@ commands = flake8 channels_redis tests black --check channels_redis tests isort --check-only --diff channels_redis tests + + +[mypy] +disallow_untyped_defs = False