Skip to content

Commit 11c18d9

Browse files
authored
Send reply messages from the main process (#88)
* Send reply messages directly * Always forever autocommit * Clean up SQL * Log cleanup * Remove tasks file no longer needed * Add type hinting and rename from review
1 parent 958ff20 commit 11c18d9

File tree

8 files changed

+64
-51
lines changed

8 files changed

+64
-51
lines changed

dispatcher/brokers/pg_notify.py

Lines changed: 32 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -18,14 +18,21 @@
1818

1919
async def acreate_connection(**config) -> psycopg.AsyncConnection:
2020
"Create a new asyncio connection"
21-
return await psycopg.AsyncConnection.connect(**config)
21+
connection = await psycopg.AsyncConnection.connect(**config)
22+
if not connection.autocommit:
23+
await connection.set_autocommit(True)
24+
return connection
2225

2326

2427
def create_connection(**config) -> psycopg.Connection:
25-
return psycopg.Connection.connect(**config)
28+
connection = psycopg.Connection.connect(**config)
29+
if not connection.autocommit:
30+
connection.set_autocommit(True)
31+
return connection
2632

2733

2834
class Broker:
35+
NOTIFY_QUERY_TEMPLATE = 'SELECT pg_notify(%s, %s);'
2936

3037
def __init__(
3138
self,
@@ -63,13 +70,18 @@ def __init__(
6370

6471
if config:
6572
self._config: dict = config.copy()
66-
self._config['autocommit'] = True
6773
else:
6874
self._config = {}
6975

7076
self.channels = channels
7177
self.default_publish_channel = default_publish_channel
7278

79+
# If we are in the notification loop (receiving messages),
80+
# then we have to break out before sending messages
81+
# These variables track things so that we can exit, send, and re-enter
82+
self.notify_loop_active: bool = False
83+
self.notify_queue: list = []
84+
7385
def get_publish_channel(self, channel: Optional[str] = None) -> str:
7486
"Handle default for the publishing channel for calls to publish_message, shared sync and async"
7587
if channel is not None:
@@ -112,23 +124,35 @@ async def aprocess_notify(self, connected_callback: Optional[Callable] = None) -
112124

113125
while True:
114126
logger.debug('Starting listening for pg_notify notifications')
127+
self.notify_loop_active = True
115128
async for notify in connection.notifies():
116129
yield notify.channel, notify.payload
130+
if self.notify_queue:
131+
break
132+
self.notify_loop_active = False
133+
for reply_to, reply_message in self.notify_queue:
134+
await self.apublish_message_from_cursor(cur, channel=reply_to, message=reply_message)
135+
self.notify_queue = []
136+
137+
async def apublish_message_from_cursor(self, cursor: psycopg.AsyncCursor, channel: Optional[str] = None, message: str = '') -> None:
138+
"""The inner logic of async message publishing where we already have a cursor"""
139+
await cursor.execute(self.NOTIFY_QUERY_TEMPLATE, (channel, message))
117140

118141
async def apublish_message(self, channel: Optional[str] = None, message: str = '') -> None: # public
119142
"""asyncio way to publish a message, used to send control in control-and-reply
120143
121144
Not strictly necessary for the service itself if it sends replies in the workers,
122145
but this may change in the future.
123146
"""
147+
if self.notify_loop_active:
148+
self.notify_queue.append((channel, message))
149+
return
150+
124151
connection = await self.aget_connection()
125152
channel = self.get_publish_channel(channel)
126153

127154
async with connection.cursor() as cur:
128-
if not message:
129-
await cur.execute(f'NOTIFY {channel};')
130-
else:
131-
await cur.execute(f"NOTIFY {channel}, '{message}';")
155+
await self.apublish_message_from_cursor(cur, channel=channel, message=message)
132156

133157
logger.debug(f'Sent pg_notify message of {len(message)} chars to {channel}')
134158

@@ -159,10 +183,7 @@ def publish_message(self, channel: Optional[str] = None, message: str = '') -> N
159183
channel = self.get_publish_channel(channel)
160184

161185
with connection.cursor() as cur:
162-
if message:
163-
cur.execute('SELECT pg_notify(%s, %s);', (channel, message))
164-
else:
165-
cur.execute(f'NOTIFY {channel};')
186+
cur.execute(self.NOTIFY_QUERY_TEMPLATE, (channel, message))
166187

167188
logger.debug(f'Sent pg_notify message of {len(message)} chars to {channel}')
168189

@@ -189,7 +210,6 @@ def connection_saver(**config) -> psycopg.Connection:
189210
Dispatcher does not manage connections, so this a simulation of that.
190211
"""
191212
if connection_save._connection is None:
192-
config['autocommit'] = True
193213
connection_save._connection = create_connection(**config)
194214
return connection_save._connection
195215

@@ -202,6 +222,5 @@ async def async_connection_saver(**config) -> psycopg.AsyncConnection:
202222
Dispatcher does not manage connections, so this a simulation of that.
203223
"""
204224
if connection_save._async_connection is None:
205-
config['autocommit'] = True
206225
connection_save._async_connection = await acreate_connection(**config)
207226
return connection_save._async_connection

dispatcher/control.py

Lines changed: 7 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -22,26 +22,29 @@ class ControlCallbacks:
2222
it exists to interact with producers, using variables relevant to the particular
2323
control message being sent"""
2424

25-
def __init__(self, queuename, send_data, expected_replies):
25+
def __init__(self, queuename, send_data, expected_replies) -> None:
2626
self.queuename = queuename
2727
self.send_data = send_data
2828
self.expected_replies = expected_replies
2929

30-
self.received_replies = []
30+
# received_replies only tracks the reply message, not the channel name
31+
# because they come via a temporary reply_to channel and that is not user-facing
32+
self.received_replies: list[str] = []
3133
self.events = ControlEvents()
3234
self.shutting_down = False
3335

34-
async def process_message(self, payload, producer=None, channel=None):
36+
async def process_message(self, payload, producer=None, channel=None) -> tuple[Optional[str], Optional[str]]:
3537
self.received_replies.append(payload)
3638
if self.expected_replies and (len(self.received_replies) >= self.expected_replies):
3739
self.events.exit_event.set()
40+
return (None, None)
3841

3942
async def connected_callback(self, producer) -> None:
4043
payload = json.dumps(self.send_data)
4144
await producer.notify(channel=self.queuename, message=payload)
4245
logger.info('Sent control message, expecting replies soon')
4346

44-
def fatal_error_callback(self, *args):
47+
def fatal_error_callback(self, *args) -> None:
4548
if self.shutting_down:
4649
return
4750

dispatcher/producers/brokered.py

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -36,7 +36,9 @@ async def produce_forever(self, dispatcher) -> None:
3636
self.dispatcher = dispatcher
3737
async for channel, payload in self.broker.aprocess_notify(connected_callback=self.connected_callback):
3838
self.produced_count += 1
39-
await dispatcher.process_message(payload, producer=self, channel=channel)
39+
reply_to, reply_payload = await dispatcher.process_message(payload, producer=self, channel=channel)
40+
if reply_to:
41+
await self.notify(channel=reply_to, message=reply_payload)
4042

4143
async def notify(self, channel: Optional[str] = None, message: str = '') -> None:
4244
await self.broker.apublish_message(channel=channel, message=message)

dispatcher/service/main.py

Lines changed: 18 additions & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -170,8 +170,16 @@ def create_delayed_task(self, message: dict) -> None:
170170
capsule.task = new_task
171171
self.delayed_messages.append(capsule)
172172

173-
async def process_message(self, payload: dict, producer: Optional[BaseProducer] = None, channel: Optional[str] = None) -> None:
174-
# Convert payload from client into python dict
173+
async def process_message(
174+
self, payload: dict, producer: Optional[BaseProducer] = None, channel: Optional[str] = None
175+
) -> tuple[Optional[str], Optional[str]]:
176+
"""Called by producers to trigger a new task
177+
178+
Convert payload from producer into python dict
179+
Process uuid default
180+
Delay tasks when applicable
181+
Send to next layer of internal processing
182+
"""
175183
# TODO: more structured validation of the incoming payload from publishers
176184
if isinstance(payload, str):
177185
try:
@@ -182,7 +190,7 @@ async def process_message(self, payload: dict, producer: Optional[BaseProducer]
182190
message = payload
183191
else:
184192
logger.error(f'Received unprocessable type {type(payload)}')
185-
return
193+
return (None, None)
186194

187195
# A client may provide a task uuid (hope they do it correctly), if not add it
188196
if 'uuid' not in message:
@@ -195,28 +203,24 @@ async def process_message(self, payload: dict, producer: Optional[BaseProducer]
195203
# NOTE: control messages with reply should never be delayed, document this for users
196204
self.create_delayed_task(message)
197205
else:
198-
await self.process_message_internal(message, producer=producer)
206+
return await self.process_message_internal(message, producer=producer)
207+
return (None, None)
199208

200-
async def process_message_internal(self, message: dict, producer=None) -> None:
209+
async def process_message_internal(self, message: dict, producer=None) -> tuple[Optional[str], Optional[str]]:
210+
"""Route message based on needed action - delay for later, return reply, or dispatch to worker"""
201211
if 'control' in message:
202212
method = getattr(self.ctl_tasks, message['control'])
203213
control_data = message.get('control_data', {})
204214
returned = await method(self, **control_data)
205215
if 'reply_to' in message:
206-
logger.info(f"Control action {message['control']} returned {returned}, sending via worker")
216+
logger.info(f"Control action {message['control']} returned {returned}, sending back reply")
207217
self.control_count += 1
208-
await self.pool.dispatch_task(
209-
{
210-
'task': 'dispatcher.service.tasks.reply_to_control',
211-
'args': [message['reply_to'], json.dumps(returned)],
212-
'uuid': f'control-{self.control_count}',
213-
'control': 'reply', # for record keeping
214-
}
215-
)
218+
return (message['reply_to'], json.dumps(returned))
216219
else:
217220
logger.info(f"Control action {message['control']} returned {returned}, done")
218221
else:
219222
await self.pool.dispatch_task(message)
223+
return (None, None)
220224

221225
async def start_working(self) -> None:
222226
logger.debug('Filling the worker pool')

dispatcher/service/pool.py

Lines changed: 0 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -107,7 +107,6 @@ def __init__(self, max_workers: int, process_manager: ProcessManager, settings:
107107
self.start_worker_task: Optional[Task] = None
108108
self.shutting_down = False
109109
self.finished_count: int = 0
110-
self.control_count: int = 0
111110
self.canceled_count: int = 0
112111
self.discard_count: int = 0
113112
self.shutdown_timeout = 3
@@ -369,8 +368,6 @@ async def process_finished(self, worker, message) -> None:
369368
async with self.management_lock:
370369
if worker.is_active_cancel and result == '<cancel>':
371370
self.canceled_count += 1
372-
elif 'control' in worker.current_task:
373-
self.control_count += 1
374371
else:
375372
self.finished_count += 1
376373
worker.mark_finished_task()

dispatcher/service/tasks.py

Lines changed: 0 additions & 8 deletions
This file was deleted.

tests/integration/test_main.py

Lines changed: 2 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -100,7 +100,7 @@ async def test_cancel_task(apg_dispatcher, pg_message, pg_control):
100100
await asyncio.wait_for(clearing_task, timeout=3)
101101

102102
pool = apg_dispatcher.pool
103-
assert [pool.finished_count, pool.canceled_count, pool.control_count] == [0, 1, 1], 'cts: [finished, canceled, control]'
103+
assert [pool.finished_count, pool.canceled_count, apg_dispatcher.control_count] == [0, 1, 1], 'cts: [finished, canceled, control]'
104104

105105

106106
@pytest.mark.asyncio
@@ -117,13 +117,10 @@ async def test_message_with_delay(apg_dispatcher, pg_message, pg_control):
117117
assert running_job['uuid'] == 'delay_task'
118118
await asyncio.wait_for(apg_dispatcher.pool.events.work_cleared.wait(), timeout=3)
119119
pool = apg_dispatcher.pool
120-
assert [pool.finished_count, pool.canceled_count, pool.control_count] == [0, 0, 1], 'cts: [finished, canceled, control]'
121-
# Completing the reply itself will be a work_cleared event, so we have to clear the event
122-
apg_dispatcher.pool.events.work_cleared.clear()
123120

124121
# Wait for task to finish, assertions after completion
125122
await asyncio.wait_for(apg_dispatcher.pool.events.work_cleared.wait(), timeout=3)
126-
assert [pool.finished_count, pool.canceled_count, pool.control_count] == [1, 0, 1], 'cts: [finished, canceled, control]'
123+
assert [pool.finished_count, pool.canceled_count, apg_dispatcher.control_count] == [1, 0, 1], 'cts: [finished, canceled, control]'
127124

128125

129126
@pytest.mark.asyncio

tools/write_messages.py

Lines changed: 2 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -4,8 +4,7 @@
44
import os
55
import sys
66

7-
from dispatcher.factories import get_publisher_from_settings
8-
from dispatcher.control import Control
7+
from dispatcher.factories import get_publisher_from_settings, get_control_from_settings
98
from dispatcher.utils import MODULE_METHOD_DELIMITER
109
from dispatcher.config import setup
1110

@@ -49,7 +48,7 @@ def main():
4948
print('performing a task cancel')
5049
# submit a task we will "find" two different ways
5150
broker.publish_message(message=json.dumps({'task': 'lambda: __import__("time").sleep(3.1415)', 'uuid': 'foobar'}))
52-
ctl = Control('test_channel')
51+
ctl = get_control_from_settings()
5352
canceled_jobs = ctl.control_with_reply('cancel', data={'uuid': 'foobar'})
5453
print(json.dumps(canceled_jobs, indent=2))
5554

0 commit comments

Comments
 (0)