Skip to content
167 changes: 123 additions & 44 deletions aiolibs_executor/_executor.py
Original file line number Diff line number Diff line change
@@ -1,13 +1,13 @@
from __future__ import annotations

import contextvars
import dataclasses
import itertools
import sys
import threading
from asyncio import (
AbstractEventLoop,
CancelledError,
Future,
Queue,
QueueShutDown,
Task,
gather,
get_running_loop,
Expand All @@ -20,10 +20,47 @@
Iterable,
)
from types import TracebackType
from typing import Any, Self, final, overload
from typing import Any, Generic, TypeVar, final, overload
from warnings import catch_warnings


# Use janus for now until aiologic is ready to implement QueueShutdown
# SEE: https://github.com/x42005e1f/aiologic/issues/7

from janus import Queue, QueueShutDown

if sys.version_info < (3, 10):
from collections.abc import Awaitable, Sequence
from typing import Protocol

_T = TypeVar("_T")

class _SupportsAnext(Protocol[_T]):
async def __anext__(self) -> _T:
pass

def anext(it: _SupportsAnext[_T]) -> Awaitable:
return it.__anext__()

def aiter(it: Sequence[_T]) -> AsyncIterable[_T]:
return it.__aiter__()


if sys.version_info < (3, 11):
from exceptiongroup import BaseExceptionGroup
from typing_extensions import Self
else:
from typing import Self


R = TypeVar("R")
T1 = TypeVar("T1")
T2 = TypeVar("T2")
T3 = TypeVar("T3")
T4 = TypeVar("T4")
T5 = TypeVar("T5")


@final
class Executor:
_counter = itertools.count().__next__
Expand Down Expand Up @@ -65,7 +102,7 @@ async def __aexit__(
) -> None:
await self.shutdown()

def submit_nowait[R](
def submit_nowait(
self,
coro: Coroutine[Any, Any, R],
/,
Expand All @@ -74,10 +111,10 @@ def submit_nowait[R](
) -> Future[R]:
loop = self._lazy_init()
work_item = _WorkItem(coro, loop, context)
self._work_items.put_nowait(work_item)
self._work_items.async_q.put_nowait(work_item)
return work_item.future

async def submit[R](
async def submit(
self,
coro: Coroutine[Any, Any, R],
/,
Expand All @@ -86,20 +123,21 @@ async def submit[R](
) -> Future[R]:
loop = self._lazy_init()
work_item = _WorkItem(coro, loop, context)
await self._work_items.put(work_item)
await self._work_items.async_q.put(work_item)
return work_item.future

@overload
def map[R, T1](
def map(
self,
fn: Callable[[T1], Coroutine[Any, Any, R]],
it1: Iterable[T1],
/,
*,
context: contextvars.Context | None = None,
) -> AsyncIterator[R]: ...

@overload
def map[R, T1, T2](
def map(
self,
fn: Callable[[T1, T2], Coroutine[Any, Any, R]],
it1: Iterable[T1],
Expand All @@ -108,8 +146,9 @@ def map[R, T1, T2](
*,
context: contextvars.Context | None = None,
) -> AsyncIterator[R]: ...

@overload
def map[R, T1, T2, T3](
def map(
self,
fn: Callable[[T1, T2, T3], Coroutine[Any, Any, R]],
it1: Iterable[T1],
Expand All @@ -120,7 +159,7 @@ def map[R, T1, T2, T3](
context: contextvars.Context | None = None,
) -> AsyncIterator[R]: ...
@overload
def map[R, T1, T2, T3, T4](
def map(
self,
fn: Callable[[T1, T2, T3, T4], Coroutine[Any, Any, R]],
it1: Iterable[T1],
Expand All @@ -132,7 +171,7 @@ def map[R, T1, T2, T3, T4](
context: contextvars.Context | None = None,
) -> AsyncIterator[R]: ...
@overload
def map[R, T1, T2, T3, T4, T5](
def map(
self,
fn: Callable[[T1, T2, T3, T4, T5], Coroutine[Any, Any, R]],
it1: Iterable[T1],
Expand All @@ -145,7 +184,7 @@ def map[R, T1, T2, T3, T4, T5](
context: contextvars.Context | None = None,
) -> AsyncIterator[R]: ...

async def map[R](
async def map(
self,
fn: Callable[..., Coroutine[Any, Any, R]],
iterable: Iterable[Any],
Expand All @@ -157,13 +196,13 @@ async def map[R](
work_items: list[_WorkItem[R]] = []
for args in zip(iterable, *iterables, strict=False):
work_item = _WorkItem(fn(*args), loop, context)
await self._work_items.put(work_item)
await self._work_items.async_q.put(work_item)
work_items.append(work_item)
async for ret in self._process_items(work_items):
yield ret

@overload
def amap[R, T1](
def amap(
self,
fn: Callable[[T1], Coroutine[Any, Any, R]],
it1: AsyncIterable[T1],
Expand All @@ -172,7 +211,7 @@ def amap[R, T1](
context: contextvars.Context | None = None,
) -> AsyncIterator[R]: ...
@overload
def amap[R, T1, T2](
def amap(
self,
fn: Callable[[T1, T2], Coroutine[Any, Any, R]],
it1: AsyncIterable[T1],
Expand All @@ -182,7 +221,7 @@ def amap[R, T1, T2](
context: contextvars.Context | None = None,
) -> AsyncIterator[R]: ...
@overload
def amap[R, T1, T2, T3](
def amap(
self,
fn: Callable[[T1, T2, T3], Coroutine[Any, Any, R]],
it1: AsyncIterable[T1],
Expand All @@ -193,7 +232,7 @@ def amap[R, T1, T2, T3](
context: contextvars.Context | None = None,
) -> AsyncIterator[R]: ...
@overload
def amap[R, T1, T2, T3, T4](
def amap(
self,
fn: Callable[[T1, T2, T3, T4], Coroutine[Any, Any, R]],
it1: AsyncIterable[T1],
Expand All @@ -205,7 +244,7 @@ def amap[R, T1, T2, T3, T4](
context: contextvars.Context | None = None,
) -> AsyncIterator[R]: ...
@overload
def amap[R, T1, T2, T3, T4, T5](
def amap(
self,
fn: Callable[[T1, T2, T3, T4, T5], Coroutine[Any, Any, R]],
it1: AsyncIterable[T1],
Expand All @@ -217,7 +256,7 @@ def amap[R, T1, T2, T3, T4, T5](
*,
context: contextvars.Context | None = None,
) -> AsyncIterator[R]: ...
async def amap[R](
async def amap(
self,
fn: Callable[..., Coroutine[Any, Any, R]],
iterable: AsyncIterable[Any],
Expand All @@ -232,7 +271,7 @@ async def amap[R](
try:
args = [await anext(it) for it in its]
work_item = _WorkItem(fn(*args), loop, context)
await self._work_items.put(work_item)
await self._work_items.async_q.put(work_item)
work_items.append(work_item)
except StopAsyncIteration:
break
Expand All @@ -250,13 +289,15 @@ async def shutdown(
self._shutdown = True
if self._loop is None:
return

if cancel_futures:
# Drain all work items from the queue, and then cancel their
# associated futures.
while not self._work_items.empty():
self._work_items.get_nowait().cancel()
self._work_items.async_q.get_nowait().cancel()

self._work_items.shutdown()

if not wait:
for task in self._tasks:
task.cancel()
Expand Down Expand Up @@ -310,15 +351,34 @@ def _lazy_init(self) -> AbstractEventLoop:
)
return loop

async def _process_items[R](
self, work_items: list["_WorkItem[R]"]
async def _process_items(
self, work_items: list[_WorkItem[R]]
) -> AsyncIterator[R]:
try:
# reverse to keep finishing order
work_items.reverse()
while work_items:
# Careful not to keep a reference to the popped future
yield await work_items.pop().future
# NOTE: Polling future objects can be a bad apporch
# callbacks need to be used in order to return items
# in finishing order

remaining = len(work_items)
queue: Queue[Future[R]] = Queue()

def on_done(fut: Future[R]) -> None:
nonlocal queue, remaining
queue.async_q.put_nowait(fut)
remaining -= 1

# No need to call for a copy,
# loop will call it later
for w in work_items:
w.future.add_done_callback(on_done)

while remaining or not queue.async_q.empty():
fut = await queue.async_q.get()
yield await fut

# cleanup
work_items.clear()

except CancelledError:
# The current task was cancelled, e.g. by timeout
for work_item in work_items:
Expand All @@ -328,22 +388,35 @@ async def _process_items[R](
async def _work(self, prefix: str) -> None:
try:
while True:
await (await self._work_items.get()).execute(prefix)
worker = await self._work_items.async_q.get()
await worker.execute(prefix)
except QueueShutDown:
pass


_global_lock = threading.Lock()


@dataclasses.dataclass
class _WorkItem[R]:
coro: Coroutine[Any, Any, R]
loop: AbstractEventLoop
context: contextvars.Context | None
task: Task[R] | None = None
class _WorkItem(Generic[R]):
__slots__ = (
"coro",
"loop",
"context",
"task",
"future",
)

def __post_init__(self) -> None:
def __init__(
self,
coro: Coroutine[Any, Any, R],
loop: AbstractEventLoop,
context: contextvars.Context | None,
task: Task[R] | None = None,
) -> None:
self.coro = coro
self.loop = loop
self.context = context
self.task = task
self.future: Future[R] = self.loop.create_future()

async def execute(self, prefix: str) -> None:
Expand All @@ -358,9 +431,16 @@ async def execute(self, prefix: str) -> None:
# Some custom coroutines and mocks could not have __qualname__,
# don't add a suffix in this case.
pass
self.task = task = self.loop.create_task(
self.coro, context=self.context, name=name
)
if sys.version_info >= (3, 11):
self.task = task = self.loop.create_task( # type: ignore[call-arg]
self.coro, context=self.context, name=name
)
# XXX: older versions of Python can't leverage context variables
# Not handling it and letting the bad arguments run results in
# a deadlock!
else:
self.task = task = self.loop.create_task(self.coro, name=name)

fut.add_done_callback(self.done_callback)
try:
ret = await task
Expand All @@ -374,12 +454,11 @@ async def execute(self, prefix: str) -> None:
fut.set_result(ret)

def cancel(self) -> None:
fut = self.future
fut.cancel()
self.future.cancel()
self.cleanup()

def cleanup(self) -> None:
with catch_warnings(action="ignore", category=RuntimeWarning):
with catch_warnings(action="ignore", category=RuntimeWarning): # type: ignore[call-overload]
# Suppress RuntimeWarning: coroutine 'coro' was never awaited.
# The warning is possible if .shutdown() was called
# with cancel_futures=True and there are non-started coroutines
Expand Down
6 changes: 4 additions & 2 deletions pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -5,8 +5,10 @@ authors = [
{name = "Andrew Svetlov",email = "[email protected]"}
]
readme = "README.md"
requires-python = ">=3.13"
requires-python = ">=3.9"
dependencies = [
'exceptiongroup==1.3.0; python_version<"3.11"',
'janus==2.0.0'
]
dynamic = ["version"]

Expand All @@ -20,12 +22,12 @@ enable = true

[tool.poetry]
version = "0.0.0"

[tool.poetry.group.dev.dependencies]
mypy = "^1.15.0"
coverage = "^7.6.11"

[tool.ruff]
target-version = "py313"
line-length = 79

[tool.ruff.lint]
Expand Down
6 changes: 6 additions & 0 deletions requirements-dev.txt
Original file line number Diff line number Diff line change
@@ -0,0 +1,6 @@
async-timeout==5.0.1; python_version < '3.11'
exceptiongroup==1.3.0; python_version<"3.11"
janus==2.0.0
pytest==8.4.1
pytest-cov==6.2.1
typing-extensions==4.14.1
Loading