Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 2 additions & 0 deletions dispatcherd/protocols.py
Original file line number Diff line number Diff line change
Expand Up @@ -210,6 +210,8 @@ def __iter__(self) -> Iterator[PoolWorker]: ...

def get_by_id(self, worker_id: int) -> PoolWorker: ...

def move_to_end(self, worker_id: int) -> None: ...


class SharedAsyncObjects:
exit_event: asyncio.Event
Expand Down
10 changes: 9 additions & 1 deletion dispatcherd/service/pool.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,7 @@
import os
import signal
import time
from collections import OrderedDict
from typing import Any, Iterator, Literal, Optional

from ..processors.blocker import Blocker
Expand Down Expand Up @@ -173,7 +174,7 @@ def __init__(self) -> None:

class WorkerData(WorkerDataProtocol):
def __init__(self) -> None:
self.workers: dict[int, PoolWorker] = {}
self.workers: OrderedDict[int, PoolWorker] = OrderedDict()
self.management_lock = asyncio.Lock()

def __iter__(self) -> Iterator[PoolWorker]:
Expand All @@ -194,6 +195,12 @@ def get_by_id(self, worker_id: int) -> PoolWorker:
def remove_by_id(self, worker_id: int) -> None:
del self.workers[worker_id]

def move_to_end(self, worker_id: int) -> None:
try:
self.workers.move_to_end(worker_id)
except KeyError:
logger.warning(f'Attempted to move worker_id={worker_id} to end, but worker was already removed from workers dict')


class WorkerPool(WorkerPoolProtocol):
def __init__(
Expand Down Expand Up @@ -559,6 +566,7 @@ async def process_finished(self, worker: PoolWorker, message: dict) -> None:
else:
self.finished_count += 1
worker.mark_finished_task()
self.workers.move_to_end(worker.worker_id)
Copy link

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Bug: Race Condition in Worker Management

A KeyError can occur at line 566 when mark_worker_done() calls self.workers.move_to_end(). This is a race condition where manage_old_workers() removes a worker from self.workers after it's been retrieved for processing, but before move_to_end() is called, even though the call is within the management_lock.

Fix in Cursor Fix in Web


if not self.queuer.queued_messages and all(worker.current_task is None for worker in self.workers):
self.events.work_cleared.set()
Expand Down
104 changes: 104 additions & 0 deletions tests/unit/service/test_worker_order.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,104 @@
import asyncio
from typing import AsyncIterator

import pytest
import pytest_asyncio

from dispatcherd.protocols import DispatcherMain
from dispatcherd.testing.asyncio import adispatcher_service


@pytest.fixture(scope='session')
def order_config():
return {
"version": 2,
"service": {
"pool_kwargs": {"min_workers": 2, "max_workers": 2},
"main_kwargs": {"node_id": "order-test"},
},
}


@pytest_asyncio.fixture
async def aorder_dispatcher(order_config) -> AsyncIterator[DispatcherMain]:
async with adispatcher_service(order_config) as dispatcher:
yield dispatcher


@pytest.mark.asyncio
async def test_workers_reorder_and_dispatch_longest_idle(aorder_dispatcher):
pool = aorder_dispatcher.pool
assert list(pool.workers.workers.keys()) == [0, 1]

pool.events.work_cleared.clear()
await aorder_dispatcher.process_message({
"task": "tests.data.methods.sleep_function",
"kwargs": {"seconds": 0.1},
"uuid": "t1",
})
await aorder_dispatcher.process_message({
"task": "tests.data.methods.sleep_function",
"kwargs": {"seconds": 0.05},
"uuid": "t2",
})
await asyncio.wait_for(pool.events.work_cleared.wait(), timeout=1)

assert list(pool.workers.workers.keys()) == [1, 0]

pool.events.work_cleared.clear()
await aorder_dispatcher.process_message({
"task": "tests.data.methods.sleep_function",
"kwargs": {"seconds": 0.01},
"uuid": "t3",
})
await asyncio.sleep(0.01)
assert pool.workers.get_by_id(1).current_task["uuid"] == "t3"
assert pool.workers.get_by_id(0).current_task is None
await asyncio.wait_for(pool.events.work_cleared.wait(), timeout=1)

pool.events.work_cleared.clear()
await aorder_dispatcher.process_message({
"task": "tests.data.methods.sleep_function",
"kwargs": {"seconds": 0.01},
"uuid": "t4",
})
await asyncio.sleep(0.01)
assert pool.workers.get_by_id(0).current_task["uuid"] == "t4"
await asyncio.wait_for(pool.events.work_cleared.wait(), timeout=1)

assert list(pool.workers.workers.keys()) == [1, 0]


@pytest.mark.asyncio
async def test_process_finished_with_removed_worker():
"""Test that process_finished handles KeyError gracefully when worker has been removed

This simulates the race condition where a worker finishes a task, but has been
removed from self.workers before process_finished is called.
"""
from unittest.mock import MagicMock, patch
from dispatcherd.service.pool import WorkerData, PoolWorker

# Create a minimal WorkerData instance
worker_data = WorkerData()

# Create a mock worker
mock_process = MagicMock()
worker = PoolWorker(worker_id=0, process=mock_process)
worker.current_task = {"uuid": "test-uuid", "task": "test.task"}
worker.finished_count = 0

# Add worker then remove it (simulating the race condition)
worker_data.add_worker(worker)
worker_data.remove_by_id(0)
assert 0 not in worker_data

# Mock the logger to verify the warning is logged
with patch('dispatcherd.service.pool.logger') as mock_logger:
# Call move_to_end on the removed worker - should not raise KeyError
worker_data.move_to_end(0)

# Verify the warning was logged
mock_logger.warning.assert_called_once()
warning_call = mock_logger.warning.call_args[0][0]
assert "Attempted to move worker_id=0 to end, but worker was already removed" in warning_call