Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
125 changes: 98 additions & 27 deletions aiolibs_executor/_executor.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,12 +5,14 @@
from asyncio import (
AbstractEventLoop,
CancelledError,
Condition,
Future,
Queue,
QueueShutDown,
Task,
gather,
get_running_loop,
sleep,
)
from collections.abc import (
AsyncIterable,
Expand All @@ -19,6 +21,7 @@
Coroutine,
Iterable,
)
from contextlib import AbstractAsyncContextManager
from types import TracebackType
from typing import Any, Self, final, overload
from warnings import catch_warnings
Expand All @@ -33,6 +36,7 @@ def __init__(
num_workers: int = 0,
*,
max_pending: int = 0,
max_throughput: int = 0,
task_name_prefix: str = "",
) -> None:
if num_workers == 0:
Expand All @@ -41,11 +45,15 @@ def __init__(
raise ValueError("num_workers must be greater than 0")
if max_pending < 0:
raise ValueError("max_pending must be non-negative number")
if max_throughput < 0:
raise ValueError("max_throughput must be non-negative number")
self._num_workers = num_workers
self._max_throughput = max_throughput
self._task_name_prefix = (
task_name_prefix or f"Executor-{Executor._counter()}"
)
self._loop: AbstractEventLoop | None = None
self._rate_limiter: _RateLimiter | None = None
self._shutdown = False
self._work_items: Queue[_WorkItem[Any]] = Queue(max_pending)
# tasks are much cheaper than threads or processes,
Expand All @@ -72,8 +80,8 @@ def submit_nowait[R](
*,
context: contextvars.Context | None = None,
) -> Future[R]:
loop = self._lazy_init()
work_item = _WorkItem(coro, loop, context)
loop, rate_limiter = self._lazy_init()
work_item = _WorkItem(coro, loop, rate_limiter, context)
self._work_items.put_nowait(work_item)
return work_item.future

Expand All @@ -84,8 +92,8 @@ async def submit[R](
*,
context: contextvars.Context | None = None,
) -> Future[R]:
loop = self._lazy_init()
work_item = _WorkItem(coro, loop, context)
loop, rate_limiter = self._lazy_init()
work_item = _WorkItem(coro, loop, rate_limiter, context)
await self._work_items.put(work_item)
return work_item.future

Expand All @@ -98,6 +106,7 @@ def map[R, T1](
*,
context: contextvars.Context | None = None,
) -> AsyncIterator[R]: ...

@overload
def map[R, T1, T2](
self,
Expand All @@ -108,6 +117,7 @@ def map[R, T1, T2](
*,
context: contextvars.Context | None = None,
) -> AsyncIterator[R]: ...

@overload
def map[R, T1, T2, T3](
self,
Expand All @@ -119,6 +129,7 @@ def map[R, T1, T2, T3](
*,
context: contextvars.Context | None = None,
) -> AsyncIterator[R]: ...

@overload
def map[R, T1, T2, T3, T4](
self,
Expand All @@ -131,6 +142,7 @@ def map[R, T1, T2, T3, T4](
*,
context: contextvars.Context | None = None,
) -> AsyncIterator[R]: ...

@overload
def map[R, T1, T2, T3, T4, T5](
self,
Expand All @@ -153,10 +165,10 @@ async def map[R](
*iterables: Iterable[Any],
context: contextvars.Context | None = None,
) -> AsyncIterator[R]:
loop = self._lazy_init()
loop, rate_limiter = self._lazy_init()
work_items: list[_WorkItem[R]] = []
for args in zip(iterable, *iterables, strict=False):
work_item = _WorkItem(fn(*args), loop, context)
work_item = _WorkItem(fn(*args), loop, rate_limiter, context)
await self._work_items.put(work_item)
work_items.append(work_item)
async for ret in self._process_items(work_items):
Expand All @@ -171,6 +183,7 @@ def amap[R, T1](
*,
context: contextvars.Context | None = None,
) -> AsyncIterator[R]: ...

@overload
def amap[R, T1, T2](
self,
Expand All @@ -181,6 +194,7 @@ def amap[R, T1, T2](
*,
context: contextvars.Context | None = None,
) -> AsyncIterator[R]: ...

@overload
def amap[R, T1, T2, T3](
self,
Expand All @@ -192,6 +206,7 @@ def amap[R, T1, T2, T3](
*,
context: contextvars.Context | None = None,
) -> AsyncIterator[R]: ...

@overload
def amap[R, T1, T2, T3, T4](
self,
Expand All @@ -204,6 +219,7 @@ def amap[R, T1, T2, T3, T4](
*,
context: contextvars.Context | None = None,
) -> AsyncIterator[R]: ...

@overload
def amap[R, T1, T2, T3, T4, T5](
self,
Expand All @@ -217,6 +233,7 @@ def amap[R, T1, T2, T3, T4, T5](
*,
context: contextvars.Context | None = None,
) -> AsyncIterator[R]: ...

async def amap[R](
self,
fn: Callable[..., Coroutine[Any, Any, R]],
Expand All @@ -225,13 +242,13 @@ async def amap[R](
*iterables: AsyncIterable[Any],
context: contextvars.Context | None = None,
) -> AsyncIterator[R]:
loop = self._lazy_init()
loop, rate_limiter = self._lazy_init()
work_items: list[_WorkItem[R]] = []
its = [aiter(iterable)] + [aiter(ait) for ait in iterables]
while True:
try:
args = [await anext(it) for it in its]
work_item = _WorkItem(fn(*args), loop, context)
work_item = _WorkItem(fn(*args), loop, rate_limiter, context)
await self._work_items.put(work_item)
work_items.append(work_item)
except StopAsyncIteration:
Expand Down Expand Up @@ -278,37 +295,43 @@ async def shutdown(
finally:
del excs

def _lazy_init(self) -> AbstractEventLoop:
def _lazy_init(self) -> tuple[AbstractEventLoop, "_RateLimiter"]:
if self._shutdown:
raise RuntimeError("cannot schedule new futures after shutdown")
if self._loop is not None:
if self._rate_limiter is None:
raise RuntimeError(f"{self!r} failed to fully initialize")
try:
loop = get_running_loop()
except RuntimeError:
# do nothing and reuse previously stored self._loop
# to allow .submit_nowait() call from non-asyncio code
return self._loop
return self._loop, self._rate_limiter
else:
# the loop check technique is borrowed from asyncio.locks.
if loop is not self._loop:
raise RuntimeError(
f"{self!r} is bound to a different event loop"
)
return loop
return loop, self._rate_limiter
else:
loop = get_running_loop()
with _global_lock:
# double-checked locking has a very low chance to have
# self._loop assigned by another thread;
# test suite doen't cover this case
# test suite doesn't cover this case
if self._loop is None: # pragma: no branch
self._loop = loop
self._rate_limiter = _RateLimiter(
loop,
self._max_throughput,
)
for i in range(self._num_workers):
task_name = self._task_name_prefix + f"_{i}"
self._tasks.append(
loop.create_task(self._work(task_name), name=task_name)
)
return loop
return loop, self._rate_limiter

async def _process_items[R](
self, work_items: list["_WorkItem[R]"]
Expand Down Expand Up @@ -336,10 +359,57 @@ async def _work(self, prefix: str) -> None:
_global_lock = threading.Lock()


class _RateLimiter(AbstractAsyncContextManager["_RateLimiter"]):
def __init__(
self,
event_loop: AbstractEventLoop,
max_throughput: int = 0,
time_window: float = 1.0,
) -> None:
if max_throughput < 0:
raise ValueError("max_throughput must be non-negative number")
if time_window <= 0:
raise ValueError("time_window must be positive number")
super().__init__()
self._event_loop = event_loop
self._throttling_lock = Condition()
self._max_throughput = max_throughput
self._time_window = time_window
self._tracked_tasks = 0

async def __aenter__(self) -> Self:
async with self._throttling_lock:
await self._throttling_lock.wait_for(self._allow_entry)
self._tracked_tasks += 1

return self

async def __aexit__(
self,
exc_type: type[BaseException] | None,
exc_val: BaseException | None,
exc_tb: TracebackType | None,
) -> None:
self._event_loop.create_task(self._remove_tracked_task())

def _allow_entry(self) -> bool:
return (
self._max_throughput == 0
or self._tracked_tasks < self._max_throughput
)

async def _remove_tracked_task(self) -> None:
await sleep(self._time_window)
async with self._throttling_lock:
self._tracked_tasks -= 1
self._throttling_lock.notify_all()


@dataclasses.dataclass
class _WorkItem[R]:
coro: Coroutine[Any, Any, R]
loop: AbstractEventLoop
rate_limiter: _RateLimiter
context: contextvars.Context | None
task: Task[R] | None = None

Expand All @@ -358,20 +428,21 @@ async def execute(self, prefix: str) -> None:
# Some custom coroutines and mocks could not have __qualname__,
# don't add a suffix in this case.
pass
self.task = task = self.loop.create_task(
self.coro, context=self.context, name=name
)
fut.add_done_callback(self.done_callback)
try:
ret = await task
except CancelledError:
fut.cancel()
except BaseException as ex:
if not fut.done():
fut.set_exception(ex)
else:
if not fut.done():
fut.set_result(ret)
async with self.rate_limiter:
self.task = task = self.loop.create_task(
self.coro, context=self.context, name=name
)
fut.add_done_callback(self.done_callback)
try:
ret = await task
except CancelledError:
fut.cancel()
except BaseException as ex:
if not fut.done():
fut.set_exception(ex)
else:
if not fut.done():
fut.set_result(ret)

def cancel(self) -> None:
fut = self.future
Expand Down
Loading