Skip to content

Commit 4c88088

Browse files
astutejoecarltongibson
authored andcommitted
Used a sorted set to guarantee message expiration.
1 parent 2ae5ad3 commit 4c88088

File tree

2 files changed

+205
-13
lines changed

2 files changed

+205
-13
lines changed

channels_redis/core.py

Lines changed: 39 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -5,6 +5,7 @@
55
import hashlib
66
import itertools
77
import logging
8+
import random
89
import sys
910
import time
1011
import types
@@ -300,12 +301,18 @@ async def send(self, channel, message):
300301
else:
301302
index = next(self._send_index_generator)
302303
async with self.connection(index) as connection:
304+
# Discard old messages based on expiry
305+
await connection.zremrangebyscore(
306+
channel_key, min=0, max=int(time.time()) - int(self.expiry)
307+
)
308+
303309
# Check the length of the list before send
304310
# This can allow the list to leak slightly over capacity, but that's fine.
305-
if await connection.llen(channel_key) >= self.get_capacity(channel):
311+
if await connection.zcount(channel_key) >= self.get_capacity(channel):
306312
raise ChannelFull()
313+
307314
# Push onto the list then set it to expire in case it's not consumed
308-
await connection.lpush(channel_key, self.serialize(message))
315+
await connection.zadd(channel_key, time.time(), self.serialize(message))
309316
await connection.expire(channel_key, int(self.expiry))
310317

311318
def _backup_channel_name(self, channel):
@@ -323,9 +330,9 @@ async def _brpop_with_clean(self, index, channel, timeout):
323330
# of the main message queue in the proper order; BRPOP must *not* be called
324331
# because that would deadlock the server
325332
cleanup_script = """
326-
local backed_up = redis.call('LRANGE', ARGV[2], 0, -1)
333+
local backed_up = redis.call('ZRANGE', ARGV[2], 0, -1)
327334
for i = #backed_up, 1, -1 do
328-
redis.call('LPUSH', ARGV[1], backed_up[i])
335+
redis.call('ZADD', ARGV[1], backed_up[i])
329336
end
330337
redis.call('DEL', ARGV[2])
331338
"""
@@ -335,15 +342,23 @@ async def _brpop_with_clean(self, index, channel, timeout):
335342
# and the script executes atomically...
336343
await connection.eval(cleanup_script, keys=[], args=[channel, backup_queue])
337344
# ...and it doesn't matter here either, the message will be safe in the backup.
338-
return await connection.brpoplpush(channel, backup_queue, timeout=timeout)
345+
result = await connection.bzpopmin(channel, timeout=timeout)
346+
347+
if result is not None:
348+
_, member, timestamp = result
349+
await connection.zadd(backup_queue, float(timestamp), member)
350+
else:
351+
member = None
352+
353+
return member
339354

340355
async def _clean_receive_backup(self, index, channel):
341356
"""
342357
Pop the oldest message off the channel backup queue.
343358
The result isn't interesting as it was already processed.
344359
"""
345360
async with self.connection(index) as connection:
346-
await connection.brpop(self._backup_channel_name(channel))
361+
await connection.zpopmin(self._backup_channel_name(channel))
347362

348363
async def receive(self, channel):
349364
"""
@@ -626,25 +641,30 @@ async def group_send(self, group, message):
626641
) = self._map_channel_keys_to_connection(channel_names, message)
627642

628643
for connection_index, channel_redis_keys in connection_to_channel_keys.items():
644+
# Discard old messages based on expiry
645+
for key in channel_redis_keys:
646+
await connection.zremrangebyscore(
647+
key, min=0, max=int(time.time()) - int(self.expiry)
648+
)
629649

630650
# Create a LUA script specific for this connection.
631651
# Make sure to use the message specific to this channel, it is
632652
# stored in channel_to_message dict and contains the
633653
# __asgi_channel__ key.
634654

635-
group_send_lua = (
636-
""" local over_capacity = 0
655+
group_send_lua = """ local over_capacity = 0
637656
for i=1,#KEYS do
638-
if redis.call('LLEN', KEYS[i]) < tonumber(ARGV[i + #KEYS]) then
639-
redis.call('LPUSH', KEYS[i], ARGV[i])
657+
if redis.call('ZCOUNT', KEYS[i], '-inf', '+inf') < tonumber(ARGV[i + #KEYS]) then
658+
redis.call('ZADD', KEYS[i], %f, ARGV[i])
640659
redis.call('EXPIRE', KEYS[i], %d)
641660
else
642661
over_capacity = over_capacity + 1
643662
end
644663
end
645664
return over_capacity
646-
"""
647-
% self.expiry
665+
""" % (
666+
time.time(),
667+
self.expiry,
648668
)
649669

650670
# We need to filter the messages to keep those related to the connection
@@ -769,12 +789,18 @@ def serialize(self, message):
769789
value = msgpack.packb(message, use_bin_type=True)
770790
if self.crypter:
771791
value = self.crypter.encrypt(value)
772-
return value
792+
793+
# As we use an sorted set to expire messages we need to guarantee uniqueness, with 12 bytes.
794+
random_prefix = random.getrandbits(8 * 12).to_bytes(12, "big")
795+
return random_prefix + value
773796

774797
def deserialize(self, message):
775798
"""
776799
Deserializes from a byte string.
777800
"""
801+
# Removes the random prefix
802+
message = message[12:]
803+
778804
if self.crypter:
779805
message = self.crypter.decrypt(message, self.expiry + 10)
780806
return msgpack.unpackb(message, raw=False)

tests/test_core.py

Lines changed: 166 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -24,6 +24,36 @@
2424
]
2525

2626

27+
async def send_three_messages_with_delay(channel_name, channel_layer, delay):
28+
await channel_layer.send(channel_name, {"type": "test.message", "text": "First!"})
29+
30+
await asyncio.sleep(delay)
31+
32+
await channel_layer.send(channel_name, {"type": "test.message", "text": "Second!"})
33+
34+
await asyncio.sleep(delay)
35+
36+
await channel_layer.send(channel_name, {"type": "test.message", "text": "Third!"})
37+
38+
39+
async def group_send_three_messages_with_delay(group_name, channel_layer, delay):
40+
await channel_layer.group_send(
41+
group_name, {"type": "test.message", "text": "First!"}
42+
)
43+
44+
await asyncio.sleep(delay)
45+
46+
await channel_layer.group_send(
47+
group_name, {"type": "test.message", "text": "Second!"}
48+
)
49+
50+
await asyncio.sleep(delay)
51+
52+
await channel_layer.group_send(
53+
group_name, {"type": "test.message", "text": "Third!"}
54+
)
55+
56+
2757
@pytest.fixture()
2858
@async_generator
2959
async def channel_layer():
@@ -445,3 +475,139 @@ async def test_random_reset__client_prefix(channel_layer):
445475
random.seed(1)
446476
channel_layer_2 = RedisChannelLayer()
447477
assert channel_layer_1.client_prefix != channel_layer_2.client_prefix
478+
479+
480+
@pytest.mark.asyncio
481+
async def test_message_expiry__earliest_message_expires(channel_layer):
482+
expiry = 3
483+
delay = 2
484+
channel_layer = RedisChannelLayer(expiry=expiry)
485+
channel_name = await channel_layer.new_channel()
486+
487+
task = asyncio.ensure_future(
488+
send_three_messages_with_delay(channel_name, channel_layer, delay)
489+
)
490+
await asyncio.wait_for(task, None)
491+
492+
# the first message should have expired, we should only see the second message and the third
493+
message = await channel_layer.receive(channel_name)
494+
assert message["type"] == "test.message"
495+
assert message["text"] == "Second!"
496+
497+
message = await channel_layer.receive(channel_name)
498+
assert message["type"] == "test.message"
499+
assert message["text"] == "Third!"
500+
501+
# Make sure there's no third message even out of order
502+
with pytest.raises(asyncio.TimeoutError):
503+
async with async_timeout.timeout(1):
504+
await channel_layer.receive(channel_name)
505+
506+
507+
@pytest.mark.asyncio
508+
async def test_message_expiry__all_messages_under_expiration_time(channel_layer):
509+
expiry = 3
510+
delay = 1
511+
channel_layer = RedisChannelLayer(expiry=expiry)
512+
channel_name = await channel_layer.new_channel()
513+
514+
task = asyncio.ensure_future(
515+
send_three_messages_with_delay(channel_name, channel_layer, delay)
516+
)
517+
await asyncio.wait_for(task, None)
518+
519+
# expiry = 3, total delay under 3, all messages there
520+
message = await channel_layer.receive(channel_name)
521+
assert message["type"] == "test.message"
522+
assert message["text"] == "First!"
523+
524+
message = await channel_layer.receive(channel_name)
525+
assert message["type"] == "test.message"
526+
assert message["text"] == "Second!"
527+
528+
message = await channel_layer.receive(channel_name)
529+
assert message["type"] == "test.message"
530+
assert message["text"] == "Third!"
531+
532+
533+
@pytest.mark.asyncio
534+
async def test_message_expiry__group_send(channel_layer):
535+
expiry = 3
536+
delay = 2
537+
channel_layer = RedisChannelLayer(expiry=expiry)
538+
channel_name = await channel_layer.new_channel()
539+
540+
await channel_layer.group_add("test-group", channel_name)
541+
542+
task = asyncio.ensure_future(
543+
group_send_three_messages_with_delay("test-group", channel_layer, delay)
544+
)
545+
await asyncio.wait_for(task, None)
546+
547+
# the first message should have expired, we should only see the second message and the third
548+
message = await channel_layer.receive(channel_name)
549+
assert message["type"] == "test.message"
550+
assert message["text"] == "Second!"
551+
552+
message = await channel_layer.receive(channel_name)
553+
assert message["type"] == "test.message"
554+
assert message["text"] == "Third!"
555+
556+
# Make sure there's no third message even out of order
557+
with pytest.raises(asyncio.TimeoutError):
558+
async with async_timeout.timeout(1):
559+
await channel_layer.receive(channel_name)
560+
561+
562+
@pytest.mark.asyncio
563+
async def test_message_expiry__group_send__one_channel_expires_message(channel_layer):
564+
expiry = 3
565+
delay = 1
566+
567+
channel_layer = RedisChannelLayer(expiry=expiry)
568+
channel_1 = await channel_layer.new_channel()
569+
channel_2 = await channel_layer.new_channel(prefix="channel_2")
570+
571+
await channel_layer.group_add("test-group", channel_1)
572+
await channel_layer.group_add("test-group", channel_2)
573+
574+
# Let's give channel_1 one additional message and then sleep
575+
await channel_layer.send(channel_1, {"type": "test.message", "text": "Zero!"})
576+
await asyncio.sleep(2)
577+
578+
task = asyncio.ensure_future(
579+
group_send_three_messages_with_delay("test-group", channel_layer, delay)
580+
)
581+
await asyncio.wait_for(task, None)
582+
583+
# message Zero! was sent about 2 + 1 + 1 seconds ago and it should have expired
584+
message = await channel_layer.receive(channel_1)
585+
assert message["type"] == "test.message"
586+
assert message["text"] == "First!"
587+
588+
message = await channel_layer.receive(channel_1)
589+
assert message["type"] == "test.message"
590+
assert message["text"] == "Second!"
591+
592+
message = await channel_layer.receive(channel_1)
593+
assert message["type"] == "test.message"
594+
assert message["text"] == "Third!"
595+
596+
# Make sure there's no fourth message even out of order
597+
with pytest.raises(asyncio.TimeoutError):
598+
async with async_timeout.timeout(1):
599+
await channel_layer.receive(channel_1)
600+
601+
# channel_2 should receive all three messages from group_send
602+
message = await channel_layer.receive(channel_2)
603+
assert message["type"] == "test.message"
604+
assert message["text"] == "First!"
605+
606+
# the first message should have expired, we should only see the second message and the third
607+
message = await channel_layer.receive(channel_2)
608+
assert message["type"] == "test.message"
609+
assert message["text"] == "Second!"
610+
611+
message = await channel_layer.receive(channel_2)
612+
assert message["type"] == "test.message"
613+
assert message["text"] == "Third!"

0 commit comments

Comments
 (0)