Skip to content
Merged
Show file tree
Hide file tree
Changes from 2 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
17 changes: 14 additions & 3 deletions dispatcher/factories.py
Original file line number Diff line number Diff line change
@@ -1,16 +1,16 @@
import inspect
from copy import deepcopy
from typing import Iterable, Optional, Type, get_args, get_origin
from typing import Iterable, Literal, Optional, Type, get_args, get_origin

from . import producers
from .brokers import get_broker
from .brokers.base import BaseBroker
from .config import LazySettings
from .config import settings as global_settings
from .control import Control
from .service import process
from .service.main import DispatcherMain
from .service.pool import WorkerPool
from .service.process import ProcessManager

"""
Creates objects from settings,
Expand All @@ -21,10 +21,16 @@
# ---- Service objects ----


def process_manager_from_settings(settings: LazySettings = global_settings):
cls_name = settings.service.get('process_manager_cls', 'ForkServer')
process_manager_cls = getattr(process, cls_name)
return process_manager_cls()


def pool_from_settings(settings: LazySettings = global_settings):
kwargs = settings.service.get('pool_kwargs', {}).copy()
kwargs['settings'] = settings
kwargs['process_manager'] = ProcessManager() # TODO: use process_manager_cls from settings
kwargs['process_manager'] = process_manager_from_settings(settings=settings)
return WorkerPool(**kwargs)


Expand Down Expand Up @@ -119,6 +125,11 @@ def generate_settings_schema(settings: LazySettings = global_settings) -> dict:
ret = deepcopy(settings.serialize())

ret['service']['pool_kwargs'] = schema_for_cls(WorkerPool)
ret['service']['process_manager_kwargs'] = {}
pm_classes = (process.ProcessManager, process.ForkServerManager)
for pm_cls in pm_classes:
ret['service']['process_manager_kwargs'].update(schema_for_cls(pm_cls))
ret['service']['process_manager_cls'] = str(Literal[tuple(pm_cls.__name__ for pm_cls in pm_classes)])

for broker_name, broker_kwargs in settings.brokers.items():
broker = get_broker(broker_name, broker_kwargs)
Expand Down
3 changes: 3 additions & 0 deletions dispatcher/service/pool.py
Original file line number Diff line number Diff line change
Expand Up @@ -93,6 +93,7 @@ def __init__(self) -> None:
self.work_cleared: asyncio.Event = asyncio.Event() # Totally quiet, no blocked or queued messages, no busy workers
self.management_event: asyncio.Event = asyncio.Event() # Process spawning is backgrounded, so this is the kicker
self.timeout_event: asyncio.Event = asyncio.Event() # Anything that might affect the timeout watcher task
self.workers_ready: asyncio.Event = asyncio.Event() # min workers have started and sent ready message


class WorkerPool:
Expand Down Expand Up @@ -399,6 +400,8 @@ async def read_results_forever(self) -> None:

if event == 'ready':
worker.status = 'ready'
if all(worker.status == 'ready' for worker in self.workers.values()):
self.events.workers_ready.set()
await self.drain_queue()

elif event == 'shutdown':
Expand Down
26 changes: 21 additions & 5 deletions dispatcher/service/process.py
Original file line number Diff line number Diff line change
@@ -1,14 +1,19 @@
import asyncio
import multiprocessing
from multiprocessing.context import BaseContext
from types import ModuleType
from typing import Callable, Iterable, Optional, Union

from ..worker.task import work_loop


class ProcessProxy:
def __init__(self, args: Iterable, finished_queue: multiprocessing.Queue, target: Callable = work_loop) -> None:
self.message_queue: multiprocessing.Queue = multiprocessing.Queue()
self._process = multiprocessing.Process(target=target, args=tuple(args) + (self.message_queue, finished_queue))
def __init__(
self, args: Iterable, finished_queue: multiprocessing.Queue, target: Callable = work_loop, ctx: Union[BaseContext, ModuleType] = multiprocessing
) -> None:
self.message_queue: multiprocessing.Queue = ctx.Queue()
# This is intended use of multiprocessing context, but not available on BaseContext
self._process = ctx.Process(target=target, args=tuple(args) + (self.message_queue, finished_queue)) # type: ignore

def start(self) -> None:
self._process.start()
Expand Down Expand Up @@ -37,8 +42,11 @@ def terminate(self) -> None:


class ProcessManager:
mp_context = 'fork'

def __init__(self) -> None:
self.finished_queue: multiprocessing.Queue = multiprocessing.Queue()
self.ctx = multiprocessing.get_context(self.mp_context)
self.finished_queue: multiprocessing.Queue = self.ctx.Queue()
self._loop = None

def get_event_loop(self):
Expand All @@ -47,8 +55,16 @@ def get_event_loop(self):
return self._loop

def create_process(self, args: Iterable[int | str | dict], **kwargs) -> ProcessProxy:
return ProcessProxy(args, self.finished_queue, **kwargs)
return ProcessProxy(args, self.finished_queue, ctx=self.ctx, **kwargs)

async def read_finished(self) -> dict[str, Union[str, int]]:
message = await self.get_event_loop().run_in_executor(None, self.finished_queue.get)
return message


class ForkServerManager(ProcessManager):
mp_context = 'forkserver'

def __init__(self, preload_modules: Optional[list[str]] = None):
super().__init__()
self.ctx.set_forkserver_preload(preload_modules if preload_modules else [])
6 changes: 5 additions & 1 deletion schema.json
Original file line number Diff line number Diff line change
Expand Up @@ -20,7 +20,11 @@
"service": {
"pool_kwargs": {
"max_workers": "<class 'int'>"
}
},
"process_manager_kwargs": {
"preload_modules": "typing.Optional[list[str]]"
},
"process_manager_cls": "typing.Literal['ProcessManager', 'ForkServerManager']"
},
"publish": {
"default_broker": "str"
Expand Down
35 changes: 30 additions & 5 deletions tests/conftest.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,7 +9,7 @@
from dispatcher.service.main import DispatcherMain
from dispatcher.control import Control

from dispatcher.brokers.pg_notify import Broker, create_connection, acreate_connection
from dispatcher.brokers.pg_notify import Broker, acreate_connection, connection_save
from dispatcher.registry import DispatcherMethodRegistry
from dispatcher.config import DispatcherSettings
from dispatcher.factories import from_settings, get_control_from_settings
Expand Down Expand Up @@ -56,6 +56,21 @@ async def aconnection_for_test():
await conn.close()


@pytest.fixture(autouse=True)
def clear_connection():
"""Always close connections between tests

Tests will do a lot of unthoughtful forking, and connections can not
be shared accross processes.
"""
if connection_save._connection:
connection_save._connection.close()
connection_save._connection = None
if connection_save._async_connection:
connection_save._async_connection.close()
connection_save._async_connection = None


@pytest.fixture
def conn_config():
return {'conninfo': CONNECTION_STRING}
Expand All @@ -73,18 +88,28 @@ def pg_dispatcher() -> DispatcherMain:
def test_settings():
return DispatcherSettings(BASIC_CONFIG)


@pytest_asyncio.fixture(loop_scope="function", scope="function")
async def apg_dispatcher(test_settings) -> AsyncIterator[DispatcherMain]:
@pytest_asyncio.fixture(
loop_scope="function",
scope="function",
params=['ProcessManager', 'ForkServerManager'],
ids=["fork", "forkserver"],
)
async def apg_dispatcher(request) -> AsyncIterator[DispatcherMain]:
dispatcher = None
try:
dispatcher = from_settings(settings=test_settings)
this_test_config = BASIC_CONFIG.copy()
this_test_config.setdefault('service', {})
this_test_config['service']['process_manager_cls'] = request.param
this_settings = DispatcherSettings(this_test_config)
dispatcher = from_settings(settings=this_settings)

await dispatcher.connect_signals()
await dispatcher.start_working()
await dispatcher.wait_for_producers_ready()
await dispatcher.pool.events.workers_ready.wait()

assert dispatcher.pool.finished_count == 0 # sanity
assert dispatcher.control_count == 0

yield dispatcher
finally:
Expand Down
18 changes: 1 addition & 17 deletions tests/integration/test_main.py
Original file line number Diff line number Diff line change
Expand Up @@ -24,8 +24,6 @@ async def wait_to_receive(dispatcher, ct, timeout=5.0, interval=0.05):

@pytest.mark.asyncio
async def test_run_lambda_function(apg_dispatcher, pg_message):
assert apg_dispatcher.pool.finished_count == 0

clearing_task = asyncio.create_task(apg_dispatcher.pool.events.work_cleared.wait(), name='test_lambda_clear_wait')
await pg_message('lambda: "This worked!"')
await asyncio.wait_for(clearing_task, timeout=3)
Expand Down Expand Up @@ -93,7 +91,7 @@ async def test_cancel_task(apg_dispatcher, pg_message, pg_control):
await pg_message(msg)

clearing_task = asyncio.create_task(apg_dispatcher.pool.events.work_cleared.wait())
await asyncio.sleep(0.04)
await asyncio.sleep(0.2)
canceled_jobs = await asyncio.wait_for(pg_control.acontrol_with_reply('cancel', data={'uuid': 'foobar'}, timeout=1), timeout=5)
worker_id, canceled_message = canceled_jobs[0][0]
assert canceled_message['uuid'] == 'foobar'
Expand Down Expand Up @@ -125,8 +123,6 @@ async def test_message_with_delay(apg_dispatcher, pg_message, pg_control):

@pytest.mark.asyncio
async def test_cancel_delayed_task(apg_dispatcher, pg_message, pg_control):
assert apg_dispatcher.pool.finished_count == 0

# Send message to run task with a delay
msg = json.dumps({'task': 'lambda: print("This task should be canceled before start")', 'uuid': 'delay_task_will_cancel', 'delay': 0.8})
await pg_message(msg)
Expand All @@ -146,8 +142,6 @@ async def test_cancel_delayed_task(apg_dispatcher, pg_message, pg_control):

@pytest.mark.asyncio
async def test_cancel_with_no_reply(apg_dispatcher, pg_message, pg_control):
assert apg_dispatcher.pool.finished_count == 0

# Send message to run task with a delay
msg = json.dumps({'task': 'lambda: print("This task should be canceled before start")', 'uuid': 'delay_task_will_cancel', 'delay': 2.0})
await pg_message(msg)
Expand All @@ -164,8 +158,6 @@ async def test_cancel_with_no_reply(apg_dispatcher, pg_message, pg_control):

@pytest.mark.asyncio
async def test_alive_check(apg_dispatcher, pg_control):
assert apg_dispatcher.control_count == 0

alive = await asyncio.wait_for(pg_control.acontrol_with_reply('alive', timeout=1), timeout=5)
assert alive == [None]

Expand All @@ -174,8 +166,6 @@ async def test_alive_check(apg_dispatcher, pg_control):

@pytest.mark.asyncio
async def test_task_discard(apg_dispatcher, pg_message):
assert apg_dispatcher.pool.finished_count == 0

messages = [
json.dumps(
{'task': 'lambda: __import__("time").sleep(9)', 'on_duplicate': 'discard', 'uuid': f'dscd-{i}'}
Expand All @@ -192,8 +182,6 @@ async def test_task_discard(apg_dispatcher, pg_message):

@pytest.mark.asyncio
async def test_task_discard_in_task_definition(apg_dispatcher, test_settings):
assert apg_dispatcher.pool.finished_count == 0

for i in range(10):
test_methods.sleep_discard.apply_async(args=[2], settings=test_settings)

Expand All @@ -205,8 +193,6 @@ async def test_task_discard_in_task_definition(apg_dispatcher, test_settings):

@pytest.mark.asyncio
async def test_tasks_in_serial(apg_dispatcher, test_settings):
assert apg_dispatcher.pool.finished_count == 0

for i in range(10):
test_methods.sleep_serial.apply_async(args=[2], settings=test_settings)

Expand All @@ -218,8 +204,6 @@ async def test_tasks_in_serial(apg_dispatcher, test_settings):

@pytest.mark.asyncio
async def test_tasks_queue_one(apg_dispatcher, test_settings):
assert apg_dispatcher.pool.finished_count == 0

for i in range(10):
test_methods.sleep_queue_one.apply_async(args=[2], settings=test_settings)

Expand Down
55 changes: 48 additions & 7 deletions tests/unit/service/test_process.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,9 @@
from multiprocessing import Queue
import os

from dispatcher.service.process import ProcessManager, ProcessProxy
import pytest

from dispatcher.service.process import ProcessManager, ForkServerManager, ProcessProxy


def test_pass_messages_to_worker():
Expand All @@ -17,15 +20,53 @@ def work_loop(a, b, c, in_q, out_q):
assert msg == 'done 1 2 3 start'


def test_pass_messages_via_process_manager():
def work_loop(var, in_q, out_q):
has_read = in_q.get()
out_q.put(f'done {var} {has_read}')
def work_loop2(var, in_q, out_q):
"""
Due to the mechanics of forkserver, this can not be defined in local variables,
it has to be importable, but this _is_ importable from the test module.
"""
has_read = in_q.get()
out_q.put(f'done {var} {has_read}')

process_manager = ProcessManager()
process = process_manager.create_process(('value',), target=work_loop)

@pytest.mark.parametrize('manager_cls', [ProcessManager, ForkServerManager])
def test_pass_messages_via_process_manager(manager_cls):
process_manager = manager_cls()
process = process_manager.create_process(('value',), target=work_loop2)
process.start()

process.message_queue.put('msg1')
msg = process_manager.finished_queue.get()
assert msg == 'done value msg1'


@pytest.mark.parametrize('manager_cls', [ProcessManager, ForkServerManager])
def test_workers_have_different_pid(manager_cls):
process_manager = manager_cls()
processes = [process_manager.create_process((f'value{i}',), target=work_loop2) for i in range(2)]

for i in range(2):
process = processes[i]
process.start()
process.message_queue.put(f'msg{i}')

assert processes[0].pid != processes[1].pid # title of test

msg1 = process_manager.finished_queue.get()
msg2 = process_manager.finished_queue.get()
assert set([msg1, msg2]) == set(['done value1 msg1', 'done value0 msg0'])



def return_pid(in_q, out_q):
out_q.put(f'{os.getpid()}')


@pytest.mark.parametrize('manager_cls', [ProcessManager, ForkServerManager])
def test_pid_is_correct(manager_cls):
process_manager = manager_cls()
process = process_manager.create_process((), target=return_pid)
process.start()

msg = process_manager.finished_queue.get()
assert int(msg) == process.pid
Loading