Skip to content
Open
Show file tree
Hide file tree
Changes from 1 commit
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
102 changes: 60 additions & 42 deletions channels_redis/core.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,6 +19,13 @@
create_pool,
decode_hosts,
)
from typing import TYPE_CHECKING, Dict, List, Tuple, Union, Optional

if TYPE_CHECKING:
from redis.asyncio.connection import ConnectionPool
from redis.asyncio.client import Redis
Copy link
Author

@iwakitakuma33 iwakitakuma33 May 29, 2025

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@amirreza8002
Shouldn't I also use Redis and ConnectionPool?

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

those are just classes

what i mean is, if you see something like Redis().some_method(), there is a change the type hint on some_method is wrong, and we shouldn't just put the same thing in this codebase

from .core import RedisChannelLayer

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

why is this imported?
i don't see any use for it

from typing_extensions import Buffer

logger = logging.getLogger(__name__)

Expand All @@ -32,23 +39,27 @@ class ChannelLock:
"""

def __init__(self):
self.locks = collections.defaultdict(asyncio.Lock)
self.wait_counts = collections.defaultdict(int)
self.locks: "collections.defaultdict[str, asyncio.Lock]" = (
collections.defaultdict(asyncio.Lock)
)
self.wait_counts: "collections.defaultdict[str, int]" = collections.defaultdict(
int
)

async def acquire(self, channel):
async def acquire(self, channel: str) -> bool:
"""
Acquire the lock for the given channel.
"""
self.wait_counts[channel] += 1
return await self.locks[channel].acquire()

def locked(self, channel):
def locked(self, channel: str) -> bool:
"""
Return ``True`` if the lock for the given channel is acquired.
"""
return self.locks[channel].locked()

def release(self, channel):
def release(self, channel: str):
"""
Release the lock for the given channel.
"""
Expand All @@ -73,12 +84,12 @@ def put_nowait(self, item):


class RedisLoopLayer:
def __init__(self, channel_layer):
def __init__(self, channel_layer: "RedisChannelLayer"):
self._lock = asyncio.Lock()
self.channel_layer = channel_layer
self._connections = {}
self._connections: "Dict[int, Redis]" = {}

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

i would just import typing.Dict and go with that instead of strings
same thing for other places where this happens

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

i don't mean to do self._connections: typing.Dict[int, "Redis"]
but rather

from typing import Dict
self._connections: Dict[int, "Redis"]


def get_connection(self, index):
def get_connection(self, index: int) -> "Redis":
if index not in self._connections:
pool = self.channel_layer.create_pool(index)
self._connections[index] = aioredis.Redis(connection_pool=pool)
Expand Down Expand Up @@ -134,7 +145,7 @@ def __init__(
symmetric_encryption_keys=symmetric_encryption_keys,
)
# Cached redis connection pools and the event loop they are from
self._layers = {}
self._layers: "Dict[asyncio.AbstractEventLoop, RedisLoopLayer]" = {}
# Normal channels choose a host index by cycling through the available hosts
self._receive_index_generator = itertools.cycle(range(len(self.hosts)))
self._send_index_generator = itertools.cycle(range(len(self.hosts)))
Expand All @@ -143,33 +154,33 @@ def __init__(
# Number of coroutines trying to receive right now
self.receive_count = 0
# The receive lock
self.receive_lock = None
self.receive_lock: "Optional[asyncio.Lock]" = None

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

typing.Optional would be a better
same thing for other places where this happens

# Event loop they are trying to receive on
self.receive_event_loop = None
self.receive_event_loop: "Optional[asyncio.AbstractEventLoop]" = None
# Buffered messages by process-local channel name
self.receive_buffer = collections.defaultdict(
functools.partial(BoundedQueue, self.capacity)
self.receive_buffer: "collections.defaultdict[str, BoundedQueue]" = (
collections.defaultdict(functools.partial(BoundedQueue, self.capacity))
)
# Detached channel cleanup tasks
self.receive_cleaners = []
self.receive_cleaners: "List[asyncio.Task]" = []

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

typing.List

# Per-channel cleanup locks to prevent a receive starting and moving
# a message back into the main queue before its cleanup has completed
self.receive_clean_locks = ChannelLock()

def create_pool(self, index):
def create_pool(self, index: int) -> "ConnectionPool":
return create_pool(self.hosts[index])

### Channel layer API ###

extensions = ["groups", "flush"]

async def send(self, channel, message):
async def send(self, channel: str, message):
"""
Send a message onto a (general or specific) channel.
"""
# Typecheck
assert isinstance(message, dict), "message is not a dict"
assert self.valid_channel_name(channel), "Channel name not valid"
assert self.require_valid_channel_name(channel), "Channel name not valid"
# Make sure the message does not contain reserved keys
assert "__asgi_channel__" not in message
# If it's a process-local channel, strip off local part and stick full name in message
Expand Down Expand Up @@ -203,13 +214,15 @@ async def send(self, channel, message):
await connection.zadd(channel_key, {self.serialize(message): time.time()})
await connection.expire(channel_key, int(self.expiry))

def _backup_channel_name(self, channel):
def _backup_channel_name(self, channel: str) -> str:
"""
Construct the key used as a backup queue for the given channel.
"""
return channel + "$inflight"

async def _brpop_with_clean(self, index, channel, timeout):
async def _brpop_with_clean(
self, index: int, channel: str, timeout: "Union[int, float, bytes, str]"
):
"""
Perform a Redis BRPOP and manage the backup processing queue.
In case of cancellation, make sure the message is not lost.
Expand Down Expand Up @@ -240,23 +253,23 @@ async def _brpop_with_clean(self, index, channel, timeout):

return member

async def _clean_receive_backup(self, index, channel):
async def _clean_receive_backup(self, index: int, channel: str):
"""
Pop the oldest message off the channel backup queue.
The result isn't interesting as it was already processed.
"""
connection = self.connection(index)
await connection.zpopmin(self._backup_channel_name(channel))

async def receive(self, channel):
async def receive(self, channel: str):
"""
Receive the first message that arrives on the channel.
If more than one coroutine waits on the same channel, the first waiter
will be given the message when it arrives.
"""
# Make sure the channel name is valid then get the non-local part
# and thus its index
assert self.valid_channel_name(channel)
assert self.require_valid_channel_name(channel)
if "!" in channel:
real_channel = self.non_local_name(channel)
assert real_channel.endswith(
Expand Down Expand Up @@ -372,12 +385,14 @@ async def receive(self, channel):
# Do a plain direct receive
return (await self.receive_single(channel))[1]

async def receive_single(self, channel):
async def receive_single(self, channel: str) -> "Tuple":
"""
Receives a single message off of the channel and returns it.
"""
# Check channel name
assert self.valid_channel_name(channel, receive=True), "Channel name invalid"
assert self.require_valid_channel_name(
channel, receive=True
), "Channel name invalid"
# Work out the connection to use
if "!" in channel:
assert channel.endswith("!")
Expand Down Expand Up @@ -408,7 +423,7 @@ async def receive_single(self, channel):
)
self.receive_cleaners.append(cleaner)

def _cleanup_done(cleaner):
def _cleanup_done(cleaner: "asyncio.Task"):
self.receive_cleaners.remove(cleaner)
self.receive_clean_locks.release(channel_key)

Expand All @@ -427,7 +442,7 @@ def _cleanup_done(cleaner):
del message["__asgi_channel__"]
return channel, message

async def new_channel(self, prefix="specific"):
async def new_channel(self, prefix: str = "specific") -> str:
"""
Returns a new channel name that can be used by something in our
process as a specific channel.
Expand Down Expand Up @@ -477,13 +492,13 @@ async def wait_received(self):

### Groups extension ###

async def group_add(self, group, channel):
async def group_add(self, group: str, channel: str):
"""
Adds the channel name to a group.
"""
# Check the inputs
assert self.valid_group_name(group), "Group name not valid"
assert self.valid_channel_name(channel), "Channel name not valid"
assert self.require_valid_group_name(group), True
assert self.require_valid_channel_name(channel), True

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

i think this name change should be in another commit
same thing for other places where this happens

# Get a connection to the right shard
group_key = self._group_key(group)
connection = self.connection(self.consistent_hash(group))
Expand All @@ -493,22 +508,22 @@ async def group_add(self, group, channel):
# it at this point is guaranteed to expire before that
await connection.expire(group_key, self.group_expiry)

async def group_discard(self, group, channel):
async def group_discard(self, group: str, channel: str):
"""
Removes the channel from the named group if it is in the group;
does nothing otherwise (does not error)
"""
assert self.valid_group_name(group), "Group name not valid"
assert self.valid_channel_name(channel), "Channel name not valid"
assert self.require_valid_group_name(group), "Group name not valid"
assert self.require_valid_channel_name(channel), "Channel name not valid"
key = self._group_key(group)
connection = self.connection(self.consistent_hash(group))
await connection.zrem(key, channel)

async def group_send(self, group, message):
async def group_send(self, group: str, message):
"""
Sends a message to the entire group.
"""
assert self.valid_group_name(group), "Group name not valid"
assert self.require_valid_group_name(group), "Group name not valid"
# Retrieve list of all channel names
key = self._group_key(group)
connection = self.connection(self.consistent_hash(group))
Expand Down Expand Up @@ -573,7 +588,12 @@ async def group_send(self, group, message):
channels_over_capacity = await connection.eval(
group_send_lua, len(channel_redis_keys), *channel_redis_keys, *args
)
if channels_over_capacity > 0:
_channels_over_capacity = -1
try:
_channels_over_capacity = float(channels_over_capacity)
except Exception:
pass
if _channels_over_capacity > 0:

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

am not sure what's happening here
but i think it should be in another commit

Copy link
Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

connection.eval is inferred to return await[str] | str.
By converting it to float, it can be compared with 0.

logger.info(
"%s of %s channels over capacity in group %s",
channels_over_capacity,
Expand Down Expand Up @@ -631,37 +651,35 @@ def _map_channel_keys_to_connection(self, channel_names, message):
channel_key_to_capacity,
)

def _group_key(self, group):
def _group_key(self, group: str) -> bytes:
"""
Common function to make the storage key for the group.
"""
return f"{self.prefix}:group:{group}".encode("utf8")

### Serialization ###

def serialize(self, message):
def serialize(self, message) -> bytes:
"""
Serializes message to a byte string.
"""
return self._serializer.serialize(message)

def deserialize(self, message):
def deserialize(self, message: bytes):
"""
Deserializes from a byte string.
"""
return self._serializer.deserialize(message)

### Internal functions ###

def consistent_hash(self, value):
def consistent_hash(self, value: "Union[str, Buffer]") -> int:
return _consistent_hash(value, self.ring_size)

def __str__(self):
return f"{self.__class__.__name__}(hosts={self.hosts})"

### Connection handling ###

def connection(self, index):
def connection(self, index: int) -> "Redis":
"""
Returns the correct connection for the index given.
Lazily instantiates pools.
Expand Down
Loading