Skip to content

Commit 49e419d

Browse files
authored
pubsub event loop layer proxy (#262)
1 parent 9793d09 commit 49e419d

File tree

4 files changed

+135
-3
lines changed

4 files changed

+135
-3
lines changed

channels_redis/pubsub.py

Lines changed: 64 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,8 @@
11
import asyncio
2+
import functools
23
import logging
4+
import sys
5+
import types
36
import uuid
47

58
import aioredis
@@ -8,7 +11,68 @@
811
logger = logging.getLogger(__name__)
912

1013

14+
if sys.version_info >= (3, 7):
15+
get_running_loop = asyncio.get_running_loop
16+
else:
17+
get_running_loop = asyncio.get_event_loop
18+
19+
20+
def _wrap_close(proxy, loop):
21+
original_impl = loop.close
22+
23+
def _wrapper(self, *args, **kwargs):
24+
if loop in proxy._layers:
25+
layer = proxy._layers[loop]
26+
del proxy._layers[loop]
27+
loop.run_until_complete(layer.flush())
28+
29+
self.close = original_impl
30+
return self.close(*args, **kwargs)
31+
32+
loop.close = types.MethodType(_wrapper, loop)
33+
34+
1135
class RedisPubSubChannelLayer:
36+
def __init__(self, *args, **kwargs) -> None:
37+
self._args = args
38+
self._kwargs = kwargs
39+
self._layers = {}
40+
41+
def __getattr__(self, name):
42+
if name in (
43+
"new_channel",
44+
"send",
45+
"receive",
46+
"group_add",
47+
"group_discard",
48+
"group_send",
49+
"flush",
50+
):
51+
return functools.partial(self._proxy, name)
52+
else:
53+
return getattr(self._get_layer(), name)
54+
55+
def _get_layer(self):
56+
loop = get_running_loop()
57+
58+
try:
59+
layer = self._layers[loop]
60+
except KeyError:
61+
layer = RedisPubSubLoopLayer(*self._args, **self._kwargs)
62+
self._layers[loop] = layer
63+
_wrap_close(self, loop)
64+
65+
return layer
66+
67+
def _proxy(self, name, *args, **kwargs):
68+
async def coro():
69+
layer = self._get_layer()
70+
return await getattr(layer, name)(*args, **kwargs)
71+
72+
return coro()
73+
74+
75+
class RedisPubSubLoopLayer:
1276
"""
1377
Channel Layer that uses Redis's pub/sub functionality.
1478
"""

tests/test_pubsub.py

Lines changed: 35 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -5,6 +5,7 @@
55
import pytest
66
from async_generator import async_generator, yield_
77

8+
from asgiref.sync import async_to_sync
89
from channels_redis.pubsub import RedisPubSubChannelLayer
910

1011
TEST_HOSTS = [("localhost", 6379)]
@@ -33,6 +34,18 @@ async def test_send_receive(channel_layer):
3334
assert message["text"] == "Ahoy-hoy!"
3435

3536

37+
@pytest.mark.asyncio
38+
def test_send_receive_sync(channel_layer, event_loop):
39+
_await = event_loop.run_until_complete
40+
channel = _await(channel_layer.new_channel())
41+
async_to_sync(channel_layer.send, force_new_loop=True)(
42+
channel, {"type": "test.message", "text": "Ahoy-hoy!"}
43+
)
44+
message = _await(channel_layer.receive(channel))
45+
assert message["type"] == "test.message"
46+
assert message["text"] == "Ahoy-hoy!"
47+
48+
3649
@pytest.mark.asyncio
3750
async def test_multi_send_receive(channel_layer):
3851
"""
@@ -47,6 +60,19 @@ async def test_multi_send_receive(channel_layer):
4760
assert (await channel_layer.receive(channel))["type"] == "message.3"
4861

4962

63+
@pytest.mark.asyncio
64+
def test_multi_send_receive_sync(channel_layer, event_loop):
65+
_await = event_loop.run_until_complete
66+
channel = _await(channel_layer.new_channel())
67+
send = async_to_sync(channel_layer.send)
68+
send(channel, {"type": "message.1"})
69+
send(channel, {"type": "message.2"})
70+
send(channel, {"type": "message.3"})
71+
assert _await(channel_layer.receive(channel))["type"] == "message.1"
72+
assert _await(channel_layer.receive(channel))["type"] == "message.2"
73+
assert _await(channel_layer.receive(channel))["type"] == "message.3"
74+
75+
5076
@pytest.mark.asyncio
5177
async def test_groups_basic(channel_layer):
5278
"""
@@ -101,3 +127,12 @@ async def test_random_reset__channel_name(channel_layer):
101127
channel_name_2 = await channel_layer.new_channel()
102128

103129
assert channel_name_1 != channel_name_2
130+
131+
132+
def test_multi_event_loop_garbage_collection(channel_layer):
133+
"""
134+
Test loop closure layer flushing and garbage collection
135+
"""
136+
assert len(channel_layer._layers.values()) == 0
137+
async_to_sync(test_send_receive)(channel_layer)
138+
assert len(channel_layer._layers.values()) == 0

tests/test_pubsub_sentinel.py

Lines changed: 35 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -5,6 +5,7 @@
55
import pytest
66
from async_generator import async_generator, yield_
77

8+
from asgiref.sync import async_to_sync
89
from channels_redis.pubsub import RedisPubSubChannelLayer
910

1011
SENTINEL_MASTER = "sentinel"
@@ -34,6 +35,18 @@ async def test_send_receive(channel_layer):
3435
assert message["text"] == "Ahoy-hoy!"
3536

3637

38+
@pytest.mark.asyncio
39+
def test_send_receive_sync(channel_layer, event_loop):
40+
_await = event_loop.run_until_complete
41+
channel = _await(channel_layer.new_channel())
42+
async_to_sync(channel_layer.send, force_new_loop=True)(
43+
channel, {"type": "test.message", "text": "Ahoy-hoy!"}
44+
)
45+
message = _await(channel_layer.receive(channel))
46+
assert message["type"] == "test.message"
47+
assert message["text"] == "Ahoy-hoy!"
48+
49+
3750
@pytest.mark.asyncio
3851
async def test_multi_send_receive(channel_layer):
3952
"""
@@ -48,6 +61,19 @@ async def test_multi_send_receive(channel_layer):
4861
assert (await channel_layer.receive(channel))["type"] == "message.3"
4962

5063

64+
@pytest.mark.asyncio
65+
def test_multi_send_receive_sync(channel_layer, event_loop):
66+
_await = event_loop.run_until_complete
67+
channel = _await(channel_layer.new_channel())
68+
send = async_to_sync(channel_layer.send)
69+
send(channel, {"type": "message.1"})
70+
send(channel, {"type": "message.2"})
71+
send(channel, {"type": "message.3"})
72+
assert _await(channel_layer.receive(channel))["type"] == "message.1"
73+
assert _await(channel_layer.receive(channel))["type"] == "message.2"
74+
assert _await(channel_layer.receive(channel))["type"] == "message.3"
75+
76+
5177
@pytest.mark.asyncio
5278
async def test_groups_basic(channel_layer):
5379
"""
@@ -102,3 +128,12 @@ async def test_random_reset__channel_name(channel_layer):
102128
channel_name_2 = await channel_layer.new_channel()
103129

104130
assert channel_name_1 != channel_name_2
131+
132+
133+
def test_multi_event_loop_garbage_collection(channel_layer):
134+
"""
135+
Test loop closure layer flushing and garbage collection
136+
"""
137+
assert len(channel_layer._layers.values()) == 0
138+
async_to_sync(test_send_receive)(channel_layer)
139+
assert len(channel_layer._layers.values()) == 0

tests/test_sentinel.py

Lines changed: 1 addition & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -62,9 +62,7 @@ async def channel_layer():
6262
Channel layer fixture that flushes automatically.
6363
"""
6464
channel_layer = RedisChannelLayer(
65-
hosts=TEST_HOSTS,
66-
capacity=3,
67-
channel_capacity={"tiny": 1},
65+
hosts=TEST_HOSTS, capacity=3, channel_capacity={"tiny": 1}
6866
)
6967
await yield_(channel_layer)
7068
await channel_layer.flush()

0 commit comments

Comments
 (0)