From 1064281458ea092ed2beadc6b8edae61860a59bd Mon Sep 17 00:00:00 2001 From: Alan Rominger Date: Sat, 15 Feb 2025 10:52:57 -0500 Subject: [PATCH 1/3] Allow using forkserver Run main tests with forkserver Add tests about pid accuracy Get tests mostly working Run linters and close connections Wait for workers to be ready before starting test wrap up linters Update schema run linters Python 3.10 compat --- dispatcher/factories.py | 17 +++++++-- dispatcher/service/pool.py | 3 ++ dispatcher/service/process.py | 26 +++++++++++--- schema.json | 6 +++- tests/conftest.py | 35 ++++++++++++++++--- tests/integration/test_main.py | 18 +--------- tests/unit/service/test_process.py | 55 ++++++++++++++++++++++++++---- 7 files changed, 122 insertions(+), 38 deletions(-) diff --git a/dispatcher/factories.py b/dispatcher/factories.py index 64116717..d3748b2d 100644 --- a/dispatcher/factories.py +++ b/dispatcher/factories.py @@ -1,6 +1,6 @@ import inspect from copy import deepcopy -from typing import Iterable, Optional, Type, get_args, get_origin +from typing import Iterable, Literal, Optional, Type, get_args, get_origin from . import producers from .brokers import get_broker @@ -10,7 +10,7 @@ from .control import Control from .service.main import DispatcherMain from .service.pool import WorkerPool -from .service.process import ProcessManager +from .service import process """ Creates objects from settings, @@ -21,10 +21,16 @@ # ---- Service objects ---- +def process_manager_from_settings(settings: LazySettings = global_settings): + cls_name = settings.service.get('process_manager_cls', 'ForkServer') + process_manager_cls = getattr(process, cls_name) + return process_manager_cls() + + def pool_from_settings(settings: LazySettings = global_settings): kwargs = settings.service.get('pool_kwargs', {}).copy() kwargs['settings'] = settings - kwargs['process_manager'] = ProcessManager() # TODO: use process_manager_cls from settings + kwargs['process_manager'] = process_manager_from_settings(settings=settings) return WorkerPool(**kwargs) @@ -119,6 +125,11 @@ def generate_settings_schema(settings: LazySettings = global_settings) -> dict: ret = deepcopy(settings.serialize()) ret['service']['pool_kwargs'] = schema_for_cls(WorkerPool) + ret['service']['process_manager_kwargs'] = {} + pm_classes = (process.ProcessManager, process.ForkServerManager) + for pm_cls in pm_classes: + ret['service']['process_manager_kwargs'].update(schema_for_cls(pm_cls)) + ret['service']['process_manager_cls'] = str(Literal[tuple(pm_cls.__name__ for pm_cls in pm_classes)]) for broker_name, broker_kwargs in settings.brokers.items(): broker = get_broker(broker_name, broker_kwargs) diff --git a/dispatcher/service/pool.py b/dispatcher/service/pool.py index 94444047..ff1f59a8 100644 --- a/dispatcher/service/pool.py +++ b/dispatcher/service/pool.py @@ -93,6 +93,7 @@ def __init__(self) -> None: self.work_cleared: asyncio.Event = asyncio.Event() # Totally quiet, no blocked or queued messages, no busy workers self.management_event: asyncio.Event = asyncio.Event() # Process spawning is backgrounded, so this is the kicker self.timeout_event: asyncio.Event = asyncio.Event() # Anything that might affect the timeout watcher task + self.workers_ready: asyncio.Event = asyncio.Event() # min workers have started and sent ready message class WorkerPool: @@ -399,6 +400,8 @@ async def read_results_forever(self) -> None: if event == 'ready': worker.status = 'ready' + if all(worker.status == 'ready' for worker in self.workers.values()): + self.events.workers_ready.set() await self.drain_queue() elif event == 'shutdown': diff --git a/dispatcher/service/process.py b/dispatcher/service/process.py index 868d6227..fec72f1d 100644 --- a/dispatcher/service/process.py +++ b/dispatcher/service/process.py @@ -1,14 +1,19 @@ import asyncio import multiprocessing +from multiprocessing.context import BaseContext +from types import ModuleType from typing import Callable, Iterable, Optional, Union from ..worker.task import work_loop class ProcessProxy: - def __init__(self, args: Iterable, finished_queue: multiprocessing.Queue, target: Callable = work_loop) -> None: - self.message_queue: multiprocessing.Queue = multiprocessing.Queue() - self._process = multiprocessing.Process(target=target, args=tuple(args) + (self.message_queue, finished_queue)) + def __init__( + self, args: Iterable, finished_queue: multiprocessing.Queue, target: Callable = work_loop, ctx: Union[BaseContext, ModuleType] = multiprocessing + ) -> None: + self.message_queue: multiprocessing.Queue = ctx.Queue() + # This is intended use of multiprocessing context, but not available on BaseContext + self._process = ctx.Process(target=target, args=tuple(args) + (self.message_queue, finished_queue)) # type: ignore def start(self) -> None: self._process.start() @@ -37,8 +42,11 @@ def terminate(self) -> None: class ProcessManager: + mp_context = 'fork' + def __init__(self) -> None: - self.finished_queue: multiprocessing.Queue = multiprocessing.Queue() + self.ctx = multiprocessing.get_context(self.mp_context) + self.finished_queue: multiprocessing.Queue = self.ctx.Queue() self._loop = None def get_event_loop(self): @@ -47,8 +55,16 @@ def get_event_loop(self): return self._loop def create_process(self, args: Iterable[int | str | dict], **kwargs) -> ProcessProxy: - return ProcessProxy(args, self.finished_queue, **kwargs) + return ProcessProxy(args, self.finished_queue, ctx=self.ctx, **kwargs) async def read_finished(self) -> dict[str, Union[str, int]]: message = await self.get_event_loop().run_in_executor(None, self.finished_queue.get) return message + + +class ForkServerManager(ProcessManager): + mp_context = 'forkserver' + + def __init__(self, preload_modules: Optional[list[str]] = None): + super().__init__() + self.ctx.set_forkserver_preload(preload_modules if preload_modules else []) diff --git a/schema.json b/schema.json index c4630614..b7015b53 100644 --- a/schema.json +++ b/schema.json @@ -20,7 +20,11 @@ "service": { "pool_kwargs": { "max_workers": "" - } + }, + "process_manager_kwargs": { + "preload_modules": "typing.Optional[list[str]]" + }, + "process_manager_cls": "typing.Literal['ProcessManager', 'ForkServerManager']" }, "publish": { "default_broker": "str" diff --git a/tests/conftest.py b/tests/conftest.py index 6e003494..a864dc43 100644 --- a/tests/conftest.py +++ b/tests/conftest.py @@ -9,7 +9,7 @@ from dispatcher.service.main import DispatcherMain from dispatcher.control import Control -from dispatcher.brokers.pg_notify import Broker, create_connection, acreate_connection +from dispatcher.brokers.pg_notify import Broker, acreate_connection, connection_save from dispatcher.registry import DispatcherMethodRegistry from dispatcher.config import DispatcherSettings from dispatcher.factories import from_settings, get_control_from_settings @@ -56,6 +56,21 @@ async def aconnection_for_test(): await conn.close() +@pytest.fixture(autouse=True) +def clear_connection(): + """Always close connections between tests + + Tests will do a lot of unthoughtful forking, and connections can not + be shared accross processes. + """ + if connection_save._connection: + connection_save._connection.close() + connection_save._connection = None + if connection_save._async_connection: + connection_save._async_connection.close() + connection_save._async_connection = None + + @pytest.fixture def conn_config(): return {'conninfo': CONNECTION_STRING} @@ -73,18 +88,28 @@ def pg_dispatcher() -> DispatcherMain: def test_settings(): return DispatcherSettings(BASIC_CONFIG) - -@pytest_asyncio.fixture(loop_scope="function", scope="function") -async def apg_dispatcher(test_settings) -> AsyncIterator[DispatcherMain]: +@pytest_asyncio.fixture( + loop_scope="function", + scope="function", + params=['ProcessManager', 'ForkServerManager'], + ids=["fork", "forkserver"], +) +async def apg_dispatcher(request) -> AsyncIterator[DispatcherMain]: dispatcher = None try: - dispatcher = from_settings(settings=test_settings) + this_test_config = BASIC_CONFIG.copy() + this_test_config.setdefault('service', {}) + this_test_config['service']['process_manager_cls'] = request.param + this_settings = DispatcherSettings(this_test_config) + dispatcher = from_settings(settings=this_settings) await dispatcher.connect_signals() await dispatcher.start_working() await dispatcher.wait_for_producers_ready() + await dispatcher.pool.events.workers_ready.wait() assert dispatcher.pool.finished_count == 0 # sanity + assert dispatcher.control_count == 0 yield dispatcher finally: diff --git a/tests/integration/test_main.py b/tests/integration/test_main.py index ae8c649b..93677a6d 100644 --- a/tests/integration/test_main.py +++ b/tests/integration/test_main.py @@ -24,8 +24,6 @@ async def wait_to_receive(dispatcher, ct, timeout=5.0, interval=0.05): @pytest.mark.asyncio async def test_run_lambda_function(apg_dispatcher, pg_message): - assert apg_dispatcher.pool.finished_count == 0 - clearing_task = asyncio.create_task(apg_dispatcher.pool.events.work_cleared.wait(), name='test_lambda_clear_wait') await pg_message('lambda: "This worked!"') await asyncio.wait_for(clearing_task, timeout=3) @@ -93,7 +91,7 @@ async def test_cancel_task(apg_dispatcher, pg_message, pg_control): await pg_message(msg) clearing_task = asyncio.create_task(apg_dispatcher.pool.events.work_cleared.wait()) - await asyncio.sleep(0.04) + await asyncio.sleep(0.2) canceled_jobs = await asyncio.wait_for(pg_control.acontrol_with_reply('cancel', data={'uuid': 'foobar'}, timeout=1), timeout=5) worker_id, canceled_message = canceled_jobs[0][0] assert canceled_message['uuid'] == 'foobar' @@ -125,8 +123,6 @@ async def test_message_with_delay(apg_dispatcher, pg_message, pg_control): @pytest.mark.asyncio async def test_cancel_delayed_task(apg_dispatcher, pg_message, pg_control): - assert apg_dispatcher.pool.finished_count == 0 - # Send message to run task with a delay msg = json.dumps({'task': 'lambda: print("This task should be canceled before start")', 'uuid': 'delay_task_will_cancel', 'delay': 0.8}) await pg_message(msg) @@ -146,8 +142,6 @@ async def test_cancel_delayed_task(apg_dispatcher, pg_message, pg_control): @pytest.mark.asyncio async def test_cancel_with_no_reply(apg_dispatcher, pg_message, pg_control): - assert apg_dispatcher.pool.finished_count == 0 - # Send message to run task with a delay msg = json.dumps({'task': 'lambda: print("This task should be canceled before start")', 'uuid': 'delay_task_will_cancel', 'delay': 2.0}) await pg_message(msg) @@ -164,8 +158,6 @@ async def test_cancel_with_no_reply(apg_dispatcher, pg_message, pg_control): @pytest.mark.asyncio async def test_alive_check(apg_dispatcher, pg_control): - assert apg_dispatcher.control_count == 0 - alive = await asyncio.wait_for(pg_control.acontrol_with_reply('alive', timeout=1), timeout=5) assert alive == [None] @@ -174,8 +166,6 @@ async def test_alive_check(apg_dispatcher, pg_control): @pytest.mark.asyncio async def test_task_discard(apg_dispatcher, pg_message): - assert apg_dispatcher.pool.finished_count == 0 - messages = [ json.dumps( {'task': 'lambda: __import__("time").sleep(9)', 'on_duplicate': 'discard', 'uuid': f'dscd-{i}'} @@ -192,8 +182,6 @@ async def test_task_discard(apg_dispatcher, pg_message): @pytest.mark.asyncio async def test_task_discard_in_task_definition(apg_dispatcher, test_settings): - assert apg_dispatcher.pool.finished_count == 0 - for i in range(10): test_methods.sleep_discard.apply_async(args=[2], settings=test_settings) @@ -205,8 +193,6 @@ async def test_task_discard_in_task_definition(apg_dispatcher, test_settings): @pytest.mark.asyncio async def test_tasks_in_serial(apg_dispatcher, test_settings): - assert apg_dispatcher.pool.finished_count == 0 - for i in range(10): test_methods.sleep_serial.apply_async(args=[2], settings=test_settings) @@ -218,8 +204,6 @@ async def test_tasks_in_serial(apg_dispatcher, test_settings): @pytest.mark.asyncio async def test_tasks_queue_one(apg_dispatcher, test_settings): - assert apg_dispatcher.pool.finished_count == 0 - for i in range(10): test_methods.sleep_queue_one.apply_async(args=[2], settings=test_settings) diff --git a/tests/unit/service/test_process.py b/tests/unit/service/test_process.py index 533f069b..89a1a08d 100644 --- a/tests/unit/service/test_process.py +++ b/tests/unit/service/test_process.py @@ -1,6 +1,9 @@ from multiprocessing import Queue +import os -from dispatcher.service.process import ProcessManager, ProcessProxy +import pytest + +from dispatcher.service.process import ProcessManager, ForkServerManager, ProcessProxy def test_pass_messages_to_worker(): @@ -17,15 +20,53 @@ def work_loop(a, b, c, in_q, out_q): assert msg == 'done 1 2 3 start' -def test_pass_messages_via_process_manager(): - def work_loop(var, in_q, out_q): - has_read = in_q.get() - out_q.put(f'done {var} {has_read}') +def work_loop2(var, in_q, out_q): + """ + Due to the mechanics of forkserver, this can not be defined in local variables, + it has to be importable, but this _is_ importable from the test module. + """ + has_read = in_q.get() + out_q.put(f'done {var} {has_read}') - process_manager = ProcessManager() - process = process_manager.create_process(('value',), target=work_loop) + +@pytest.mark.parametrize('manager_cls', [ProcessManager, ForkServerManager]) +def test_pass_messages_via_process_manager(manager_cls): + process_manager = manager_cls() + process = process_manager.create_process(('value',), target=work_loop2) process.start() process.message_queue.put('msg1') msg = process_manager.finished_queue.get() assert msg == 'done value msg1' + + +@pytest.mark.parametrize('manager_cls', [ProcessManager, ForkServerManager]) +def test_workers_have_different_pid(manager_cls): + process_manager = manager_cls() + processes = [process_manager.create_process((f'value{i}',), target=work_loop2) for i in range(2)] + + for i in range(2): + process = processes[i] + process.start() + process.message_queue.put(f'msg{i}') + + assert processes[0].pid != processes[1].pid # title of test + + msg1 = process_manager.finished_queue.get() + msg2 = process_manager.finished_queue.get() + assert set([msg1, msg2]) == set(['done value1 msg1', 'done value0 msg0']) + + + +def return_pid(in_q, out_q): + out_q.put(f'{os.getpid()}') + + +@pytest.mark.parametrize('manager_cls', [ProcessManager, ForkServerManager]) +def test_pid_is_correct(manager_cls): + process_manager = manager_cls() + process = process_manager.create_process((), target=return_pid) + process.start() + + msg = process_manager.finished_queue.get() + assert int(msg) == process.pid From a419a48e8baa403c89651ba941a741c9c8004673 Mon Sep 17 00:00:00 2001 From: Alan Rominger Date: Mon, 24 Feb 2025 14:42:50 -0500 Subject: [PATCH 2/3] run isort --- dispatcher/factories.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/dispatcher/factories.py b/dispatcher/factories.py index d3748b2d..e62f2667 100644 --- a/dispatcher/factories.py +++ b/dispatcher/factories.py @@ -8,9 +8,9 @@ from .config import LazySettings from .config import settings as global_settings from .control import Control +from .service import process from .service.main import DispatcherMain from .service.pool import WorkerPool -from .service import process """ Creates objects from settings, From 45786498caf87eb2ed7d85299f21469f4fa4e42d Mon Sep 17 00:00:00 2001 From: Alan Rominger Date: Wed, 26 Feb 2025 22:47:36 -0500 Subject: [PATCH 3/3] Fix rebase logic conflict with folder moves --- dispatcher/factories.py | 2 +- tests/integration/test_producers.py | 6 ++++-- 2 files changed, 5 insertions(+), 3 deletions(-) diff --git a/dispatcher/factories.py b/dispatcher/factories.py index e62f2667..ae9ddb2c 100644 --- a/dispatcher/factories.py +++ b/dispatcher/factories.py @@ -22,7 +22,7 @@ def process_manager_from_settings(settings: LazySettings = global_settings): - cls_name = settings.service.get('process_manager_cls', 'ForkServer') + cls_name = settings.service.get('process_manager_cls', 'ForkServerManager') process_manager_cls = getattr(process, cls_name) return process_manager_cls() diff --git a/tests/integration/test_producers.py b/tests/integration/test_producers.py index db393a46..7449ed7d 100644 --- a/tests/integration/test_producers.py +++ b/tests/integration/test_producers.py @@ -9,6 +9,7 @@ @pytest.mark.asyncio async def test_on_start_tasks(caplog): + dispatcher = None try: settings = DispatcherSettings({ 'version': 2, @@ -35,5 +36,6 @@ async def test_on_start_tasks(caplog): assert dispatcher.pool.finished_count == 1 assert 'result: confirmation_of_run' not in caplog.text finally: - await dispatcher.shutdown() - await dispatcher.cancel_tasks() + if dispatcher: + await dispatcher.shutdown() + await dispatcher.cancel_tasks()