Skip to content

Commit e2416af

Browse files
jberciandrewgodwin
authored andcommitted
Handle receive cancellation correctly
The one remaining loophole is Redis `BLPOP` cancellation, which was also the original problem when changing from per-request connections to connection pools. Once a `BLPOP` command is started, Python will switch to some other task while waiting for the server to respond. If that other task decides to cancel the `receive`, then one of two things will happen: - the connection is dropped soon enough for the server to notice, so that it won't send the response at all, or - the response already reached Python's process, but the task processing it is flagged for cancellation. In the latter scenario, the message will be lost. To solve this, `BLPOP` could be shielded from cancellation, but this would lead to problematic timeout handling, where a read command to Redis could easily outlive its event loop, leading to all sorts of problems for shutdown code or environments where the global channels layer instance is used from multiple event loops. Redis documentation suggests using `BRPOPLPUSH` instead, where messages are atomically returned and copied to an auxiliary list. This makes the Python implementation easier inasmuch as cancellation during the `BRPOPLPUSH` no longer needs to be handled; but the message must still be removed from the auxiliary list once processed - if and only if it is processed. One option for implementing this is to make the auxiliary removal completely uninterruptible from the point of view of the code path executing it. In a usual single-threaded asyncio program, the `BRPOPLPUSH` would be the last interruptible operation before returning a valid message; the code following this command is essentially atomic with respect to any other task attached to the event loop. Any other task might then reasonably expect that either: - the receive succeeded completely, will return a valid message and has cleaned the auxiliary list, or - `BRPOPLPUSH` was interrupted, there will be no message to speak of and there is a backup of it in the auxiliary list. Making removal atomic violates the second expectation and also a fairly basic principle of single-threaded asyncio code. Because it is still an interruptible operation, other tasks will be run during this "atomic" piece of code, meaning it could be cancelled. If it is cancelled, then that cancellation must become visible outside, otherwise we've just "run" two independent code paths at the same time in a single thread, and also made life difficult for anyone wanting to make sure tasks are complete in a timely manner - the receive cancellation would essentially be gobbled and the task would need to be cancelled again at its next interruption point. If the cancellation must proceed, however, we could be left in the following situation: the message has been unpacked, but the cancellation occured too late to stop the removal operation, so we're left with a message that we somehow have to return, because there's no backup of it in Redis. There would be a way to solve this e.g. by subclassing the cancellation exception and having it carry the message as a payload, but that would be messy. A second option is to create a detached removal task. This lets us still have a properly atomic code path from when the message is unpacked to when the function returns, and the removal will happen at some point "soon", because there are no delays in it. Notably, ordering is not a problem: if two close receives both fire off removals for their own messages, the end result will still be that two messages will be deleted from a list into which two messages were put earlier by Redis. On the other hand, it is a problem if a receive commences after a previous receive, but before its cleanup: the message backup would be moved back into the regular message queue, and the second receive would get it again even though the earlier receive processed it successfully. The solution to this is rather heavy-handed but easily reasoned about: the stretch of time between starting a removal task and finishing it can be protected by a per-channel lock. Acquiring the lock is interruptible but doesn't change the cancellation semantics of the design, since cancelling lock acquisition has the same effect as cancelling `BRPOPLPUSH` before the command reaches Redis - no message is received and Redis state is not corrupted. *NOTE*: It is useful to observe that `BRPOPLPUSH` is also the only Redis command that requires this level of special handling. The others used by `channels_redis` are either read-only, such as `LLEN`, or are destructive without expecting a return value, such as `ZREM`. `BRPOPLPUSH` is at the same time destructive (an element is removed server-side from the list) and with an important return value.
1 parent 8bc0433 commit e2416af

File tree

1 file changed

+142
-18
lines changed

1 file changed

+142
-18
lines changed

channels_redis/core.py

Lines changed: 142 additions & 18 deletions
Original file line numberDiff line numberDiff line change
@@ -90,6 +90,42 @@ async def close(self):
9090
await conn.wait_closed()
9191

9292

93+
class ChannelLock:
94+
"""
95+
Helper class for per-channel locking.
96+
97+
Once a lock is released and has no waiters, it will also be deleted,
98+
to mitigate multi-event loop problems.
99+
"""
100+
101+
def __init__(self):
102+
self.locks = collections.defaultdict(asyncio.Lock)
103+
self.wait_counts = collections.defaultdict(int)
104+
105+
async def acquire(self, channel):
106+
"""
107+
Acquire the lock for the given channel.
108+
"""
109+
self.wait_counts[channel] += 1
110+
return await self.locks[channel].acquire()
111+
112+
def locked(self, channel):
113+
"""
114+
Return ``True`` if the lock for the given channel is acquired.
115+
"""
116+
return self.locks[channel].locked()
117+
118+
def release(self, channel):
119+
"""
120+
Release the lock for the given channel.
121+
"""
122+
self.locks[channel].release()
123+
self.wait_counts[channel] -= 1
124+
if self.wait_counts[channel] < 1:
125+
del self.locks[channel]
126+
del self.wait_counts[channel]
127+
128+
93129
class UnsupportedRedis(Exception):
94130
pass
95131

@@ -103,7 +139,7 @@ class RedisChannelLayer(BaseChannelLayer):
103139
encryption are provided.
104140
"""
105141

106-
blpop_timeout = 5
142+
brpop_timeout = 5
107143

108144
def __init__(
109145
self,
@@ -143,6 +179,11 @@ def __init__(
143179
self.receive_event_loop = None
144180
# Buffered messages by process-local channel name
145181
self.receive_buffer = collections.defaultdict(asyncio.Queue)
182+
# Detached channel cleanup tasks
183+
self.receive_cleaners = []
184+
# Per-channel cleanup locks to prevent a receive starting and moving
185+
# a message back into the main queue before its cleanup has completed
186+
self.receive_clean_locks = ChannelLock()
146187

147188
def decode_hosts(self, hosts):
148189
"""
@@ -213,9 +254,50 @@ async def send(self, channel, message):
213254
if await connection.llen(channel_key) >= self.get_capacity(channel):
214255
raise ChannelFull()
215256
# Push onto the list then set it to expire in case it's not consumed
216-
await connection.rpush(channel_key, self.serialize(message))
257+
await connection.lpush(channel_key, self.serialize(message))
217258
await connection.expire(channel_key, int(self.expiry))
218259

260+
def _backup_channel_name(self, channel):
261+
"""
262+
Construct the key used as a backup queue for the given channel.
263+
"""
264+
return channel + "$inflight"
265+
266+
async def _brpop_with_clean(self, index, channel, timeout):
267+
"""
268+
Perform a Redis BRPOP and manage the backup processing queue.
269+
In case of cancellation, make sure the message is not lost.
270+
"""
271+
# The script will pop messages from the processing queue and push them in front
272+
# of the main message queue in the proper order; BRPOP must *not* be called
273+
# because that would deadlock the server
274+
cleanup_script = """
275+
local backed_up = redis.call('LRANGE', ARGV[2], 0, -1)
276+
for i = #backed_up, 1, -1 do
277+
redis.call('LPUSH', ARGV[1], backed_up[i])
278+
end
279+
redis.call('DEL', ARGV[2])
280+
"""
281+
backup_queue = self._backup_channel_name(channel)
282+
async with self.connection(index) as connection:
283+
# Cancellation here doesn't matter, we're not doing anything destructive
284+
# and the script executes atomically...
285+
await connection.eval(
286+
cleanup_script,
287+
keys=[],
288+
args=[channel, backup_queue]
289+
)
290+
# ...and it doesn't matter here either, the message will be safe in the backup.
291+
return await connection.brpoplpush(channel, backup_queue, timeout=timeout)
292+
293+
async def _clean_receive_backup(self, index, channel):
294+
"""
295+
Pop the oldest message off the channel backup queue.
296+
The result isn't interesting as it was already processed.
297+
"""
298+
async with self.connection(index) as connection:
299+
await connection.brpop(self._backup_channel_name(channel))
300+
219301
async def receive(self, channel):
220302
"""
221303
Receive the first message that arrives on the channel.
@@ -289,12 +371,15 @@ async def receive(self, channel):
289371

290372
# We hold the receive lock, receive and then release it.
291373
try:
374+
# There is no interruption point from when the message is
375+
# unpacked in receive_single to when we get back here, so
376+
# the following lines are essentially atomic.
292377
message_channel, message = await self.receive_single(real_channel)
293378
if type(message_channel) is list:
294379
for chan in message_channel:
295-
await self.receive_buffer[chan].put(message)
380+
self.receive_buffer[chan].put_nowait(message)
296381
else:
297-
await self.receive_buffer[message_channel].put(message)
382+
self.receive_buffer[message_channel].put_nowait(message)
298383
message = None
299384
finally:
300385
self.receive_lock.release()
@@ -314,6 +399,7 @@ async def receive(self, channel):
314399
if self.receive_count == 0:
315400
assert not self.receive_lock.locked()
316401
self.receive_lock = None
402+
self.receive_event_loop = None
317403
else:
318404
# Do a plain direct receive
319405
return (await self.receive_single(channel))[1]
@@ -330,21 +416,44 @@ async def receive_single(self, channel):
330416
index = self.consistent_hash(channel)
331417
else:
332418
index = next(self._receive_index_generator)
333-
# Get that connection and receive off of it
334-
async with self.connection(index) as connection:
335-
channel_key = self.prefix + channel
336-
content = None
337419

420+
channel_key = self.prefix + channel
421+
content = None
422+
await self.receive_clean_locks.acquire(channel_key)
423+
try:
338424
while content is None:
339-
content = await connection.blpop(channel_key, timeout=self.blpop_timeout)
340-
# Message decode
341-
message = self.deserialize(content[1])
342-
# TODO: message expiry?
343-
# If there is a full channel name stored in the message, unpack it.
344-
if "__asgi_channel__" in message:
345-
channel = message["__asgi_channel__"]
346-
del message["__asgi_channel__"]
347-
return channel, message
425+
# Nothing is lost here by cancellations, messages will still
426+
# be in the backup queue.
427+
content = await self._brpop_with_clean(index, channel_key, timeout=self.brpop_timeout)
428+
429+
# Fire off a task to clean the message from its backup queue.
430+
# Per-channel locking isn't needed, because the backup is a queue
431+
# and additionally, we don't care about the order; all processed
432+
# messages need to be removed, no matter if the current one is
433+
# removed after the next one.
434+
# NOTE: Duplicate messages will be received eventually if any
435+
# of these cleaners are cancelled.
436+
cleaner = asyncio.ensure_future(self._clean_receive_backup(index, channel_key))
437+
self.receive_cleaners.append(cleaner)
438+
439+
def _cleanup_done(cleaner):
440+
self.receive_cleaners.remove(cleaner)
441+
self.receive_clean_locks.release(channel_key)
442+
443+
cleaner.add_done_callback(_cleanup_done)
444+
445+
except Exception as exc:
446+
self.receive_clean_locks.release(channel_key)
447+
raise
448+
449+
# Message decode
450+
message = self.deserialize(content)
451+
# TODO: message expiry?
452+
# If there is a full channel name stored in the message, unpack it.
453+
if "__asgi_channel__" in message:
454+
channel = message["__asgi_channel__"]
455+
del message["__asgi_channel__"]
456+
return channel, message
348457

349458
async def new_channel(self, prefix="specific"):
350459
"""
@@ -364,6 +473,10 @@ async def flush(self):
364473
"""
365474
Deletes all messages and groups on all shards.
366475
"""
476+
# Make sure all channel cleaners have finished before removing
477+
# keys from under their feet.
478+
await self.wait_received()
479+
367480
# Lua deletion script
368481
delete_prefix = """
369482
local keys = redis.call('keys', ARGV[1])
@@ -386,9 +499,20 @@ async def close_pools(self):
386499
"""
387500
Close all connections in the event loop pools.
388501
"""
502+
# Flush all cleaners, in case somebody just wanted to close the
503+
# pools without flushing first.
504+
await self.wait_received()
505+
389506
for pool in self.pools:
390507
await pool.close()
391508

509+
async def wait_received(self):
510+
"""
511+
Wait for all channel cleanup functions to finish.
512+
"""
513+
if self.receive_cleaners:
514+
await asyncio.wait(self.receive_cleaners[:])
515+
392516
### Groups extension ###
393517

394518
async def group_add(self, group, channel):
@@ -451,7 +575,7 @@ async def group_send(self, group, message):
451575
group_send_lua = """
452576
for i=1,#KEYS do
453577
if redis.call('LLEN', KEYS[i]) < tonumber(ARGV[i + #KEYS]) then
454-
redis.call('RPUSH', KEYS[i], ARGV[i])
578+
redis.call('LPUSH', KEYS[i], ARGV[i])
455579
redis.call('EXPIRE', KEYS[i], %d)
456580
end
457581
end

0 commit comments

Comments
 (0)