Skip to content

Commit 0d0e309

Browse files
authored
Allow using forkserver (#78)
* Allow using forkserver Run main tests with forkserver Add tests about pid accuracy Get tests mostly working Run linters and close connections Wait for workers to be ready before starting test Update schema Python 3.10 compat * Fix rebase logic conflict with folder moves
1 parent 9ece818 commit 0d0e309

File tree

8 files changed

+126
-40
lines changed

8 files changed

+126
-40
lines changed

dispatcher/factories.py

Lines changed: 14 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -1,16 +1,16 @@
11
import inspect
22
from copy import deepcopy
3-
from typing import Iterable, Optional, Type, get_args, get_origin
3+
from typing import Iterable, Literal, Optional, Type, get_args, get_origin
44

55
from . import producers
66
from .brokers import get_broker
77
from .brokers.base import BaseBroker
88
from .config import LazySettings
99
from .config import settings as global_settings
1010
from .control import Control
11+
from .service import process
1112
from .service.main import DispatcherMain
1213
from .service.pool import WorkerPool
13-
from .service.process import ProcessManager
1414

1515
"""
1616
Creates objects from settings,
@@ -21,10 +21,16 @@
2121
# ---- Service objects ----
2222

2323

24+
def process_manager_from_settings(settings: LazySettings = global_settings):
25+
cls_name = settings.service.get('process_manager_cls', 'ForkServerManager')
26+
process_manager_cls = getattr(process, cls_name)
27+
return process_manager_cls()
28+
29+
2430
def pool_from_settings(settings: LazySettings = global_settings):
2531
kwargs = settings.service.get('pool_kwargs', {}).copy()
2632
kwargs['settings'] = settings
27-
kwargs['process_manager'] = ProcessManager() # TODO: use process_manager_cls from settings
33+
kwargs['process_manager'] = process_manager_from_settings(settings=settings)
2834
return WorkerPool(**kwargs)
2935

3036

@@ -119,6 +125,11 @@ def generate_settings_schema(settings: LazySettings = global_settings) -> dict:
119125
ret = deepcopy(settings.serialize())
120126

121127
ret['service']['pool_kwargs'] = schema_for_cls(WorkerPool)
128+
ret['service']['process_manager_kwargs'] = {}
129+
pm_classes = (process.ProcessManager, process.ForkServerManager)
130+
for pm_cls in pm_classes:
131+
ret['service']['process_manager_kwargs'].update(schema_for_cls(pm_cls))
132+
ret['service']['process_manager_cls'] = str(Literal[tuple(pm_cls.__name__ for pm_cls in pm_classes)])
122133

123134
for broker_name, broker_kwargs in settings.brokers.items():
124135
broker = get_broker(broker_name, broker_kwargs)

dispatcher/service/pool.py

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -93,6 +93,7 @@ def __init__(self) -> None:
9393
self.work_cleared: asyncio.Event = asyncio.Event() # Totally quiet, no blocked or queued messages, no busy workers
9494
self.management_event: asyncio.Event = asyncio.Event() # Process spawning is backgrounded, so this is the kicker
9595
self.timeout_event: asyncio.Event = asyncio.Event() # Anything that might affect the timeout watcher task
96+
self.workers_ready: asyncio.Event = asyncio.Event() # min workers have started and sent ready message
9697

9798

9899
class WorkerPool:
@@ -399,6 +400,8 @@ async def read_results_forever(self) -> None:
399400

400401
if event == 'ready':
401402
worker.status = 'ready'
403+
if all(worker.status == 'ready' for worker in self.workers.values()):
404+
self.events.workers_ready.set()
402405
await self.drain_queue()
403406

404407
elif event == 'shutdown':

dispatcher/service/process.py

Lines changed: 21 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -1,14 +1,19 @@
11
import asyncio
22
import multiprocessing
3+
from multiprocessing.context import BaseContext
4+
from types import ModuleType
35
from typing import Callable, Iterable, Optional, Union
46

57
from ..worker.task import work_loop
68

79

810
class ProcessProxy:
9-
def __init__(self, args: Iterable, finished_queue: multiprocessing.Queue, target: Callable = work_loop) -> None:
10-
self.message_queue: multiprocessing.Queue = multiprocessing.Queue()
11-
self._process = multiprocessing.Process(target=target, args=tuple(args) + (self.message_queue, finished_queue))
11+
def __init__(
12+
self, args: Iterable, finished_queue: multiprocessing.Queue, target: Callable = work_loop, ctx: Union[BaseContext, ModuleType] = multiprocessing
13+
) -> None:
14+
self.message_queue: multiprocessing.Queue = ctx.Queue()
15+
# This is intended use of multiprocessing context, but not available on BaseContext
16+
self._process = ctx.Process(target=target, args=tuple(args) + (self.message_queue, finished_queue)) # type: ignore
1217

1318
def start(self) -> None:
1419
self._process.start()
@@ -37,8 +42,11 @@ def terminate(self) -> None:
3742

3843

3944
class ProcessManager:
45+
mp_context = 'fork'
46+
4047
def __init__(self) -> None:
41-
self.finished_queue: multiprocessing.Queue = multiprocessing.Queue()
48+
self.ctx = multiprocessing.get_context(self.mp_context)
49+
self.finished_queue: multiprocessing.Queue = self.ctx.Queue()
4250
self._loop = None
4351

4452
def get_event_loop(self):
@@ -47,8 +55,16 @@ def get_event_loop(self):
4755
return self._loop
4856

4957
def create_process(self, args: Iterable[int | str | dict], **kwargs) -> ProcessProxy:
50-
return ProcessProxy(args, self.finished_queue, **kwargs)
58+
return ProcessProxy(args, self.finished_queue, ctx=self.ctx, **kwargs)
5159

5260
async def read_finished(self) -> dict[str, Union[str, int]]:
5361
message = await self.get_event_loop().run_in_executor(None, self.finished_queue.get)
5462
return message
63+
64+
65+
class ForkServerManager(ProcessManager):
66+
mp_context = 'forkserver'
67+
68+
def __init__(self, preload_modules: Optional[list[str]] = None):
69+
super().__init__()
70+
self.ctx.set_forkserver_preload(preload_modules if preload_modules else [])

schema.json

Lines changed: 5 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -20,7 +20,11 @@
2020
"service": {
2121
"pool_kwargs": {
2222
"max_workers": "<class 'int'>"
23-
}
23+
},
24+
"process_manager_kwargs": {
25+
"preload_modules": "typing.Optional[list[str]]"
26+
},
27+
"process_manager_cls": "typing.Literal['ProcessManager', 'ForkServerManager']"
2428
},
2529
"publish": {
2630
"default_broker": "str"

tests/conftest.py

Lines changed: 30 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -9,7 +9,7 @@
99
from dispatcher.service.main import DispatcherMain
1010
from dispatcher.control import Control
1111

12-
from dispatcher.brokers.pg_notify import Broker, create_connection, acreate_connection
12+
from dispatcher.brokers.pg_notify import Broker, acreate_connection, connection_save
1313
from dispatcher.registry import DispatcherMethodRegistry
1414
from dispatcher.config import DispatcherSettings
1515
from dispatcher.factories import from_settings, get_control_from_settings
@@ -56,6 +56,21 @@ async def aconnection_for_test():
5656
await conn.close()
5757

5858

59+
@pytest.fixture(autouse=True)
60+
def clear_connection():
61+
"""Always close connections between tests
62+
63+
Tests will do a lot of unthoughtful forking, and connections can not
64+
be shared accross processes.
65+
"""
66+
if connection_save._connection:
67+
connection_save._connection.close()
68+
connection_save._connection = None
69+
if connection_save._async_connection:
70+
connection_save._async_connection.close()
71+
connection_save._async_connection = None
72+
73+
5974
@pytest.fixture
6075
def conn_config():
6176
return {'conninfo': CONNECTION_STRING}
@@ -73,18 +88,28 @@ def pg_dispatcher() -> DispatcherMain:
7388
def test_settings():
7489
return DispatcherSettings(BASIC_CONFIG)
7590

76-
77-
@pytest_asyncio.fixture(loop_scope="function", scope="function")
78-
async def apg_dispatcher(test_settings) -> AsyncIterator[DispatcherMain]:
91+
@pytest_asyncio.fixture(
92+
loop_scope="function",
93+
scope="function",
94+
params=['ProcessManager', 'ForkServerManager'],
95+
ids=["fork", "forkserver"],
96+
)
97+
async def apg_dispatcher(request) -> AsyncIterator[DispatcherMain]:
7998
dispatcher = None
8099
try:
81-
dispatcher = from_settings(settings=test_settings)
100+
this_test_config = BASIC_CONFIG.copy()
101+
this_test_config.setdefault('service', {})
102+
this_test_config['service']['process_manager_cls'] = request.param
103+
this_settings = DispatcherSettings(this_test_config)
104+
dispatcher = from_settings(settings=this_settings)
82105

83106
await dispatcher.connect_signals()
84107
await dispatcher.start_working()
85108
await dispatcher.wait_for_producers_ready()
109+
await dispatcher.pool.events.workers_ready.wait()
86110

87111
assert dispatcher.pool.finished_count == 0 # sanity
112+
assert dispatcher.control_count == 0
88113

89114
yield dispatcher
90115
finally:

tests/integration/test_main.py

Lines changed: 1 addition & 17 deletions
Original file line numberDiff line numberDiff line change
@@ -24,8 +24,6 @@ async def wait_to_receive(dispatcher, ct, timeout=5.0, interval=0.05):
2424

2525
@pytest.mark.asyncio
2626
async def test_run_lambda_function(apg_dispatcher, pg_message):
27-
assert apg_dispatcher.pool.finished_count == 0
28-
2927
clearing_task = asyncio.create_task(apg_dispatcher.pool.events.work_cleared.wait(), name='test_lambda_clear_wait')
3028
await pg_message('lambda: "This worked!"')
3129
await asyncio.wait_for(clearing_task, timeout=3)
@@ -93,7 +91,7 @@ async def test_cancel_task(apg_dispatcher, pg_message, pg_control):
9391
await pg_message(msg)
9492

9593
clearing_task = asyncio.create_task(apg_dispatcher.pool.events.work_cleared.wait())
96-
await asyncio.sleep(0.04)
94+
await asyncio.sleep(0.2)
9795
canceled_jobs = await asyncio.wait_for(pg_control.acontrol_with_reply('cancel', data={'uuid': 'foobar'}, timeout=1), timeout=5)
9896
worker_id, canceled_message = canceled_jobs[0][0]
9997
assert canceled_message['uuid'] == 'foobar'
@@ -125,8 +123,6 @@ async def test_message_with_delay(apg_dispatcher, pg_message, pg_control):
125123

126124
@pytest.mark.asyncio
127125
async def test_cancel_delayed_task(apg_dispatcher, pg_message, pg_control):
128-
assert apg_dispatcher.pool.finished_count == 0
129-
130126
# Send message to run task with a delay
131127
msg = json.dumps({'task': 'lambda: print("This task should be canceled before start")', 'uuid': 'delay_task_will_cancel', 'delay': 0.8})
132128
await pg_message(msg)
@@ -146,8 +142,6 @@ async def test_cancel_delayed_task(apg_dispatcher, pg_message, pg_control):
146142

147143
@pytest.mark.asyncio
148144
async def test_cancel_with_no_reply(apg_dispatcher, pg_message, pg_control):
149-
assert apg_dispatcher.pool.finished_count == 0
150-
151145
# Send message to run task with a delay
152146
msg = json.dumps({'task': 'lambda: print("This task should be canceled before start")', 'uuid': 'delay_task_will_cancel', 'delay': 2.0})
153147
await pg_message(msg)
@@ -164,8 +158,6 @@ async def test_cancel_with_no_reply(apg_dispatcher, pg_message, pg_control):
164158

165159
@pytest.mark.asyncio
166160
async def test_alive_check(apg_dispatcher, pg_control):
167-
assert apg_dispatcher.control_count == 0
168-
169161
alive = await asyncio.wait_for(pg_control.acontrol_with_reply('alive', timeout=1), timeout=5)
170162
assert alive == [None]
171163

@@ -174,8 +166,6 @@ async def test_alive_check(apg_dispatcher, pg_control):
174166

175167
@pytest.mark.asyncio
176168
async def test_task_discard(apg_dispatcher, pg_message):
177-
assert apg_dispatcher.pool.finished_count == 0
178-
179169
messages = [
180170
json.dumps(
181171
{'task': 'lambda: __import__("time").sleep(9)', 'on_duplicate': 'discard', 'uuid': f'dscd-{i}'}
@@ -192,8 +182,6 @@ async def test_task_discard(apg_dispatcher, pg_message):
192182

193183
@pytest.mark.asyncio
194184
async def test_task_discard_in_task_definition(apg_dispatcher, test_settings):
195-
assert apg_dispatcher.pool.finished_count == 0
196-
197185
for i in range(10):
198186
test_methods.sleep_discard.apply_async(args=[2], settings=test_settings)
199187

@@ -205,8 +193,6 @@ async def test_task_discard_in_task_definition(apg_dispatcher, test_settings):
205193

206194
@pytest.mark.asyncio
207195
async def test_tasks_in_serial(apg_dispatcher, test_settings):
208-
assert apg_dispatcher.pool.finished_count == 0
209-
210196
for i in range(10):
211197
test_methods.sleep_serial.apply_async(args=[2], settings=test_settings)
212198

@@ -218,8 +204,6 @@ async def test_tasks_in_serial(apg_dispatcher, test_settings):
218204

219205
@pytest.mark.asyncio
220206
async def test_tasks_queue_one(apg_dispatcher, test_settings):
221-
assert apg_dispatcher.pool.finished_count == 0
222-
223207
for i in range(10):
224208
test_methods.sleep_queue_one.apply_async(args=[2], settings=test_settings)
225209

tests/integration/test_producers.py

Lines changed: 4 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -9,6 +9,7 @@
99

1010
@pytest.mark.asyncio
1111
async def test_on_start_tasks(caplog):
12+
dispatcher = None
1213
try:
1314
settings = DispatcherSettings({
1415
'version': 2,
@@ -35,5 +36,6 @@ async def test_on_start_tasks(caplog):
3536
assert dispatcher.pool.finished_count == 1
3637
assert 'result: confirmation_of_run' not in caplog.text
3738
finally:
38-
await dispatcher.shutdown()
39-
await dispatcher.cancel_tasks()
39+
if dispatcher:
40+
await dispatcher.shutdown()
41+
await dispatcher.cancel_tasks()

tests/unit/service/test_process.py

Lines changed: 48 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,9 @@
11
from multiprocessing import Queue
2+
import os
23

3-
from dispatcher.service.process import ProcessManager, ProcessProxy
4+
import pytest
5+
6+
from dispatcher.service.process import ProcessManager, ForkServerManager, ProcessProxy
47

58

69
def test_pass_messages_to_worker():
@@ -17,15 +20,53 @@ def work_loop(a, b, c, in_q, out_q):
1720
assert msg == 'done 1 2 3 start'
1821

1922

20-
def test_pass_messages_via_process_manager():
21-
def work_loop(var, in_q, out_q):
22-
has_read = in_q.get()
23-
out_q.put(f'done {var} {has_read}')
23+
def work_loop2(var, in_q, out_q):
24+
"""
25+
Due to the mechanics of forkserver, this can not be defined in local variables,
26+
it has to be importable, but this _is_ importable from the test module.
27+
"""
28+
has_read = in_q.get()
29+
out_q.put(f'done {var} {has_read}')
2430

25-
process_manager = ProcessManager()
26-
process = process_manager.create_process(('value',), target=work_loop)
31+
32+
@pytest.mark.parametrize('manager_cls', [ProcessManager, ForkServerManager])
33+
def test_pass_messages_via_process_manager(manager_cls):
34+
process_manager = manager_cls()
35+
process = process_manager.create_process(('value',), target=work_loop2)
2736
process.start()
2837

2938
process.message_queue.put('msg1')
3039
msg = process_manager.finished_queue.get()
3140
assert msg == 'done value msg1'
41+
42+
43+
@pytest.mark.parametrize('manager_cls', [ProcessManager, ForkServerManager])
44+
def test_workers_have_different_pid(manager_cls):
45+
process_manager = manager_cls()
46+
processes = [process_manager.create_process((f'value{i}',), target=work_loop2) for i in range(2)]
47+
48+
for i in range(2):
49+
process = processes[i]
50+
process.start()
51+
process.message_queue.put(f'msg{i}')
52+
53+
assert processes[0].pid != processes[1].pid # title of test
54+
55+
msg1 = process_manager.finished_queue.get()
56+
msg2 = process_manager.finished_queue.get()
57+
assert set([msg1, msg2]) == set(['done value1 msg1', 'done value0 msg0'])
58+
59+
60+
61+
def return_pid(in_q, out_q):
62+
out_q.put(f'{os.getpid()}')
63+
64+
65+
@pytest.mark.parametrize('manager_cls', [ProcessManager, ForkServerManager])
66+
def test_pid_is_correct(manager_cls):
67+
process_manager = manager_cls()
68+
process = process_manager.create_process((), target=return_pid)
69+
process.start()
70+
71+
msg = process_manager.finished_queue.get()
72+
assert int(msg) == process.pid

0 commit comments

Comments
 (0)