Skip to content

Commit 084ddc4

Browse files
authored
Added RedisPubSubChannelLayer. (#251)
1 parent 3656d87 commit 084ddc4

File tree

4 files changed

+462
-1
lines changed

4 files changed

+462
-1
lines changed

.github/workflows/tests.yml

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -3,7 +3,7 @@ name: Tests
33
on:
44
push:
55
branches:
6-
- master
6+
- main
77
pull_request:
88

99
jobs:

README.rst

Lines changed: 11 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -31,6 +31,17 @@ Set up the channel layer in your Django settings file like so::
3131
},
3232
}
3333

34+
Or, you can use the alternate implementation which uses Redis Pub/Sub::
35+
36+
CHANNEL_LAYERS = {
37+
"default": {
38+
"BACKEND": "channels_redis.pubsub.RedisPubSubChannelLayer",
39+
"CONFIG": {
40+
"hosts": [("localhost", 6379)],
41+
},
42+
},
43+
}
44+
3445
Possible options for ``CONFIG`` are listed below.
3546

3647
``hosts``

channels_redis/pubsub.py

Lines changed: 347 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,347 @@
1+
import asyncio
2+
import logging
3+
import uuid
4+
5+
import aioredis
6+
import msgpack
7+
8+
logger = logging.getLogger(__name__)
9+
10+
11+
class RedisPubSubChannelLayer:
12+
"""
13+
Channel Layer that uses Redis's pub/sub functionality.
14+
"""
15+
16+
def __init__(self, hosts=None, prefix="asgi", **kwargs):
17+
if hosts is None:
18+
hosts = [("localhost", 6379)]
19+
assert (
20+
isinstance(hosts, list) and len(hosts) > 0
21+
), "`hosts` must be a list with at least one Redis server"
22+
23+
self.prefix = prefix
24+
25+
# Each consumer gets its own *specific* channel, created with the `new_channel()` method.
26+
# This dict maps `channel_name` to a queue of messages for that channel.
27+
self.channels = {}
28+
29+
# A channel can subscribe to zero or more groups.
30+
# This dict maps `group_name` to set of channel names who are subscribed to that group.
31+
self.groups = {}
32+
33+
# For each host, we create a `RedisSingleShardConnection` to manage the connection to that host.
34+
self._shards = [RedisSingleShardConnection(host, self) for host in hosts]
35+
36+
def _get_shard(self, channel_or_group_name):
37+
"""
38+
Return the shard that is used exclusively for this channel or group.
39+
"""
40+
if len(self._shards) == 1:
41+
# Avoid the overhead of hashing and modulo when it is unnecessary.
42+
return self._shards[0]
43+
shard_index = abs(hash(channel_or_group_name)) % len(self._shards)
44+
return self._shards[shard_index]
45+
46+
def _get_group_channel_name(self, group):
47+
"""
48+
Return the channel name used by a group.
49+
Includes '__group__' in the returned
50+
string so that these names are distinguished
51+
from those returned by `new_channel()`.
52+
Technically collisions are possible, but it
53+
takes what I believe is intentional abuse in
54+
order to have colliding names.
55+
"""
56+
return f"{self.prefix}__group__{group}"
57+
58+
extensions = ["groups", "flush"]
59+
60+
################################################################################
61+
# Channel layer API
62+
################################################################################
63+
64+
async def send(self, channel, message):
65+
"""
66+
Send a message onto a (general or specific) channel.
67+
"""
68+
shard = self._get_shard(channel)
69+
await shard.publish(channel, message)
70+
71+
async def new_channel(self, prefix="specific."):
72+
"""
73+
Returns a new channel name that can be used by a consumer in our
74+
process as a specific channel.
75+
"""
76+
channel = f"{self.prefix}{prefix}{uuid.uuid4().hex}"
77+
self.channels[channel] = asyncio.Queue()
78+
shard = self._get_shard(channel)
79+
await shard.subscribe(channel)
80+
return channel
81+
82+
async def receive(self, channel):
83+
"""
84+
Receive the first message that arrives on the channel.
85+
If more than one coroutine waits on the same channel, a random one
86+
of the waiting coroutines will get the result.
87+
"""
88+
if channel not in self.channels:
89+
raise RuntimeError(
90+
'You should only call receive() on channels that you "own" and that were created with `new_channel()`.'
91+
)
92+
93+
q = self.channels[channel]
94+
95+
try:
96+
message = await q.get()
97+
except asyncio.CancelledError:
98+
# We assume here that the reason we are cancelled is because the consumer
99+
# is exiting, therefore we need to cleanup by unsubscribe below. Indeed,
100+
# currently the way that Django Channels works, this is a safe assumption.
101+
# In the future, Dajngo Channels could change to call a *new* method that
102+
# would serve as the antithesis of `new_channel()`; this new method might
103+
# be named `delete_channel()`. If that were the case, we would do the
104+
# following cleanup from that new `delete_channel()` method, but, since
105+
# that's not how Django Channels works (yet), we do the cleanup below:
106+
if channel in self.channels:
107+
del self.channels[channel]
108+
try:
109+
shard = self._get_shard(channel)
110+
await shard.unsubscribe(channel)
111+
except BaseException:
112+
logger.exception("Unexpected exception while cleaning-up channel:")
113+
# We don't re-raise here because we want the CancelledError to be the one re-raised.
114+
raise
115+
116+
return msgpack.unpackb(message)
117+
118+
################################################################################
119+
# Groups extension
120+
################################################################################
121+
122+
async def group_add(self, group, channel):
123+
"""
124+
Adds the channel name to a group.
125+
"""
126+
if channel not in self.channels:
127+
raise RuntimeError(
128+
"You can only call group_add() on channels that exist in-process.\n"
129+
"Consumers are encouraged to use the common pattern:\n"
130+
f" self.channel_layer.group_add({repr(group)}, self.channel_name)"
131+
)
132+
group_channel = self._get_group_channel_name(group)
133+
if group_channel not in self.groups:
134+
self.groups[group_channel] = set()
135+
group_channels = self.groups[group_channel]
136+
if channel not in group_channels:
137+
group_channels.add(channel)
138+
shard = self._get_shard(group_channel)
139+
await shard.subscribe(group_channel)
140+
141+
async def group_discard(self, group, channel):
142+
"""
143+
Removes the channel from a group.
144+
"""
145+
group_channel = self._get_group_channel_name(group)
146+
assert group_channel in self.groups
147+
group_channels = self.groups[group_channel]
148+
assert channel in group_channels
149+
group_channels.remove(channel)
150+
if len(group_channels) == 0:
151+
del self.groups[group_channel]
152+
shard = self._get_shard(group_channel)
153+
await shard.unsubscribe(group_channel)
154+
155+
async def group_send(self, group, message):
156+
"""
157+
Send the message to all subscribers of the group.
158+
"""
159+
group_channel = self._get_group_channel_name(group)
160+
shard = self._get_shard(group_channel)
161+
await shard.publish(group_channel, message)
162+
163+
################################################################################
164+
# Flush extension
165+
################################################################################
166+
167+
async def flush(self):
168+
"""
169+
Flush the layer, making it like new. It can continue to be used as if it
170+
was just created. This also closes connections, serving as a clean-up
171+
method; connections will be re-opened if you continue using this layer.
172+
"""
173+
self.channels = {}
174+
self.groups = {}
175+
for shard in self._shards:
176+
await shard.flush()
177+
178+
179+
def on_close_noop(sender, exc=None):
180+
"""
181+
If you don't pass an `on_close` function to the `Receiver`, then it
182+
defaults to one that closes the Receiver whenever the last subscriber
183+
unsubscribes. That is not what we want; instead, we want the Receiver
184+
to continue even if no one is subscribed, because soon someone *will*
185+
subscribe and we want things to continue from there. Passing this
186+
empty function solves it.
187+
"""
188+
pass
189+
190+
191+
class RedisSingleShardConnection:
192+
def __init__(self, host, channel_layer):
193+
self.host = host
194+
self.channel_layer = channel_layer
195+
self._subscribed_to = set()
196+
self._lock = None
197+
self._pub_conn = None
198+
self._sub_conn = None
199+
self._receiver = None
200+
self._receive_task = None
201+
self._keepalive_task = None
202+
203+
async def publish(self, channel, message):
204+
conn = await self._get_pub_conn()
205+
await conn.publish(channel, msgpack.packb(message))
206+
207+
async def subscribe(self, channel):
208+
if channel not in self._subscribed_to:
209+
self._subscribed_to.add(channel)
210+
conn = await self._get_sub_conn()
211+
await conn.subscribe(self._receiver.channel(channel))
212+
213+
async def unsubscribe(self, channel):
214+
if channel in self._subscribed_to:
215+
self._subscribed_to.remove(channel)
216+
conn = await self._get_sub_conn()
217+
await conn.unsubscribe(channel)
218+
219+
async def flush(self):
220+
for task in [self._keepalive_task, self._receive_task]:
221+
if task is not None:
222+
task.cancel()
223+
try:
224+
await task
225+
except asyncio.CancelledError:
226+
pass
227+
self._keepalive_task = None
228+
self._receive_task = None
229+
self._receiver = None
230+
if self._sub_conn is not None:
231+
self._sub_conn.close()
232+
await self._sub_conn.wait_closed()
233+
self._sub_conn = None
234+
if self._pub_conn is not None:
235+
self._pub_conn.close()
236+
await self._pub_conn.wait_closed()
237+
self._pub_conn = None
238+
self._subscribed_to = set()
239+
240+
async def _get_pub_conn(self):
241+
"""
242+
Return the connection to this shard that is used for *publishing* messages.
243+
244+
If the connection is dead, automatically reconnect.
245+
"""
246+
if self._lock is None:
247+
self._lock = asyncio.Lock()
248+
async with self._lock:
249+
if self._pub_conn is not None and self._pub_conn.closed:
250+
self._pub_conn = None
251+
while self._pub_conn is None:
252+
try:
253+
self._pub_conn = await aioredis.create_redis(self.host)
254+
except BaseException:
255+
logger.warning(
256+
f"Failed to connect to Redis publish host: {self.host}; will try again in 1 second..."
257+
)
258+
await asyncio.sleep(1)
259+
return self._pub_conn
260+
261+
async def _get_sub_conn(self):
262+
"""
263+
Return the connection to this shard that is used for *subscribing* to channels.
264+
265+
If the connection is dead, automatically reconnect and resubscribe to all our channels!
266+
"""
267+
if self._keepalive_task is None:
268+
self._keepalive_task = asyncio.ensure_future(self._do_keepalive())
269+
if self._lock is None:
270+
self._lock = asyncio.Lock()
271+
async with self._lock:
272+
if self._sub_conn is not None and self._sub_conn.closed:
273+
self._sub_conn = None
274+
if self._sub_conn is None:
275+
if self._receive_task is not None:
276+
self._receive_task.cancel()
277+
try:
278+
await self._receive_task
279+
except asyncio.CancelledError:
280+
# This is the normal case, that `asyncio.CancelledError` is throw. All good.
281+
pass
282+
except BaseException:
283+
logger.exception(
284+
"Unexpected exception while canceling the receiver task:"
285+
)
286+
# Don't re-raise here. We don't actually care why `_receive_task` didn't exit cleanly.
287+
self._receive_task = None
288+
while self._sub_conn is None:
289+
try:
290+
self._sub_conn = await aioredis.create_redis(self.host)
291+
except BaseException:
292+
logger.warning(
293+
f"Failed to connect to Redis subscribe host: {self.host}; will try again in 1 second..."
294+
)
295+
await asyncio.sleep(1)
296+
self._receiver = aioredis.pubsub.Receiver(on_close=on_close_noop)
297+
self._receive_task = asyncio.ensure_future(self._do_receiving())
298+
if len(self._subscribed_to) > 0:
299+
# Do our best to recover by resubscribing to the channels that we were previously subscribed to.
300+
resubscribe_to = [
301+
self._receiver.channel(name) for name in self._subscribed_to
302+
]
303+
await self._sub_conn.subscribe(*resubscribe_to)
304+
return self._sub_conn
305+
306+
async def _do_receiving(self):
307+
async for ch, message in self._receiver.iter():
308+
name = ch.name
309+
if isinstance(name, bytes):
310+
# Reversing what happens here:
311+
# https://github.com/aio-libs/aioredis-py/blob/8a207609b7f8a33e74c7c8130d97186e78cc0052/aioredis/util.py#L17
312+
name = name.decode()
313+
if name in self.channel_layer.channels:
314+
self.channel_layer.channels[name].put_nowait(message)
315+
elif name in self.channel_layer.groups:
316+
for channel_name in self.channel_layer.groups[name]:
317+
if channel_name in self.channel_layer.channels:
318+
self.channel_layer.channels[channel_name].put_nowait(message)
319+
320+
async def _do_keepalive(self):
321+
"""
322+
This task's simple job is just to call `self._get_sub_conn()` periodically.
323+
324+
Why? Well, calling `self._get_sub_conn()` has the nice side-effect that if
325+
that connection has died (because Redis was restarted, or there was a networking
326+
hiccup, for example), then calling `self._get_sub_conn()` will reconnect and
327+
restore our old subscriptions. Thus, we want to do this on a predictable schedule.
328+
This is kinda a sub-optimal way to achieve this, but I can't find a way in aioredis
329+
to get a notification when the connection dies. I find this (sub-optimal) method
330+
of checking the connection state works fine for my app; if Redis restarts, we reconnect
331+
and resubscribe *quickly enough*; I mean, Redis restarting is already bad because it
332+
will cause messages to get lost, and this periodic check at least minimizes the
333+
damage *enough*.
334+
335+
Note you wouldn't need this if you were *sure* that there would be a lot of subscribe/
336+
unsubscribe events on your site, because such events each call `self._get_sub_conn()`.
337+
Thus, on a site with heavy traffic this task may not be necessary, but also maybe it is.
338+
Why? Well, in a heavy traffic site you probably have more than one Django server replicas,
339+
so it might be the case that one of your replicas is under-utilized and this periodic
340+
connection check will be beneficial in the same way as it is for a low-traffic site.
341+
"""
342+
while True:
343+
await asyncio.sleep(1)
344+
try:
345+
await self._get_sub_conn()
346+
except Exception:
347+
logger.exception("Unexpected exception in keepalive task:")

0 commit comments

Comments
 (0)