diff --git a/.github/workflows/tests.yml b/.github/workflows/tests.yml index 9cfe431..5adcb50 100644 --- a/.github/workflows/tests.yml +++ b/.github/workflows/tests.yml @@ -42,6 +42,7 @@ jobs: env: VALKEY_SENTINEL_QUORUM: "1" VALKEY_SENTINEL_PASSWORD: channels_valkey + VALKEY_PRIMAY_SET: "mymaster" steps: - uses: actions/checkout@v4 diff --git a/channels_valkey/core.py b/channels_valkey/core.py index cacf8c5..b3c2925 100644 --- a/channels_valkey/core.py +++ b/channels_valkey/core.py @@ -166,22 +166,33 @@ async def send(self, channel, message): """ Send a message onto a (general or specific) channel. """ + await self.send_bulk(channel, (message,)) + + async def send_bulk(self, channel, messages): # Typecheck - assert isinstance(message, dict), "message is not a dict" assert self.valid_channel_name(channel), "Channel name not valid" - # Make sure the message does not contain reserved keys - assert "__asgi_channel__" not in message # If it's a process-local channel, strip off local part and stick full name in message channel_non_local_name = channel - if "!" in channel: - message = dict(message.items()) - message["__asgi_channel__"] = channel + process_local = "!" in channel + if process_local: channel_non_local_name = self.non_local_name(channel) + + now = time.time() + mapping = {} + for message in messages: + assert isinstance(message, dict), "message is not a dict" + assert "__asgi_channel__" not in message + if process_local: + message = dict(message.items()) + message["__asgi_channel__"] = channel + + mapping[self.serialize(message)] = now + # Write out message into expiring key (avoids big items in list) channel_key = self.prefix + channel_non_local_name # Pick a connection to the right server - consistent for specific # channels, random for general channels - if "!" in channel: + if process_local: index = self.consistent_hash(channel) else: index = next(self._send_index_generator) @@ -193,13 +204,12 @@ async def send(self, channel, message): # Check the length of the list before send # This can allow the list to leak slightly over capacity, but that's fine. - if await connection.zcount(channel_key, "-inf", "+inf") >= self.get_capacity( - channel - ): + current_length = await connection.zcount(channel_key, "-inf", "+inf") + if current_length + len(messages) > self.get_capacity(channel): raise ChannelFull() # Push onto the list then set it to expire in case it's not consumed - await connection.zadd(channel_key, {self.serialize(message): time.time()}) + await connection.zadd(channel_key, mapping) await connection.expire(channel_key, int(self.expiry)) def _backup_channel_name(self, channel): @@ -503,10 +513,7 @@ async def group_discard(self, group, channel): connection = self.connection(self.consistent_hash(group)) await connection.zrem(key, channel) - async def group_send(self, group, message): - """ - Sends a message to the entire group. - """ + async def _get_group_connection_and_channels(self, group): assert self.valid_group_name(group), "Group name not valid" # Retrieve list of all channel names key = self._group_key(group) @@ -518,11 +525,31 @@ async def group_send(self, group, message): channel_names = [x.decode("utf8") for x in await connection.zrange(key, 0, -1)] + return connection, channel_names + + async def _exec_group_lua_scripts( + self, conn_idx, group, channel_valkey_keys, channel_names, script, *args + ): + connection = self.connection(conn_idx) + channels_over_capacity = await connection.eval( + script, len(channel_valkey_keys), *channel_valkey_keys, *args + ) + if channels_over_capacity > 0: + logger.info( + f"{channels_over_capacity} of {len(channel_names)} channels over capacity in group {group}" + ) + + async def group_send(self, group, message): + """ + Sends messages to the entire group. + """ + connection, channel_names = await self._get_group_connection_and_channels(group) + ( connection_to_channel_keys, channel_keys_to_message, channel_keys_to_capacity, - ) = self._map_channel_keys_to_connection(channel_names, message) + ) = self._map_channel_keys_to_connection(channel_names, (message,)) for connection_index, channel_valkey_keys in connection_to_channel_keys.items(): # Discard old messages based on expiry @@ -555,7 +582,7 @@ async def group_send(self, group, message): # We need to filter the messages to keep those related to the connection args = [ - channel_keys_to_message[channel_key] + channel_keys_to_message[channel_key][0] for channel_key in channel_valkey_keys ] @@ -567,20 +594,84 @@ async def group_send(self, group, message): args += [time.time(), self.expiry] - # channel_keys does not contain a single valkey key more than once - connection = self.connection(connection_index) - channels_over_capacity = await connection.eval( - group_send_lua, len(channel_valkey_keys), *channel_valkey_keys, *args + await self._exec_group_lua_scripts( + connection_index, + group, + channel_valkey_keys, + channel_names, + group_send_lua, + args, ) - if channels_over_capacity > 0: - logger.info( - "%s of %s channels over capacity in group %s", - channels_over_capacity, - len(channel_names), - group, + + async def group_send_bulk(self, group, messages): + """ + Send multiple messages in bulk to the entire group. + The `messages` argument should be an iterable of dicts. + """ + connection, channel_names = await self._get_group_connection_and_channels(group) + + ( + connection_to_channel_keys, + channel_keys_to_message, + channel_keys_to_capacity, + ) = self._map_channel_keys_to_connection(channel_names, messages) + + for connection_index, channel_valkey_keys in connection_to_channel_keys.items(): + pipe = connection.pipeline() + for key in channel_valkey_keys: + pipe.zremrangebyscore( + key, min=0, max=int(time.time()) - int(self.expiry) ) + await pipe.execute() + + # Create a LUA script specific for this connection. + # Make sure to use the message list specific to this channel, it is + # stored in channel_to_message dict and each message contains the + # __asgi_channel__ key. + + group_send_lua = """ + local over_capacity = 0 + local num_messages = tonumber(ARGV[#ARGV - 2]) + local current_time = ARGV[#ARGV - 1] + local expiry = ARGV[#ARGV] + for i=1,#KEYS do + if server.call('ZCOUNT', KEYS[i], '-inf', '+inf') < tonumber(ARGV[i + #KEYS * num_messages]) then + local messages = {} + for j=num_messages * (i - 1) + 1, num_messages * i do + table.insert(messages, current_time) + table.insert(messages, ARGV[j]) + end + server.call('ZADD', KEYS[i], unpack(messages)) + server.call('EXPIRE', KEYS[i], expiry) + else + over_capacity = over_capacity + 1 + end + end + return over_capacity + """ - def _map_channel_keys_to_connection(self, channel_names, message): + args = [] + + for channel_key in channel_valkey_keys: + args += channel_keys_to_message[channel_key] + + args += [ + channel_keys_to_capacity[channel_key] + for channel_key in channel_valkey_keys + ] + + args += [len(messages), time.time(), self.expiry] + + await self._exec_group_lua_scripts( + connection_index, + group, + channel_valkey_keys, + channel_names, + group_send_lua, + args, + ) + + def _map_channel_keys_to_connection(self, channel_names, messages): """ For a list of channel names, GET @@ -595,7 +686,7 @@ def _map_channel_keys_to_connection(self, channel_names, message): # Connection dict keyed by index to list of valkey keys mapped on that index connection_to_channel_keys = collections.defaultdict(list) # Message dict maps valkey key to the message that needs to be sent on that key - channel_key_to_message = dict() + channel_key_to_message = collections.defaultdict(list) # Channel key mapped to its capacity channel_key_to_capacity = dict() @@ -609,20 +700,23 @@ def _map_channel_keys_to_connection(self, channel_names, message): # Have we come across the same valkey key? if channel_key not in channel_key_to_message: # If not, fill the corresponding dicts - message = dict(message.items()) - message["__asgi_channel__"] = [channel] - channel_key_to_message[channel_key] = message + for message in messages: + message = dict(message.items()) + message["__asgi_channel__"] = [channel] + channel_key_to_message[channel_key].append(message) channel_key_to_capacity[channel_key] = self.get_capacity(channel) idx = self.consistent_hash(channel_non_local_name) connection_to_channel_keys[idx].append(channel_key) else: # Yes, Append the channel in message dict - channel_key_to_message[channel_key]["__asgi_channel__"].append(channel) + for message in channel_key_to_message[channel_key]: + message["__asgi_channel__"].append(channel) # Now that we know what message needs to be sent on a valkey key we serialize it for key, value in channel_key_to_message.items(): # Serialize the message stored for each valkey key - channel_key_to_message[key] = self.serialize(value) + for idx, message in enumerate(value): + channel_key_to_message[key][idx] = self.serialize(message) return ( connection_to_channel_keys, diff --git a/tests/test_core.py b/tests/test_core.py index f79e83a..112eaf7 100644 --- a/tests/test_core.py +++ b/tests/test_core.py @@ -1,4 +1,5 @@ import asyncio +import collections import random import async_timeout @@ -125,6 +126,25 @@ async def listen2(): async_to_sync(channel_layer.flush)() +@pytest.mark.asyncio +async def test_send_multiple(channel_layer): + messages = [ + {"type": "test.message.1"}, + {"type": "test.message.2"}, + {"type": "test.message.3"}, + ] + + await channel_layer.send_bulk("test-channel-1", messages) + + expected = {"test.message.1", "test.message.2", "test.message.3"} + received = set() + for _ in range(3): + msg = await channel_layer.receive("test-channel-1") + received.add(msg["type"]) + + assert received == expected + + @pytest.mark.asyncio async def test_send_capacity(channel_layer): """ @@ -225,6 +245,40 @@ async def test_groups_basic(channel_layer): await channel_layer.flush() +@pytest.mark.asyncio +async def test_group_multiple(channel_layer): + """ + Tests basic group operation + """ + channel_name1 = await channel_layer.new_channel(prefix="test-gr-chan-1") + channel_name2 = await channel_layer.new_channel(prefix="test-gr-chan-2") + channel_name3 = await channel_layer.new_channel(prefix="test-gr-chan-3") + await channel_layer.group_add("test-group", channel_name1) + await channel_layer.group_add("test-group", channel_name2) + await channel_layer.group_add("test-group", channel_name3) + + messages = [ + {"type": "message.1"}, + {"type": "message.2"}, + {"type": "message.3"}, + ] + + expected = {msg["type"] for msg in messages} + + await channel_layer.group_send_bulk("test-group", messages) + + received = collections.defaultdict(set) + + for channel_name in (channel_name1, channel_name2, channel_name3): + async with async_timeout.timeout(1): + for _ in range(len(messages)): + received[channel_name].add( + (await channel_layer.receive(channel_name))["type"] + ) + + assert received[channel_name] == expected + + @pytest.mark.asyncio async def test_groups_channel_full(channel_layer): """ diff --git a/tox.ini b/tox.ini index 59714dc..5510dd6 100644 --- a/tox.ini +++ b/tox.ini @@ -10,7 +10,7 @@ extras = cryptography runner = uv-venv-lock-runner commands = - uv run pytest -v {posargs} + uv run pytest -xsvv {posargs} deps = ch30: channels>=3.0,<3.1