diff --git a/aiolibs_executor/_executor.py b/aiolibs_executor/_executor.py index d422cc8..394c8c6 100644 --- a/aiolibs_executor/_executor.py +++ b/aiolibs_executor/_executor.py @@ -5,12 +5,14 @@ from asyncio import ( AbstractEventLoop, CancelledError, + Condition, Future, Queue, QueueShutDown, Task, gather, get_running_loop, + sleep, ) from collections.abc import ( AsyncIterable, @@ -19,6 +21,7 @@ Coroutine, Iterable, ) +from contextlib import AbstractAsyncContextManager from types import TracebackType from typing import Any, Self, final, overload from warnings import catch_warnings @@ -33,6 +36,7 @@ def __init__( num_workers: int = 0, *, max_pending: int = 0, + max_throughput: int = 0, task_name_prefix: str = "", ) -> None: if num_workers == 0: @@ -41,11 +45,15 @@ def __init__( raise ValueError("num_workers must be greater than 0") if max_pending < 0: raise ValueError("max_pending must be non-negative number") + if max_throughput < 0: + raise ValueError("max_throughput must be non-negative number") self._num_workers = num_workers + self._max_throughput = max_throughput self._task_name_prefix = ( task_name_prefix or f"Executor-{Executor._counter()}" ) self._loop: AbstractEventLoop | None = None + self._rate_limiter: _RateLimiter | None = None self._shutdown = False self._work_items: Queue[_WorkItem[Any]] = Queue(max_pending) # tasks are much cheaper than threads or processes, @@ -72,8 +80,8 @@ def submit_nowait[R]( *, context: contextvars.Context | None = None, ) -> Future[R]: - loop = self._lazy_init() - work_item = _WorkItem(coro, loop, context) + loop, rate_limiter = self._lazy_init() + work_item = _WorkItem(coro, loop, rate_limiter, context) self._work_items.put_nowait(work_item) return work_item.future @@ -84,8 +92,8 @@ async def submit[R]( *, context: contextvars.Context | None = None, ) -> Future[R]: - loop = self._lazy_init() - work_item = _WorkItem(coro, loop, context) + loop, rate_limiter = self._lazy_init() + work_item = _WorkItem(coro, loop, rate_limiter, context) await self._work_items.put(work_item) return work_item.future @@ -98,6 +106,7 @@ def map[R, T1]( *, context: contextvars.Context | None = None, ) -> AsyncIterator[R]: ... + @overload def map[R, T1, T2]( self, @@ -108,6 +117,7 @@ def map[R, T1, T2]( *, context: contextvars.Context | None = None, ) -> AsyncIterator[R]: ... + @overload def map[R, T1, T2, T3]( self, @@ -119,6 +129,7 @@ def map[R, T1, T2, T3]( *, context: contextvars.Context | None = None, ) -> AsyncIterator[R]: ... + @overload def map[R, T1, T2, T3, T4]( self, @@ -131,6 +142,7 @@ def map[R, T1, T2, T3, T4]( *, context: contextvars.Context | None = None, ) -> AsyncIterator[R]: ... + @overload def map[R, T1, T2, T3, T4, T5]( self, @@ -153,10 +165,10 @@ async def map[R]( *iterables: Iterable[Any], context: contextvars.Context | None = None, ) -> AsyncIterator[R]: - loop = self._lazy_init() + loop, rate_limiter = self._lazy_init() work_items: list[_WorkItem[R]] = [] for args in zip(iterable, *iterables, strict=False): - work_item = _WorkItem(fn(*args), loop, context) + work_item = _WorkItem(fn(*args), loop, rate_limiter, context) await self._work_items.put(work_item) work_items.append(work_item) async for ret in self._process_items(work_items): @@ -171,6 +183,7 @@ def amap[R, T1]( *, context: contextvars.Context | None = None, ) -> AsyncIterator[R]: ... + @overload def amap[R, T1, T2]( self, @@ -181,6 +194,7 @@ def amap[R, T1, T2]( *, context: contextvars.Context | None = None, ) -> AsyncIterator[R]: ... + @overload def amap[R, T1, T2, T3]( self, @@ -192,6 +206,7 @@ def amap[R, T1, T2, T3]( *, context: contextvars.Context | None = None, ) -> AsyncIterator[R]: ... + @overload def amap[R, T1, T2, T3, T4]( self, @@ -204,6 +219,7 @@ def amap[R, T1, T2, T3, T4]( *, context: contextvars.Context | None = None, ) -> AsyncIterator[R]: ... + @overload def amap[R, T1, T2, T3, T4, T5]( self, @@ -217,6 +233,7 @@ def amap[R, T1, T2, T3, T4, T5]( *, context: contextvars.Context | None = None, ) -> AsyncIterator[R]: ... + async def amap[R]( self, fn: Callable[..., Coroutine[Any, Any, R]], @@ -225,13 +242,13 @@ async def amap[R]( *iterables: AsyncIterable[Any], context: contextvars.Context | None = None, ) -> AsyncIterator[R]: - loop = self._lazy_init() + loop, rate_limiter = self._lazy_init() work_items: list[_WorkItem[R]] = [] its = [aiter(iterable)] + [aiter(ait) for ait in iterables] while True: try: args = [await anext(it) for it in its] - work_item = _WorkItem(fn(*args), loop, context) + work_item = _WorkItem(fn(*args), loop, rate_limiter, context) await self._work_items.put(work_item) work_items.append(work_item) except StopAsyncIteration: @@ -278,37 +295,43 @@ async def shutdown( finally: del excs - def _lazy_init(self) -> AbstractEventLoop: + def _lazy_init(self) -> tuple[AbstractEventLoop, "_RateLimiter"]: if self._shutdown: raise RuntimeError("cannot schedule new futures after shutdown") if self._loop is not None: + if self._rate_limiter is None: + raise RuntimeError(f"{self!r} failed to fully initialize") try: loop = get_running_loop() except RuntimeError: # do nothing and reuse previously stored self._loop # to allow .submit_nowait() call from non-asyncio code - return self._loop + return self._loop, self._rate_limiter else: # the loop check technique is borrowed from asyncio.locks. if loop is not self._loop: raise RuntimeError( f"{self!r} is bound to a different event loop" ) - return loop + return loop, self._rate_limiter else: loop = get_running_loop() with _global_lock: # double-checked locking has a very low chance to have # self._loop assigned by another thread; - # test suite doen't cover this case + # test suite doesn't cover this case if self._loop is None: # pragma: no branch self._loop = loop + self._rate_limiter = _RateLimiter( + loop, + self._max_throughput, + ) for i in range(self._num_workers): task_name = self._task_name_prefix + f"_{i}" self._tasks.append( loop.create_task(self._work(task_name), name=task_name) ) - return loop + return loop, self._rate_limiter async def _process_items[R]( self, work_items: list["_WorkItem[R]"] @@ -336,10 +359,57 @@ async def _work(self, prefix: str) -> None: _global_lock = threading.Lock() +class _RateLimiter(AbstractAsyncContextManager["_RateLimiter"]): + def __init__( + self, + event_loop: AbstractEventLoop, + max_throughput: int = 0, + time_window: float = 1.0, + ) -> None: + if max_throughput < 0: + raise ValueError("max_throughput must be non-negative number") + if time_window <= 0: + raise ValueError("time_window must be positive number") + super().__init__() + self._event_loop = event_loop + self._throttling_lock = Condition() + self._max_throughput = max_throughput + self._time_window = time_window + self._tracked_tasks = 0 + + async def __aenter__(self) -> Self: + async with self._throttling_lock: + await self._throttling_lock.wait_for(self._allow_entry) + self._tracked_tasks += 1 + + return self + + async def __aexit__( + self, + exc_type: type[BaseException] | None, + exc_val: BaseException | None, + exc_tb: TracebackType | None, + ) -> None: + self._event_loop.create_task(self._remove_tracked_task()) + + def _allow_entry(self) -> bool: + return ( + self._max_throughput == 0 + or self._tracked_tasks < self._max_throughput + ) + + async def _remove_tracked_task(self) -> None: + await sleep(self._time_window) + async with self._throttling_lock: + self._tracked_tasks -= 1 + self._throttling_lock.notify_all() + + @dataclasses.dataclass class _WorkItem[R]: coro: Coroutine[Any, Any, R] loop: AbstractEventLoop + rate_limiter: _RateLimiter context: contextvars.Context | None task: Task[R] | None = None @@ -358,20 +428,21 @@ async def execute(self, prefix: str) -> None: # Some custom coroutines and mocks could not have __qualname__, # don't add a suffix in this case. pass - self.task = task = self.loop.create_task( - self.coro, context=self.context, name=name - ) - fut.add_done_callback(self.done_callback) - try: - ret = await task - except CancelledError: - fut.cancel() - except BaseException as ex: - if not fut.done(): - fut.set_exception(ex) - else: - if not fut.done(): - fut.set_result(ret) + async with self.rate_limiter: + self.task = task = self.loop.create_task( + self.coro, context=self.context, name=name + ) + fut.add_done_callback(self.done_callback) + try: + ret = await task + except CancelledError: + fut.cancel() + except BaseException as ex: + if not fut.done(): + fut.set_exception(ex) + else: + if not fut.done(): + fut.set_result(ret) def cancel(self) -> None: fut = self.future diff --git a/tests/test_executor.py b/tests/test_executor.py index da81175..b3add43 100644 --- a/tests/test_executor.py +++ b/tests/test_executor.py @@ -1,10 +1,12 @@ import asyncio import unittest +import warnings from collections.abc import AsyncIterator from contextvars import ContextVar, copy_context from typing import Any from aiolibs_executor import Executor +from aiolibs_executor._executor import _RateLimiter class BaseTestCase(unittest.IsolatedAsyncioTestCase): @@ -13,11 +15,13 @@ def make_executor( num_workers: int = 0, *, max_pending: int = 0, + max_throughput: int = 0, task_name_prefix: str = "", ) -> Executor: executor = Executor( num_workers=num_workers, max_pending=max_pending, + max_throughput=max_throughput, task_name_prefix=task_name_prefix, ) self.addAsyncCleanup(executor.shutdown) @@ -229,11 +233,17 @@ def test_invalid_max_pending(self) -> None: ): self.make_executor(max_pending=-1) + def test_invalid_max_throughput(self) -> None: + with self.assertRaisesRegex( + ValueError, "max_throughput must be non-negative number" + ): + self.make_executor(max_throughput=-1) + async def test_double_lazy_init(self) -> None: executor = self.make_executor() - loop = executor._lazy_init() + loop = executor._lazy_init()[0] self.assertIs(loop, asyncio.get_running_loop()) - loop = executor._lazy_init() + loop = executor._lazy_init()[0] self.assertIs(loop, asyncio.get_running_loop()) async def test_lazy_init_after_shutdown(self) -> None: @@ -249,7 +259,7 @@ async def test_lazy_init_from_nonasyncio_if_inited(self) -> None: executor._lazy_init() def f() -> asyncio.AbstractEventLoop: - return executor._lazy_init() + return executor._lazy_init()[0] self.assertEqual( await asyncio.to_thread(f), asyncio.get_running_loop() @@ -280,6 +290,74 @@ def f() -> None: ): await asyncio.to_thread(f) + async def test_incomplete_initialization(self) -> None: + executor = self.make_executor() + executor._lazy_init() + executor._rate_limiter = None + with self.assertRaisesRegex(RuntimeError, "fully initialize"): + executor._lazy_init() + + +class TestThroughput(BaseTestCase): + async def test_no_throughput_limit(self) -> None: + task_count = 512 + + executor = self.make_executor(num_workers=128) + + async def f() -> None: + pass + + start_time = asyncio.get_running_loop().time() + tasks = [await executor.submit(f()) for _ in range(task_count)] + await asyncio.gather(*tasks) + end_time = asyncio.get_running_loop().time() + self.assertGreater(task_count / (end_time - start_time), 128) + + async def test_throughput(self) -> None: + executor = self.make_executor(1, max_throughput=1) + + async def f() -> None: + pass + + task = await executor.submit(f()) + await task + + with warnings.catch_warnings(): + with self.assertRaises(asyncio.TimeoutError): + async with asyncio.timeout(0.01): + task = await executor.submit(f()) + await task + + async def test_fast_tasks_throttling(self) -> None: + executor = self.make_executor(128, max_throughput=2) + + async def f(_: int) -> None: + pass + + start_time = asyncio.get_running_loop().time() + async for _ in executor.map(f, range(3)): + pass + end_time = asyncio.get_running_loop().time() + throughput = 3 / (end_time - start_time) + self.assertGreaterEqual(throughput, 2) + self.assertLess(throughput, 3) + + async def test_slow_tasks_throttling(self) -> None: + executor = self.make_executor(2, max_throughput=128) + + async def f(_: int) -> None: + await asyncio.sleep(0.1) + + start_time = asyncio.get_running_loop().time() + async for _ in executor.map(f, range(10)): + pass + end_time = asyncio.get_running_loop().time() + throughput = 10 / (end_time - start_time) + # Almost 20 tasks should complete per second (2 workers, 0.1 sec per + # task) + self.assertGreater(throughput, 18) + self.assertLess(throughput, 20) + class TestShutdown(BaseTestCase): async def test_shutdown_not_inited(self) -> None: @@ -511,5 +589,17 @@ async def f() -> str: self.assertRegex(ret, rf"custom_(\d+)\[{f.__qualname__}\]") +class TestRateLimiter(BaseTestCase): + async def test_invalid_input(self) -> None: + with self.assertRaisesRegex( + ValueError, "max_throughput must be non-negative number" + ): + _RateLimiter(asyncio.get_event_loop(), -1) + with self.assertRaisesRegex( + ValueError, "time_window must be positive number" + ): + _RateLimiter(asyncio.get_event_loop(), 1, 0) + + if __name__ == "__main__": unittest.main()