Skip to content

Commit 1132fcc

Browse files
Kub-ATacu192
andauthored
Add serialize and deserialize method to RedisPubSubChannelLayer (#281)
* Add serializer and deserialize method to RedisPubSubChannelLayer * move serialization methods directly to RedisPubSubChannelLayer * Pass a full reference of RedisPubSubChannelLayer to RedisPubSubLoopLayer instance * Fix linter errors * Fix linter errors (again) Using the version of `black` to match the repo this time. * Not specifying these parameters (since they are the defaults) There are many _other_ parameters which are not specified. It's more clear this way to show we are using _all_ defaults. Co-authored-by: Jakub Stawowy <[email protected]> Co-authored-by: Ryan Henning <[email protected]>
1 parent 781228c commit 1132fcc

File tree

5 files changed

+134
-7
lines changed

5 files changed

+134
-7
lines changed

channels_redis/pubsub.py

Lines changed: 32 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -54,13 +54,29 @@ def __getattr__(self, name):
5454
else:
5555
return getattr(self._get_layer(), name)
5656

57+
def serialize(self, message):
58+
"""
59+
Serializes message to a byte string.
60+
"""
61+
return msgpack.packb(message)
62+
63+
def deserialize(self, message):
64+
"""
65+
Deserializes from a byte string.
66+
"""
67+
return msgpack.unpackb(message)
68+
5769
def _get_layer(self):
5870
loop = get_running_loop()
5971

6072
try:
6173
layer = self._layers[loop]
6274
except KeyError:
63-
layer = RedisPubSubLoopLayer(*self._args, **self._kwargs)
75+
layer = RedisPubSubLoopLayer(
76+
*self._args,
77+
**self._kwargs,
78+
channel_layer=self,
79+
)
6480
self._layers[loop] = layer
6581
_wrap_close(self, loop)
6682

@@ -80,7 +96,13 @@ class RedisPubSubLoopLayer:
8096
"""
8197

8298
def __init__(
83-
self, hosts=None, prefix="asgi", on_disconnect=None, on_reconnect=None, **kwargs
99+
self,
100+
hosts=None,
101+
prefix="asgi",
102+
on_disconnect=None,
103+
on_reconnect=None,
104+
channel_layer=None,
105+
**kwargs,
84106
):
85107
if hosts is None:
86108
hosts = [("localhost", 6379)]
@@ -92,6 +114,7 @@ def __init__(
92114

93115
self.on_disconnect = on_disconnect
94116
self.on_reconnect = on_reconnect
117+
self.channel_layer = channel_layer
95118

96119
# Each consumer gets its own *specific* channel, created with the `new_channel()` method.
97120
# This dict maps `channel_name` to a queue of messages for that channel.
@@ -133,7 +156,7 @@ async def send(self, channel, message):
133156
Send a message onto a (general or specific) channel.
134157
"""
135158
shard = self._get_shard(channel)
136-
await shard.publish(channel, message)
159+
await shard.publish(channel, self.channel_layer.serialize(message))
137160

138161
async def new_channel(self, prefix="specific."):
139162
"""
@@ -180,7 +203,7 @@ async def receive(self, channel):
180203
# We don't re-raise here because we want the CancelledError to be the one re-raised.
181204
raise
182205

183-
return msgpack.unpackb(message)
206+
return self.channel_layer.deserialize(message)
184207

185208
################################################################################
186209
# Groups extension
@@ -225,7 +248,7 @@ async def group_send(self, group, message):
225248
"""
226249
group_channel = self._get_group_channel_name(group)
227250
shard = self._get_shard(group_channel)
228-
await shard.publish(group_channel, message)
251+
await shard.publish(group_channel, self.channel_layer.serialize(message))
229252

230253
################################################################################
231254
# Flush extension
@@ -271,7 +294,7 @@ def __init__(self, host, channel_layer):
271294

272295
async def publish(self, channel, message):
273296
conn = await self._get_pub_conn()
274-
await conn.publish(channel, msgpack.packb(message))
297+
await conn.publish(channel, message)
275298

276299
async def subscribe(self, channel):
277300
if channel not in self._subscribed_to:
@@ -397,7 +420,9 @@ async def _do_receiving(self):
397420
def _notify_consumers(self, mtype):
398421
if mtype is not None:
399422
for channel in self.channel_layer.channels.values():
400-
channel.put_nowait(msgpack.packb({"type": mtype}))
423+
channel.put_nowait(
424+
self.channel_layer.channel_layer.serialize({"type": mtype})
425+
)
401426

402427
async def _ensure_redis(self):
403428
if self._redis is None:

tests/test_core.py

Lines changed: 23 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -641,3 +641,26 @@ def test_receive_buffer_respects_capacity():
641641
assert buff.qsize() == capacity
642642
messages = [buff.get_nowait() for _ in range(capacity)]
643643
assert list(range(9900, 10000)) == messages
644+
645+
646+
def test_serialize():
647+
"""
648+
Test default serialization method
649+
"""
650+
message = {"a": True, "b": None, "c": {"d": []}}
651+
channel_layer = RedisChannelLayer()
652+
serialized = channel_layer.serialize(message)
653+
assert isinstance(serialized, bytes)
654+
assert serialized[12:] == b"\x83\xa1a\xc3\xa1b\xc0\xa1c\x81\xa1d\x90"
655+
656+
657+
def test_deserialize():
658+
"""
659+
Test default deserialization method
660+
"""
661+
message = b"Q\x0c\xbb?Q\xbc\xe3|D\xfd9\x00\x83\xa1a\xc3\xa1b\xc0\xa1c\x81\xa1d\x90"
662+
channel_layer = RedisChannelLayer()
663+
deserialized = channel_layer.deserialize(message)
664+
665+
assert isinstance(deserialized, dict)
666+
assert deserialized == {"a": True, "b": None, "c": {"d": []}}

tests/test_pubsub.py

Lines changed: 28 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -129,6 +129,34 @@ async def test_random_reset__channel_name(channel_layer):
129129
assert channel_name_1 != channel_name_2
130130

131131

132+
@pytest.mark.asyncio
133+
async def test_loop_instance_channel_layer_reference(channel_layer):
134+
redis_pub_sub_loop_layer = channel_layer._get_layer()
135+
136+
assert redis_pub_sub_loop_layer.channel_layer == channel_layer
137+
138+
139+
def test_serialize(channel_layer):
140+
"""
141+
Test default serialization method
142+
"""
143+
message = {"a": True, "b": None, "c": {"d": []}}
144+
serialized = channel_layer.serialize(message)
145+
assert isinstance(serialized, bytes)
146+
assert serialized == b"\x83\xa1a\xc3\xa1b\xc0\xa1c\x81\xa1d\x90"
147+
148+
149+
def test_deserialize(channel_layer):
150+
"""
151+
Test default deserialization method
152+
"""
153+
message = b"\x83\xa1a\xc3\xa1b\xc0\xa1c\x81\xa1d\x90"
154+
deserialized = channel_layer.deserialize(message)
155+
156+
assert isinstance(deserialized, dict)
157+
assert deserialized == {"a": True, "b": None, "c": {"d": []}}
158+
159+
132160
def test_multi_event_loop_garbage_collection(channel_layer):
133161
"""
134162
Test loop closure layer flushing and garbage collection

tests/test_pubsub_sentinel.py

Lines changed: 28 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -130,6 +130,34 @@ async def test_random_reset__channel_name(channel_layer):
130130
assert channel_name_1 != channel_name_2
131131

132132

133+
@pytest.mark.asyncio
134+
async def test_loop_instance_channel_layer_reference(channel_layer):
135+
redis_pub_sub_loop_layer = channel_layer._get_layer()
136+
137+
assert redis_pub_sub_loop_layer.channel_layer == channel_layer
138+
139+
140+
def test_serialize(channel_layer):
141+
"""
142+
Test default serialization method
143+
"""
144+
message = {"a": True, "b": None, "c": {"d": []}}
145+
serialized = channel_layer.serialize(message)
146+
assert isinstance(serialized, bytes)
147+
assert serialized == b"\x83\xa1a\xc3\xa1b\xc0\xa1c\x81\xa1d\x90"
148+
149+
150+
def test_deserialize(channel_layer):
151+
"""
152+
Test default deserialization method
153+
"""
154+
message = b"\x83\xa1a\xc3\xa1b\xc0\xa1c\x81\xa1d\x90"
155+
deserialized = channel_layer.deserialize(message)
156+
157+
assert isinstance(deserialized, dict)
158+
assert deserialized == {"a": True, "b": None, "c": {"d": []}}
159+
160+
133161
def test_multi_event_loop_garbage_collection(channel_layer):
134162
"""
135163
Test loop closure layer flushing and garbage collection

tests/test_sentinel.py

Lines changed: 23 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -644,3 +644,26 @@ def test_receive_buffer_respects_capacity():
644644
assert buff.qsize() == capacity
645645
messages = [buff.get_nowait() for _ in range(capacity)]
646646
assert list(range(9900, 10000)) == messages
647+
648+
649+
def test_serialize():
650+
"""
651+
Test default serialization method
652+
"""
653+
message = {"a": True, "b": None, "c": {"d": []}}
654+
channel_layer = RedisChannelLayer()
655+
serialized = channel_layer.serialize(message)
656+
assert isinstance(serialized, bytes)
657+
assert serialized[12:] == b"\x83\xa1a\xc3\xa1b\xc0\xa1c\x81\xa1d\x90"
658+
659+
660+
def test_deserialize():
661+
"""
662+
Test default deserialization method
663+
"""
664+
message = b"Q\x0c\xbb?Q\xbc\xe3|D\xfd9\x00\x83\xa1a\xc3\xa1b\xc0\xa1c\x81\xa1d\x90"
665+
channel_layer = RedisChannelLayer()
666+
deserialized = channel_layer.deserialize(message)
667+
668+
assert isinstance(deserialized, dict)
669+
assert deserialized == {"a": True, "b": None, "c": {"d": []}}

0 commit comments

Comments
 (0)