Skip to content

Commit 752c0c7

Browse files
authored
Separate process management out of pool (#70)
* Introduce separate process manager * Separate process management out of pool * Add some simple tests for interacting with subprocesses
1 parent 89993f9 commit 752c0c7

File tree

3 files changed

+97
-19
lines changed

3 files changed

+97
-19
lines changed

dispatcher/pool.py

Lines changed: 12 additions & 19 deletions
Original file line numberDiff line numberDiff line change
@@ -1,26 +1,19 @@
11
import asyncio
22
import logging
3-
import multiprocessing
4-
import os
5-
import signal
63
import time
74
from asyncio import Task
85
from typing import Iterator, Optional
96

7+
from dispatcher.process import ProcessManager, ProcessProxy
108
from dispatcher.utils import DuplicateBehavior, MessageAction
11-
from dispatcher.worker.task import work_loop
129

1310
logger = logging.getLogger(__name__)
1411

1512

1613
class PoolWorker:
17-
def __init__(self, worker_id: int, finished_queue: multiprocessing.Queue):
14+
def __init__(self, worker_id: int, process: ProcessProxy) -> None:
1815
self.worker_id = worker_id
19-
# TODO: rename message_queue to call_queue, because this is what cpython ProcessPoolExecutor calls them
20-
self.message_queue: multiprocessing.Queue = multiprocessing.Queue()
21-
self.process = multiprocessing.Process(target=work_loop, args=(self.worker_id, self.message_queue, finished_queue))
22-
23-
# Info specific to the current task being ran
16+
self.process = process
2417
self.current_task: Optional[dict] = None
2518
self.started_at: Optional[int] = None
2619
self.is_active_cancel: bool = False
@@ -38,15 +31,15 @@ async def start(self) -> None:
3831

3932
async def start_task(self, message: dict) -> None:
4033
self.current_task = message # NOTE: this marks this worker as busy
41-
self.message_queue.put(message)
34+
self.process.message_queue.put(message)
4235
self.started_at = time.monotonic_ns()
4336

4437
async def join(self) -> None:
4538
logger.debug(f'Joining worker {self.worker_id} pid={self.process.pid} subprocess')
4639
self.process.join()
4740

4841
async def stop(self) -> None:
49-
self.message_queue.put("stop")
42+
self.process.message_queue.put("stop")
5043
if self.current_task:
5144
uuid = self.current_task.get('uuid', '<unknown>')
5245
logger.warning(f'Worker {self.worker_id} is currently running task (uuid={uuid}), canceling for shutdown')
@@ -105,7 +98,7 @@ def __init__(self, num_workers: int, fd_lock: Optional[asyncio.Lock] = None):
10598
self.num_workers = num_workers
10699
self.workers: dict[int, PoolWorker] = {}
107100
self.next_worker_id = 0
108-
self.finished_queue: multiprocessing.Queue = multiprocessing.Queue()
101+
self.process_manager = ProcessManager()
109102
self.queued_messages: list[dict] = [] # TODO: use deque, invent new kinds of logging anxiety
110103
self.read_results_task: Optional[Task] = None
111104
self.start_worker_task: Optional[Task] = None
@@ -193,7 +186,8 @@ async def manage_timeout(self) -> None:
193186
self.events.timeout_event.clear()
194187

195188
async def up(self) -> None:
196-
worker = PoolWorker(worker_id=self.next_worker_id, finished_queue=self.finished_queue)
189+
process = self.process_manager.create_process((self.next_worker_id,))
190+
worker = PoolWorker(self.next_worker_id, process)
197191
self.workers[self.next_worker_id] = worker
198192
self.next_worker_id += 1
199193

@@ -205,7 +199,7 @@ async def force_shutdown(self) -> None:
205199
for worker in self.workers.values():
206200
if worker.process.pid and worker.process.is_alive():
207201
logger.warning(f'Force killing worker {worker.worker_id} pid={worker.process.pid}')
208-
os.kill(worker.process.pid, signal.SIGKILL)
202+
worker.process.kill()
209203

210204
if self.read_results_task:
211205
self.read_results_task.cancel()
@@ -220,7 +214,7 @@ async def shutdown(self) -> None:
220214
self.events.management_event.set()
221215
self.events.timeout_event.set()
222216
await self.stop_workers()
223-
self.finished_queue.put('stop')
217+
self.process_manager.finished_queue.put('stop')
224218

225219
if self.read_results_task:
226220
logger.info('Waiting for the finished watcher to return')
@@ -382,10 +376,9 @@ async def process_finished(self, worker, message) -> None:
382376

383377
async def read_results_forever(self) -> None:
384378
"""Perpetual task that continuously waits for task completions."""
385-
loop = asyncio.get_event_loop()
386379
while True:
387380
# Wait for a result from the finished queue
388-
message = await loop.run_in_executor(None, self.finished_queue.get)
381+
message = await self.process_manager.read_finished()
389382

390383
if message == 'stop':
391384
if self.shutting_down:
@@ -396,7 +389,7 @@ async def read_results_forever(self) -> None:
396389
logger.error('Results queue got stop message even through not shutting down')
397390
continue
398391

399-
worker_id = message["worker"]
392+
worker_id = int(message["worker"])
400393
event = message["event"]
401394
worker = self.workers[worker_id]
402395

dispatcher/process.py

Lines changed: 54 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,54 @@
1+
import asyncio
2+
import multiprocessing
3+
from typing import Callable, Iterable, Optional, Union
4+
5+
from dispatcher.worker.task import work_loop
6+
7+
8+
class ProcessProxy:
9+
def __init__(self, args: Iterable, finished_queue: multiprocessing.Queue, target: Callable = work_loop) -> None:
10+
self.message_queue: multiprocessing.Queue = multiprocessing.Queue()
11+
self._process = multiprocessing.Process(target=target, args=tuple(args) + (self.message_queue, finished_queue))
12+
13+
def start(self) -> None:
14+
self._process.start()
15+
16+
def join(self, timeout: Optional[int] = None) -> None:
17+
if timeout:
18+
self._process.join(timeout=timeout)
19+
else:
20+
self._process.join()
21+
22+
@property
23+
def pid(self) -> Optional[int]:
24+
return self._process.pid
25+
26+
def exitcode(self) -> Optional[int]:
27+
return self._process.exitcode
28+
29+
def is_alive(self) -> bool:
30+
return self._process.is_alive()
31+
32+
def kill(self) -> None:
33+
self._process.kill()
34+
35+
def terminate(self) -> None:
36+
self._process.terminate()
37+
38+
39+
class ProcessManager:
40+
def __init__(self) -> None:
41+
self.finished_queue: multiprocessing.Queue = multiprocessing.Queue()
42+
self._loop = None
43+
44+
def get_event_loop(self):
45+
if not self._loop:
46+
self._loop = asyncio.get_event_loop()
47+
return self._loop
48+
49+
def create_process(self, args: Iterable[int | str], **kwargs) -> ProcessProxy:
50+
return ProcessProxy(args, self.finished_queue, **kwargs)
51+
52+
async def read_finished(self) -> dict[str, Union[str, int]]:
53+
message = await self.get_event_loop().run_in_executor(None, self.finished_queue.get)
54+
return message

tests/unit/service/test_process.py

Lines changed: 31 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,31 @@
1+
from multiprocessing import Queue
2+
3+
from dispatcher.process import ProcessManager, ProcessProxy
4+
5+
6+
def test_pass_messages_to_worker():
7+
def work_loop(a, b, c, in_q, out_q):
8+
has_read = in_q.get()
9+
out_q.put(f'done {a} {b} {c} {has_read}')
10+
11+
finished_q = Queue()
12+
process = ProcessProxy((1, 2, 3), finished_q, target=work_loop)
13+
process.start()
14+
15+
process.message_queue.put('start')
16+
msg = finished_q.get()
17+
assert msg == 'done 1 2 3 start'
18+
19+
20+
def test_pass_messages_via_process_manager():
21+
def work_loop(var, in_q, out_q):
22+
has_read = in_q.get()
23+
out_q.put(f'done {var} {has_read}')
24+
25+
process_manager = ProcessManager()
26+
process = process_manager.create_process(('value',), target=work_loop)
27+
process.start()
28+
29+
process.message_queue.put('msg1')
30+
msg = process_manager.finished_queue.get()
31+
assert msg == 'done value msg1'

0 commit comments

Comments
 (0)