Skip to content

Commit b210097

Browse files
authored
Add method for synchronous listening (#89)
* Sync publisher, using new psycopg features * Update more type hints
1 parent 11c18d9 commit b210097

File tree

5 files changed

+129
-14
lines changed

5 files changed

+129
-14
lines changed

dispatcher/brokers/base.py

Lines changed: 6 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1,14 +1,18 @@
1-
from typing import AsyncGenerator, Optional, Protocol
1+
from typing import Any, AsyncGenerator, Callable, Coroutine, Iterator, Optional, Protocol
22

33

44
class BaseBroker(Protocol):
5-
async def aprocess_notify(self, connected_callback=None) -> AsyncGenerator[tuple[str, str], None]:
5+
async def aprocess_notify(
6+
self, connected_callback: Optional[Optional[Callable[[], Coroutine[Any, Any, None]]]] = None
7+
) -> AsyncGenerator[tuple[str, str], None]:
68
yield ('', '') # yield affects CPython type https://github.com/python/mypy/pull/18422
79

810
async def apublish_message(self, channel: Optional[str] = None, message: str = '') -> None: ...
911

1012
async def aclose(self) -> None: ...
1113

14+
def process_notify(self, connected_callback: Optional[Callable] = None, timeout: float = 5.0, max_messages: int = 1) -> Iterator[tuple[str, str]]: ...
15+
1216
def publish_message(self, channel=None, message=None): ...
1317

1418
def close(self): ...

dispatcher/brokers/pg_notify.py

Lines changed: 34 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,5 @@
11
import logging
2-
from typing import AsyncGenerator, Callable, Optional, Union
2+
from typing import Any, AsyncGenerator, Callable, Coroutine, Iterator, Optional, Union
33

44
import psycopg
55

@@ -112,11 +112,21 @@ async def aget_connection(self) -> psycopg.AsyncConnection:
112112
return connection # slightly weird due to MyPY
113113
return self._async_connection
114114

115-
async def aprocess_notify(self, connected_callback: Optional[Callable] = None) -> AsyncGenerator[tuple[str, str], None]: # public
115+
def get_listen_query(self, channel: str) -> psycopg.sql.Composed:
116+
"""Returns SQL command for listening on pg_notify channel
117+
118+
This uses the psycopg utilities which ensure correct escaping so SQL injection is not possible.
119+
Return value is a valid argument for cursor.execute()
120+
"""
121+
return psycopg.sql.SQL("LISTEN {};").format(psycopg.sql.Identifier(channel))
122+
123+
async def aprocess_notify(
124+
self, connected_callback: Optional[Callable[[], Coroutine[Any, Any, None]]] = None
125+
) -> AsyncGenerator[tuple[str, str], None]: # public
116126
connection = await self.aget_connection()
117127
async with connection.cursor() as cur:
118128
for channel in self.channels:
119-
await cur.execute(f"LISTEN {channel};")
129+
await cur.execute(self.get_listen_query(channel))
120130
logger.info(f"Set up pg_notify listening on channel '{channel}'")
121131

122132
if connected_callback:
@@ -178,6 +188,27 @@ def get_connection(self) -> psycopg.Connection:
178188
return connection
179189
return self._sync_connection
180190

191+
def process_notify(self, connected_callback: Optional[Callable] = None, timeout: float = 5.0, max_messages: int = 1) -> Iterator[tuple[str, str]]:
192+
"""Blocking method that listens for messages on subscribed pg_notify channels until timeout
193+
194+
This has two different exit conditions:
195+
- received max_messages number of messages or more
196+
- taken longer than the specified timeout condition
197+
"""
198+
connection = self.get_connection()
199+
200+
with connection.cursor() as cur:
201+
for channel in self.channels:
202+
cur.execute(self.get_listen_query(channel))
203+
logger.info(f"Set up pg_notify listening on channel '{channel}'")
204+
205+
if connected_callback:
206+
connected_callback()
207+
208+
logger.debug('Starting listening for pg_notify notifications')
209+
for notify in connection.notifies(timeout=timeout, stop_after=max_messages):
210+
yield (notify.channel, notify.payload)
211+
181212
def publish_message(self, channel: Optional[str] = None, message: str = '') -> None:
182213
connection = self.get_connection()
183214
channel = self.get_publish_channel(channel)

dispatcher/control.py

Lines changed: 13 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -116,22 +116,27 @@ async def acontrol(self, command, data=None):
116116
producer = self.make_producer(Control.generate_reply_queue_name()) # reply queue not used
117117
await control_callbacks.connected_callback(producer)
118118

119-
def control_with_reply(self, command, expected_replies=1, timeout=1, data=None):
119+
def control_with_reply(self, command: str, expected_replies: int = 1, timeout: float = 1.0, data: Optional[dict] = None) -> list[dict]:
120120
logger.info('control-and-reply {} to {}'.format(command, self.queuename))
121121
start = time.time()
122122
reply_queue = Control.generate_reply_queue_name()
123123
send_data = {'control': command, 'reply_to': reply_queue}
124124
if data:
125125
send_data['control_data'] = data
126126

127-
producer = self.make_producer(reply_queue)
127+
broker = get_broker(self.broker_name, self.broker_config, channels=[reply_queue])
128128

129-
loop = asyncio.new_event_loop()
130-
try:
131-
replies = loop.run_until_complete(self.acontrol_with_reply_internal(producer, send_data, expected_replies, timeout))
132-
finally:
133-
loop.close()
134-
loop = None
129+
def connected_callback():
130+
payload = json.dumps(send_data)
131+
if self.queuename:
132+
broker.publish_message(channel=self.queuename, message=payload)
133+
else:
134+
broker.publish_message(message=payload)
135+
136+
replies = []
137+
for channel, payload in broker.process_notify(connected_callback=connected_callback, max_messages=expected_replies, timeout=timeout):
138+
reply_data = json.loads(payload)
139+
replies.append(reply_data)
135140

136141
logger.info(f'control-and-reply message returned in {time.time() - start} seconds')
137142
return replies

pyproject.toml

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -53,7 +53,7 @@ profile = "black"
5353
line_length = 160
5454

5555
[project.optional-dependencies]
56-
pg_notify = ["psycopg[binary]"]
56+
pg_notify = ["psycopg[binary]>=3.2.0"]
5757

5858
[tool.pytest.ini_options]
5959
log_cli_level = "DEBUG"

tests/integration/brokers/test_pg_notify.py

Lines changed: 75 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,3 +1,6 @@
1+
import time
2+
import multiprocessing
3+
14
import pytest
25

36
from dispatcher.brokers.pg_notify import Broker, create_connection, acreate_connection
@@ -16,6 +19,78 @@ def test_sync_connection_from_config_reuse(conn_config):
1619
assert conn is not create_connection(**conn_config)
1720

1821

22+
def test_sync_listen_timeout(conn_config):
23+
broker = Broker(config=conn_config)
24+
timeout_value = 0.05
25+
start = time.monotonic()
26+
assert list(broker.process_notify(timeout=timeout_value)) == []
27+
delta = time.monotonic() - start
28+
assert delta > timeout_value
29+
30+
31+
def _send_message(conn_config):
32+
broker = Broker(config=conn_config)
33+
if broker._sync_connection:
34+
broker._sync_connection.close()
35+
36+
broker.publish_message('test_sync_listen_receive', 'test_message')
37+
38+
39+
def test_sync_listen_receive(conn_config):
40+
messages = []
41+
with multiprocessing.Pool(processes=1) as pool:
42+
def send_from_subprocess():
43+
pool.apply(_send_message, args=(conn_config,))
44+
45+
broker = Broker(config=conn_config, channels=('test_sync_listen_receive',))
46+
timeout_value = 2.0
47+
start = time.monotonic()
48+
for channel, message in broker.process_notify(connected_callback=send_from_subprocess, timeout=timeout_value):
49+
messages.append(message)
50+
delta = time.monotonic() - start
51+
52+
assert messages == ['test_message']
53+
assert delta < timeout_value
54+
55+
56+
def test_sync_listen_receive_multi_message(conn_config):
57+
"""Tests that the expected messages exit condition works, we get 3 messages, not just 1"""
58+
messages = []
59+
with multiprocessing.Pool(processes=1) as pool:
60+
def send_from_subprocess():
61+
pool.apply(_send_message, args=(conn_config,))
62+
pool.apply(_send_message, args=(conn_config,))
63+
pool.apply(_send_message, args=(conn_config,))
64+
65+
broker = Broker(config=conn_config, channels=('test_sync_listen_receive',))
66+
timeout_value = 2.0
67+
start = time.monotonic()
68+
for channel, message in broker.process_notify(connected_callback=send_from_subprocess, max_messages=3):
69+
messages.append(message)
70+
delta = time.monotonic() - start
71+
72+
assert messages == ['test_message' for i in range(3)]
73+
assert delta < timeout_value
74+
75+
76+
def test_get_message_then_timeout(conn_config):
77+
"""Tests that the expected messages exit condition works, we get 3 messages, not just 1"""
78+
messages = []
79+
with multiprocessing.Pool(processes=1) as pool:
80+
def send_from_subprocess():
81+
pool.apply(_send_message, args=(conn_config,))
82+
83+
broker = Broker(config=conn_config, channels=('test_sync_listen_receive',))
84+
timeout_value = 0.5
85+
start = time.monotonic()
86+
for channel, message in broker.process_notify(connected_callback=send_from_subprocess, timeout=timeout_value, max_messages=2):
87+
messages.append(message)
88+
delta = time.monotonic() - start
89+
90+
assert messages == ['test_message']
91+
assert delta > timeout_value # goes until timeout
92+
93+
1994
@pytest.mark.asyncio
2095
async def test_async_connection_from_config_reuse(conn_config):
2196
broker = Broker(config=conn_config)

0 commit comments

Comments
 (0)