Skip to content
This repository was archived by the owner on Feb 21, 2023. It is now read-only.

Commit 659a14d

Browse files
authored
Merge pull request #1284 from stephanm/async_callback
Support for async callbacks
2 parents 8a38f98 + 359a741 commit 659a14d

File tree

4 files changed

+32
-5
lines changed

4 files changed

+32
-5
lines changed

CHANGES/1284.feature

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1 @@
1+
Using coroutines as callbacks is now possible.

CONTRIBUTORS.txt

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -63,6 +63,7 @@ Sean Stewart <seandstewart>
6363
Sergey Miletskiy
6464
SeungHyun Hwang
6565
Stanislau Arkhipenka
66+
Stephan Meier
6667
Taku Fukada
6768
Taras Voinarovskyi
6869
Thanos Lefteris

aioredis/client.py

Lines changed: 8 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -74,7 +74,7 @@
7474
AnyKeyT = TypeVar("AnyKeyT", bytes, str, memoryview)
7575
AnyFieldT = TypeVar("AnyFieldT", bytes, str, memoryview)
7676
AnyChannelT = ChannelT
77-
PubSubHandler = Callable[[Dict[str, str]], None]
77+
PubSubHandler = Callable[[Dict[str, str]], Awaitable[None]]
7878

7979
SYM_EMPTY = b""
8080
EMPTY_RESPONSE = "EMPTY_RESPONSE"
@@ -4176,7 +4176,7 @@ def unsubscribe(self, *args) -> Awaitable:
41764176
async def listen(self) -> AsyncIterator:
41774177
"""Listen for messages on channels this client has been subscribed to"""
41784178
while self.subscribed:
4179-
response = self.handle_message(await self.parse_response(block=True))
4179+
response = await self.handle_message(await self.parse_response(block=True))
41804180
if response is not None:
41814181
yield response
41824182

@@ -4192,7 +4192,7 @@ async def get_message(
41924192
"""
41934193
response = await self.parse_response(block=False, timeout=timeout)
41944194
if response:
4195-
return self.handle_message(response, ignore_subscribe_messages)
4195+
return await self.handle_message(response, ignore_subscribe_messages)
41964196
return None
41974197

41984198
def ping(self, message=None) -> Awaitable:
@@ -4202,7 +4202,7 @@ def ping(self, message=None) -> Awaitable:
42024202
message = "" if message is None else message
42034203
return self.execute_command("PING", message)
42044204

4205-
def handle_message(self, response, ignore_subscribe_messages=False):
4205+
async def handle_message(self, response, ignore_subscribe_messages=False):
42064206
"""
42074207
Parses a pub/sub message. If the channel or pattern was subscribed to
42084208
with a message handler, the handler is invoked instead of a parsed
@@ -4251,7 +4251,10 @@ def handle_message(self, response, ignore_subscribe_messages=False):
42514251
else:
42524252
handler = self.channels.get(message["channel"], None)
42534253
if handler:
4254-
handler(message)
4254+
if inspect.iscoroutinefunction(handler):
4255+
await handler(message)
4256+
else:
4257+
handler(message)
42554258
return None
42564259
elif message_type != "pong":
42574260
# this is a subscribe/unsubscribe message. ignore if we don't

tests/test_pubsub.py

Lines changed: 22 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -270,6 +270,9 @@ def setup_method(self, method):
270270
def message_handler(self, message):
271271
self.message = message
272272

273+
async def async_message_handler(self, message):
274+
self.async_message = message
275+
273276
async def test_published_message_to_channel(self, r):
274277
p = r.pubsub()
275278
await p.subscribe("foo")
@@ -311,6 +314,25 @@ async def test_channel_message_handler(self, r):
311314
assert await wait_for_message(p) is None
312315
assert self.message == make_message("message", "foo", "test message")
313316

317+
async def test_channel_async_message_handler(self, r):
318+
p = r.pubsub(ignore_subscribe_messages=True)
319+
await p.subscribe(foo=self.async_message_handler)
320+
assert await wait_for_message(p) is None
321+
assert await r.publish("foo", "test message") == 1
322+
assert await wait_for_message(p) is None
323+
assert self.async_message == make_message("message", "foo", "test message")
324+
325+
async def test_channel_sync_async_message_handler(self, r):
326+
p = r.pubsub(ignore_subscribe_messages=True)
327+
await p.subscribe(foo=self.message_handler)
328+
await p.subscribe(bar=self.async_message_handler)
329+
assert await wait_for_message(p) is None
330+
assert await r.publish("foo", "test message") == 1
331+
assert await r.publish("bar", "test message 2") == 1
332+
assert await wait_for_message(p) is None
333+
assert self.message == make_message("message", "foo", "test message")
334+
assert self.async_message == make_message("message", "bar", "test message 2")
335+
314336
async def test_pattern_message_handler(self, r):
315337
p = r.pubsub(ignore_subscribe_messages=True)
316338
await p.psubscribe(**{"f*": self.message_handler})

0 commit comments

Comments
 (0)