Skip to content

Commit 3449871

Browse files
authored
Merge branch 'main' into pubsub-redis-events
2 parents 9ea85ff + 4e883c5 commit 3449871

File tree

2 files changed

+140
-3
lines changed

2 files changed

+140
-3
lines changed

channels_redis/pubsub.py

Lines changed: 36 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -195,10 +195,12 @@ def on_close_noop(sender, exc=None):
195195

196196
class RedisSingleShardConnection:
197197
def __init__(self, host, channel_layer):
198-
self.host = host
198+
self.host = host.copy() if type(host) is dict else {"address": host}
199+
self.master_name = self.host.pop("master_name", None)
199200
self.channel_layer = channel_layer
200201
self._subscribed_to = set()
201202
self._lock = None
203+
self._redis = None
202204
self._pub_conn = None
203205
self._sub_conn = None
204206
self._receiver = None
@@ -235,10 +237,12 @@ async def flush(self):
235237
if self._sub_conn is not None:
236238
self._sub_conn.close()
237239
await self._sub_conn.wait_closed()
240+
self._put_redis_conn(self._sub_conn)
238241
self._sub_conn = None
239242
if self._pub_conn is not None:
240243
self._pub_conn.close()
241244
await self._pub_conn.wait_closed()
245+
self._put_redis_conn(self._pub_conn)
242246
self._pub_conn = None
243247
self._subscribed_to = set()
244248

@@ -252,11 +256,13 @@ async def _get_pub_conn(self):
252256
self._lock = asyncio.Lock()
253257
async with self._lock:
254258
if self._pub_conn is not None and self._pub_conn.closed:
259+
self._put_redis_conn(self._pub_conn)
255260
self._pub_conn = None
256261
while self._pub_conn is None:
257262
try:
258-
self._pub_conn = await aioredis.create_redis(self.host)
263+
self._pub_conn = await self._get_redis_conn()
259264
except BaseException:
265+
self._put_redis_conn(self._pub_conn)
260266
logger.warning(
261267
f"Failed to connect to Redis publish host: {self.host}; will try again in 1 second..."
262268
)
@@ -275,6 +281,7 @@ async def _get_sub_conn(self):
275281
self._lock = asyncio.Lock()
276282
async with self._lock:
277283
if self._sub_conn is not None and self._sub_conn.closed:
284+
self._put_redis_conn(self._sub_conn)
278285
self._sub_conn = None
279286
self._notify_consumers(self.channel_layer.on_disconnect)
280287
if self._sub_conn is None:
@@ -293,8 +300,9 @@ async def _get_sub_conn(self):
293300
self._receive_task = None
294301
while self._sub_conn is None:
295302
try:
296-
self._sub_conn = await aioredis.create_redis(self.host)
303+
self._sub_conn = await self._get_redis_conn()
297304
except BaseException:
305+
self._put_redis_conn(self._sub_conn)
298306
logger.warning(
299307
f"Failed to connect to Redis subscribe host: {self.host}; will try again in 1 second..."
300308
)
@@ -329,6 +337,31 @@ def _notify_consumers(self, mtype):
329337
for channel in self.channel_layer.channels.values():
330338
channel.put_nowait(msgpack.packb({"type": mtype}))
331339

340+
async def _ensure_redis(self):
341+
if self._redis is None:
342+
if self.master_name is None:
343+
self._redis = await aioredis.create_redis_pool(**self.host)
344+
else:
345+
# aioredis default timeout is way too low
346+
self._redis = await aioredis.sentinel.create_sentinel(
347+
timeout=2, **self.host
348+
)
349+
350+
def _get_aioredis_pool(self):
351+
if self.master_name is None:
352+
return self._redis._pool_or_conn
353+
else:
354+
return self._redis.master_for(self.master_name)._pool_or_conn
355+
356+
async def _get_redis_conn(self):
357+
await self._ensure_redis()
358+
conn = await self._get_aioredis_pool().acquire()
359+
return aioredis.Redis(conn)
360+
361+
def _put_redis_conn(self, conn):
362+
if conn:
363+
self._get_aioredis_pool().release(conn._pool_or_conn)
364+
332365
async def _do_keepalive(self):
333366
"""
334367
This task's simple job is just to call `self._get_sub_conn()` periodically.

tests/test_pubsub_sentinel.py

Lines changed: 104 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,104 @@
1+
import asyncio
2+
import random
3+
4+
import async_timeout
5+
import pytest
6+
from async_generator import async_generator, yield_
7+
8+
from channels_redis.pubsub import RedisPubSubChannelLayer
9+
10+
SENTINEL_MASTER = "sentinel"
11+
TEST_HOSTS = [{"sentinels": [("localhost", 26379)], "master_name": SENTINEL_MASTER}]
12+
13+
14+
@pytest.fixture()
15+
@async_generator
16+
async def channel_layer():
17+
"""
18+
Channel layer fixture that flushes automatically.
19+
"""
20+
channel_layer = RedisPubSubChannelLayer(hosts=TEST_HOSTS)
21+
await yield_(channel_layer)
22+
await channel_layer.flush()
23+
24+
25+
@pytest.mark.asyncio
26+
async def test_send_receive(channel_layer):
27+
"""
28+
Makes sure we can send a message to a normal channel then receive it.
29+
"""
30+
channel = await channel_layer.new_channel()
31+
await channel_layer.send(channel, {"type": "test.message", "text": "Ahoy-hoy!"})
32+
message = await channel_layer.receive(channel)
33+
assert message["type"] == "test.message"
34+
assert message["text"] == "Ahoy-hoy!"
35+
36+
37+
@pytest.mark.asyncio
38+
async def test_multi_send_receive(channel_layer):
39+
"""
40+
Tests overlapping sends and receives, and ordering.
41+
"""
42+
channel = await channel_layer.new_channel()
43+
await channel_layer.send(channel, {"type": "message.1"})
44+
await channel_layer.send(channel, {"type": "message.2"})
45+
await channel_layer.send(channel, {"type": "message.3"})
46+
assert (await channel_layer.receive(channel))["type"] == "message.1"
47+
assert (await channel_layer.receive(channel))["type"] == "message.2"
48+
assert (await channel_layer.receive(channel))["type"] == "message.3"
49+
50+
51+
@pytest.mark.asyncio
52+
async def test_groups_basic(channel_layer):
53+
"""
54+
Tests basic group operation.
55+
"""
56+
channel_name1 = await channel_layer.new_channel(prefix="test-gr-chan-1")
57+
channel_name2 = await channel_layer.new_channel(prefix="test-gr-chan-2")
58+
channel_name3 = await channel_layer.new_channel(prefix="test-gr-chan-3")
59+
await channel_layer.group_add("test-group", channel_name1)
60+
await channel_layer.group_add("test-group", channel_name2)
61+
await channel_layer.group_add("test-group", channel_name3)
62+
await channel_layer.group_discard("test-group", channel_name2)
63+
await channel_layer.group_send("test-group", {"type": "message.1"})
64+
# Make sure we get the message on the two channels that were in
65+
async with async_timeout.timeout(1):
66+
assert (await channel_layer.receive(channel_name1))["type"] == "message.1"
67+
assert (await channel_layer.receive(channel_name3))["type"] == "message.1"
68+
# Make sure the removed channel did not get the message
69+
with pytest.raises(asyncio.TimeoutError):
70+
async with async_timeout.timeout(1):
71+
await channel_layer.receive(channel_name2)
72+
73+
74+
@pytest.mark.asyncio
75+
async def test_groups_same_prefix(channel_layer):
76+
"""
77+
Tests group_send with multiple channels with same channel prefix
78+
"""
79+
channel_name1 = await channel_layer.new_channel(prefix="test-gr-chan")
80+
channel_name2 = await channel_layer.new_channel(prefix="test-gr-chan")
81+
channel_name3 = await channel_layer.new_channel(prefix="test-gr-chan")
82+
await channel_layer.group_add("test-group", channel_name1)
83+
await channel_layer.group_add("test-group", channel_name2)
84+
await channel_layer.group_add("test-group", channel_name3)
85+
await channel_layer.group_send("test-group", {"type": "message.1"})
86+
87+
# Make sure we get the message on the channels that were in
88+
async with async_timeout.timeout(1):
89+
assert (await channel_layer.receive(channel_name1))["type"] == "message.1"
90+
assert (await channel_layer.receive(channel_name2))["type"] == "message.1"
91+
assert (await channel_layer.receive(channel_name3))["type"] == "message.1"
92+
93+
94+
@pytest.mark.asyncio
95+
async def test_random_reset__channel_name(channel_layer):
96+
"""
97+
Makes sure resetting random seed does not make us reuse channel names.
98+
"""
99+
random.seed(1)
100+
channel_name_1 = await channel_layer.new_channel()
101+
random.seed(1)
102+
channel_name_2 = await channel_layer.new_channel()
103+
104+
assert channel_name_1 != channel_name_2

0 commit comments

Comments
 (0)