Skip to content
160 changes: 121 additions & 39 deletions aiolibs_executor/_executor.py
Original file line number Diff line number Diff line change
@@ -1,13 +1,13 @@
from __future__ import annotations

import contextvars
import dataclasses
import itertools
import sys
import threading
from asyncio import (
AbstractEventLoop,
CancelledError,
Future,
Queue,
QueueShutDown,
Task,
gather,
get_running_loop,
Expand All @@ -20,10 +20,50 @@
Iterable,
)
from types import TracebackType
from typing import Any, Self, final, overload
from typing import Any, Generic, TypeVar, final, overload
from warnings import catch_warnings


if sys.version_info < (3, 13):
from backports.asyncio.queues import ( # type: ignore[import-untyped]
Queue,
QueueShutDown,
)
else:
from asyncio.queues import Queue, QueueShutDown

if sys.version_info < (3, 10):
from collections.abc import Awaitable, Sequence
from typing import Protocol

_T = TypeVar("_T")

class _SupportsAnext(Protocol[_T]):
async def __anext__(self) -> _T:
pass

def anext(it: _SupportsAnext[_T]) -> Awaitable:
return it.__anext__()

def aiter(it: Sequence[_T]) -> AsyncIterable[_T]:
return it.__aiter__()


if sys.version_info < (3, 11):
from exceptiongroup import BaseExceptionGroup
from typing_extensions import Self
else:
from typing import Self


R = TypeVar("R")
T1 = TypeVar("T1")
T2 = TypeVar("T2")
T3 = TypeVar("T3")
T4 = TypeVar("T4")
T5 = TypeVar("T5")


@final
class Executor:
_counter = itertools.count().__next__
Expand Down Expand Up @@ -65,7 +105,7 @@ async def __aexit__(
) -> None:
await self.shutdown()

def submit_nowait[R](
def submit_nowait(
self,
coro: Coroutine[Any, Any, R],
/,
Expand All @@ -77,7 +117,7 @@ def submit_nowait[R](
self._work_items.put_nowait(work_item)
return work_item.future

async def submit[R](
async def submit(
self,
coro: Coroutine[Any, Any, R],
/,
Expand All @@ -90,16 +130,17 @@ async def submit[R](
return work_item.future

@overload
def map[R, T1](
def map(
self,
fn: Callable[[T1], Coroutine[Any, Any, R]],
it1: Iterable[T1],
/,
*,
context: contextvars.Context | None = None,
) -> AsyncIterator[R]: ...

@overload
def map[R, T1, T2](
def map(
self,
fn: Callable[[T1, T2], Coroutine[Any, Any, R]],
it1: Iterable[T1],
Expand All @@ -108,8 +149,9 @@ def map[R, T1, T2](
*,
context: contextvars.Context | None = None,
) -> AsyncIterator[R]: ...

@overload
def map[R, T1, T2, T3](
def map(
self,
fn: Callable[[T1, T2, T3], Coroutine[Any, Any, R]],
it1: Iterable[T1],
Expand All @@ -120,7 +162,7 @@ def map[R, T1, T2, T3](
context: contextvars.Context | None = None,
) -> AsyncIterator[R]: ...
@overload
def map[R, T1, T2, T3, T4](
def map(
self,
fn: Callable[[T1, T2, T3, T4], Coroutine[Any, Any, R]],
it1: Iterable[T1],
Expand All @@ -132,7 +174,7 @@ def map[R, T1, T2, T3, T4](
context: contextvars.Context | None = None,
) -> AsyncIterator[R]: ...
@overload
def map[R, T1, T2, T3, T4, T5](
def map(
self,
fn: Callable[[T1, T2, T3, T4, T5], Coroutine[Any, Any, R]],
it1: Iterable[T1],
Expand All @@ -145,7 +187,7 @@ def map[R, T1, T2, T3, T4, T5](
context: contextvars.Context | None = None,
) -> AsyncIterator[R]: ...

async def map[R](
async def map(
self,
fn: Callable[..., Coroutine[Any, Any, R]],
iterable: Iterable[Any],
Expand All @@ -163,7 +205,7 @@ async def map[R](
yield ret

@overload
def amap[R, T1](
def amap(
self,
fn: Callable[[T1], Coroutine[Any, Any, R]],
it1: AsyncIterable[T1],
Expand All @@ -172,7 +214,7 @@ def amap[R, T1](
context: contextvars.Context | None = None,
) -> AsyncIterator[R]: ...
@overload
def amap[R, T1, T2](
def amap(
self,
fn: Callable[[T1, T2], Coroutine[Any, Any, R]],
it1: AsyncIterable[T1],
Expand All @@ -182,7 +224,7 @@ def amap[R, T1, T2](
context: contextvars.Context | None = None,
) -> AsyncIterator[R]: ...
@overload
def amap[R, T1, T2, T3](
def amap(
self,
fn: Callable[[T1, T2, T3], Coroutine[Any, Any, R]],
it1: AsyncIterable[T1],
Expand All @@ -193,7 +235,7 @@ def amap[R, T1, T2, T3](
context: contextvars.Context | None = None,
) -> AsyncIterator[R]: ...
@overload
def amap[R, T1, T2, T3, T4](
def amap(
self,
fn: Callable[[T1, T2, T3, T4], Coroutine[Any, Any, R]],
it1: AsyncIterable[T1],
Expand All @@ -205,7 +247,7 @@ def amap[R, T1, T2, T3, T4](
context: contextvars.Context | None = None,
) -> AsyncIterator[R]: ...
@overload
def amap[R, T1, T2, T3, T4, T5](
def amap(
self,
fn: Callable[[T1, T2, T3, T4, T5], Coroutine[Any, Any, R]],
it1: AsyncIterable[T1],
Expand All @@ -217,7 +259,7 @@ def amap[R, T1, T2, T3, T4, T5](
*,
context: contextvars.Context | None = None,
) -> AsyncIterator[R]: ...
async def amap[R](
async def amap(
self,
fn: Callable[..., Coroutine[Any, Any, R]],
iterable: AsyncIterable[Any],
Expand Down Expand Up @@ -250,13 +292,15 @@ async def shutdown(
self._shutdown = True
if self._loop is None:
return

if cancel_futures:
# Drain all work items from the queue, and then cancel their
# associated futures.
while not self._work_items.empty():
self._work_items.get_nowait().cancel()

self._work_items.shutdown()

if not wait:
for task in self._tasks:
task.cancel()
Expand Down Expand Up @@ -310,15 +354,34 @@ def _lazy_init(self) -> AbstractEventLoop:
)
return loop

async def _process_items[R](
self, work_items: list["_WorkItem[R]"]
async def _process_items(
self, work_items: list[_WorkItem[R]]
) -> AsyncIterator[R]:
try:
# reverse to keep finishing order
work_items.reverse()
while work_items:
# Careful not to keep a reference to the popped future
yield await work_items.pop().future
# NOTE: Polling future objects can be a bad apporch
# callbacks need to be used in order to return items
# in finishing order

remaining = len(work_items)
queue: Queue[Future[R]] = Queue()

def on_done(fut: Future[R]) -> None:
nonlocal queue, remaining
queue.put_nowait(fut)
remaining -= 1

# No need to call for a copy,
# loop will call it later
for w in work_items:
w.future.add_done_callback(on_done)

while remaining or not queue.empty():
fut = await queue.get()
yield await fut

# cleanup
work_items.clear()

except CancelledError:
# The current task was cancelled, e.g. by timeout
for work_item in work_items:
Expand All @@ -328,22 +391,35 @@ async def _process_items[R](
async def _work(self, prefix: str) -> None:
try:
while True:
await (await self._work_items.get()).execute(prefix)
worker = await self._work_items.get()
await worker.execute(prefix)
except QueueShutDown:
pass


_global_lock = threading.Lock()


@dataclasses.dataclass
class _WorkItem[R]:
coro: Coroutine[Any, Any, R]
loop: AbstractEventLoop
context: contextvars.Context | None
task: Task[R] | None = None
class _WorkItem(Generic[R]):
__slots__ = (
"coro",
"loop",
"context",
"task",
"future",
)

def __post_init__(self) -> None:
def __init__(
self,
coro: Coroutine[Any, Any, R],
loop: AbstractEventLoop,
context: contextvars.Context | None,
task: Task[R] | None = None,
) -> None:
self.coro = coro
self.loop = loop
self.context = context
self.task = task
self.future: Future[R] = self.loop.create_future()

async def execute(self, prefix: str) -> None:
Expand All @@ -358,9 +434,16 @@ async def execute(self, prefix: str) -> None:
# Some custom coroutines and mocks could not have __qualname__,
# don't add a suffix in this case.
pass
self.task = task = self.loop.create_task(
self.coro, context=self.context, name=name
)
if sys.version_info >= (3, 11):
self.task = task = self.loop.create_task( # type: ignore[call-arg]
self.coro, context=self.context, name=name
)
# XXX: older versions of Python can't leverage context variables
# Not handling it and letting the bad arguments run results in
# a deadlock!
else:
self.task = task = self.loop.create_task(self.coro, name=name)

fut.add_done_callback(self.done_callback)
try:
ret = await task
Expand All @@ -374,12 +457,11 @@ async def execute(self, prefix: str) -> None:
fut.set_result(ret)

def cancel(self) -> None:
fut = self.future
fut.cancel()
self.future.cancel()
self.cleanup()

def cleanup(self) -> None:
with catch_warnings(action="ignore", category=RuntimeWarning):
with catch_warnings(action="ignore", category=RuntimeWarning): # type: ignore[call-overload]
# Suppress RuntimeWarning: coroutine 'coro' was never awaited.
# The warning is possible if .shutdown() was called
# with cancel_futures=True and there are non-started coroutines
Expand Down
6 changes: 4 additions & 2 deletions pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -5,8 +5,10 @@ authors = [
{name = "Andrew Svetlov",email = "[email protected]"}
]
readme = "README.md"
requires-python = ">=3.13"
requires-python = ">=3.9"
dependencies = [
'exceptiongroup==1.3.0; python_version<"3.11"',
'backports-asyncio-queues==0.1.2; python_version<"3.13"'
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Core packaging metadata mustn't use pins. Lower bounds are acceptable, and occasional known broken version exclusions.
Also, backports-asyncio-queues doesn't have any provenance metadata. So we should be careful. I think I've found potential source for it on GH. It must be verified and the project persuaded to at least start using trusted publishing and have other verifiable links to the source. The patch below also makes specs sparse, which is more readable.

Suggested change
'exceptiongroup==1.3.0; python_version<"3.11"',
'backports-asyncio-queues==0.1.2; python_version<"3.13"'
'exceptiongroup; python_version < "3.11"',
'backports-asyncio-queues; python_version < "3.13"'

]
dynamic = ["version"]

Expand All @@ -20,12 +22,12 @@ enable = true

[tool.poetry]
version = "0.0.0"

[tool.poetry.group.dev.dependencies]
mypy = "^1.15.0"
coverage = "^7.6.11"

[tool.ruff]
target-version = "py313"
line-length = 79

[tool.ruff.lint]
Expand Down
Binary file added requirements-dev.txt
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Please, fix the file encoding to be UTF-8.

Copy link
Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Done

Binary file not shown.
Loading