Skip to content

Commit 32dd102

Browse files
committed
Allow using forkserver
1 parent 752c0c7 commit 32dd102

File tree

3 files changed

+35
-15
lines changed

3 files changed

+35
-15
lines changed

dispatcher/pool.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -4,7 +4,7 @@
44
from asyncio import Task
55
from typing import Iterator, Optional
66

7-
from dispatcher.process import ProcessManager, ProcessProxy
7+
from dispatcher.process import ForkServerManager, ProcessProxy
88
from dispatcher.utils import DuplicateBehavior, MessageAction
99

1010
logger = logging.getLogger(__name__)
@@ -98,7 +98,7 @@ def __init__(self, num_workers: int, fd_lock: Optional[asyncio.Lock] = None):
9898
self.num_workers = num_workers
9999
self.workers: dict[int, PoolWorker] = {}
100100
self.next_worker_id = 0
101-
self.process_manager = ProcessManager()
101+
self.process_manager = ForkServerManager()
102102
self.queued_messages: list[dict] = [] # TODO: use deque, invent new kinds of logging anxiety
103103
self.read_results_task: Optional[Task] = None
104104
self.start_worker_task: Optional[Task] = None

dispatcher/process.py

Lines changed: 18 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -1,14 +1,15 @@
11
import asyncio
22
import multiprocessing
3-
from typing import Callable, Iterable, Optional, Union
3+
from multiprocessing.context import BaseContext
4+
from typing import Callable, Iterable, Optional, Union, Sized
45

56
from dispatcher.worker.task import work_loop
67

78

89
class ProcessProxy:
9-
def __init__(self, args: Iterable, finished_queue: multiprocessing.Queue, target: Callable = work_loop) -> None:
10-
self.message_queue: multiprocessing.Queue = multiprocessing.Queue()
11-
self._process = multiprocessing.Process(target=target, args=tuple(args) + (self.message_queue, finished_queue))
10+
def __init__(self, args: Iterable, finished_queue: multiprocessing.Queue, target: Callable = work_loop, ctx: BaseContext = multiprocessing) -> None:
11+
self.message_queue: multiprocessing.Queue = ctx.Queue()
12+
self._process = ctx.Process(target=target, args=tuple(args) + (self.message_queue, finished_queue))
1213

1314
def start(self) -> None:
1415
self._process.start()
@@ -37,8 +38,11 @@ def terminate(self) -> None:
3738

3839

3940
class ProcessManager:
41+
mp_context = 'fork'
42+
4043
def __init__(self) -> None:
41-
self.finished_queue: multiprocessing.Queue = multiprocessing.Queue()
44+
self.ctx = multiprocessing.get_context(self.mp_context)
45+
self.finished_queue: multiprocessing.Queue = self.ctx.Queue()
4246
self._loop = None
4347

4448
def get_event_loop(self):
@@ -47,8 +51,16 @@ def get_event_loop(self):
4751
return self._loop
4852

4953
def create_process(self, args: Iterable[int | str], **kwargs) -> ProcessProxy:
50-
return ProcessProxy(args, self.finished_queue, **kwargs)
54+
return ProcessProxy(args, self.finished_queue, ctx=self.ctx, **kwargs)
5155

5256
async def read_finished(self) -> dict[str, Union[str, int]]:
5357
message = await self.get_event_loop().run_in_executor(None, self.finished_queue.get)
5458
return message
59+
60+
61+
class ForkServerManager(ProcessManager):
62+
mp_context = 'forkserver'
63+
64+
def __init__(self, preload_modules: Sized = ()):
65+
super().__init__()
66+
self.ctx.set_forkserver_preload(preload_modules)

tests/unit/service/test_process.py

Lines changed: 15 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,8 @@
11
from multiprocessing import Queue
22

3-
from dispatcher.process import ProcessManager, ProcessProxy
3+
import pytest
4+
5+
from dispatcher.process import ProcessManager, ForkServerManager, ProcessProxy
46

57

68
def test_pass_messages_to_worker():
@@ -17,13 +19,19 @@ def work_loop(a, b, c, in_q, out_q):
1719
assert msg == 'done 1 2 3 start'
1820

1921

20-
def test_pass_messages_via_process_manager():
21-
def work_loop(var, in_q, out_q):
22-
has_read = in_q.get()
23-
out_q.put(f'done {var} {has_read}')
22+
def work_loop2(var, in_q, out_q):
23+
"""
24+
Due to the mechanics of forkserver, this can not be defined in local variables,
25+
it has to be importable, but this _is_ importable from the test module.
26+
"""
27+
has_read = in_q.get()
28+
out_q.put(f'done {var} {has_read}')
29+
2430

25-
process_manager = ProcessManager()
26-
process = process_manager.create_process(('value',), target=work_loop)
31+
@pytest.mark.parametrize('manager_cls', [ProcessManager, ForkServerManager])
32+
def test_pass_messages_via_process_manager(manager_cls):
33+
process_manager = manager_cls()
34+
process = process_manager.create_process(('value',), target=work_loop2)
2735
process.start()
2836

2937
process.message_queue.put('msg1')

0 commit comments

Comments
 (0)