Skip to content

Commit 7c104c7

Browse files
committed
Add test for worker ordering
1 parent 399c14b commit 7c104c7

File tree

3 files changed

+77
-1
lines changed

3 files changed

+77
-1
lines changed

dispatcherd/protocols.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -210,6 +210,8 @@ def __iter__(self) -> Iterator[PoolWorker]: ...
210210

211211
def get_by_id(self, worker_id: int) -> PoolWorker: ...
212212

213+
def move_to_end(self, worker_id: int) -> None: ...
214+
213215

214216
class SharedAsyncObjects:
215217
exit_event: asyncio.Event

dispatcherd/service/pool.py

Lines changed: 6 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -5,6 +5,7 @@
55
import signal
66
import time
77
from typing import Any, Iterator, Literal, Optional
8+
from collections import OrderedDict
89

910
from ..processors.blocker import Blocker
1011
from ..processors.queuer import Queuer
@@ -173,7 +174,7 @@ def __init__(self) -> None:
173174

174175
class WorkerData(WorkerDataProtocol):
175176
def __init__(self) -> None:
176-
self.workers: dict[int, PoolWorker] = {}
177+
self.workers: OrderedDict[int, PoolWorker] = OrderedDict()
177178
self.management_lock = asyncio.Lock()
178179

179180
def __iter__(self) -> Iterator[PoolWorker]:
@@ -194,6 +195,9 @@ def get_by_id(self, worker_id: int) -> PoolWorker:
194195
def remove_by_id(self, worker_id: int) -> None:
195196
del self.workers[worker_id]
196197

198+
def move_to_end(self, worker_id: int) -> None:
199+
self.workers.move_to_end(worker_id)
200+
197201

198202
class WorkerPool(WorkerPoolProtocol):
199203
def __init__(
@@ -559,6 +563,7 @@ async def process_finished(self, worker: PoolWorker, message: dict) -> None:
559563
else:
560564
self.finished_count += 1
561565
worker.mark_finished_task()
566+
self.workers.move_to_end(worker.worker_id)
562567

563568
if not self.queuer.queued_messages and all(worker.current_task is None for worker in self.workers):
564569
self.events.work_cleared.set()
Lines changed: 69 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,69 @@
1+
import asyncio
2+
from typing import AsyncIterator
3+
4+
import pytest
5+
import pytest_asyncio
6+
7+
from dispatcherd.protocols import DispatcherMain
8+
from dispatcherd.testing.asyncio import adispatcher_service
9+
10+
11+
@pytest.fixture(scope='session')
12+
def order_config():
13+
return {
14+
"version": 2,
15+
"service": {
16+
"pool_kwargs": {"min_workers": 2, "max_workers": 2},
17+
"main_kwargs": {"node_id": "order-test"},
18+
},
19+
}
20+
21+
22+
@pytest_asyncio.fixture
23+
async def aorder_dispatcher(order_config) -> AsyncIterator[DispatcherMain]:
24+
async with adispatcher_service(order_config) as dispatcher:
25+
yield dispatcher
26+
27+
28+
@pytest.mark.asyncio
29+
async def test_workers_reorder_and_dispatch_longest_idle(aorder_dispatcher):
30+
pool = aorder_dispatcher.pool
31+
assert list(pool.workers.workers.keys()) == [0, 1]
32+
33+
pool.events.work_cleared.clear()
34+
await aorder_dispatcher.process_message({
35+
"task": "tests.data.methods.sleep_function",
36+
"kwargs": {"seconds": 0.1},
37+
"uuid": "t1",
38+
})
39+
await aorder_dispatcher.process_message({
40+
"task": "tests.data.methods.sleep_function",
41+
"kwargs": {"seconds": 0.05},
42+
"uuid": "t2",
43+
})
44+
await asyncio.wait_for(pool.events.work_cleared.wait(), timeout=1)
45+
46+
assert list(pool.workers.workers.keys()) == [1, 0]
47+
48+
pool.events.work_cleared.clear()
49+
await aorder_dispatcher.process_message({
50+
"task": "tests.data.methods.sleep_function",
51+
"kwargs": {"seconds": 0.01},
52+
"uuid": "t3",
53+
})
54+
await asyncio.sleep(0.01)
55+
assert pool.workers.get_by_id(1).current_task["uuid"] == "t3"
56+
assert pool.workers.get_by_id(0).current_task is None
57+
await asyncio.wait_for(pool.events.work_cleared.wait(), timeout=1)
58+
59+
pool.events.work_cleared.clear()
60+
await aorder_dispatcher.process_message({
61+
"task": "tests.data.methods.sleep_function",
62+
"kwargs": {"seconds": 0.01},
63+
"uuid": "t4",
64+
})
65+
await asyncio.sleep(0.01)
66+
assert pool.workers.get_by_id(0).current_task["uuid"] == "t4"
67+
await asyncio.wait_for(pool.events.work_cleared.wait(), timeout=1)
68+
69+
assert list(pool.workers.workers.keys()) == [1, 0]

0 commit comments

Comments
 (0)