diff --git a/aiolibs_executor/_executor.py b/aiolibs_executor/_executor.py index d422cc8..abb63f0 100644 --- a/aiolibs_executor/_executor.py +++ b/aiolibs_executor/_executor.py @@ -1,13 +1,13 @@ +from __future__ import annotations + import contextvars -import dataclasses import itertools +import sys import threading from asyncio import ( AbstractEventLoop, CancelledError, Future, - Queue, - QueueShutDown, Task, gather, get_running_loop, @@ -20,10 +20,47 @@ Iterable, ) from types import TracebackType -from typing import Any, Self, final, overload +from typing import Any, Generic, TypeVar, final, overload from warnings import catch_warnings +# Use janus for now until aiologic is ready to implement QueueShutdown +# SEE: https://github.com/x42005e1f/aiologic/issues/7 + +from janus import Queue, QueueShutDown + +if sys.version_info < (3, 10): + from collections.abc import Awaitable, Sequence + from typing import Protocol + + _T = TypeVar("_T") + + class _SupportsAnext(Protocol[_T]): + async def __anext__(self) -> _T: + pass + + def anext(it: _SupportsAnext[_T]) -> Awaitable: + return it.__anext__() + + def aiter(it: Sequence[_T]) -> AsyncIterable[_T]: + return it.__aiter__() + + +if sys.version_info < (3, 11): + from exceptiongroup import BaseExceptionGroup + from typing_extensions import Self +else: + from typing import Self + + +R = TypeVar("R") +T1 = TypeVar("T1") +T2 = TypeVar("T2") +T3 = TypeVar("T3") +T4 = TypeVar("T4") +T5 = TypeVar("T5") + + @final class Executor: _counter = itertools.count().__next__ @@ -65,7 +102,7 @@ async def __aexit__( ) -> None: await self.shutdown() - def submit_nowait[R]( + def submit_nowait( self, coro: Coroutine[Any, Any, R], /, @@ -74,10 +111,10 @@ def submit_nowait[R]( ) -> Future[R]: loop = self._lazy_init() work_item = _WorkItem(coro, loop, context) - self._work_items.put_nowait(work_item) + self._work_items.async_q.put_nowait(work_item) return work_item.future - async def submit[R]( + async def submit( self, coro: Coroutine[Any, Any, R], /, @@ -86,11 +123,11 @@ async def submit[R]( ) -> Future[R]: loop = self._lazy_init() work_item = _WorkItem(coro, loop, context) - await self._work_items.put(work_item) + await self._work_items.async_q.put(work_item) return work_item.future @overload - def map[R, T1]( + def map( self, fn: Callable[[T1], Coroutine[Any, Any, R]], it1: Iterable[T1], @@ -98,8 +135,9 @@ def map[R, T1]( *, context: contextvars.Context | None = None, ) -> AsyncIterator[R]: ... + @overload - def map[R, T1, T2]( + def map( self, fn: Callable[[T1, T2], Coroutine[Any, Any, R]], it1: Iterable[T1], @@ -108,8 +146,9 @@ def map[R, T1, T2]( *, context: contextvars.Context | None = None, ) -> AsyncIterator[R]: ... + @overload - def map[R, T1, T2, T3]( + def map( self, fn: Callable[[T1, T2, T3], Coroutine[Any, Any, R]], it1: Iterable[T1], @@ -120,7 +159,7 @@ def map[R, T1, T2, T3]( context: contextvars.Context | None = None, ) -> AsyncIterator[R]: ... @overload - def map[R, T1, T2, T3, T4]( + def map( self, fn: Callable[[T1, T2, T3, T4], Coroutine[Any, Any, R]], it1: Iterable[T1], @@ -132,7 +171,7 @@ def map[R, T1, T2, T3, T4]( context: contextvars.Context | None = None, ) -> AsyncIterator[R]: ... @overload - def map[R, T1, T2, T3, T4, T5]( + def map( self, fn: Callable[[T1, T2, T3, T4, T5], Coroutine[Any, Any, R]], it1: Iterable[T1], @@ -145,7 +184,7 @@ def map[R, T1, T2, T3, T4, T5]( context: contextvars.Context | None = None, ) -> AsyncIterator[R]: ... - async def map[R]( + async def map( self, fn: Callable[..., Coroutine[Any, Any, R]], iterable: Iterable[Any], @@ -157,13 +196,13 @@ async def map[R]( work_items: list[_WorkItem[R]] = [] for args in zip(iterable, *iterables, strict=False): work_item = _WorkItem(fn(*args), loop, context) - await self._work_items.put(work_item) + await self._work_items.async_q.put(work_item) work_items.append(work_item) async for ret in self._process_items(work_items): yield ret @overload - def amap[R, T1]( + def amap( self, fn: Callable[[T1], Coroutine[Any, Any, R]], it1: AsyncIterable[T1], @@ -172,7 +211,7 @@ def amap[R, T1]( context: contextvars.Context | None = None, ) -> AsyncIterator[R]: ... @overload - def amap[R, T1, T2]( + def amap( self, fn: Callable[[T1, T2], Coroutine[Any, Any, R]], it1: AsyncIterable[T1], @@ -182,7 +221,7 @@ def amap[R, T1, T2]( context: contextvars.Context | None = None, ) -> AsyncIterator[R]: ... @overload - def amap[R, T1, T2, T3]( + def amap( self, fn: Callable[[T1, T2, T3], Coroutine[Any, Any, R]], it1: AsyncIterable[T1], @@ -193,7 +232,7 @@ def amap[R, T1, T2, T3]( context: contextvars.Context | None = None, ) -> AsyncIterator[R]: ... @overload - def amap[R, T1, T2, T3, T4]( + def amap( self, fn: Callable[[T1, T2, T3, T4], Coroutine[Any, Any, R]], it1: AsyncIterable[T1], @@ -205,7 +244,7 @@ def amap[R, T1, T2, T3, T4]( context: contextvars.Context | None = None, ) -> AsyncIterator[R]: ... @overload - def amap[R, T1, T2, T3, T4, T5]( + def amap( self, fn: Callable[[T1, T2, T3, T4, T5], Coroutine[Any, Any, R]], it1: AsyncIterable[T1], @@ -217,7 +256,7 @@ def amap[R, T1, T2, T3, T4, T5]( *, context: contextvars.Context | None = None, ) -> AsyncIterator[R]: ... - async def amap[R]( + async def amap( self, fn: Callable[..., Coroutine[Any, Any, R]], iterable: AsyncIterable[Any], @@ -232,7 +271,7 @@ async def amap[R]( try: args = [await anext(it) for it in its] work_item = _WorkItem(fn(*args), loop, context) - await self._work_items.put(work_item) + await self._work_items.async_q.put(work_item) work_items.append(work_item) except StopAsyncIteration: break @@ -250,13 +289,15 @@ async def shutdown( self._shutdown = True if self._loop is None: return + if cancel_futures: # Drain all work items from the queue, and then cancel their # associated futures. while not self._work_items.empty(): - self._work_items.get_nowait().cancel() + self._work_items.async_q.get_nowait().cancel() self._work_items.shutdown() + if not wait: for task in self._tasks: task.cancel() @@ -310,15 +351,34 @@ def _lazy_init(self) -> AbstractEventLoop: ) return loop - async def _process_items[R]( - self, work_items: list["_WorkItem[R]"] + async def _process_items( + self, work_items: list[_WorkItem[R]] ) -> AsyncIterator[R]: try: - # reverse to keep finishing order - work_items.reverse() - while work_items: - # Careful not to keep a reference to the popped future - yield await work_items.pop().future + # NOTE: Polling future objects can be a bad apporch + # callbacks need to be used in order to return items + # in finishing order + + remaining = len(work_items) + queue: Queue[Future[R]] = Queue() + + def on_done(fut: Future[R]) -> None: + nonlocal queue, remaining + queue.async_q.put_nowait(fut) + remaining -= 1 + + # No need to call for a copy, + # loop will call it later + for w in work_items: + w.future.add_done_callback(on_done) + + while remaining or not queue.async_q.empty(): + fut = await queue.async_q.get() + yield await fut + + # cleanup + work_items.clear() + except CancelledError: # The current task was cancelled, e.g. by timeout for work_item in work_items: @@ -328,7 +388,8 @@ async def _process_items[R]( async def _work(self, prefix: str) -> None: try: while True: - await (await self._work_items.get()).execute(prefix) + worker = await self._work_items.async_q.get() + await worker.execute(prefix) except QueueShutDown: pass @@ -336,14 +397,26 @@ async def _work(self, prefix: str) -> None: _global_lock = threading.Lock() -@dataclasses.dataclass -class _WorkItem[R]: - coro: Coroutine[Any, Any, R] - loop: AbstractEventLoop - context: contextvars.Context | None - task: Task[R] | None = None +class _WorkItem(Generic[R]): + __slots__ = ( + "coro", + "loop", + "context", + "task", + "future", + ) - def __post_init__(self) -> None: + def __init__( + self, + coro: Coroutine[Any, Any, R], + loop: AbstractEventLoop, + context: contextvars.Context | None, + task: Task[R] | None = None, + ) -> None: + self.coro = coro + self.loop = loop + self.context = context + self.task = task self.future: Future[R] = self.loop.create_future() async def execute(self, prefix: str) -> None: @@ -358,9 +431,16 @@ async def execute(self, prefix: str) -> None: # Some custom coroutines and mocks could not have __qualname__, # don't add a suffix in this case. pass - self.task = task = self.loop.create_task( - self.coro, context=self.context, name=name - ) + if sys.version_info >= (3, 11): + self.task = task = self.loop.create_task( # type: ignore[call-arg] + self.coro, context=self.context, name=name + ) + # XXX: older versions of Python can't leverage context variables + # Not handling it and letting the bad arguments run results in + # a deadlock! + else: + self.task = task = self.loop.create_task(self.coro, name=name) + fut.add_done_callback(self.done_callback) try: ret = await task @@ -374,12 +454,11 @@ async def execute(self, prefix: str) -> None: fut.set_result(ret) def cancel(self) -> None: - fut = self.future - fut.cancel() + self.future.cancel() self.cleanup() def cleanup(self) -> None: - with catch_warnings(action="ignore", category=RuntimeWarning): + with catch_warnings(action="ignore", category=RuntimeWarning): # type: ignore[call-overload] # Suppress RuntimeWarning: coroutine 'coro' was never awaited. # The warning is possible if .shutdown() was called # with cancel_futures=True and there are non-started coroutines diff --git a/pyproject.toml b/pyproject.toml index 6dd6e03..0d8e986 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -5,8 +5,10 @@ authors = [ {name = "Andrew Svetlov",email = "andrew.svetlov@gmail.com"} ] readme = "README.md" -requires-python = ">=3.13" +requires-python = ">=3.9" dependencies = [ + 'exceptiongroup==1.3.0; python_version<"3.11"', + 'janus==2.0.0' ] dynamic = ["version"] @@ -20,12 +22,12 @@ enable = true [tool.poetry] version = "0.0.0" + [tool.poetry.group.dev.dependencies] mypy = "^1.15.0" coverage = "^7.6.11" [tool.ruff] -target-version = "py313" line-length = 79 [tool.ruff.lint] diff --git a/requirements-dev.txt b/requirements-dev.txt new file mode 100644 index 0000000..07f6e4f --- /dev/null +++ b/requirements-dev.txt @@ -0,0 +1,6 @@ +async-timeout==5.0.1; python_version < '3.11' +exceptiongroup==1.3.0; python_version<"3.11" +janus==2.0.0 +pytest==8.4.1 +pytest-cov==6.2.1 +typing-extensions==4.14.1 diff --git a/tests/test_executor.py b/tests/test_executor.py index da81175..eea23d2 100644 --- a/tests/test_executor.py +++ b/tests/test_executor.py @@ -1,12 +1,33 @@ +from __future__ import annotations + import asyncio +import sys import unittest -from collections.abc import AsyncIterator +from collections.abc import AsyncIterator, Awaitable from contextvars import ContextVar, copy_context -from typing import Any +from typing import Any, Callable from aiolibs_executor import Executor +if sys.version_info < (3, 11): + from async_timeout import timeout +else: + from asyncio import timeout + + +def skip_if_earlier_than_313( + reason: str = "Currently Breaks", +) -> Callable[..., Awaitable[None]]: + return unittest.skipIf(sys.version_info < (3, 13), reason) + + +def skip_if_earlier_than_311( + reason: str = "Currently Breaks", +) -> Callable[..., Awaitable[None]]: + return unittest.skipIf(sys.version_info < (3, 11), reason) + + class BaseTestCase(unittest.IsolatedAsyncioTestCase): def make_executor( self, @@ -26,185 +47,201 @@ def make_executor( class TestSubmit(BaseTestCase): async def test_submit_nowait(self) -> None: - executor = self.make_executor() + async with self.make_executor() as executor: - async def f( - *args: Any, **kwargs: Any - ) -> tuple[tuple[Any, ...], dict[str, Any]]: - await asyncio.sleep(0) - return args, kwargs + async def f( + *args: Any, **kwargs: Any + ) -> tuple[tuple[Any, ...], dict[str, Any]]: + await asyncio.sleep(0) + return args, kwargs - self.assertEqual( - await executor.submit_nowait(f(1, a=2)), ((1,), {"a": 2}) - ) + self.assertEqual( + await executor.submit_nowait(f(1, a=2)), ((1,), {"a": 2}) + ) async def test_submit(self) -> None: - executor = self.make_executor() + async with self.make_executor() as executor: - async def f( - *args: Any, **kwargs: Any - ) -> tuple[tuple[Any, ...], dict[str, Any]]: - await asyncio.sleep(0) - return args, kwargs + async def f( + *args: Any, **kwargs: Any + ) -> tuple[tuple[Any, ...], dict[str, Any]]: + await asyncio.sleep(0) + return args, kwargs - fut = await executor.submit(f(1, a=2)) - self.assertEqual(await fut, ((1,), {"a": 2})) + fut = await executor.submit(f(1, a=2)) + self.assertEqual(await fut, ((1,), {"a": 2})) async def test_map(self) -> None: - executor = self.make_executor() + async with self.make_executor() as executor: - async def f(a: int, b: int) -> int: - await asyncio.sleep(0) - return a + b + async def f(a: int, b: int) -> int: + await asyncio.sleep(0) + return a + b - arg = list(range(3)) - ret = [i async for i in executor.map(f, arg, arg)] - self.assertEqual(ret, [0, 2, 4]) + arg = list(range(3)) + ret = [i async for i in executor.map(f, arg, arg)] + self.assertEqual(ret, [0, 2, 4]) async def test_amap(self) -> None: - executor = self.make_executor() - - async def f(a: int, b: int) -> int: - await asyncio.sleep(0) - return a + b + async with self.make_executor() as executor: - async def inp() -> AsyncIterator[int]: - for i in range(1, 4): + async def f(a: int, b: int) -> int: await asyncio.sleep(0) - yield i + return a + b - ret = [i async for i in executor.amap(f, inp(), inp())] - self.assertEqual(ret, [2, 4, 6]) + async def inp() -> AsyncIterator[int]: + for i in range(1, 4): + await asyncio.sleep(0) + yield i - async def test_submit_nowait_default_context(self) -> None: - executor = self.make_executor() + ret = [i async for i in executor.amap(f, inp(), inp())] + self.assertEqual(ret, [2, 4, 6]) - c: ContextVar[int] = ContextVar("c") + @skip_if_earlier_than_311( + "Context Varaibles not supported on 3.10 or earlier" + ) + async def test_submit_nowait_default_context(self) -> None: + async with self.make_executor() as executor: + c: ContextVar[int] = ContextVar("c") - async def f(a: int) -> int: - await asyncio.sleep(0) - return a + c.get() + async def f(a: int) -> int: + await asyncio.sleep(0) + return a + c.get() - c.set(1) + c.set(1) - fut = executor.submit_nowait(f(1)) - self.assertEqual(await fut, 2) + fut = executor.submit_nowait(f(1)) + self.assertEqual(await fut, 2) + @skip_if_earlier_than_311( + "Context Varaibles not supported on 3.10 or earlier" + ) async def test_submit_nowait_with_context(self) -> None: - executor = self.make_executor() - - c: ContextVar[int] = ContextVar("c") + async with self.make_executor() as executor: + c: ContextVar[int] = ContextVar("c") - async def f(a: int) -> int: - await asyncio.sleep(0) - return a + c.get() + async def f(a: int) -> int: + await asyncio.sleep(0) + return a + c.get() - token = c.set(1) - context = copy_context() - c.reset(token) + token = c.set(1) + context = copy_context() + c.reset(token) - fut = executor.submit_nowait(f(1), context=context) - self.assertEqual(await fut, 2) + fut = executor.submit_nowait(f(1), context=context) + self.assertEqual(await fut, 2) + @skip_if_earlier_than_311( + "Context Varaibles not supported on 3.10 or earlier" + ) async def test_submit_default_context(self) -> None: - executor = self.make_executor() - - c: ContextVar[int] = ContextVar("c") + async with self.make_executor() as executor: + c: ContextVar[int] = ContextVar("c") - async def f(a: int) -> int: - await asyncio.sleep(0) - return a + c.get() + async def f(a: int) -> int: + await asyncio.sleep(0) + return a + c.get() - c.set(1) + c.set(1) - fut = await executor.submit(f(1)) - self.assertEqual(await fut, 2) + fut = await executor.submit(f(1)) + self.assertEqual(await fut, 2) + @skip_if_earlier_than_311( + reason="Contextvar tasks are not supported on 3.11 or older" + ) async def test_submit_with_context(self) -> None: - executor = self.make_executor() - - c: ContextVar[int] = ContextVar("c") + async with self.make_executor() as executor: + c: ContextVar[int] = ContextVar("c") - async def f(a: int) -> int: - await asyncio.sleep(0) - return a + c.get() + async def f(a: int) -> int: + await asyncio.sleep(0) + return a + c.get() - token = c.set(1) - context = copy_context() - c.reset(token) + token = c.set(1) + context = copy_context() + c.reset(token) - fut = await executor.submit(f(1), context=context) - self.assertEqual(await fut, 2) + fut = await executor.submit(f(1), context=context) + self.assertEqual(await fut, 2) + @skip_if_earlier_than_311( + "Context Varaibles not supported on 3.10 or earlier" + ) async def test_map_default_context(self) -> None: - executor = self.make_executor() - - c: ContextVar[int] = ContextVar("c") + async with self.make_executor() as executor: + c: ContextVar[int] = ContextVar("c") - async def f(a: int) -> int: - await asyncio.sleep(0) - return a + c.get() + async def f(a: int) -> int: + await asyncio.sleep(0) + return a + c.get() - c.set(1) + c.set(1) - ret = [i async for i in executor.map(f, range(3))] - self.assertEqual(ret, [1, 2, 3]) + ret = [i async for i in executor.map(f, range(3))] + self.assertEqual(ret, [1, 2, 3]) + @skip_if_earlier_than_311( + "Context Varaibles not supported on 3.10 or earlier" + ) async def test_map_with_context(self) -> None: - executor = self.make_executor() - - c: ContextVar[int] = ContextVar("c") + async with self.make_executor() as executor: + c: ContextVar[int] = ContextVar("c") - async def f(a: int) -> int: - await asyncio.sleep(0) - return a + c.get() + async def f(a: int) -> int: + await asyncio.sleep(0) + return a + c.get() - token = c.set(1) - context = copy_context() - c.reset(token) + token = c.set(1) + context = copy_context() + c.reset(token) - ret = [i async for i in executor.map(f, range(3), context=context)] - self.assertEqual(ret, [1, 2, 3]) + ret = [i async for i in executor.map(f, range(3), context=context)] + self.assertEqual(ret, [1, 2, 3]) + @skip_if_earlier_than_311( + reason="Context Varaibles not supported on 3.10 or earlier" + ) async def test_amap_default_context(self) -> None: - executor = self.make_executor() - - c: ContextVar[int] = ContextVar("c") + async with self.make_executor() as executor: + c: ContextVar[int] = ContextVar("c") - async def f(a: int) -> int: - await asyncio.sleep(0) - return a + c.get() + async def f(a: int) -> int: + await asyncio.sleep(0) + return a + c.get() - c.set(1) + c.set(1) - async def inp() -> AsyncIterator[int]: - for i in range(3): - await asyncio.sleep(0) - yield i + async def inp() -> AsyncIterator[int]: + for i in range(3): + await asyncio.sleep(0) + yield i - ret = [i async for i in executor.amap(f, inp())] - self.assertEqual(ret, [1, 2, 3]) + ret = [i async for i in executor.amap(f, inp())] + self.assertEqual(ret, [1, 2, 3]) + @skip_if_earlier_than_311( + reason="Context Varaibles not supported on 3.10 or earlier" + ) async def test_amap_with_context(self) -> None: - executor = self.make_executor() - - c: ContextVar[int] = ContextVar("c") + async with self.make_executor() as executor: + c: ContextVar[int] = ContextVar("c") - async def f(a: int) -> int: - await asyncio.sleep(0) - return a + c.get() + async def f(a: int) -> int: + await asyncio.sleep(0) + return a + c.get() - token = c.set(1) - context = copy_context() - c.reset(token) + token = c.set(1) + context = copy_context() + c.reset(token) - async def inp() -> AsyncIterator[int]: - for i in range(3): - await asyncio.sleep(0) - yield i + async def inp() -> AsyncIterator[int]: + for i in range(3): + await asyncio.sleep(0) + yield i - ret = [i async for i in executor.amap(f, inp(), context=context)] - self.assertEqual(ret, [1, 2, 3]) + ret = [i async for i in executor.amap(f, inp(), context=context)] + self.assertEqual(ret, [1, 2, 3]) async def test_context_manager(self) -> None: async def f(a: int) -> int: @@ -230,55 +267,68 @@ def test_invalid_max_pending(self) -> None: self.make_executor(max_pending=-1) async def test_double_lazy_init(self) -> None: - executor = self.make_executor() - loop = executor._lazy_init() - self.assertIs(loop, asyncio.get_running_loop()) - loop = executor._lazy_init() - self.assertIs(loop, asyncio.get_running_loop()) + async with self.make_executor() as executor: + loop = executor._lazy_init() + self.assertIs(loop, asyncio.get_running_loop()) + loop = executor._lazy_init() + self.assertIs(loop, asyncio.get_running_loop()) async def test_lazy_init_after_shutdown(self) -> None: - executor = self.make_executor() - await executor.shutdown() - with self.assertRaisesRegex( - RuntimeError, "cannot schedule new futures after shutdown" - ): - executor._lazy_init() - + async with self.make_executor() as executor: + await executor.shutdown() + with self.assertRaisesRegex( + RuntimeError, "cannot schedule new futures after shutdown" + ): + executor._lazy_init() + + @unittest.skipIf( + sys.version_info < (3, 13), + reason="Not supported on ealier versions than 3.13", + ) async def test_lazy_init_from_nonasyncio_if_inited(self) -> None: - executor = self.make_executor() - executor._lazy_init() + async with self.make_executor() as executor: + executor._lazy_init() - def f() -> asyncio.AbstractEventLoop: - return executor._lazy_init() + def f() -> asyncio.AbstractEventLoop: + return executor._lazy_init() - self.assertEqual( - await asyncio.to_thread(f), asyncio.get_running_loop() - ) + self.assertEqual( + await asyncio.to_thread(f), asyncio.get_running_loop() + ) + @unittest.skipIf( + sys.version_info < (3, 13), + reason="Not supported" + " on ealier versions than 3.13" + "Bugs on Earlier versions of python", + ) async def test_lazy_init_from_nonasyncio_if_not_inited(self) -> None: - executor = self.make_executor() + async with self.make_executor() as executor: - def f() -> None: - executor._lazy_init() + def f() -> None: + executor._lazy_init() - with self.assertRaisesRegex(RuntimeError, "no running event loop"): - await asyncio.to_thread(f) + with self.assertRaisesRegex(RuntimeError, "no running event loop"): + await asyncio.to_thread(f) + @unittest.skipIf( + sys.version_info < (3, 11), reason="don't have asyncio.Runner" + ) async def test_lazy_init_bound_to_different_loop(self) -> None: - executor = self.make_executor() - executor._lazy_init() - - async def g() -> None: + async with self.make_executor() as executor: executor._lazy_init() - def f() -> None: - with asyncio.Runner() as runner: - runner.run(g()) + async def g() -> None: + executor._lazy_init() - with self.assertRaisesRegex( - RuntimeError, "is bound to a different event loop" - ): - await asyncio.to_thread(f) + def f() -> None: + with asyncio.Runner() as runner: # type: ignore[attr-defined] + runner.run(g()) + + with self.assertRaisesRegex( + RuntimeError, "is bound to a different event loop" + ): + await asyncio.to_thread(f) class TestShutdown(BaseTestCase): @@ -291,6 +341,10 @@ async def test_shutdown_twice(self) -> None: await executor.shutdown() await executor.shutdown() + @unittest.skipIf( + sys.version_info < (3, 13), + reason="Not supported on ealier versions than 3.13", + ) async def test_shutdown_cancel_futures(self) -> None: executor = self.make_executor(1) started = asyncio.Event() @@ -333,17 +387,21 @@ async def f() -> None: self.assertTrue(fut.cancelled()) + @unittest.skipIf( + sys.version_info < (3, 13), + reason="Not supported on ealier versions than 3.13", + ) async def test_shutdown_wt_exception_from_worker(self) -> None: executor = self.make_executor() executor._lazy_init() # emulate unhandled error by putting bad data into the queue - await executor._work_items.put(None) # type: ignore[arg-type] + await executor._work_items.put(None) ok = False try: await executor.shutdown() - except* AttributeError: + except AttributeError: ok = True self.assertTrue(ok) @@ -364,151 +422,176 @@ async def f() -> None: raise fut = await executor.submit(f()) + assert executor._tasks await started.wait() fut.cancel() await cancelled.wait() + @unittest.skipIf( + sys.version_info < (3, 13), + reason="Not supported on ealier versions than 3.13", + ) async def test_cancelling_map_cancels_tasks(self) -> None: - executor = self.make_executor() - cancelled = set() - ev = asyncio.Event() - - async def f(i: int) -> None: - try: - await asyncio.sleep(60) - except asyncio.CancelledError: - cancelled.add(i) - if len(cancelled) == 5: - ev.set() - raise - - with self.assertRaises(TimeoutError): - async with asyncio.timeout(0.01): - async for _ in executor.map(f, range(5)): - pass + async with self.make_executor() as executor: + cancelled = set() + ev = asyncio.Event() + + async def f(i: int) -> None: + try: + await asyncio.sleep(60) + except asyncio.CancelledError: + cancelled.add(i) + if len(cancelled) == 5: + ev.set() + raise + + with self.assertRaises(TimeoutError): + async with timeout(0.01): + async for _ in executor.map(f, range(5)): + pass - await ev.wait() - self.assertEqual(cancelled, {0, 1, 2, 3, 4}) + await ev.wait() + self.assertEqual(cancelled, {0, 1, 2, 3, 4}) class TestExceptions(BaseTestCase): + @unittest.skipIf( + sys.version_info < (3, 13), + reason="Not supported on ealier versions than 3.13", + ) async def test_dont_execute_with_done_future(self) -> None: - executor = self.make_executor(1) - started = asyncio.Event() + async with self.make_executor(1) as executor: + started = asyncio.Event() - async def f(num: int, ev: asyncio.Event) -> int: - started.set() - await ev.wait() - return num + async def f(num: int, ev: asyncio.Event) -> int: + started.set() + await ev.wait() + return num - ev1 = asyncio.Event() - # executing - fut1 = await executor.submit(f(1, ev1)) + ev1 = asyncio.Event() + # executing + fut1 = await executor.submit(f(1, ev1)) - ev2 = asyncio.Event() - # pending - fut2 = await executor.submit(f(2, ev2)) + ev2 = asyncio.Event() + # pending + fut2 = await executor.submit(f(2, ev2)) - # wait to put submitted request into a worker - await started.wait() + # wait to put submitted request into a worker + await started.wait() - # Setting the result is strange, user should never do it. - # But the executor should not crash at least - fut2.set_result(10) + # Setting the result is strange, user should never do it. + # But the executor should not crash at least + fut2.set_result(10) - ev1.set() - ev2.set() + ev1.set() + ev2.set() - self.assertEqual(await fut1, 1) - self.assertEqual(await fut2, 10) + self.assertEqual(await fut1, 1) + self.assertEqual(await fut2, 10) + @unittest.skipIf( + sys.version_info <= (3, 13), + "Currently chokes ealier versions of python", + ) async def test_dont_override_done_future(self) -> None: - executor = self.make_executor() - started = asyncio.Event() + async with self.make_executor() as executor: + started = asyncio.Event() - async def f(num: int, ev: asyncio.Event) -> int: - started.set() - await ev.wait() - return num + async def f(num: int, ev: asyncio.Event) -> int: + started.set() + await ev.wait() + return num - ev = asyncio.Event() - fut = await executor.submit(f(1, ev)) + ev = asyncio.Event() + fut = await executor.submit(f(1, ev)) - # wait to put submitted request into a worker - await started.wait() + # wait to put submitted request into a worker + await started.wait() - # Setting the result is strange, user should never do it. - # But the executor should not crash at least - fut.set_result(10) - ev.set() + # Setting the result is strange, user should never do it. + # But the executor should not crash at least + fut.set_result(10) + ev.set() - self.assertEqual(await fut, 10) + self.assertEqual(await fut, 10) async def test_coro_raises_exception(self) -> None: - executor = self.make_executor() + async with self.make_executor() as executor: - async def f() -> None: - raise Exception("test exception") + async def f() -> None: + raise Exception("test exception") - fut = await executor.submit(f()) + fut = await executor.submit(f()) - with self.assertRaisesRegex(Exception, "test exception"): - await fut + with self.assertRaisesRegex(Exception, "test exception"): + await fut async def test_dont_override_exception_in_future(self) -> None: - executor = self.make_executor() - started = asyncio.Event() + async with self.make_executor() as executor: + started = asyncio.Event() - async def f(ev: asyncio.Event) -> int: - started.set() - await ev.wait() - raise Exception("test exception") + async def f(ev: asyncio.Event) -> int: + started.set() + await ev.wait() + raise Exception("test exception") - ev = asyncio.Event() - fut = await executor.submit(f(ev)) + ev = asyncio.Event() + fut = await executor.submit(f(ev)) - # wait to put submitted request into a worker - await started.wait() + # wait to put submitted request into a worker + await started.wait() - # Setting the result is strange, user should never do it. - # But the executor should not crash at least - fut.set_exception(Exception("override")) - ev.set() + # Setting the result is strange, user should never do it. + # But the executor should not crash at least + fut.set_exception(Exception("override")) + ev.set() - with self.assertRaisesRegex(Exception, "override"): - await fut + with self.assertRaisesRegex(Exception, "override"): + await fut class TestTaskNames(BaseTestCase): async def test_worker_name(self) -> None: - executor = self.make_executor() - executor._lazy_init() - self.assertRegex( - executor._tasks[0].get_name(), r"Executor-(\d+)_(\d+)" - ) - + async with self.make_executor() as executor: + executor._lazy_init() + self.assertRegex( + executor._tasks[0].get_name(), r"Executor-(\d+)_(\d+)" + ) + + @unittest.skipIf( + sys.version_info < (3, 13), + reason="Not supported" + " on ealier versions than 3.13 Regex doesn't format correclty" + " in earlier versions of python" + " [TestTaskNames.test_submit_name..f]", + ) async def test_submit_name(self) -> None: - executor = self.make_executor() + async with self.make_executor() as executor: - async def f() -> str: - task = asyncio.current_task() - assert task is not None - return task.get_name() + async def f() -> str: + task = asyncio.current_task() + assert task is not None + return task.get_name() - ret = await (await executor.submit(f())) - self.assertRegex(ret, rf"Executor-(\d+)_(\d+)\[{f.__qualname__}\]") + ret = await (await executor.submit(f())) + self.assertRegex(ret, rf"Executor-(\d+)_(\d+)\[{f.__qualname__}\]") + @unittest.skipIf( + sys.version_info < (3, 13), + reason=" Regex doesn't format correclty in earlier" + " versions of python [TestTaskNames.test_custom_name..f\\]", + ) async def test_custom_name(self) -> None: - executor = self.make_executor(task_name_prefix="custom") + async with self.make_executor(task_name_prefix="custom") as executor: - async def f() -> str: - task = asyncio.current_task() - assert task is not None - return task.get_name() + async def f() -> str: + task = asyncio.current_task() + assert task is not None + return task.get_name() - ret = await (await executor.submit(f())) - self.assertRegex(ret, rf"custom_(\d+)\[{f.__qualname__}\]") + ret = await (await executor.submit(f())) + self.assertRegex(ret, rf"custom_(\d+)\[{f.__qualname__}\]") if __name__ == "__main__":