diff --git a/pytest_asyncio/plugin.py b/pytest_asyncio/plugin.py index ec52ee4c..819ceca9 100644 --- a/pytest_asyncio/plugin.py +++ b/pytest_asyncio/plugin.py @@ -16,6 +16,7 @@ from collections.abc import ( AsyncIterator, Awaitable, + Coroutine as CoroutineT, Generator, Iterable, Iterator, @@ -276,6 +277,69 @@ def _fixture_synchronizer( AsyncGenFixtureYieldType = TypeVar("AsyncGenFixtureYieldType") +def _create_task_in_context( + coro: CoroutineT[Any, Any, Any], + loop: AbstractEventLoop, + context: contextvars.Context, +) -> asyncio.Task[Any]: + if sys.version_info >= (3, 11): + return loop.create_task(coro, context=context) + + from backports.asyncio.runner._patch import _patch_object + from backports.asyncio.runner.tasks import Task + + with ( + _patch_object(asyncio.tasks, asyncio.tasks.Task.__name__, Task), + _patch_object(contextvars, contextvars.copy_context.__name__, lambda: context), + ): + return loop.create_task(coro) + + +class _FixtureRunner: + def __init__(self, loop: AbstractEventLoop, context: contextvars.Context) -> None: + self.loop = loop + self.queue: asyncio.Queue[tuple[Awaitable[Any], asyncio.Future[Any]] | None] = ( + asyncio.Queue() + ) + self._context = context + self._task = None + + async def _worker(self) -> None: + while True: + item = await self.queue.get() + if item is None: + break + coro, future = item + try: + retval = await coro + future.set_result(retval) + except Exception as exc: + future.set_exception(exc) + + def run(self, func): + return self.loop.run_until_complete(self._run(func)) + + async def _run(self, func): + if self._task is None: + self._task = _create_task_in_context( + self._worker(), loop=self.loop, context=self._context + ) + + coro = func() + future = self.loop.create_future() + self.queue.put_nowait((coro, future)) + return await future + + async def _stop(self): + self.queue.put_nowait(None) + if self._task is not None: + await self._task + self._task = None + + def stop(self) -> None: + self.loop.run_until_complete(self._stop()) + + def _wrap_asyncgen_fixture( fixture_function: Callable[ AsyncGenFixtureParams, AsyncGeneratorType[AsyncGenFixtureYieldType, Any] @@ -295,7 +359,8 @@ async def setup(): return res context = contextvars.copy_context() - result = runner.run(setup(), context=context) + fixture_runner = _FixtureRunner(loop=runner.get_loop(), context=context) + result = fixture_runner.run(setup) reset_contextvars = _apply_contextvar_changes(context) @@ -312,7 +377,8 @@ async def async_finalizer() -> None: msg += "Yield only once." raise ValueError(msg) - runner.run(async_finalizer(), context=context) + fixture_runner.run(async_finalizer) + fixture_runner.stop() if reset_contextvars is not None: reset_contextvars() diff --git a/tests/async_fixtures/test_async_fixtures.py b/tests/async_fixtures/test_async_fixtures.py index 16478539..b8ccf190 100644 --- a/tests/async_fixtures/test_async_fixtures.py +++ b/tests/async_fixtures/test_async_fixtures.py @@ -37,3 +37,14 @@ async def async_fixture_method(self): @pytest.mark.asyncio async def test_async_fixture_method(self): assert self.is_same_instance + + +@pytest.fixture() +async def setup_and_teardown_tasks(): + task = asyncio.current_task() + yield + assert task is asyncio.current_task() + + +async def test_setup_and_teardown_tasks(setup_and_teardown_tasks): + pass