diff --git a/dispatcher/factories.py b/dispatcher/factories.py index 64116717..ae9ddb2c 100644 --- a/dispatcher/factories.py +++ b/dispatcher/factories.py @@ -1,6 +1,6 @@ import inspect from copy import deepcopy -from typing import Iterable, Optional, Type, get_args, get_origin +from typing import Iterable, Literal, Optional, Type, get_args, get_origin from . import producers from .brokers import get_broker @@ -8,9 +8,9 @@ from .config import LazySettings from .config import settings as global_settings from .control import Control +from .service import process from .service.main import DispatcherMain from .service.pool import WorkerPool -from .service.process import ProcessManager """ Creates objects from settings, @@ -21,10 +21,16 @@ # ---- Service objects ---- +def process_manager_from_settings(settings: LazySettings = global_settings): + cls_name = settings.service.get('process_manager_cls', 'ForkServerManager') + process_manager_cls = getattr(process, cls_name) + return process_manager_cls() + + def pool_from_settings(settings: LazySettings = global_settings): kwargs = settings.service.get('pool_kwargs', {}).copy() kwargs['settings'] = settings - kwargs['process_manager'] = ProcessManager() # TODO: use process_manager_cls from settings + kwargs['process_manager'] = process_manager_from_settings(settings=settings) return WorkerPool(**kwargs) @@ -119,6 +125,11 @@ def generate_settings_schema(settings: LazySettings = global_settings) -> dict: ret = deepcopy(settings.serialize()) ret['service']['pool_kwargs'] = schema_for_cls(WorkerPool) + ret['service']['process_manager_kwargs'] = {} + pm_classes = (process.ProcessManager, process.ForkServerManager) + for pm_cls in pm_classes: + ret['service']['process_manager_kwargs'].update(schema_for_cls(pm_cls)) + ret['service']['process_manager_cls'] = str(Literal[tuple(pm_cls.__name__ for pm_cls in pm_classes)]) for broker_name, broker_kwargs in settings.brokers.items(): broker = get_broker(broker_name, broker_kwargs) diff --git a/dispatcher/service/pool.py b/dispatcher/service/pool.py index 94444047..ff1f59a8 100644 --- a/dispatcher/service/pool.py +++ b/dispatcher/service/pool.py @@ -93,6 +93,7 @@ def __init__(self) -> None: self.work_cleared: asyncio.Event = asyncio.Event() # Totally quiet, no blocked or queued messages, no busy workers self.management_event: asyncio.Event = asyncio.Event() # Process spawning is backgrounded, so this is the kicker self.timeout_event: asyncio.Event = asyncio.Event() # Anything that might affect the timeout watcher task + self.workers_ready: asyncio.Event = asyncio.Event() # min workers have started and sent ready message class WorkerPool: @@ -399,6 +400,8 @@ async def read_results_forever(self) -> None: if event == 'ready': worker.status = 'ready' + if all(worker.status == 'ready' for worker in self.workers.values()): + self.events.workers_ready.set() await self.drain_queue() elif event == 'shutdown': diff --git a/dispatcher/service/process.py b/dispatcher/service/process.py index 868d6227..fec72f1d 100644 --- a/dispatcher/service/process.py +++ b/dispatcher/service/process.py @@ -1,14 +1,19 @@ import asyncio import multiprocessing +from multiprocessing.context import BaseContext +from types import ModuleType from typing import Callable, Iterable, Optional, Union from ..worker.task import work_loop class ProcessProxy: - def __init__(self, args: Iterable, finished_queue: multiprocessing.Queue, target: Callable = work_loop) -> None: - self.message_queue: multiprocessing.Queue = multiprocessing.Queue() - self._process = multiprocessing.Process(target=target, args=tuple(args) + (self.message_queue, finished_queue)) + def __init__( + self, args: Iterable, finished_queue: multiprocessing.Queue, target: Callable = work_loop, ctx: Union[BaseContext, ModuleType] = multiprocessing + ) -> None: + self.message_queue: multiprocessing.Queue = ctx.Queue() + # This is intended use of multiprocessing context, but not available on BaseContext + self._process = ctx.Process(target=target, args=tuple(args) + (self.message_queue, finished_queue)) # type: ignore def start(self) -> None: self._process.start() @@ -37,8 +42,11 @@ def terminate(self) -> None: class ProcessManager: + mp_context = 'fork' + def __init__(self) -> None: - self.finished_queue: multiprocessing.Queue = multiprocessing.Queue() + self.ctx = multiprocessing.get_context(self.mp_context) + self.finished_queue: multiprocessing.Queue = self.ctx.Queue() self._loop = None def get_event_loop(self): @@ -47,8 +55,16 @@ def get_event_loop(self): return self._loop def create_process(self, args: Iterable[int | str | dict], **kwargs) -> ProcessProxy: - return ProcessProxy(args, self.finished_queue, **kwargs) + return ProcessProxy(args, self.finished_queue, ctx=self.ctx, **kwargs) async def read_finished(self) -> dict[str, Union[str, int]]: message = await self.get_event_loop().run_in_executor(None, self.finished_queue.get) return message + + +class ForkServerManager(ProcessManager): + mp_context = 'forkserver' + + def __init__(self, preload_modules: Optional[list[str]] = None): + super().__init__() + self.ctx.set_forkserver_preload(preload_modules if preload_modules else []) diff --git a/schema.json b/schema.json index c4630614..b7015b53 100644 --- a/schema.json +++ b/schema.json @@ -20,7 +20,11 @@ "service": { "pool_kwargs": { "max_workers": "" - } + }, + "process_manager_kwargs": { + "preload_modules": "typing.Optional[list[str]]" + }, + "process_manager_cls": "typing.Literal['ProcessManager', 'ForkServerManager']" }, "publish": { "default_broker": "str" diff --git a/tests/conftest.py b/tests/conftest.py index 6e003494..a864dc43 100644 --- a/tests/conftest.py +++ b/tests/conftest.py @@ -9,7 +9,7 @@ from dispatcher.service.main import DispatcherMain from dispatcher.control import Control -from dispatcher.brokers.pg_notify import Broker, create_connection, acreate_connection +from dispatcher.brokers.pg_notify import Broker, acreate_connection, connection_save from dispatcher.registry import DispatcherMethodRegistry from dispatcher.config import DispatcherSettings from dispatcher.factories import from_settings, get_control_from_settings @@ -56,6 +56,21 @@ async def aconnection_for_test(): await conn.close() +@pytest.fixture(autouse=True) +def clear_connection(): + """Always close connections between tests + + Tests will do a lot of unthoughtful forking, and connections can not + be shared accross processes. + """ + if connection_save._connection: + connection_save._connection.close() + connection_save._connection = None + if connection_save._async_connection: + connection_save._async_connection.close() + connection_save._async_connection = None + + @pytest.fixture def conn_config(): return {'conninfo': CONNECTION_STRING} @@ -73,18 +88,28 @@ def pg_dispatcher() -> DispatcherMain: def test_settings(): return DispatcherSettings(BASIC_CONFIG) - -@pytest_asyncio.fixture(loop_scope="function", scope="function") -async def apg_dispatcher(test_settings) -> AsyncIterator[DispatcherMain]: +@pytest_asyncio.fixture( + loop_scope="function", + scope="function", + params=['ProcessManager', 'ForkServerManager'], + ids=["fork", "forkserver"], +) +async def apg_dispatcher(request) -> AsyncIterator[DispatcherMain]: dispatcher = None try: - dispatcher = from_settings(settings=test_settings) + this_test_config = BASIC_CONFIG.copy() + this_test_config.setdefault('service', {}) + this_test_config['service']['process_manager_cls'] = request.param + this_settings = DispatcherSettings(this_test_config) + dispatcher = from_settings(settings=this_settings) await dispatcher.connect_signals() await dispatcher.start_working() await dispatcher.wait_for_producers_ready() + await dispatcher.pool.events.workers_ready.wait() assert dispatcher.pool.finished_count == 0 # sanity + assert dispatcher.control_count == 0 yield dispatcher finally: diff --git a/tests/integration/test_main.py b/tests/integration/test_main.py index ae8c649b..93677a6d 100644 --- a/tests/integration/test_main.py +++ b/tests/integration/test_main.py @@ -24,8 +24,6 @@ async def wait_to_receive(dispatcher, ct, timeout=5.0, interval=0.05): @pytest.mark.asyncio async def test_run_lambda_function(apg_dispatcher, pg_message): - assert apg_dispatcher.pool.finished_count == 0 - clearing_task = asyncio.create_task(apg_dispatcher.pool.events.work_cleared.wait(), name='test_lambda_clear_wait') await pg_message('lambda: "This worked!"') await asyncio.wait_for(clearing_task, timeout=3) @@ -93,7 +91,7 @@ async def test_cancel_task(apg_dispatcher, pg_message, pg_control): await pg_message(msg) clearing_task = asyncio.create_task(apg_dispatcher.pool.events.work_cleared.wait()) - await asyncio.sleep(0.04) + await asyncio.sleep(0.2) canceled_jobs = await asyncio.wait_for(pg_control.acontrol_with_reply('cancel', data={'uuid': 'foobar'}, timeout=1), timeout=5) worker_id, canceled_message = canceled_jobs[0][0] assert canceled_message['uuid'] == 'foobar' @@ -125,8 +123,6 @@ async def test_message_with_delay(apg_dispatcher, pg_message, pg_control): @pytest.mark.asyncio async def test_cancel_delayed_task(apg_dispatcher, pg_message, pg_control): - assert apg_dispatcher.pool.finished_count == 0 - # Send message to run task with a delay msg = json.dumps({'task': 'lambda: print("This task should be canceled before start")', 'uuid': 'delay_task_will_cancel', 'delay': 0.8}) await pg_message(msg) @@ -146,8 +142,6 @@ async def test_cancel_delayed_task(apg_dispatcher, pg_message, pg_control): @pytest.mark.asyncio async def test_cancel_with_no_reply(apg_dispatcher, pg_message, pg_control): - assert apg_dispatcher.pool.finished_count == 0 - # Send message to run task with a delay msg = json.dumps({'task': 'lambda: print("This task should be canceled before start")', 'uuid': 'delay_task_will_cancel', 'delay': 2.0}) await pg_message(msg) @@ -164,8 +158,6 @@ async def test_cancel_with_no_reply(apg_dispatcher, pg_message, pg_control): @pytest.mark.asyncio async def test_alive_check(apg_dispatcher, pg_control): - assert apg_dispatcher.control_count == 0 - alive = await asyncio.wait_for(pg_control.acontrol_with_reply('alive', timeout=1), timeout=5) assert alive == [None] @@ -174,8 +166,6 @@ async def test_alive_check(apg_dispatcher, pg_control): @pytest.mark.asyncio async def test_task_discard(apg_dispatcher, pg_message): - assert apg_dispatcher.pool.finished_count == 0 - messages = [ json.dumps( {'task': 'lambda: __import__("time").sleep(9)', 'on_duplicate': 'discard', 'uuid': f'dscd-{i}'} @@ -192,8 +182,6 @@ async def test_task_discard(apg_dispatcher, pg_message): @pytest.mark.asyncio async def test_task_discard_in_task_definition(apg_dispatcher, test_settings): - assert apg_dispatcher.pool.finished_count == 0 - for i in range(10): test_methods.sleep_discard.apply_async(args=[2], settings=test_settings) @@ -205,8 +193,6 @@ async def test_task_discard_in_task_definition(apg_dispatcher, test_settings): @pytest.mark.asyncio async def test_tasks_in_serial(apg_dispatcher, test_settings): - assert apg_dispatcher.pool.finished_count == 0 - for i in range(10): test_methods.sleep_serial.apply_async(args=[2], settings=test_settings) @@ -218,8 +204,6 @@ async def test_tasks_in_serial(apg_dispatcher, test_settings): @pytest.mark.asyncio async def test_tasks_queue_one(apg_dispatcher, test_settings): - assert apg_dispatcher.pool.finished_count == 0 - for i in range(10): test_methods.sleep_queue_one.apply_async(args=[2], settings=test_settings) diff --git a/tests/integration/test_producers.py b/tests/integration/test_producers.py index db393a46..7449ed7d 100644 --- a/tests/integration/test_producers.py +++ b/tests/integration/test_producers.py @@ -9,6 +9,7 @@ @pytest.mark.asyncio async def test_on_start_tasks(caplog): + dispatcher = None try: settings = DispatcherSettings({ 'version': 2, @@ -35,5 +36,6 @@ async def test_on_start_tasks(caplog): assert dispatcher.pool.finished_count == 1 assert 'result: confirmation_of_run' not in caplog.text finally: - await dispatcher.shutdown() - await dispatcher.cancel_tasks() + if dispatcher: + await dispatcher.shutdown() + await dispatcher.cancel_tasks() diff --git a/tests/unit/service/test_process.py b/tests/unit/service/test_process.py index 533f069b..89a1a08d 100644 --- a/tests/unit/service/test_process.py +++ b/tests/unit/service/test_process.py @@ -1,6 +1,9 @@ from multiprocessing import Queue +import os -from dispatcher.service.process import ProcessManager, ProcessProxy +import pytest + +from dispatcher.service.process import ProcessManager, ForkServerManager, ProcessProxy def test_pass_messages_to_worker(): @@ -17,15 +20,53 @@ def work_loop(a, b, c, in_q, out_q): assert msg == 'done 1 2 3 start' -def test_pass_messages_via_process_manager(): - def work_loop(var, in_q, out_q): - has_read = in_q.get() - out_q.put(f'done {var} {has_read}') +def work_loop2(var, in_q, out_q): + """ + Due to the mechanics of forkserver, this can not be defined in local variables, + it has to be importable, but this _is_ importable from the test module. + """ + has_read = in_q.get() + out_q.put(f'done {var} {has_read}') - process_manager = ProcessManager() - process = process_manager.create_process(('value',), target=work_loop) + +@pytest.mark.parametrize('manager_cls', [ProcessManager, ForkServerManager]) +def test_pass_messages_via_process_manager(manager_cls): + process_manager = manager_cls() + process = process_manager.create_process(('value',), target=work_loop2) process.start() process.message_queue.put('msg1') msg = process_manager.finished_queue.get() assert msg == 'done value msg1' + + +@pytest.mark.parametrize('manager_cls', [ProcessManager, ForkServerManager]) +def test_workers_have_different_pid(manager_cls): + process_manager = manager_cls() + processes = [process_manager.create_process((f'value{i}',), target=work_loop2) for i in range(2)] + + for i in range(2): + process = processes[i] + process.start() + process.message_queue.put(f'msg{i}') + + assert processes[0].pid != processes[1].pid # title of test + + msg1 = process_manager.finished_queue.get() + msg2 = process_manager.finished_queue.get() + assert set([msg1, msg2]) == set(['done value1 msg1', 'done value0 msg0']) + + + +def return_pid(in_q, out_q): + out_q.put(f'{os.getpid()}') + + +@pytest.mark.parametrize('manager_cls', [ProcessManager, ForkServerManager]) +def test_pid_is_correct(manager_cls): + process_manager = manager_cls() + process = process_manager.create_process((), target=return_pid) + process.start() + + msg = process_manager.finished_queue.get() + assert int(msg) == process.pid