Skip to content

Commit 4e883c5

Browse files
authored
Merge pull request #257 from zumalabs/pubsub-sentinel
Add sentinel support to pubsub layer
2 parents 084ddc4 + 0caaacd commit 4e883c5

File tree

2 files changed

+140
-3
lines changed

2 files changed

+140
-3
lines changed

channels_redis/pubsub.py

Lines changed: 36 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -190,10 +190,12 @@ def on_close_noop(sender, exc=None):
190190

191191
class RedisSingleShardConnection:
192192
def __init__(self, host, channel_layer):
193-
self.host = host
193+
self.host = host.copy() if type(host) is dict else {"address": host}
194+
self.master_name = self.host.pop("master_name", None)
194195
self.channel_layer = channel_layer
195196
self._subscribed_to = set()
196197
self._lock = None
198+
self._redis = None
197199
self._pub_conn = None
198200
self._sub_conn = None
199201
self._receiver = None
@@ -230,10 +232,12 @@ async def flush(self):
230232
if self._sub_conn is not None:
231233
self._sub_conn.close()
232234
await self._sub_conn.wait_closed()
235+
self._put_redis_conn(self._sub_conn)
233236
self._sub_conn = None
234237
if self._pub_conn is not None:
235238
self._pub_conn.close()
236239
await self._pub_conn.wait_closed()
240+
self._put_redis_conn(self._pub_conn)
237241
self._pub_conn = None
238242
self._subscribed_to = set()
239243

@@ -247,11 +251,13 @@ async def _get_pub_conn(self):
247251
self._lock = asyncio.Lock()
248252
async with self._lock:
249253
if self._pub_conn is not None and self._pub_conn.closed:
254+
self._put_redis_conn(self._pub_conn)
250255
self._pub_conn = None
251256
while self._pub_conn is None:
252257
try:
253-
self._pub_conn = await aioredis.create_redis(self.host)
258+
self._pub_conn = await self._get_redis_conn()
254259
except BaseException:
260+
self._put_redis_conn(self._pub_conn)
255261
logger.warning(
256262
f"Failed to connect to Redis publish host: {self.host}; will try again in 1 second..."
257263
)
@@ -270,6 +276,7 @@ async def _get_sub_conn(self):
270276
self._lock = asyncio.Lock()
271277
async with self._lock:
272278
if self._sub_conn is not None and self._sub_conn.closed:
279+
self._put_redis_conn(self._sub_conn)
273280
self._sub_conn = None
274281
if self._sub_conn is None:
275282
if self._receive_task is not None:
@@ -287,8 +294,9 @@ async def _get_sub_conn(self):
287294
self._receive_task = None
288295
while self._sub_conn is None:
289296
try:
290-
self._sub_conn = await aioredis.create_redis(self.host)
297+
self._sub_conn = await self._get_redis_conn()
291298
except BaseException:
299+
self._put_redis_conn(self._sub_conn)
292300
logger.warning(
293301
f"Failed to connect to Redis subscribe host: {self.host}; will try again in 1 second..."
294302
)
@@ -317,6 +325,31 @@ async def _do_receiving(self):
317325
if channel_name in self.channel_layer.channels:
318326
self.channel_layer.channels[channel_name].put_nowait(message)
319327

328+
async def _ensure_redis(self):
329+
if self._redis is None:
330+
if self.master_name is None:
331+
self._redis = await aioredis.create_redis_pool(**self.host)
332+
else:
333+
# aioredis default timeout is way too low
334+
self._redis = await aioredis.sentinel.create_sentinel(
335+
timeout=2, **self.host
336+
)
337+
338+
def _get_aioredis_pool(self):
339+
if self.master_name is None:
340+
return self._redis._pool_or_conn
341+
else:
342+
return self._redis.master_for(self.master_name)._pool_or_conn
343+
344+
async def _get_redis_conn(self):
345+
await self._ensure_redis()
346+
conn = await self._get_aioredis_pool().acquire()
347+
return aioredis.Redis(conn)
348+
349+
def _put_redis_conn(self, conn):
350+
if conn:
351+
self._get_aioredis_pool().release(conn._pool_or_conn)
352+
320353
async def _do_keepalive(self):
321354
"""
322355
This task's simple job is just to call `self._get_sub_conn()` periodically.

tests/test_pubsub_sentinel.py

Lines changed: 104 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,104 @@
1+
import asyncio
2+
import random
3+
4+
import async_timeout
5+
import pytest
6+
from async_generator import async_generator, yield_
7+
8+
from channels_redis.pubsub import RedisPubSubChannelLayer
9+
10+
SENTINEL_MASTER = "sentinel"
11+
TEST_HOSTS = [{"sentinels": [("localhost", 26379)], "master_name": SENTINEL_MASTER}]
12+
13+
14+
@pytest.fixture()
15+
@async_generator
16+
async def channel_layer():
17+
"""
18+
Channel layer fixture that flushes automatically.
19+
"""
20+
channel_layer = RedisPubSubChannelLayer(hosts=TEST_HOSTS)
21+
await yield_(channel_layer)
22+
await channel_layer.flush()
23+
24+
25+
@pytest.mark.asyncio
26+
async def test_send_receive(channel_layer):
27+
"""
28+
Makes sure we can send a message to a normal channel then receive it.
29+
"""
30+
channel = await channel_layer.new_channel()
31+
await channel_layer.send(channel, {"type": "test.message", "text": "Ahoy-hoy!"})
32+
message = await channel_layer.receive(channel)
33+
assert message["type"] == "test.message"
34+
assert message["text"] == "Ahoy-hoy!"
35+
36+
37+
@pytest.mark.asyncio
38+
async def test_multi_send_receive(channel_layer):
39+
"""
40+
Tests overlapping sends and receives, and ordering.
41+
"""
42+
channel = await channel_layer.new_channel()
43+
await channel_layer.send(channel, {"type": "message.1"})
44+
await channel_layer.send(channel, {"type": "message.2"})
45+
await channel_layer.send(channel, {"type": "message.3"})
46+
assert (await channel_layer.receive(channel))["type"] == "message.1"
47+
assert (await channel_layer.receive(channel))["type"] == "message.2"
48+
assert (await channel_layer.receive(channel))["type"] == "message.3"
49+
50+
51+
@pytest.mark.asyncio
52+
async def test_groups_basic(channel_layer):
53+
"""
54+
Tests basic group operation.
55+
"""
56+
channel_name1 = await channel_layer.new_channel(prefix="test-gr-chan-1")
57+
channel_name2 = await channel_layer.new_channel(prefix="test-gr-chan-2")
58+
channel_name3 = await channel_layer.new_channel(prefix="test-gr-chan-3")
59+
await channel_layer.group_add("test-group", channel_name1)
60+
await channel_layer.group_add("test-group", channel_name2)
61+
await channel_layer.group_add("test-group", channel_name3)
62+
await channel_layer.group_discard("test-group", channel_name2)
63+
await channel_layer.group_send("test-group", {"type": "message.1"})
64+
# Make sure we get the message on the two channels that were in
65+
async with async_timeout.timeout(1):
66+
assert (await channel_layer.receive(channel_name1))["type"] == "message.1"
67+
assert (await channel_layer.receive(channel_name3))["type"] == "message.1"
68+
# Make sure the removed channel did not get the message
69+
with pytest.raises(asyncio.TimeoutError):
70+
async with async_timeout.timeout(1):
71+
await channel_layer.receive(channel_name2)
72+
73+
74+
@pytest.mark.asyncio
75+
async def test_groups_same_prefix(channel_layer):
76+
"""
77+
Tests group_send with multiple channels with same channel prefix
78+
"""
79+
channel_name1 = await channel_layer.new_channel(prefix="test-gr-chan")
80+
channel_name2 = await channel_layer.new_channel(prefix="test-gr-chan")
81+
channel_name3 = await channel_layer.new_channel(prefix="test-gr-chan")
82+
await channel_layer.group_add("test-group", channel_name1)
83+
await channel_layer.group_add("test-group", channel_name2)
84+
await channel_layer.group_add("test-group", channel_name3)
85+
await channel_layer.group_send("test-group", {"type": "message.1"})
86+
87+
# Make sure we get the message on the channels that were in
88+
async with async_timeout.timeout(1):
89+
assert (await channel_layer.receive(channel_name1))["type"] == "message.1"
90+
assert (await channel_layer.receive(channel_name2))["type"] == "message.1"
91+
assert (await channel_layer.receive(channel_name3))["type"] == "message.1"
92+
93+
94+
@pytest.mark.asyncio
95+
async def test_random_reset__channel_name(channel_layer):
96+
"""
97+
Makes sure resetting random seed does not make us reuse channel names.
98+
"""
99+
random.seed(1)
100+
channel_name_1 = await channel_layer.new_channel()
101+
random.seed(1)
102+
channel_name_2 = await channel_layer.new_channel()
103+
104+
assert channel_name_1 != channel_name_2

0 commit comments

Comments
 (0)