diff --git a/asgiref/current_thread_executor.py b/asgiref/current_thread_executor.py index 67a7926f..1934e623 100644 --- a/asgiref/current_thread_executor.py +++ b/asgiref/current_thread_executor.py @@ -1,8 +1,8 @@ -import queue import sys import threading +from collections import deque from concurrent.futures import Executor, Future -from typing import TYPE_CHECKING, Any, Callable, TypeVar, Union +from typing import TYPE_CHECKING, Any, Callable, TypeVar if sys.version_info >= (3, 10): from typing import ParamSpec @@ -53,10 +53,12 @@ class CurrentThreadExecutor(Executor): the thread they came from. """ - def __init__(self) -> None: + def __init__(self, old_executor: "CurrentThreadExecutor | None") -> None: self._work_thread = threading.current_thread() - self._work_queue: queue.Queue[Union[_WorkItem, "Future[Any]"]] = queue.Queue() - self._broken = False + self._work_ready = threading.Condition(threading.Lock()) + self._work_items = deque[_WorkItem]() # synchronized by _work_ready + self._broken = False # synchronized by _work_ready + self._old_executor = old_executor def run_until_future(self, future: "Future[Any]") -> None: """ @@ -68,20 +70,25 @@ def run_until_future(self, future: "Future[Any]") -> None: raise RuntimeError( "You cannot run CurrentThreadExecutor from a different thread" ) - future.add_done_callback(self._work_queue.put) - # Keep getting and running work items until we get the future we're waiting for - # back via the future's done callback. - try: - while True: + + def done(future: "Future[Any]") -> None: + with self._work_ready: + self._broken = True + self._work_ready.notify() + + future.add_done_callback(done) + # Keep getting and running work items until the future we're waiting for + # is done and the queue is empty. + while True: + with self._work_ready: + while not self._work_items and not self._broken: + self._work_ready.wait() + if not self._work_items: + break # Get a work item and run it - work_item = self._work_queue.get() - if work_item is future: - return - assert isinstance(work_item, _WorkItem) - work_item.run() - del work_item - finally: - self._broken = True + work_item = self._work_items.popleft() + work_item.run() + del work_item def _submit( self, @@ -94,13 +101,23 @@ def _submit( raise RuntimeError( "You cannot submit onto CurrentThreadExecutor from its own thread" ) - # Check they're not too late or the executor errored - if self._broken: - raise RuntimeError("CurrentThreadExecutor already quit or is broken") - # Add to work queue f: "Future[_R]" = Future() work_item = _WorkItem(f, fn, *args, **kwargs) - self._work_queue.put(work_item) + + # Walk up the CurrentThreadExecutor stack to find the closest one still + # running + executor = self + while True: + with executor._work_ready: + if not executor._broken: + # Add to work queue + executor._work_items.append(work_item) + executor._work_ready.notify() + break + if executor._old_executor is None: + raise RuntimeError("CurrentThreadExecutor already quit or is broken") + executor = executor._old_executor + # Return the future return f diff --git a/asgiref/sync.py b/asgiref/sync.py index 377075d1..4c3a48b9 100644 --- a/asgiref/sync.py +++ b/asgiref/sync.py @@ -196,7 +196,7 @@ def __call__(self, *args: _P.args, **kwargs: _P.kwargs) -> _R: # need one for every sync frame, even if there's one above us in the # same thread. old_executor = getattr(self.executors, "current", None) - current_executor = CurrentThreadExecutor() + current_executor = CurrentThreadExecutor(old_executor) self.executors.current = current_executor # Wrapping context in list so it can be reassigned from within diff --git a/tests/test_sync.py b/tests/test_sync.py index 0c67308c..6226770c 100644 --- a/tests/test_sync.py +++ b/tests/test_sync.py @@ -1208,3 +1208,76 @@ def test_function(**kwargs: Any) -> None: # SyncToAsync.__call__.loop.run_in_executor has a param named `task_context`. await test_function(task_context=1) + + +def test_nested_task() -> None: + async def inner() -> asyncio.Task[None]: + return asyncio.create_task(sync_to_async(print)("inner")) + + async def main() -> None: + task = await sync_to_async(async_to_sync(inner))() + await task + + async_to_sync(main)() + + +def test_nested_task_later() -> None: + def later(fut: asyncio.Future[asyncio.Task[None]]) -> None: + task = asyncio.create_task(sync_to_async(print)("later")) + fut.set_result(task) + + async def inner() -> asyncio.Future[asyncio.Task[None]]: + loop = asyncio.get_running_loop() + fut = loop.create_future() + loop.call_later(0.1, later, fut) + return fut + + async def main() -> None: + fut = await sync_to_async(async_to_sync(inner))() + task = await fut + await task + + async_to_sync(main)() + + +def test_double_nested_task() -> None: + async def inner() -> asyncio.Task[None]: + return asyncio.create_task(sync_to_async(print)("inner")) + + async def outer() -> asyncio.Task[asyncio.Task[None]]: + return asyncio.create_task(sync_to_async(async_to_sync(inner))()) + + async def main() -> None: + outer_task = await sync_to_async(async_to_sync(outer))() + inner_task = await outer_task + await inner_task + + async_to_sync(main)() + + +# asyncio.Barrier is new in Python 3.11. Nest definition (rather than using +# skipIf) to avoid mypy error. +if sys.version_info >= (3, 11): + + def test_two_nested_tasks_with_asyncio_run() -> None: + barrier = asyncio.Barrier(3) + event = threading.Event() + + async def inner() -> None: + task = asyncio.create_task(sync_to_async(event.wait)()) + await barrier.wait() + await task + + async def outer() -> tuple[asyncio.Task[None], asyncio.Task[None]]: + task0 = asyncio.create_task(inner()) + task1 = asyncio.create_task(inner()) + await barrier.wait() + event.set() + return task0, task1 + + async def main() -> None: + task0, task1 = await sync_to_async(async_to_sync(outer))() + await task0 + await task1 + + asyncio.run(main())