Skip to content
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions .github/workflows/tests.yml
Original file line number Diff line number Diff line change
Expand Up @@ -42,6 +42,7 @@ jobs:
env:
VALKEY_SENTINEL_QUORUM: "1"
VALKEY_SENTINEL_PASSWORD: channels_valkey
VALKEY_PRIMAY_SET: "mymaster"

steps:
- uses: actions/checkout@v4
Expand Down
162 changes: 128 additions & 34 deletions channels_valkey/core.py
Original file line number Diff line number Diff line change
Expand Up @@ -166,22 +166,33 @@ async def send(self, channel, message):
"""
Send a message onto a (general or specific) channel.
"""
await self.send_bulk(channel, (message,))

async def send_bulk(self, channel, messages):
# Typecheck
assert isinstance(message, dict), "message is not a dict"
assert self.valid_channel_name(channel), "Channel name not valid"
# Make sure the message does not contain reserved keys
assert "__asgi_channel__" not in message
# If it's a process-local channel, strip off local part and stick full name in message
channel_non_local_name = channel
if "!" in channel:
message = dict(message.items())
message["__asgi_channel__"] = channel
process_local = "!" in channel
if process_local:
channel_non_local_name = self.non_local_name(channel)

now = time.time()
mapping = {}
for message in messages:
assert isinstance(message, dict), "message is not a dict"
assert "__asgi_channel__" not in message
if process_local:
message = dict(message.items())
message["__asgi_channel__"] = channel

mapping[self.serialize(message)] = now

# Write out message into expiring key (avoids big items in list)
channel_key = self.prefix + channel_non_local_name
# Pick a connection to the right server - consistent for specific
# channels, random for general channels
if "!" in channel:
if process_local:
index = self.consistent_hash(channel)
else:
index = next(self._send_index_generator)
Expand All @@ -193,13 +204,12 @@ async def send(self, channel, message):

# Check the length of the list before send
# This can allow the list to leak slightly over capacity, but that's fine.
if await connection.zcount(channel_key, "-inf", "+inf") >= self.get_capacity(
channel
):
current_length = await connection.zcount(channel_key, "-inf", "+inf")
if current_length + len(messages) > self.get_capacity(channel):
raise ChannelFull()

# Push onto the list then set it to expire in case it's not consumed
await connection.zadd(channel_key, {self.serialize(message): time.time()})
await connection.zadd(channel_key, mapping)
await connection.expire(channel_key, int(self.expiry))

def _backup_channel_name(self, channel):
Expand Down Expand Up @@ -503,10 +513,7 @@ async def group_discard(self, group, channel):
connection = self.connection(self.consistent_hash(group))
await connection.zrem(key, channel)

async def group_send(self, group, message):
"""
Sends a message to the entire group.
"""
async def _get_group_connection_and_channels(self, group):
assert self.valid_group_name(group), "Group name not valid"
# Retrieve list of all channel names
key = self._group_key(group)
Expand All @@ -518,11 +525,31 @@ async def group_send(self, group, message):

channel_names = [x.decode("utf8") for x in await connection.zrange(key, 0, -1)]

return connection, channel_names

async def _exec_group_lua_scripts(
self, conn_idx, group, channel_valkey_keys, channel_names, script, *args
):
connection = self.connection(conn_idx)
channels_over_capacity = await connection.eval(
script, len(channel_valkey_keys), *channel_valkey_keys, *args
)
if channels_over_capacity > 0:
logger.info(
f"{channels_over_capacity} of {len(channel_names)} channels over capacity in group {group}"
)

async def group_send(self, group, message):
"""
Sends messages to the entire group.
"""
connection, channel_names = await self._get_group_connection_and_channels(group)

(
connection_to_channel_keys,
channel_keys_to_message,
channel_keys_to_capacity,
) = self._map_channel_keys_to_connection(channel_names, message)
) = self._map_channel_keys_to_connection(channel_names, (message,))

for connection_index, channel_valkey_keys in connection_to_channel_keys.items():
# Discard old messages based on expiry
Expand Down Expand Up @@ -555,7 +582,7 @@ async def group_send(self, group, message):

# We need to filter the messages to keep those related to the connection
args = [
channel_keys_to_message[channel_key]
channel_keys_to_message[channel_key][0]
for channel_key in channel_valkey_keys
]

Expand All @@ -567,20 +594,84 @@ async def group_send(self, group, message):

args += [time.time(), self.expiry]

# channel_keys does not contain a single valkey key more than once
connection = self.connection(connection_index)
channels_over_capacity = await connection.eval(
group_send_lua, len(channel_valkey_keys), *channel_valkey_keys, *args
await self._exec_group_lua_scripts(
connection_index,
group,
channel_valkey_keys,
channel_names,
group_send_lua,
args,
)
if channels_over_capacity > 0:
logger.info(
"%s of %s channels over capacity in group %s",
channels_over_capacity,
len(channel_names),
group,

async def group_send_bulk(self, group, messages):
"""
Send multiple messages in bulk to the entire group.
The `messages` argument should be an iterable of dicts.
"""
connection, channel_names = await self._get_group_connection_and_channels(group)

(
connection_to_channel_keys,
channel_keys_to_message,
channel_keys_to_capacity,
) = self._map_channel_keys_to_connection(channel_names, messages)

for connection_index, channel_valkey_keys in connection_to_channel_keys.items():
pipe = connection.pipeline()
for key in channel_valkey_keys:
pipe.zremrangebyscore(
key, min=0, max=int(time.time()) - int(self.expiry)
)
await pipe.execute()

# Create a LUA script specific for this connection.
# Make sure to use the message list specific to this channel, it is
# stored in channel_to_message dict and each message contains the
# __asgi_channel__ key.

group_send_lua = """
local over_capacity = 0
local num_messages = tonumber(ARGV[#ARGV - 2])
local current_time = ARGV[#ARGV - 1]
local expiry = ARGV[#ARGV]
for i=1,#KEYS do
if server.call('ZCOUNT', KEYS[i], '-inf', '+inf') < tonumber(ARGV[i + #KEYS * num_messages]) then
local messages = {}
for j=num_messages * (i - 1) + 1, num_messages * i do
table.insert(messages, current_time)
table.insert(messages, ARGV[j])
end
server.call('ZADD', KEYS[i], unpack(messages))
server.call('EXPIRE', KEYS[i], expiry)
else
over_capacity = over_capacity + 1
end
end
return over_capacity
"""

def _map_channel_keys_to_connection(self, channel_names, message):
args = []

for channel_key in channel_valkey_keys:
args += channel_keys_to_message[channel_key]

args += [
channel_keys_to_capacity[channel_key]
for channel_key in channel_valkey_keys
]

args += [len(messages), time.time(), self.expiry]

await self._exec_group_lua_scripts(
connection_index,
group,
channel_valkey_keys,
channel_names,
group_send_lua,
args,
)

def _map_channel_keys_to_connection(self, channel_names, messages):
"""
For a list of channel names, GET

Expand All @@ -595,7 +686,7 @@ def _map_channel_keys_to_connection(self, channel_names, message):
# Connection dict keyed by index to list of valkey keys mapped on that index
connection_to_channel_keys = collections.defaultdict(list)
# Message dict maps valkey key to the message that needs to be sent on that key
channel_key_to_message = dict()
channel_key_to_message = collections.defaultdict(list)
# Channel key mapped to its capacity
channel_key_to_capacity = dict()

Expand All @@ -609,20 +700,23 @@ def _map_channel_keys_to_connection(self, channel_names, message):
# Have we come across the same valkey key?
if channel_key not in channel_key_to_message:
# If not, fill the corresponding dicts
message = dict(message.items())
message["__asgi_channel__"] = [channel]
channel_key_to_message[channel_key] = message
for message in messages:
message = dict(message.items())
message["__asgi_channel__"] = [channel]
channel_key_to_message[channel_key].append(message)
channel_key_to_capacity[channel_key] = self.get_capacity(channel)
idx = self.consistent_hash(channel_non_local_name)
connection_to_channel_keys[idx].append(channel_key)
else:
# Yes, Append the channel in message dict
channel_key_to_message[channel_key]["__asgi_channel__"].append(channel)
for message in channel_key_to_message[channel_key]:
message["__asgi_channel__"].append(channel)

# Now that we know what message needs to be sent on a valkey key we serialize it
for key, value in channel_key_to_message.items():
# Serialize the message stored for each valkey key
channel_key_to_message[key] = self.serialize(value)
for idx, message in enumerate(value):
channel_key_to_message[key][idx] = self.serialize(message)

return (
connection_to_channel_keys,
Expand Down
54 changes: 54 additions & 0 deletions tests/test_core.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,5 @@
import asyncio
import collections
import random

import async_timeout
Expand Down Expand Up @@ -125,6 +126,25 @@ async def listen2():
async_to_sync(channel_layer.flush)()


@pytest.mark.asyncio
async def test_send_multiple(channel_layer):
messages = [
{"type": "test.message.1"},
{"type": "test.message.2"},
{"type": "test.message.3"},
]

await channel_layer.send_bulk("test-channel-1", messages)

expected = {"test.message.1", "test.message.2", "test.message.3"}
received = set()
for _ in range(3):
msg = await channel_layer.receive("test-channel-1")
received.add(msg["type"])

assert received == expected


@pytest.mark.asyncio
async def test_send_capacity(channel_layer):
"""
Expand Down Expand Up @@ -225,6 +245,40 @@ async def test_groups_basic(channel_layer):
await channel_layer.flush()


@pytest.mark.asyncio
async def test_group_multiple(channel_layer):
"""
Tests basic group operation
"""
channel_name1 = await channel_layer.new_channel(prefix="test-gr-chan-1")
channel_name2 = await channel_layer.new_channel(prefix="test-gr-chan-2")
channel_name3 = await channel_layer.new_channel(prefix="test-gr-chan-3")
await channel_layer.group_add("test-group", channel_name1)
await channel_layer.group_add("test-group", channel_name2)
await channel_layer.group_add("test-group", channel_name3)

messages = [
{"type": "message.1"},
{"type": "message.2"},
{"type": "message.3"},
]

expected = {msg["type"] for msg in messages}

await channel_layer.group_send_bulk("test-group", messages)

received = collections.defaultdict(set)

for channel_name in (channel_name1, channel_name2, channel_name3):
async with async_timeout.timeout(1):
for _ in range(len(messages)):
received[channel_name].add(
(await channel_layer.receive(channel_name))["type"]
)

assert received[channel_name] == expected


@pytest.mark.asyncio
async def test_groups_channel_full(channel_layer):
"""
Expand Down
2 changes: 1 addition & 1 deletion tox.ini
Original file line number Diff line number Diff line change
Expand Up @@ -10,7 +10,7 @@ extras =
cryptography
runner = uv-venv-lock-runner
commands =
uv run pytest -v {posargs}
uv run pytest -xsvv {posargs}

deps =
ch30: channels>=3.0,<3.1
Expand Down
Loading