Skip to content

Commit 4cb9b90

Browse files
committed
Add send_multiple method to core layer
1 parent 13cef45 commit 4cb9b90

File tree

2 files changed

+47
-11
lines changed

2 files changed

+47
-11
lines changed

channels_redis/core.py

Lines changed: 28 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -180,22 +180,39 @@ async def send(self, channel, message):
180180
"""
181181
Send a message onto a (general or specific) channel.
182182
"""
183+
await self.send_multiple(channel, (message,))
184+
185+
async def send_multiple(self, channel, messages):
186+
"""
187+
Send multiple messages at once onto a (general or specific) channel.
188+
"""
183189
# Typecheck
184-
assert isinstance(message, dict), "message is not a dict"
185190
assert self.valid_channel_name(channel), "Channel name not valid"
186-
# Make sure the message does not contain reserved keys
187-
assert "__asgi_channel__" not in message
191+
assert hasattr(messages, "__iter__"), "messages is not an iterable"
192+
188193
# If it's a process-local channel, strip off local part and stick full name in message
189194
channel_non_local_name = channel
190-
if "!" in channel:
191-
message = dict(message.items())
192-
message["__asgi_channel__"] = channel
195+
process_local = "!" in channel
196+
if process_local:
193197
channel_non_local_name = self.non_local_name(channel)
198+
199+
now = time.time()
200+
mapping = {}
201+
for message in messages:
202+
assert isinstance(message, dict), "message is not a dict"
203+
# Make sure the message does not contain reserved keys
204+
assert "__asgi_channel__" not in message
205+
if process_local:
206+
message = dict(message.items())
207+
message["__asgi_channel__"] = channel
208+
209+
mapping[self.serialize(message)] = now
210+
194211
# Write out message into expiring key (avoids big items in list)
195212
channel_key = self.prefix + channel_non_local_name
196213
# Pick a connection to the right server - consistent for specific
197214
# channels, random for general channels
198-
if "!" in channel:
215+
if process_local:
199216
index = self.consistent_hash(channel)
200217
else:
201218
index = next(self._send_index_generator)
@@ -207,13 +224,13 @@ async def send(self, channel, message):
207224

208225
# Check the length of the list before send
209226
# This can allow the list to leak slightly over capacity, but that's fine.
210-
if await connection.zcount(channel_key, "-inf", "+inf") >= self.get_capacity(
211-
channel
212-
):
227+
current_length = await connection.zcount(channel_key, "-inf", "+inf")
228+
229+
if current_length + len(messages) > self.get_capacity(channel):
213230
raise ChannelFull()
214231

215232
# Push onto the list then set it to expire in case it's not consumed
216-
await connection.zadd(channel_key, {self.serialize(message): time.time()})
233+
await connection.zadd(channel_key, mapping)
217234
await connection.expire(channel_key, int(self.expiry))
218235

219236
def _backup_channel_name(self, channel):

tests/test_core.py

Lines changed: 19 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -125,6 +125,25 @@ async def listen2():
125125
async_to_sync(channel_layer.flush)()
126126

127127

128+
@pytest.mark.asyncio
129+
async def test_send_multiple(channel_layer):
130+
messsages = [
131+
{"type": "test.message.1"},
132+
{"type": "test.message.2"},
133+
{"type": "test.message.3"},
134+
]
135+
136+
await channel_layer.send_multiple("test-channel-1", messsages)
137+
138+
expected = {"test.message.1", "test.message.2", "test.message.3"}
139+
received = set()
140+
for _ in range(3):
141+
msg = await channel_layer.receive("test-channel-1")
142+
received.add(msg["type"])
143+
144+
assert received == expected
145+
146+
128147
@pytest.mark.asyncio
129148
async def test_send_capacity(channel_layer):
130149
"""

0 commit comments

Comments
 (0)