diff --git a/changelog/1144.feature b/changelog/1144.feature new file mode 100644 index 00000000..911d3d00 --- /dev/null +++ b/changelog/1144.feature @@ -0,0 +1,3 @@ +The internal `steal` command is now atomic - it unschedules either all requested tests or none. + +This is a prerequisite for group/scope support in the `worksteal` scheduler, so test groups won't be broken up incorrectly. diff --git a/src/xdist/remote.py b/src/xdist/remote.py index dd1f9883..5439f6f0 100644 --- a/src/xdist/remote.py +++ b/src/xdist/remote.py @@ -8,6 +8,7 @@ from __future__ import annotations +import collections import contextlib import enum import os @@ -15,9 +16,11 @@ import time from typing import Any from typing import Generator +from typing import Iterable from typing import Literal from typing import Sequence from typing import TypedDict +from typing import Union import warnings from _pytest.config import _prepareconfig @@ -66,7 +69,44 @@ def worker_title(title: str) -> None: class Marker(enum.Enum): SHUTDOWN = 0 - QUEUE_REPLACED = 1 + + +class TestQueue: + """A simple queue that can be inspected and modified while the lock is held via the ``lock()`` method.""" + + Item = Union[int, Literal[Marker.SHUTDOWN]] + + def __init__(self, execmodel: execnet.gateway_base.ExecModel): + self._items: collections.deque[TestQueue.Item] = collections.deque() + self._lock = execmodel.RLock() # type: ignore[no-untyped-call] + self._has_items_event = execmodel.Event() + + def get(self) -> Item: + while True: + with self.lock() as locked_items: + if locked_items: + return locked_items.popleft() + + self._has_items_event.wait() + + def put(self, item: Item) -> None: + with self.lock() as locked_items: + locked_items.append(item) + + def replace(self, iterable: Iterable[Item]) -> None: + with self.lock(): + self._items = collections.deque(iterable) + + @contextlib.contextmanager + def lock(self) -> Generator[collections.deque[Item], None, None]: + with self._lock: + try: + yield self._items + finally: + if self._items: + self._has_items_event.set() + else: + self._has_items_event.clear() class WorkerInteractor: @@ -77,22 +117,10 @@ def __init__(self, config: pytest.Config, channel: execnet.Channel) -> None: self.testrunuid = workerinput["testrunuid"] self.log = Producer(f"worker-{self.workerid}", enabled=config.option.debug) self.channel = channel - self.torun = self._make_queue() + self.torun = TestQueue(self.channel.gateway.execmodel) self.nextitem_index: int | None | Literal[Marker.SHUTDOWN] = None config.pluginmanager.register(self) - def _make_queue(self) -> Any: - return self.channel.gateway.execmodel.queue.Queue() - - def _get_next_item_index(self) -> int | Literal[Marker.SHUTDOWN]: - """Gets the next item from test queue. Handles the case when the queue - is replaced concurrently in another thread. - """ - result = self.torun.get() - while result is Marker.QUEUE_REPLACED: - result = self.torun.get() - return result # type: ignore[no-any-return] - def sendevent(self, name: str, **kwargs: object) -> None: self.log("sending", name, kwargs) self.channel.send((name, kwargs)) @@ -146,30 +174,34 @@ def handle_command( self.steal(kwargs["indices"]) def steal(self, indices: Sequence[int]) -> None: - indices_set = set(indices) - stolen = [] + """ + Remove tests from the queue. - old_queue, self.torun = self.torun, self._make_queue() + Removes either all requested tests, or none, if some of these tests + are not in the queue (for example, if they were processed already). - def old_queue_get_nowait_noraise() -> int | None: - with contextlib.suppress(self.channel.gateway.execmodel.queue.Empty): - return old_queue.get_nowait() # type: ignore[no-any-return] - return None + :param indices: indices of the tests to remove. + """ + requested_set = set(indices) + + with self.torun.lock() as locked_queue: + stolen = list(item for item in locked_queue if item in requested_set) - for i in iter(old_queue_get_nowait_noraise, None): - if i in indices_set: - stolen.append(i) + # Stealing only if all requested tests are still pending + if len(stolen) == len(requested_set): + self.torun.replace( + item for item in locked_queue if item not in requested_set + ) else: - self.torun.put(i) + stolen = [] self.sendevent("unscheduled", indices=stolen) - old_queue.put(Marker.QUEUE_REPLACED) @pytest.hookimpl def pytest_runtestloop(self, session: pytest.Session) -> bool: self.log("entering main loop") self.channel.setcallback(self.handle_command, endmarker=Marker.SHUTDOWN) - self.nextitem_index = self._get_next_item_index() + self.nextitem_index = self.torun.get() while self.nextitem_index is not Marker.SHUTDOWN: self.run_one_test() if session.shouldfail or session.shouldstop: @@ -179,7 +211,7 @@ def pytest_runtestloop(self, session: pytest.Session) -> bool: def run_one_test(self) -> None: assert isinstance(self.nextitem_index, int) self.item_index = self.nextitem_index - self.nextitem_index = self._get_next_item_index() + self.nextitem_index = self.torun.get() items = self.session.items item = items[self.item_index] diff --git a/testing/test_remote.py b/testing/test_remote.py index 0b0334dc..b995cc4a 100644 --- a/testing/test_remote.py +++ b/testing/test_remote.py @@ -267,6 +267,12 @@ def test_func4(): pass worker.sendcommand("steal", indices=[1, 2]) ev = worker.popevent("unscheduled") + # Cannot steal index 1 because it is completed already, so do not steal any. + assert ev.kwargs["indices"] == [] + + # Index 2 can be stolen, as it is still pending. + worker.sendcommand("steal", indices=[2]) + ev = worker.popevent("unscheduled") assert ev.kwargs["indices"] == [2] reports = [