diff --git a/src/frequenz/core/asyncio/__init__.py b/src/frequenz/core/asyncio/__init__.py new file mode 100644 index 0000000..ec80e21 --- /dev/null +++ b/src/frequenz/core/asyncio/__init__.py @@ -0,0 +1,33 @@ +# License: MIT +# Copyright © 2024 Frequenz Energy-as-a-Service GmbH + +"""General purpose async tools. + +This module provides general purpose async tools that can be used to simplify the +development of asyncio-based applications. + +The module provides the following classes and functions: + +- [cancel_and_await][frequenz.core.asyncio.cancel_and_await]: A function that cancels a + task and waits for it to finish, handling `CancelledError` exceptions. +- [PersistentTaskGroup][frequenz.core.asyncio.PersistentTaskGroup]: An alternative to + [`asyncio.TaskGroup`][] to manage tasks that run until explicitly stopped. +- [Service][frequenz.core.asyncio.Service]: An interface for services running in the + background. +- [ServiceBase][frequenz.core.asyncio.ServiceBase]: A base class for implementing + services running in the background. +- [TaskCreator][frequenz.core.asyncio.TaskCreator]: A protocol for creating tasks. +""" + +from ._service import Service, ServiceBase +from ._task_group import PersistentTaskGroup +from ._util import TaskCreator, TaskReturnT, cancel_and_await + +__all__ = [ + "PersistentTaskGroup", + "Service", + "ServiceBase", + "TaskCreator", + "TaskReturnT", + "cancel_and_await", +] diff --git a/src/frequenz/core/asyncio.py b/src/frequenz/core/asyncio/_service.py similarity index 65% rename from src/frequenz/core/asyncio.py rename to src/frequenz/core/asyncio/_service.py index f144344..503b715 100644 --- a/src/frequenz/core/asyncio.py +++ b/src/frequenz/core/asyncio/_service.py @@ -1,89 +1,20 @@ # License: MIT # Copyright © 2024 Frequenz Energy-as-a-Service GmbH -"""General purpose async tools. - -This module provides general purpose async tools that can be used to simplify the -development of asyncio-based applications. - -The module provides the following classes and functions: - -- [cancel_and_await][frequenz.core.asyncio.cancel_and_await]: A function that cancels a - task and waits for it to finish, handling `CancelledError` exceptions. -- [Service][frequenz.core.asyncio.Service]: An interface for services running in the - background. -- [ServiceBase][frequenz.core.asyncio.ServiceBase]: A base class for implementing - services running in the background. -- [TaskCreator][frequenz.core.asyncio.TaskCreator]: A protocol for creating tasks. -""" +"""Module implementing the `Service` and `ServiceBase` classes.""" import abc import asyncio import collections.abc import contextvars -import logging from types import TracebackType -from typing import Any, Protocol, Self, TypeVar, runtime_checkable +from typing import Any, Self from typing_extensions import override -_logger = logging.getLogger(__name__) - - -TaskReturnT = TypeVar("TaskReturnT") -"""The type of the return value of a task.""" - - -@runtime_checkable -class TaskCreator(Protocol): - """A protocol for creating tasks. - - Built-in asyncio functions and classes implementing this protocol: - - - [`asyncio`][] - - [`asyncio.AbstractEventLoop`][] (returned by [`asyncio.get_event_loop`][] for - example) - - [`asyncio.TaskGroup`][] - """ - - def create_task( - self, - coro: collections.abc.Coroutine[Any, Any, TaskReturnT], - *, - name: str | None = None, - context: contextvars.Context | None = None, - ) -> asyncio.Task[TaskReturnT]: - """Create a task. - - Args: - coro: The coroutine to be executed. - name: The name of the task. - context: The context to be used for the task. - - Returns: - The new task. - """ - ... # pylint: disable=unnecessary-ellipsis - - -async def cancel_and_await(task: asyncio.Task[Any]) -> None: - """Cancel a task and wait for it to finish. - - Exits immediately if the task is already done. - - The `CancelledError` is suppressed, but any other exception will be propagated. - - Args: - task: The task to be cancelled and waited for. - """ - if task.done(): - return - task.cancel() - try: - await task - except asyncio.CancelledError: - pass +from ._task_group import PersistentTaskGroup +from ._util import TaskCreator, TaskReturnT class Service(abc.ABC): @@ -126,30 +57,14 @@ def start(self) -> None: def unique_id(self) -> str: """The unique ID of this service.""" - @property - @abc.abstractmethod - def tasks(self) -> collections.abc.Set[asyncio.Task[Any]]: - """The set of running tasks spawned by this service. - - Users typically should not modify the tasks in the returned set and only use - them for informational purposes. - - Danger: - Changing the returned tasks may lead to unexpected behavior, don't do it - unless the class explicitly documents it is safe to do so. - """ - @property @abc.abstractmethod def is_running(self) -> bool: - """Whether this service is running. - - A service is considered running when at least one task is running. - """ + """Whether this service is running.""" @abc.abstractmethod def cancel(self, msg: str | None = None) -> None: - """Cancel all running tasks spawned by this service. + """Cancel this service. Args: msg: The message to be passed to the tasks being cancelled. @@ -159,8 +74,7 @@ def cancel(self, msg: str | None = None) -> None: async def stop(self, msg: str | None = None) -> None: # noqa: DOC502 """Stop this service. - This method cancels all running tasks spawned by this service and waits for them - to finish. + This method cancels the service and waits for it to finish. Args: msg: The message to be passed to the tasks being cancelled. @@ -186,7 +100,7 @@ async def __aexit__( exc_type: type[BaseException] | None, exc_val: BaseException | None, exc_tb: TracebackType | None, - ) -> None: + ) -> bool | None: """Exit an async context. Stop this service. @@ -195,6 +109,9 @@ async def __aexit__( exc_type: The type of the exception raised, if any. exc_val: The exception raised, if any. exc_tb: The traceback of the exception raised, if any. + + Returns: + Whether the exception was handled. """ @abc.abstractmethod @@ -226,22 +143,19 @@ class ServiceBase(Service, abc.ABC): [`stop()`][frequenz.core.asyncio.ServiceBase.stop] method, as the base implementation does not collect any results and re-raises all exceptions. - Example: + Example: Simple single-task example ```python import datetime import asyncio + from typing_extensions import override class Clock(ServiceBase): def __init__(self, resolution_s: float, *, unique_id: str | None = None) -> None: super().__init__(unique_id=unique_id) self._resolution_s = resolution_s - def start(self) -> None: - # Managed tasks are automatically saved, so there is no need to hold a - # reference to them if you don't need to further interact with them. - self.create_task(self._tick()) - - async def _tick(self) -> None: + @override + async def main(self) -> None: while True: await asyncio.sleep(self._resolution_s) print(datetime.datetime.now()) @@ -259,6 +173,49 @@ async def main() -> None: asyncio.run(main()) ``` + + Example: Multi-tasks example + ```python + import asyncio + import datetime + from typing_extensions import override + + class MultiTaskService(ServiceBase): + + async def _print_every(self, *, seconds: float) -> None: + while True: + await asyncio.sleep(seconds) + print(datetime.datetime.now()) + + async def _fail_after(self, *, seconds: float) -> None: + await asyncio.sleep(seconds) + raise ValueError("I failed") + + @override + async def main(self) -> None: + self.create_task(self._print_every(seconds=1), name="print_1") + self.create_task(self._print_every(seconds=11), name="print_11") + failing = self.create_task(self._fail_after(seconds=5), name=f"fail_5") + + async for task in self.task_group.as_completed(): + assert task.done() # For demonstration purposes only + try: + task.result() + except ValueError as error: + if failing == task: + failing = self.create_task( + self._fail_after(seconds=5), name=f"fail_5" + ) + else: + raise + + async def main() -> None: + async with MultiTaskService(): + await asyncio.sleep(11) + + asyncio.run(main()) + ``` + """ def __init__( @@ -278,13 +235,10 @@ def __init__( # [2:] is used to remove the '0x' prefix from the hex representation of the id, # as it doesn't add any uniqueness to the string. self._unique_id: str = hex(id(self))[2:] if unique_id is None else unique_id - self._tasks: set[asyncio.Task[Any]] = set() - self._task_creator: TaskCreator = task_creator - - @override - @abc.abstractmethod - def start(self) -> None: - """Start this service.""" + self._main_task: asyncio.Task[None] | None = None + self._task_group: PersistentTaskGroup = PersistentTaskGroup( + unique_id=self._unique_id, task_creator=task_creator + ) @property @override @@ -293,18 +247,22 @@ def unique_id(self) -> str: return self._unique_id @property - @override - def tasks(self) -> collections.abc.Set[asyncio.Task[Any]]: - """The set of running tasks spawned by this service. + def task_group(self) -> PersistentTaskGroup: + """The task group managing the tasks of this service.""" + return self._task_group - Users typically should not modify the tasks in the returned set and only use - them for informational purposes. + @abc.abstractmethod + async def main(self) -> None: + """Execute the service logic.""" - Danger: - Changing the returned tasks may lead to unexpected behavior, don't do it - unless the class explicitly documents it is safe to do so. - """ - return self._tasks + @override + def start(self) -> None: + """Start this service.""" + if self.is_running: + return + self._main_task = self._task_group.task_creator.create_task( + self.main(), name=str(self) + ) @property @override @@ -313,7 +271,7 @@ def is_running(self) -> bool: A service is considered running when at least one task is running. """ - return any(not task.done() for task in self._tasks) + return self._main_task is not None and not self._main_task.done() def create_task( self, @@ -328,8 +286,8 @@ def create_task( A reference to the task will be held by the service, so there is no need to save the task object. - Tasks can be retrieved via the - [`tasks`][frequenz.core.asyncio.ServiceBase.tasks] property. + Tasks are created using the + [`task_group`][frequenz.core.asyncio.ServiceBase.task_group]. Managed tasks always have a `name` including information about the service itself. If you need to retrieve the final name of the task you can always do so @@ -354,24 +312,9 @@ def create_task( """ if not name: name = hex(id(coro))[2:] - task = self._task_creator.create_task( - coro, name=f"{self}:{name}", context=context + return self._task_group.create_task( + coro, name=f"{self}:{name}", context=context, log_exception=log_exception ) - self._tasks.add(task) - task.add_done_callback(self._tasks.discard) - - if log_exception: - - def _log_exception(task: asyncio.Task[TaskReturnT]) -> None: - try: - task.result() - except asyncio.CancelledError: - pass - except BaseException: # pylint: disable=broad-except - _logger.exception("%s: Task %r raised an exception", self, task) - - task.add_done_callback(_log_exception) - return task @override def cancel(self, msg: str | None = None) -> None: @@ -380,8 +323,9 @@ def cancel(self, msg: str | None = None) -> None: Args: msg: The message to be passed to the tasks being cancelled. """ - for task in self._tasks: - task.cancel(msg) + if self._main_task is not None: + self._main_task.cancel(msg) + self._task_group.cancel(msg) @override async def stop(self, msg: str | None = None) -> None: @@ -397,8 +341,6 @@ async def stop(self, msg: str | None = None) -> None: BaseExceptionGroup: If any of the tasks spawned by this service raised an exception. """ - if not self._tasks: - return self.cancel(msg) try: await self @@ -429,7 +371,7 @@ async def __aexit__( exc_type: type[BaseException] | None, exc_val: BaseException | None, exc_tb: TracebackType | None, - ) -> None: + ) -> bool | None: """Exit an async context. Stop this service. @@ -438,8 +380,12 @@ async def __aexit__( exc_type: The type of the exception raised, if any. exc_val: The exception raised, if any. exc_tb: The traceback of the exception raised, if any. + + Returns: + Whether the exception was handled. """ await self.stop() + return None async def _wait(self) -> None: """Wait for this service to finish. @@ -451,28 +397,21 @@ async def _wait(self) -> None: exception (`CancelError` is not considered an error and not returned in the exception group). """ - # We need to account for tasks that were created between when we started - # awaiting and we finished awaiting. - while self._tasks: - done, pending = await asyncio.wait(self._tasks) - assert not pending - - # We remove the done tasks, but there might be new ones created after we - # started waiting. - self._tasks = self._tasks - done - - exceptions: list[BaseException] = [] - for task in done: - try: - # This will raise a CancelledError if the task was cancelled or any - # other exception if the task raised one. - _ = task.result() - except BaseException as error: # pylint: disable=broad-except - exceptions.append(error) - if exceptions: - raise BaseExceptionGroup( - f"Error while stopping service {self}", exceptions - ) + exceptions: list[BaseException] = [] + + if self._main_task is not None: + try: + await self._main_task + except BaseException as error: # pylint: disable=broad-except + exceptions.append(error) + + try: + await self._task_group + except BaseExceptionGroup as exc_group: + exceptions.append(exc_group) + + if exceptions: + raise BaseExceptionGroup(f"Error while stopping {self}", exceptions) @override def __await__(self) -> collections.abc.Generator[None, None, None]: @@ -490,7 +429,7 @@ def __del__(self) -> None: Cancel all running tasks spawned by this service. """ - self.cancel("{self!r} was deleted") + self.cancel(f"{self!r} was deleted") def __repr__(self) -> str: """Return a string representation of this instance. @@ -498,7 +437,13 @@ def __repr__(self) -> str: Returns: A string representation of this instance. """ - return f"{type(self).__name__}(unique_id={self._unique_id!r}, tasks={self._tasks!r})" + details = "main" + if not self.is_running: + details += " not" + details += " running" + if self._task_group.is_running: + details += f", {len(self._task_group.tasks)} extra tasks" + return f"{type(self).__name__}<{self._unique_id} {details}>" def __str__(self) -> str: """Return a string representation of this instance. @@ -506,4 +451,4 @@ def __str__(self) -> str: Returns: A string representation of this instance. """ - return f"{type(self).__name__}[{self._unique_id}]" + return f"{type(self).__name__}:{self._unique_id}" diff --git a/src/frequenz/core/asyncio/_task_group.py b/src/frequenz/core/asyncio/_task_group.py new file mode 100644 index 0000000..10398e8 --- /dev/null +++ b/src/frequenz/core/asyncio/_task_group.py @@ -0,0 +1,392 @@ +# License: MIT +# Copyright © 2024 Frequenz Energy-as-a-Service GmbH + +"""Module implementing the `PersistentTaskGroup` class.""" + + +import asyncio +import contextvars +import datetime +import logging +from collections.abc import AsyncIterator, Coroutine, Generator, Set +from types import TracebackType +from typing import Any, Self + +from ._util import TaskCreator, TaskReturnT + +_logger = logging.getLogger(__name__) + + +class PersistentTaskGroup: + """A group of tasks that should run until explicitly stopped. + + [`asyncio.TaskGroup`][] is a very convenient construct when using parallelization + for doing calculations for example, where the results for all the tasks need to be + merged together to produce a final result. In this case if one of the tasks fails, + it makes sense to cancel the others and abort as soon as possible, as any further + calculations would be thrown away. + + This class is intended to help managing a group of tasks that should persist even if + other tasks in the group fail, usually by either only discarding the failed task or + by restarting it somehow. + + This class is also typically used as a context manager, but in this case when the + context manager is exited, the tasks are not only awaited, they are first cancelled, + so all the background tasks are stopped. If any task was ended due to an unhandled + exception, the exception will be re-raised when the context manager exits as + [`BaseExceptionGroup`][]. + + As with [`asyncio.TaskGroup`][], the tasks should be created using the + [`create_task()`][frequenz.core.asyncio.PersistentTaskGroup.create_task] method. + + To monitor the subtasks and handle exceptions or early termination, + a [`as_completed()`][frequenz.core.asyncio.PersistentTaskGroup.as_completed] method + is provided, similar to [`asyncio.as_completed`][] but not quite the same. Using + this method is the only way to acknowledge tasks failures, so they are not raised + when the service is `await`ed or when the context manager is exited. + + Example: + This program will run forever, printing the current time now and then and + restarting the failing task each time it crashes. + + ```python + import asyncio + import datetime + + async def print_every(*, seconds: float) -> None: + while True: + await asyncio.sleep(seconds) + print(datetime.datetime.now()) + + async def fail_after(*, seconds: float) -> None: + await asyncio.sleep(seconds) + raise ValueError("I failed") + + async def main() -> None: + + async with PersistentTaskGroup() as group: + group.create_task(print_every(seconds=1), name="print_1") + group.create_task(print_every(seconds=11), name="print_11") + failing = group.create_task(fail_after(seconds=5), name=f"fail_5") + + async for task in group.as_completed(): + assert task.done() # For demonstration purposes only + try: + task.result() + except ValueError as error: + if failing == task: + failing = group.create_task(fail_after(seconds=5), name=f"fail_5") + else: + raise + + asyncio.run(main()) + ``` + """ + + def __init__( + self, *, unique_id: str | None = None, task_creator: TaskCreator = asyncio + ) -> None: + """Initialize this instance. + + Args: + unique_id: The string to uniquely identify this instance. If `None`, + a string based on `hex(id(self))` will be used. This is used in + `__repr__` and `__str__` methods, mainly for debugging purposes, to + identify a particular instance of a persistent task group. + task_creator: The object that will be used to create tasks. Usually one of: + the [`asyncio`]() module, an [`asyncio.AbstractEventLoop`]() or + an [`asyncio.TaskGroup`](). + """ + # [2:] is used to remove the '0x' prefix from the hex representation of the id, + # as it doesn't add any uniqueness to the string. + self._unique_id: str = hex(id(self))[2:] if unique_id is None else unique_id + """The unique ID of this instance.""" + + self._task_creator: TaskCreator = task_creator + """The object that will be used to create tasks.""" + + self._running: set[asyncio.Task[Any]] = set() + """The set of tasks that are still running. + + Tasks are removed from this set automatically when they finish using the + Task.add_done_callback method. + """ + + self._waiting_ack: set[asyncio.Task[Any]] = set() + """The set of tasks that have finished but waiting for the user's ACK. + + Tasks are added to this set automatically when they finish using the + Task.add_done_callback method. + """ + + @property + def unique_id(self) -> str: + """The unique ID of this instance.""" + return self._unique_id + + @property + def tasks(self) -> Set[asyncio.Task[Any]]: + """The set of tasks managed by this group. + + Users typically should not modify the tasks in the returned set and only use + them for informational purposes. + + Both running tasks and tasks pending for acknowledgment are included in the + returned set. + + Danger: + Changing the returned tasks may lead to unexpected behavior, don't do it + unless the class explicitly documents it is safe to do so. + """ + return self._running | self._waiting_ack + + @property + def task_creator(self) -> TaskCreator: + """The object that will be used to create tasks.""" + return self._task_creator + + @property + def is_running(self) -> bool: + """Whether this task group is running. + + A task group is considered running when at least one task is running. + """ + return bool(self._running) + + def create_task( + self, + coro: Coroutine[Any, Any, TaskReturnT], + *, + name: str | None = None, + context: contextvars.Context | None = None, + log_exception: bool = True, + ) -> asyncio.Task[TaskReturnT]: + """Start a managed task. + + A reference to the task will be held by the task group, so there is no need to + save the task object. + + Tasks can be retrieved via the + [`tasks`][frequenz.core.asyncio.PersistentTaskGroup.tasks] property. + + Managed tasks always have a `name` including information about the task group + itself. If you need to retrieve the final name of the task you can always do so + by calling [`.get_name()`][asyncio.Task.get_name] on the returned task. + + Tasks created this way will also be automatically cancelled when calling + [`cancel()`][frequenz.core.asyncio.ServiceBase.cancel] or + [`stop()`][frequenz.core.asyncio.ServiceBase.stop], or when the service is used + as a async context manager. + + To inform that a finished task was properly handled, the method + [`as_completed()`][frequenz.core.asyncio.PersistentTaskGroup.as_completed] + should be used. + + Args: + coro: The coroutine to be managed. + name: The name of the task. Names will always have the form + `f"{self}:{name}"`. If `None` or empty, the default name will be + `hex(id(coro))[2:]`. If you need the final name of the task, it can + always be retrieved + context: The context to be used for the task. + log_exception: Whether to log exceptions raised by the task. + + Returns: + The new task. + """ + if not name: + name = hex(id(coro))[2:] + task = self._task_creator.create_task( + coro, name=f"{self}:{name}", context=context + ) + self._running.add(task) + task.add_done_callback(self._running.discard) + task.add_done_callback(self._waiting_ack.add) + + if log_exception: + + def _log_exception(task: asyncio.Task[TaskReturnT]) -> None: + try: + task.result() + except asyncio.CancelledError: + pass + except BaseException: # pylint: disable=broad-except + _logger.exception( + "Task %s raised an unhandled exception", task.get_name() + ) + + task.add_done_callback(_log_exception) + return task + + def cancel(self, msg: str | None = None) -> None: + """Cancel all running tasks spawned by this group. + + Args: + msg: The message to be passed to the tasks being cancelled. + """ + for task in self._running: + task.cancel(msg) + + async def stop(self, msg: str | None = None) -> None: + """Stop this task group. + + This method cancels all running tasks spawned by this group and waits for them + to finish. + + Args: + msg: The message to be passed to the tasks being cancelled. + + Raises: + BaseExceptionGroup: If any of the tasks spawned by this group raised an + exception. + """ + self.cancel(msg) + try: + await self + except BaseExceptionGroup as exc_group: + # We want to ignore CancelledError here as we explicitly cancelled all the + # tasks. + _, rest = exc_group.split(asyncio.CancelledError) + if rest is not None: + # We are filtering out from an exception group, we really don't want to + # add the exceptions we just filtered by adding a from clause here. + raise rest # pylint: disable=raise-missing-from + + async def as_completed( + self, *, timeout: float | datetime.timedelta | None = None + ) -> AsyncIterator[asyncio.Task[Any]]: + """Iterate over running tasks yielding as they complete. + + Stops iterating when there are no more running tasks and all done tasks have + been acknowledged, or if the timeout is reached. + + Note: + If an exception is raised while yielding a task, the task will be considered + not handled and will be yielded again until it is handled without raising + any exceptions. + + Args: + timeout: The maximum time to wait for the next task to complete. If `None`, + the function will wait indefinitely. + + Yields: + The tasks as they complete. + """ + while True: + while task := next(iter(self._waiting_ack), None): + yield task + # We discard instead of removing in case someone else already ACKed + # the task. + self._waiting_ack.discard(task) + + if not self._running: + break + + done, _ = await asyncio.wait( + self._running, + return_when=asyncio.FIRST_COMPLETED, + timeout=( + timeout.total_seconds() + if isinstance(timeout, datetime.timedelta) + else timeout + ), + ) + + if not done: # wait timed out + break + + # We don't need to add done tasks to _waiting_ack, as they are added there + # automatically via add_done_callback(). + + async def __aenter__(self) -> Self: + """Enter an async context. + + Returns: + This instance. + """ + return self + + async def __aexit__( # noqa: DOC502 + self, + exc_type: type[BaseException] | None, + exc_val: BaseException | None, + exc_tb: TracebackType | None, + ) -> bool | None: + """Exit an async context. + + Stop this instance. + + Args: + exc_type: The type of the exception raised, if any. + exc_val: The exception raised, if any. + exc_tb: The traceback of the exception raised, if any. + + Returns: + Whether the exception was handled. + + Raises: + BaseExceptionGroup: If any of the tasks spawned by this group raised an + exception. + """ + await self.stop() + return None + + async def _wait(self) -> None: + """Wait for this instance to finish. + + Wait until all the group tasks are finished. + + Raises: + BaseExceptionGroup: If any of the tasks spawned by this group raised an + exception. + """ + exceptions: list[BaseException] = [] + + async for task in self.as_completed(): + try: + await task + except BaseException as error: # pylint: disable=broad-except + exceptions.append(error) + + if exceptions: + raise BaseExceptionGroup(f"Error while stopping {self}", exceptions) + + def __await__(self) -> Generator[None, None, None]: # noqa: DOC502 + """Await for all tasks managed by this group to finish. + + Returns: + An implementation-specific generator for the awaitable. + + Raises: + BaseExceptionGroup: If any of the tasks spawned by this group raised an + exception. + """ + return self._wait().__await__() + + def __del__(self) -> None: + """Destroy this instance. + + Cancel all running tasks spawned by this group. + """ + self.cancel("{self!r} was deleted") + + def __repr__(self) -> str: + """Return a string representation of this instance. + + Returns: + A string representation of this instance. + """ + details = "" + if self._running: + details += f" running={len(self._running)}" + if self._waiting_ack: + details += f" waiting_ack={len(self._waiting_ack)}" + return f"{type(self).__name__}<{self.unique_id}{details}>" + + def __str__(self) -> str: + """Return a string representation of this instance. + + Returns: + A string representation of this instance. + """ + return f"{type(self).__name__}:{self._unique_id}" diff --git a/src/frequenz/core/asyncio/_util.py b/src/frequenz/core/asyncio/_util.py new file mode 100644 index 0000000..a01d834 --- /dev/null +++ b/src/frequenz/core/asyncio/_util.py @@ -0,0 +1,64 @@ +# License: MIT +# Copyright © 2024 Frequenz Energy-as-a-Service GmbH + +"""General purpose async utilities.""" + + +import asyncio +import collections.abc +import contextvars +from typing import Any, Protocol, TypeVar, runtime_checkable + +TaskReturnT = TypeVar("TaskReturnT") +"""The type of the return value of a task.""" + + +@runtime_checkable +class TaskCreator(Protocol): + """A protocol for creating tasks. + + Built-in asyncio functions and classes implementing this protocol: + + - [`asyncio`][] + - [`asyncio.AbstractEventLoop`][] (returned by [`asyncio.get_event_loop`][] for + example) + - [`asyncio.TaskGroup`][] + """ + + def create_task( + self, + coro: collections.abc.Coroutine[Any, Any, TaskReturnT], + *, + name: str | None = None, + context: contextvars.Context | None = None, + ) -> asyncio.Task[TaskReturnT]: + """Create a task. + + Args: + coro: The coroutine to be executed. + name: The name of the task. + context: The context to be used for the task. + + Returns: + The new task. + """ + ... # pylint: disable=unnecessary-ellipsis + + +async def cancel_and_await(task: asyncio.Task[Any]) -> None: + """Cancel a task and wait for it to finish. + + Exits immediately if the task is already done. + + The `CancelledError` is suppressed, but any other exception will be propagated. + + Args: + task: The task to be cancelled and waited for. + """ + if task.done(): + return + task.cancel() + try: + await task + except asyncio.CancelledError: + pass diff --git a/tests/test_asyncio.py b/tests/asyncio/test_service.py similarity index 77% rename from tests/test_asyncio.py rename to tests/asyncio/test_service.py index 6d7c6cd..cbbd399 100644 --- a/tests/test_asyncio.py +++ b/tests/asyncio/test_service.py @@ -1,15 +1,16 @@ # License: MIT # Copyright © 2022 Frequenz Energy-as-a-Service GmbH -"""Tests for the asyncio module.""" +"""Tests for the asyncio service module.""" import asyncio from typing import Literal, assert_never import async_solipsism import pytest +from typing_extensions import override -from frequenz.core.asyncio import ServiceBase, TaskCreator +from frequenz.core.asyncio import ServiceBase # This method replaces the event loop for all tests in the file. @@ -34,28 +35,24 @@ def __init__( self._sleep = sleep self._exc = exc - def start(self) -> None: - """Start this service.""" - - async def nop() -> None: - if self._sleep is not None: - await asyncio.sleep(self._sleep) - if self._exc is not None: - raise self._exc - - self._tasks.add(asyncio.create_task(nop(), name="nop")) + @override + async def main(self) -> None: + """Run this service.""" + if self._sleep is not None: + await asyncio.sleep(self._sleep) + if self._exc is not None: + raise self._exc async def test_construction_defaults() -> None: """Test the construction of a service with default arguments.""" fake_service = FakeService() assert fake_service.unique_id == hex(id(fake_service))[2:] - assert fake_service.tasks == set() + assert fake_service.task_group.tasks == set() assert fake_service.is_running is False - assert str(fake_service) == f"FakeService[{fake_service.unique_id}]" + assert str(fake_service) == f"FakeService:{fake_service.unique_id}" assert ( - repr(fake_service) - == f"FakeService(unique_id={fake_service.unique_id!r}, tasks=set())" + repr(fake_service) == f"FakeService<{fake_service.unique_id} main not running>" ) @@ -63,7 +60,7 @@ async def test_construction_custom() -> None: """Test the construction of a service with a custom unique ID.""" fake_service = FakeService(unique_id="test") assert fake_service.unique_id == "test" - assert fake_service.tasks == set() + assert fake_service.task_group.tasks == set() assert fake_service.is_running is False @@ -150,18 +147,3 @@ async def test_async_context_manager() -> None: assert fake_service.is_running is True assert fake_service.is_running is False - - -def test_task_creator_asyncio() -> None: - """Test that the asyncio module is a TaskCreator.""" - assert isinstance(asyncio, TaskCreator) - - -async def test_task_creator_loop() -> None: - """Test that the asyncio event loop is a TaskCreator.""" - assert isinstance(asyncio.get_event_loop(), TaskCreator) - - -def test_task_creator_task_group() -> None: - """Test that the asyncio task group is a TaskCreator.""" - assert isinstance(asyncio.TaskGroup(), TaskCreator) diff --git a/tests/asyncio/test_task_group.py b/tests/asyncio/test_task_group.py new file mode 100644 index 0000000..3b3590c --- /dev/null +++ b/tests/asyncio/test_task_group.py @@ -0,0 +1,387 @@ +# License: MIT +# Copyright © 2024 Frequenz Energy-as-a-Service GmbH + +"""Tests for PersistentTaskGroup.""" + +import asyncio + +import async_solipsism +import pytest + +from frequenz.core.asyncio import PersistentTaskGroup, TaskCreator + + +# This method replaces the event loop for all tests in the file. +@pytest.fixture +def event_loop_policy() -> async_solipsism.EventLoopPolicy: + """Return an event loop policy that uses the async solipsism event loop.""" + return async_solipsism.EventLoopPolicy() + + +async def test_construction_defaults() -> None: + """Test the construction of a group with default arguments.""" + group = PersistentTaskGroup() + assert group.unique_id == hex(id(group))[2:] + assert group.tasks == set() + assert group.is_running is False + assert str(group) == f"PersistentTaskGroup:{group.unique_id}" + assert repr(group) == f"PersistentTaskGroup<{group.unique_id}>" + + +async def test_construction_custom() -> None: + """Test the construction of a group with a custom unique ID.""" + group = PersistentTaskGroup(unique_id="test") + assert group.unique_id == "test" + assert group.tasks == set() + assert group.is_running is False + assert str(group) == "PersistentTaskGroup:test" + assert repr(group) == "PersistentTaskGroup" + + +async def test_task_name() -> None: + """Test a group with some task can be awaited when finishing successfully.""" + group = PersistentTaskGroup(unique_id="test") + + task = group.create_task(asyncio.sleep(0), name="sleep_1") + + assert group.tasks == {task} + assert group.is_running is True + assert str(group) == "PersistentTaskGroup:test" + assert repr(group) == f"PersistentTaskGroup<{group.unique_id} running=1>" + assert task.get_name() == "PersistentTaskGroup:test:sleep_1" + await task + + +async def test_cancel() -> None: + """Test a group cancel all tasks when cancel is called.""" + group = PersistentTaskGroup(unique_id="test") + + task = group.create_task(asyncio.sleep(0), name="sleep_1") + group.cancel() + + await asyncio.sleep(1) # Make sure the task is cancelled + + assert group.is_running is False + assert task.cancelled() + + +async def test_as_completed_with_timeout() -> None: + """Test tasks in a group can be iterated as the complete.""" + group = PersistentTaskGroup(unique_id="test") + + group.create_task(asyncio.sleep(1), name="sleep_1") + + async with asyncio.timeout(1): # Make sure this doesn't hang + async for _ in group.as_completed(timeout=0.5): + assert False, "Should not have any task completed" + + assert group.is_running is True + assert len(group.tasks) == 1 + + +async def test_as_completed() -> None: + """Test tasks in a group can be iterated as they complete.""" + group = PersistentTaskGroup(unique_id="test") + + expected_exception = RuntimeError("Boom!") + + async def _boom_at_2() -> None: + await asyncio.sleep(2) + raise expected_exception + + async def _cancel_at_3() -> None: + await asyncio.sleep(3) + self = asyncio.current_task() + assert self is not None + self.cancel() + await asyncio.sleep(10) + + task_sleep_1 = group.create_task(asyncio.sleep(1), name="sleep_1") + task_boom_at_2 = group.create_task(_boom_at_2(), name="boom_at_2") + task_cancel_at_3 = group.create_task(_cancel_at_3(), name="cancel_at_3") + + assert len(group.tasks) == 3 + + async with asyncio.timeout(4): # Make sure this doesn't hang + order = iter([task_sleep_1, task_boom_at_2, task_cancel_at_3]) + async for task in group.as_completed(): + expected_task = next(order) + assert task is expected_task + if task is task_sleep_1: + assert task.result() is None + elif task is task_boom_at_2: + assert task.exception() is expected_exception + elif task is task_cancel_at_3: + assert task.cancelled() is True + + assert group.is_running is False + assert next(order, None) is None + + +async def test_repr() -> None: + """Test the representation of a group.""" + group = PersistentTaskGroup(unique_id="test") + + tasks: set[asyncio.Task[None]] = set() + tasks.add(group.create_task(asyncio.sleep(0), name="sleep_1")) + + assert repr(group) == f"PersistentTaskGroup<{group.unique_id} running=1>" + + tasks.add(group.create_task(asyncio.sleep(0), name="sleep_2")) + tasks.add(group.create_task(asyncio.sleep(2), name="sleep_3")) + + assert repr(group) == f"PersistentTaskGroup<{group.unique_id} running=3>" + + await asyncio.sleep(1) # Make sure 2 tasks are done + + assert ( + repr(group) == f"PersistentTaskGroup<{group.unique_id} running=1 waiting_ack=2>" + ) + as_completed_iter = group.as_completed() + task = await anext(as_completed_iter, None) + assert task is not None + assert ( + repr(group) == f"PersistentTaskGroup<{group.unique_id} running=1 waiting_ack=2>" + ) + + task = await anext(as_completed_iter, None) + assert task is not None + assert ( + repr(group) == f"PersistentTaskGroup<{group.unique_id} running=1 waiting_ack=1>" + ) + + task = await anext(as_completed_iter, None) + assert task is not None + assert repr(group) == f"PersistentTaskGroup<{group.unique_id} waiting_ack=1>" + + task = await anext(as_completed_iter, None) + assert task is None + assert repr(group) == f"PersistentTaskGroup<{group.unique_id}>" + + await asyncio.gather(*tasks) + + +async def test_await_success() -> None: + """Test a group with some task can be awaited when finishing successfully.""" + group = PersistentTaskGroup(unique_id="test") + + # Is a no-op if the group is not running + await group.stop() + assert group.is_running is False + + task = group.create_task(asyncio.sleep(0), name="sleep_1") + assert group.is_running is True + + # Should stop immediately + async with asyncio.timeout(1): + await group + + assert group.is_running is False + assert task.done() + assert not task.cancelled() + assert task.exception() is None + + +async def test_await_error() -> None: + """Test a group with some task can be awaited when finishing with an error.""" + group = PersistentTaskGroup(unique_id="test") + + expected_exception = RuntimeError("Boom!") + + async def _boom() -> None: + raise expected_exception + + task = group.create_task(_boom(), name="boom") + assert group.is_running is True + + # Should stop immediately + async with asyncio.timeout(1): + with pytest.raises(BaseExceptionGroup) as exc_info: + await group + assert exc_info.value.args == ( + "Error while stopping PersistentTaskGroup:test", + [expected_exception], + ) + + assert group.is_running is False + assert task.done() + assert not task.cancelled() + assert task.exception() is expected_exception + + +async def test_await_cancelled() -> None: + """Test a group with some task can be awaited when cancelled.""" + group = PersistentTaskGroup(unique_id="test") + + task = group.create_task(asyncio.sleep(1), name="sleep_1") + assert group.is_running is True + cancelled = task.cancel("bye bye") + assert cancelled is True + + # Should stop immediately + async with asyncio.timeout(1): + with pytest.raises(BaseExceptionGroup) as exc_info: + await group + assert exc_info.value.args[0] == "Error while stopping PersistentTaskGroup:test" + exceptions = exc_info.value.exceptions + assert len(exceptions) == 1 + assert isinstance(exceptions[0], asyncio.CancelledError) + + assert group.is_running is False + assert task.cancelled() + + +async def test_stop_success() -> None: + """Test a group with some task can be stopped when finishing successfully.""" + group = PersistentTaskGroup(unique_id="test") + + task = group.create_task(asyncio.sleep(2), name="sleep_1") + assert group.is_running is True + + await asyncio.sleep(1) + assert group.is_running is True + + await group.stop() + assert group.is_running is False + + assert task.cancelled() + + await group.stop() + assert group.is_running is False + + +async def test_stop_error() -> None: + """Test a group with some task can be stopped when finishing with an error.""" + group = PersistentTaskGroup(unique_id="test") + + expected_exception = RuntimeError("Boom!") + + async def _boom() -> None: + raise expected_exception + + task = group.create_task(_boom(), name="boom") + assert group.is_running is True + + await asyncio.sleep(1) + assert group.is_running is False + + with pytest.raises(BaseExceptionGroup) as exc_info: + await group.stop() + assert exc_info.value.args == ( + "Error while stopping PersistentTaskGroup:test", + [expected_exception], + ) + + assert group.is_running is False + assert task.done() + assert not task.cancelled() + assert task.exception() is expected_exception + + await group.stop() + assert group.is_running is False + + +async def test_stop_cancelled() -> None: + """Test a group with some task can be stopped when cancelled.""" + group = PersistentTaskGroup(unique_id="test") + + task = group.create_task(asyncio.sleep(1), name="sleep_1") + assert group.is_running is True + + cancelled = task.cancel("bye bye") + assert cancelled is True + + # If we give it some time, then the task will be cancelled and the group will be + # stopped + await asyncio.sleep(0.5) + assert group.is_running is False + + await group.stop() + + assert group.is_running is False + assert task.cancelled() + + await group.stop() + assert group.is_running is False + + +async def test_async_context_manager_success() -> None: + """Test a group works as an async context manager when finishing successfully.""" + async with PersistentTaskGroup(unique_id="test") as group: + assert group.is_running is False + + task = group.create_task(asyncio.sleep(1), name="sleep_1") + + assert group.is_running is True + assert task.done() is False + assert task.cancelled() is False + + await asyncio.sleep(2) + + assert group.is_running is False + assert task.done() + assert not task.cancelled() + assert task.exception() is None + + assert group.is_running is False + assert task.done() + assert not task.cancelled() + assert task.exception() is None + + +async def test_async_context_manager_error() -> None: + """Test a group works as an async context manager when finishing with an error.""" + expected_exception = RuntimeError("Boom!") + + async def _boom() -> None: + raise expected_exception + + async_with_block_finished = False + group: PersistentTaskGroup | None = None + task: asyncio.Task[None] | None = None + + with pytest.raises(BaseExceptionGroup) as exc_info: + async with PersistentTaskGroup(unique_id="test") as group: + task = group.create_task(_boom(), name="boom") + + assert group.is_running is True + assert task.done() is False + assert task.cancelled() is False + + await asyncio.sleep(1) + + assert group.is_running is False + assert task.done() + assert not task.cancelled() + assert task.exception() is expected_exception + async_with_block_finished = True + + assert exc_info.value.args == ( + "Error while stopping PersistentTaskGroup:test", + [expected_exception], + ) + assert async_with_block_finished is True + assert group is not None + assert group.is_running is False + assert task is not None + assert task.done() + assert not task.cancelled() + assert task.exception() is expected_exception + + +async def test_async_context_manager_cancelled() -> None: + """Test a group works as an async context manager when cancelled.""" + async with PersistentTaskGroup(unique_id="test") as group: + task = group.create_task(asyncio.sleep(1), name="sleep_1") + assert group.is_running is True + + cancelled = task.cancel("bye bye") + assert cancelled is True + + assert group.is_running is False + assert task.cancelled() + + +def test_is_task_creator() -> None: + """Test that a persistent task group is a TaskCreator.""" + assert isinstance(PersistentTaskGroup(), TaskCreator) diff --git a/tests/asyncio/test_util.py b/tests/asyncio/test_util.py new file mode 100644 index 0000000..2439669 --- /dev/null +++ b/tests/asyncio/test_util.py @@ -0,0 +1,23 @@ +# License: MIT +# Copyright © 2022 Frequenz Energy-as-a-Service GmbH + +"""Tests for the asyncio util module.""" + +import asyncio + +from frequenz.core.asyncio import TaskCreator + + +def test_task_creator_asyncio() -> None: + """Test that the asyncio module is a TaskCreator.""" + assert isinstance(asyncio, TaskCreator) + + +async def test_task_creator_loop() -> None: + """Test that the asyncio event loop is a TaskCreator.""" + assert isinstance(asyncio.get_event_loop(), TaskCreator) + + +def test_task_creator_task_group() -> None: + """Test that the asyncio task group is a TaskCreator.""" + assert isinstance(asyncio.TaskGroup(), TaskCreator)