diff --git a/dispatcherd/protocols.py b/dispatcherd/protocols.py index fd0c792..f463a73 100644 --- a/dispatcherd/protocols.py +++ b/dispatcherd/protocols.py @@ -210,6 +210,8 @@ def __iter__(self) -> Iterator[PoolWorker]: ... def get_by_id(self, worker_id: int) -> PoolWorker: ... + def move_to_end(self, worker_id: int) -> None: ... + class SharedAsyncObjects: exit_event: asyncio.Event diff --git a/dispatcherd/service/pool.py b/dispatcherd/service/pool.py index bf1b760..542b616 100644 --- a/dispatcherd/service/pool.py +++ b/dispatcherd/service/pool.py @@ -4,6 +4,7 @@ import os import signal import time +from collections import OrderedDict from typing import Any, Iterator, Literal, Optional from ..processors.blocker import Blocker @@ -173,7 +174,7 @@ def __init__(self) -> None: class WorkerData(WorkerDataProtocol): def __init__(self) -> None: - self.workers: dict[int, PoolWorker] = {} + self.workers: OrderedDict[int, PoolWorker] = OrderedDict() self.management_lock = asyncio.Lock() def __iter__(self) -> Iterator[PoolWorker]: @@ -194,6 +195,12 @@ def get_by_id(self, worker_id: int) -> PoolWorker: def remove_by_id(self, worker_id: int) -> None: del self.workers[worker_id] + def move_to_end(self, worker_id: int) -> None: + try: + self.workers.move_to_end(worker_id) + except KeyError: + logger.warning(f'Attempted to move worker_id={worker_id} to end, but worker was already removed from workers dict') + class WorkerPool(WorkerPoolProtocol): def __init__( @@ -559,6 +566,7 @@ async def process_finished(self, worker: PoolWorker, message: dict) -> None: else: self.finished_count += 1 worker.mark_finished_task() + self.workers.move_to_end(worker.worker_id) if not self.queuer.queued_messages and all(worker.current_task is None for worker in self.workers): self.events.work_cleared.set() diff --git a/tests/unit/service/test_worker_order.py b/tests/unit/service/test_worker_order.py new file mode 100644 index 0000000..95058a7 --- /dev/null +++ b/tests/unit/service/test_worker_order.py @@ -0,0 +1,104 @@ +import asyncio +from typing import AsyncIterator + +import pytest +import pytest_asyncio + +from dispatcherd.protocols import DispatcherMain +from dispatcherd.testing.asyncio import adispatcher_service + + +@pytest.fixture(scope='session') +def order_config(): + return { + "version": 2, + "service": { + "pool_kwargs": {"min_workers": 2, "max_workers": 2}, + "main_kwargs": {"node_id": "order-test"}, + }, + } + + +@pytest_asyncio.fixture +async def aorder_dispatcher(order_config) -> AsyncIterator[DispatcherMain]: + async with adispatcher_service(order_config) as dispatcher: + yield dispatcher + + +@pytest.mark.asyncio +async def test_workers_reorder_and_dispatch_longest_idle(aorder_dispatcher): + pool = aorder_dispatcher.pool + assert list(pool.workers.workers.keys()) == [0, 1] + + pool.events.work_cleared.clear() + await aorder_dispatcher.process_message({ + "task": "tests.data.methods.sleep_function", + "kwargs": {"seconds": 0.1}, + "uuid": "t1", + }) + await aorder_dispatcher.process_message({ + "task": "tests.data.methods.sleep_function", + "kwargs": {"seconds": 0.05}, + "uuid": "t2", + }) + await asyncio.wait_for(pool.events.work_cleared.wait(), timeout=1) + + assert list(pool.workers.workers.keys()) == [1, 0] + + pool.events.work_cleared.clear() + await aorder_dispatcher.process_message({ + "task": "tests.data.methods.sleep_function", + "kwargs": {"seconds": 0.01}, + "uuid": "t3", + }) + await asyncio.sleep(0.01) + assert pool.workers.get_by_id(1).current_task["uuid"] == "t3" + assert pool.workers.get_by_id(0).current_task is None + await asyncio.wait_for(pool.events.work_cleared.wait(), timeout=1) + + pool.events.work_cleared.clear() + await aorder_dispatcher.process_message({ + "task": "tests.data.methods.sleep_function", + "kwargs": {"seconds": 0.01}, + "uuid": "t4", + }) + await asyncio.sleep(0.01) + assert pool.workers.get_by_id(0).current_task["uuid"] == "t4" + await asyncio.wait_for(pool.events.work_cleared.wait(), timeout=1) + + assert list(pool.workers.workers.keys()) == [1, 0] + + +@pytest.mark.asyncio +async def test_process_finished_with_removed_worker(): + """Test that process_finished handles KeyError gracefully when worker has been removed + + This simulates the race condition where a worker finishes a task, but has been + removed from self.workers before process_finished is called. + """ + from unittest.mock import MagicMock, patch + from dispatcherd.service.pool import WorkerData, PoolWorker + + # Create a minimal WorkerData instance + worker_data = WorkerData() + + # Create a mock worker + mock_process = MagicMock() + worker = PoolWorker(worker_id=0, process=mock_process) + worker.current_task = {"uuid": "test-uuid", "task": "test.task"} + worker.finished_count = 0 + + # Add worker then remove it (simulating the race condition) + worker_data.add_worker(worker) + worker_data.remove_by_id(0) + assert 0 not in worker_data + + # Mock the logger to verify the warning is logged + with patch('dispatcherd.service.pool.logger') as mock_logger: + # Call move_to_end on the removed worker - should not raise KeyError + worker_data.move_to_end(0) + + # Verify the warning was logged + mock_logger.warning.assert_called_once() + warning_call = mock_logger.warning.call_args[0][0] + assert "Attempted to move worker_id=0 to end, but worker was already removed" in warning_call