diff --git a/Lib/asyncio/__init__.py b/Lib/asyncio/__init__.py index 4be7112fa017d4..537ee08e055d98 100644 --- a/Lib/asyncio/__init__.py +++ b/Lib/asyncio/__init__.py @@ -9,6 +9,7 @@ from .coroutines import * from .events import * from .exceptions import * +from .executor import * from .futures import * from .graph import * from .locks import * @@ -27,6 +28,7 @@ coroutines.__all__ + events.__all__ + exceptions.__all__ + + executor.__all__ + futures.__all__ + graph.__all__ + locks.__all__ + diff --git a/Lib/asyncio/executor.py b/Lib/asyncio/executor.py new file mode 100644 index 00000000000000..5b48758e18b058 --- /dev/null +++ b/Lib/asyncio/executor.py @@ -0,0 +1,226 @@ +import time +from collections.abc import AsyncIterable, Awaitable, Iterable +from dataclasses import dataclass +from typing import Any, Protocol + +from . import timeouts +from .exceptions import CancelledError +from .futures import Future +from .locks import Event +from .queues import Queue, QueueShutDown +from .tasks import FIRST_COMPLETED, Task, create_task, gather, wait + +__all__ = ( + "Executor", +) + + +class _WorkFunction[**P, R](Protocol): + def __call__(self, *args: P.args, **kwargs: P.kwargs) -> Awaitable[R]: + ... + + +@dataclass(frozen=True, slots=True) +class _WorkItem[**P, R]: + fn: _WorkFunction[P, R] + args: tuple[Any, ...] + kwargs: dict[Any, Any] + future: Future[R] + + +async def _azip(*iterables: Iterable | AsyncIterable) -> AsyncIterable[tuple]: + def _as_async_iterable[T]( + iterable: Iterable[T] | AsyncIterable[T], + ) -> AsyncIterable[T]: + async def _to_async_iterable( + iterable: Iterable[T], + ) -> AsyncIterable[T]: + for item in iterable: + yield item + + if isinstance(iterable, AsyncIterable): + return iterable + return _to_async_iterable(iterable) + + async_iterables = [_as_async_iterable(iterable) for iterable in iterables] + iterators = [aiter(async_iterable) for async_iterable in async_iterables] + while True: + try: + items = [await anext(iterator) for iterator in iterators] + yield tuple(items) + except StopAsyncIteration: + break + + +async def _consume_cancelled_future(future): + try: + await future + except CancelledError: + pass + + +class Executor[**P, R]: + _input_queue: Queue[_WorkItem[P, R]] + _workers: list[Task] + _feeders: set[Task] + _shutdown: bool = False + + def __init__(self, max_workers: int) -> None: + if max_workers <= 0: + raise ValueError("max_workers must be greater than 0") + + self._input_queue = Queue(max_workers) + self._workers = [ + create_task(self._worker()) + for _ in range(max_workers) + ] + self._feeders = set() + + async def submit( + self, + fn: _WorkFunction[P, R], + /, + *args: P.args, + **kwargs: P.kwargs, + ) -> Future[R]: + if self._shutdown: + raise RuntimeError("Cannot schedule new tasks after shutdown") + + future = Future() + work_item = _WorkItem(fn, args, kwargs, future) + await self._input_queue.put(work_item) + + return future + + async def map( + self, + fn: _WorkFunction[P, R], + *iterables: Iterable | AsyncIterable, + timeout: float | None = None, + ) -> AsyncIterable[R]: + if self._shutdown: + raise RuntimeError("Cannot schedule new tasks after shutdown") + + end_time = None if timeout is None else time.monotonic() + timeout + + inputs_stream = _azip(*iterables) + submitted_tasks = Queue[Future[R]]() + tasks_in_flight_limit = len(self._workers) + self._input_queue.maxsize + resume_feeding = Event() + + feeder_task = create_task(self._feeder( + inputs_stream, + fn, + submitted_tasks, + tasks_in_flight_limit, + resume_feeding, + )) + self._feeders.add(feeder_task) + feeder_task.add_done_callback(self._feeders.remove) + + try: + while True: + task = await submitted_tasks.get() + + remaining_time = ( + None if end_time is None else end_time - time.monotonic() + ) + if remaining_time is not None and remaining_time <= 0: + raise TimeoutError() + + async with timeouts.timeout(remaining_time): + result = await task + yield result + resume_feeding.set() + except QueueShutDown: + # The executor was shut down while map was running. + pass + finally: + feeder_task.cancel() + await _consume_cancelled_future(feeder_task) + + finalization_tasks = [] + while submitted_tasks.qsize() > 0: + task = submitted_tasks.get_nowait() + task.cancel() + finalization_tasks.append(task) + for task in finalization_tasks: + await _consume_cancelled_future(task) + + async def shutdown(self, wait=True, *, cancel_futures=False) -> None: + if self._shutdown: + return + self._shutdown = True + + if cancel_futures: + finalization_tasks = [] + while not self._input_queue.empty(): + work_item = self._input_queue.get_nowait() + work_item.future.cancel() + finalization_tasks.append(work_item.future) + for task in finalization_tasks: + await _consume_cancelled_future(task) + + self._input_queue.shutdown() + + if wait: + await gather(*self._workers) + + async def _worker(self) -> None: + while True: + try: + work_item = await self._input_queue.get() + item_future = work_item.future + + try: + if item_future.cancelled(): + continue + + task = create_task(work_item.fn( + *work_item.args, + **work_item.kwargs, + )) + await wait([task, item_future], return_when=FIRST_COMPLETED) + if not item_future.cancelled(): + item_future.set_result(task.result()) + else: + task.cancel() + except BaseException as exception: + if not item_future.cancelled(): + item_future.set_exception(exception) + finally: + self._input_queue.task_done() + except QueueShutDown: # The executor has been shut down. + break + + async def _feeder[I]( + self, + inputs_stream: AsyncIterable[I], + fn: _WorkFunction[P, R], + submitted_tasks: Queue[Future[R]], + tasks_in_flight_limit: int, + resume_feeding: Event, + ) -> None: + try: + async for args in inputs_stream: + if self._shutdown: + break + future = await self.submit(fn, *args) # type: ignore + await submitted_tasks.put(future) + + if submitted_tasks.qsize() >= tasks_in_flight_limit: + await resume_feeding.wait() + resume_feeding.clear() + except QueueShutDown: + # The executor was shut down while feeder waited to submit a + # task. + pass + finally: + submitted_tasks.shutdown() + + async def __aenter__(self) -> "Executor": + return self + + async def __aexit__(self, exc_type, exc_val, exc_tb) -> bool: + await self.shutdown(wait=True) + return False diff --git a/Lib/test/test_asyncio/test_executor.py b/Lib/test/test_asyncio/test_executor.py new file mode 100644 index 00000000000000..0456f5b69e6d85 --- /dev/null +++ b/Lib/test/test_asyncio/test_executor.py @@ -0,0 +1,209 @@ +import asyncio +import itertools +import unittest +from asyncio import Executor + + +class ExecutorSubmitTests(unittest.IsolatedAsyncioTestCase): + + async def test_submit_sleeping_function(self): + async with asyncio.timeout(1), Executor(max_workers=2) as executor: + async def async_fn(x): + await asyncio.sleep(0.1) + return x * 2 + + future = await executor.submit(async_fn, 5) + result = await future + self.assertEqual(result, 10) + + async def test_submit_function_raises_exception(self): + async with asyncio.timeout(1), Executor(max_workers=2) as executor: + async def sync_fn(x): + raise ValueError("Test exception") + + future = await executor.submit(sync_fn, 5) + with self.assertRaises(ValueError): + await future + + async def test_submit_cancel_task(self): + async with asyncio.timeout(2), Executor(max_workers=2) as executor: + was_not_cancelled = False + + async def async_fn(x): + nonlocal was_not_cancelled + await asyncio.sleep(1) + was_not_cancelled = True + return x * 2 + + future = await executor.submit(async_fn, 5) + await asyncio.sleep(0.1) + future.cancel() + with self.assertRaises(asyncio.CancelledError): + await future + self.assertFalse(was_not_cancelled) + + +class ExecutorMapTests(unittest.IsolatedAsyncioTestCase): + + async def test_map_sleeping_function(self): + async with asyncio.timeout(1), Executor(max_workers=2) as executor: + async def async_fn(x): + await asyncio.sleep(0.1) + return x * 2 + + results = [ + result + async for result in executor.map(async_fn, range(5)) + ] + self.assertEqual(results, [0, 2, 4, 6, 8]) + + async def test_map_function_raises_exception(self): + async with asyncio.timeout(1), Executor(max_workers=2) as executor: + async def sync_fn(x): + if x == 3: + raise ValueError("Test exception") + return x * 2 + + with self.assertRaises(ValueError): + _ = [ + result + async for result in executor.map(sync_fn, range(5)) + ] + + async def test_map_cancel_task(self): + async with asyncio.timeout(3), Executor(max_workers=2) as executor: + was_not_cancelled = False + + async def async_fn(x): + nonlocal was_not_cancelled + await asyncio.sleep(1) + if x == 3: + was_not_cancelled = True + return x * 2 + + async for _ in executor.map(async_fn, range(5)): + # There are 2 workers, therefore the first 2 tasks will be + # completed when we reach the break statement, at which point + # the third task would have started. + break + + await asyncio.sleep(0.1) # Make sure the 3rd task is running. + await executor.shutdown(cancel_futures=True) + self.assertFalse(was_not_cancelled) + + +class ExecutorStressTests(unittest.IsolatedAsyncioTestCase): + + async def test_map_with_large_number_of_tasks(self): + async with asyncio.timeout(4), Executor(max_workers=64) as executor: + async def sync_fn(x): + return x * 2 + + results = [ + result + async for result in executor.map(sync_fn, range(1000)) + ] + self.assertEqual(results, [x * 2 for x in range(1000)]) + + async def test_map_with_infinite_iterable(self): + async with asyncio.timeout(1), Executor(max_workers=2) as executor: + async def sync_fn(x): + return x * 2 + + results = [] + async for result in executor.map(sync_fn, itertools.count()): + results.append(result) + if len(results) >= 10: + break + self.assertEqual(results, [x * 2 for x in range(10)]) + + +class ExecutorEdgeCasesTests(unittest.IsolatedAsyncioTestCase): + + async def test_map_timeout(self): + async with asyncio.timeout(2), Executor(max_workers=2) as executor: + async def async_fn(x): + await asyncio.sleep(1) + return x * 2 + + with self.assertRaises(asyncio.TimeoutError): + _ = [ + result + async for result in executor.map( + async_fn, + range(5), + timeout=0.5, + ) + ] + + async def test_shutdown_while_tasks_running(self): + async with asyncio.timeout(3), Executor(max_workers=2) as executor: + async def async_fn(x): + await asyncio.sleep(1) + return x * 2 + + futures = [await executor.submit(async_fn, i) for i in range(5)] + # Since we used submit instead of map, and the input queue is 2, + # we have 2 finished tasks, 2 running tasks, and 1 pending task. + # Therefore, the graceful shutdown will let the 2 running tasks + # finish, and cancel the pending task. + await executor.shutdown(cancel_futures=True) + for future in futures[:4]: + await future + with self.assertRaises(asyncio.CancelledError): + await futures[-1] + + async def test_resource_cleanup(self): + async with asyncio.timeout(1), Executor(max_workers=2) as executor: + async def async_fn(x): + await asyncio.sleep(0.1) + return x * 2 + + futures = [await executor.submit(async_fn, i) for i in range(5)] + await executor.shutdown() + for future in futures: + self.assertTrue(future.done()) + + async def test_reject_submit_on_shutdown_executor(self): + executor = Executor(max_workers=2) + await executor.shutdown() + with self.assertRaises(RuntimeError): + await executor.submit(lambda x: x, 5) + + async def test_reject_map_on_shutdown_executor(self): + executor = Executor(max_workers=2) + await executor.shutdown() + with self.assertRaises(RuntimeError): + async for _ in executor.map(lambda x: x, range(5)): + pass + + async def test_submitted_task_cancellation_after_leaving_context(self): + # This test also verifies that all created async generators are + # terminated gracefully, which requires gracefully terminating the + # feeders. + async with asyncio.timeout(1), Executor(max_workers=1) as executor: + async def async_fn(x): + await asyncio.sleep(0.1) + return x * 2 + + results = [] + results_stream = executor.map(async_fn, range(3)) + results = [await anext(results_stream)] + # Exiting the async content manager calls shutdown(wait=True, + # cancel_futures=False). With max_workers=1, the first task will be + # running, and the second task will be pending in queue. The third task + # will not be submitted yet. Therefore, the first and second tasks will + # be completed, and the third task will be cancelled (cancel_futures + # cancels already scheduled tasks). + + await asyncio.sleep(0.1) + + results += [ + result + async for result in results_stream + ] + self.assertEqual(results, [0, 2]) + + +if __name__ == '__main__': + unittest.main()