diff --git a/src/frequenz/core/asyncio.py b/src/frequenz/core/asyncio.py index 3b964b8..f144344 100644 --- a/src/frequenz/core/asyncio.py +++ b/src/frequenz/core/asyncio.py @@ -10,16 +10,61 @@ - [cancel_and_await][frequenz.core.asyncio.cancel_and_await]: A function that cancels a task and waits for it to finish, handling `CancelledError` exceptions. -- [BackgroundService][frequenz.core.asyncio.BackgroundService]: A base class for - implementing background services that can be started and stopped. +- [Service][frequenz.core.asyncio.Service]: An interface for services running in the + background. +- [ServiceBase][frequenz.core.asyncio.ServiceBase]: A base class for implementing + services running in the background. +- [TaskCreator][frequenz.core.asyncio.TaskCreator]: A protocol for creating tasks. """ import abc import asyncio import collections.abc +import contextvars +import logging from types import TracebackType -from typing import Any, Self +from typing import Any, Protocol, Self, TypeVar, runtime_checkable + +from typing_extensions import override + +_logger = logging.getLogger(__name__) + + +TaskReturnT = TypeVar("TaskReturnT") +"""The type of the return value of a task.""" + + +@runtime_checkable +class TaskCreator(Protocol): + """A protocol for creating tasks. + + Built-in asyncio functions and classes implementing this protocol: + + - [`asyncio`][] + - [`asyncio.AbstractEventLoop`][] (returned by [`asyncio.get_event_loop`][] for + example) + - [`asyncio.TaskGroup`][] + """ + + def create_task( + self, + coro: collections.abc.Coroutine[Any, Any, TaskReturnT], + *, + name: str | None = None, + context: contextvars.Context | None = None, + ) -> asyncio.Task[TaskReturnT]: + """Create a task. + + Args: + coro: The coroutine to be executed. + name: The name of the task. + context: The context to be used for the task. + + Returns: + The new task. + """ + ... # pylint: disable=unnecessary-ellipsis async def cancel_and_await(task: asyncio.Task[Any]) -> None: @@ -41,44 +86,160 @@ async def cancel_and_await(task: asyncio.Task[Any]) -> None: pass -class BackgroundService(abc.ABC): - """A background service that can be started and stopped. +class Service(abc.ABC): + """A service running in the background. + + A service swpawns one of more background tasks and can be + [started][frequenz.core.asyncio.Service.start] and + [stopped][frequenz.core.asyncio.Service.stop] and can work as an async context + manager to provide deterministic cleanup. + + Warning: + As services manage [`asyncio.Task`][] objects, a reference to a running service + must be held for as long as the service is expected to be running. Otherwise, its + tasks will be cancelled and the service will stop. For more information, please + refer to the [Python `asyncio` + documentation](https://docs.python.org/3/library/asyncio-task.html#asyncio.create_task). + + Example: + ```python + async def as_context_manager(service: Service) -> None: + async with service: + assert service.is_running + await asyncio.sleep(5) + assert not service.is_running + + async def manual_start_stop(service: Service) -> None: + # Use only if necessary, as cleanup is more complicated + service.start() + await asyncio.sleep(5) + await service.stop() + ``` + """ + + @abc.abstractmethod + def start(self) -> None: + """Start this service.""" + + @property + @abc.abstractmethod + def unique_id(self) -> str: + """The unique ID of this service.""" + + @property + @abc.abstractmethod + def tasks(self) -> collections.abc.Set[asyncio.Task[Any]]: + """The set of running tasks spawned by this service. + + Users typically should not modify the tasks in the returned set and only use + them for informational purposes. + + Danger: + Changing the returned tasks may lead to unexpected behavior, don't do it + unless the class explicitly documents it is safe to do so. + """ + + @property + @abc.abstractmethod + def is_running(self) -> bool: + """Whether this service is running. + + A service is considered running when at least one task is running. + """ + + @abc.abstractmethod + def cancel(self, msg: str | None = None) -> None: + """Cancel all running tasks spawned by this service. + + Args: + msg: The message to be passed to the tasks being cancelled. + """ + + @abc.abstractmethod + async def stop(self, msg: str | None = None) -> None: # noqa: DOC502 + """Stop this service. + + This method cancels all running tasks spawned by this service and waits for them + to finish. + + Args: + msg: The message to be passed to the tasks being cancelled. + + Raises: + BaseExceptionGroup: If any of the tasks spawned by this service raised an + exception. + """ + + @abc.abstractmethod + async def __aenter__(self) -> Self: + """Enter an async context. + + Start this service. + + Returns: + This service. + """ + + @abc.abstractmethod + async def __aexit__( + self, + exc_type: type[BaseException] | None, + exc_val: BaseException | None, + exc_tb: TracebackType | None, + ) -> None: + """Exit an async context. + + Stop this service. + + Args: + exc_type: The type of the exception raised, if any. + exc_val: The exception raised, if any. + exc_tb: The traceback of the exception raised, if any. + """ + + @abc.abstractmethod + def __await__(self) -> collections.abc.Generator[None, None, None]: # noqa: DOC502 + """Wait for this service to finish. - A background service is a service that runs in the background spawning one or more - tasks. The service can be [started][frequenz.core.asyncio.BackgroundService.start] - and [stopped][frequenz.core.asyncio.BackgroundService.stop] and can work as an - async context manager to provide deterministic cleanup. + Wait until all the service tasks are finished. - To implement a background service, subclasses must implement the - [`start()`][frequenz.core.asyncio.BackgroundService.start] method, which should - start the background tasks needed by the service, and add them to the `_tasks` - protected attribute. + Returns: + An implementation-specific generator for the awaitable. + + Raises: + BaseExceptionGroup: If any of the tasks spawned by this service raised an + exception (`CancelError` is not considered an error and not returned in + the exception group). + """ + + +class ServiceBase(Service, abc.ABC): + """A base class for implementing a service running in the background. + + To implement a service, subclasses must implement the + [`start()`][frequenz.core.asyncio.ServiceBase.start] method, which should start the + background tasks needed by the service using the + [`create_task()`][frequenz.core.asyncio.ServiceBase.create_task] method. If you need to collect results or handle exceptions of the tasks when stopping the service, then you need to also override the - [`stop()`][frequenz.core.asyncio.BackgroundService.stop] method, as the base + [`stop()`][frequenz.core.asyncio.ServiceBase.stop] method, as the base implementation does not collect any results and re-raises all exceptions. - !!! warning - - As background services manage [`asyncio.Task`][] objects, a reference to them - must be held for as long as the background service is expected to be running, - otherwise its tasks will be cancelled and the service will stop. For more - information, please refer to the [Python `asyncio` - documentation](https://docs.python.org/3/library/asyncio-task.html#asyncio.create_task). - Example: ```python import datetime import asyncio - class Clock(BackgroundService): + class Clock(ServiceBase): def __init__(self, resolution_s: float, *, unique_id: str | None = None) -> None: super().__init__(unique_id=unique_id) self._resolution_s = resolution_s def start(self) -> None: - self._tasks.add(asyncio.create_task(self._tick())) + # Managed tasks are automatically saved, so there is no need to hold a + # reference to them if you don't need to further interact with them. + self.create_task(self._tick()) async def _tick(self) -> None: while True: @@ -100,63 +261,121 @@ async def main() -> None: ``` """ - def __init__(self, *, unique_id: str | None = None) -> None: - """Initialize this BackgroundService. + def __init__( + self, *, unique_id: str | None = None, task_creator: TaskCreator = asyncio + ) -> None: + """Initialize this Service. Args: - unique_id: The string to uniquely identify this background service instance. + unique_id: The string to uniquely identify this service instance. If `None`, a string based on `hex(id(self))` will be used. This is used in `__repr__` and `__str__` methods, mainly for debugging - purposes, to identify a particular instance of a background service. + purposes, to identify a particular instance of a service. + task_creator: The object that will be used to create tasks. Usually one of: + the [`asyncio`]() module, an [`asyncio.AbstractEventLoop`]() or + an [`asyncio.TaskGroup`](). """ # [2:] is used to remove the '0x' prefix from the hex representation of the id, # as it doesn't add any uniqueness to the string. self._unique_id: str = hex(id(self))[2:] if unique_id is None else unique_id self._tasks: set[asyncio.Task[Any]] = set() + self._task_creator: TaskCreator = task_creator + @override @abc.abstractmethod def start(self) -> None: - """Start this background service.""" + """Start this service.""" @property + @override def unique_id(self) -> str: - """The unique ID of this background service. - - Returns: - The unique ID of this background service. - """ + """The unique ID of this service.""" return self._unique_id @property + @override def tasks(self) -> collections.abc.Set[asyncio.Task[Any]]: - """Return the set of running tasks spawned by this background service. + """The set of running tasks spawned by this service. Users typically should not modify the tasks in the returned set and only use them for informational purposes. - !!! danger - + Danger: Changing the returned tasks may lead to unexpected behavior, don't do it unless the class explicitly documents it is safe to do so. - - Returns: - The set of running tasks spawned by this background service. """ return self._tasks @property + @override def is_running(self) -> bool: - """Return whether this background service is running. + """Whether this service is running. A service is considered running when at least one task is running. + """ + return any(not task.done() for task in self._tasks) + + def create_task( + self, + coro: collections.abc.Coroutine[Any, Any, TaskReturnT], + *, + name: str | None = None, + context: contextvars.Context | None = None, + log_exception: bool = True, + ) -> asyncio.Task[TaskReturnT]: + """Start a managed task. + + A reference to the task will be held by the service, so there is no need to save + the task object. + + Tasks can be retrieved via the + [`tasks`][frequenz.core.asyncio.ServiceBase.tasks] property. + + Managed tasks always have a `name` including information about the service + itself. If you need to retrieve the final name of the task you can always do so + by calling [`.get_name()`][asyncio.Task.get_name] on the returned task. + + Tasks created this way will also be automatically cancelled when calling + [`cancel()`][frequenz.core.asyncio.ServiceBase.cancel] or + [`stop()`][frequenz.core.asyncio.ServiceBase.stop], or when the service is used + as a async context manager. + + Args: + coro: The coroutine to be managed. + name: The name of the task. Names will always have the form + `f"{self}:{name}"`. If `None` or empty, the default name will be + `hex(id(coro))[2:]`. If you need the final name of the task, it can + always be retrieved + context: The context to be used for the task. + log_exception: Whether to log exceptions raised by the task. Returns: - Whether this background service is running. + The new task. """ - return any(not task.done() for task in self._tasks) + if not name: + name = hex(id(coro))[2:] + task = self._task_creator.create_task( + coro, name=f"{self}:{name}", context=context + ) + self._tasks.add(task) + task.add_done_callback(self._tasks.discard) + + if log_exception: + + def _log_exception(task: asyncio.Task[TaskReturnT]) -> None: + try: + task.result() + except asyncio.CancelledError: + pass + except BaseException: # pylint: disable=broad-except + _logger.exception("%s: Task %r raised an exception", self, task) + + task.add_done_callback(_log_exception) + return task + @override def cancel(self, msg: str | None = None) -> None: - """Cancel all running tasks spawned by this background service. + """Cancel all running tasks spawned by this service. Args: msg: The message to be passed to the tasks being cancelled. @@ -164,8 +383,9 @@ def cancel(self, msg: str | None = None) -> None: for task in self._tasks: task.cancel(msg) + @override async def stop(self, msg: str | None = None) -> None: - """Stop this background service. + """Stop this service. This method cancels all running tasks spawned by this service and waits for them to finish. @@ -181,7 +401,7 @@ async def stop(self, msg: str | None = None) -> None: return self.cancel(msg) try: - await self.wait() + await self except BaseExceptionGroup as exc_group: # We want to ignore CancelledError here as we explicitly cancelled all the # tasks. @@ -191,17 +411,19 @@ async def stop(self, msg: str | None = None) -> None: # add the exceptions we just filtered by adding a from clause here. raise rest # pylint: disable=raise-missing-from + @override async def __aenter__(self) -> Self: """Enter an async context. - Start this background service. + Start this service. Returns: - This background service. + This service. """ self.start() return self + @override async def __aexit__( self, exc_type: type[BaseException] | None, @@ -210,7 +432,7 @@ async def __aexit__( ) -> None: """Exit an async context. - Stop this background service. + Stop this service. Args: exc_type: The type of the exception raised, if any. @@ -219,10 +441,10 @@ async def __aexit__( """ await self.stop() - async def wait(self) -> None: - """Wait this background service to finish. + async def _wait(self) -> None: + """Wait for this service to finish. - Wait until all background service tasks are finished. + Wait until all the service tasks are finished. Raises: BaseExceptionGroup: If any of the tasks spawned by this service raised an @@ -249,23 +471,24 @@ async def wait(self) -> None: exceptions.append(error) if exceptions: raise BaseExceptionGroup( - f"Error while stopping background service {self}", exceptions + f"Error while stopping service {self}", exceptions ) + @override def __await__(self) -> collections.abc.Generator[None, None, None]: - """Await this background service. + """Await this service. - An awaited background service will wait for all its tasks to finish. + An awaited service will wait for all its tasks to finish. Returns: An implementation-specific generator for the awaitable. """ - return self.wait().__await__() + return self._wait().__await__() def __del__(self) -> None: """Destroy this instance. - Cancel all running tasks spawned by this background service. + Cancel all running tasks spawned by this service. """ self.cancel("{self!r} was deleted") diff --git a/tests/test_asyncio.py b/tests/test_asyncio.py index 730a987..6d7c6cd 100644 --- a/tests/test_asyncio.py +++ b/tests/test_asyncio.py @@ -9,7 +9,7 @@ import async_solipsism import pytest -from frequenz.core.asyncio import BackgroundService +from frequenz.core.asyncio import ServiceBase, TaskCreator # This method replaces the event loop for all tests in the file. @@ -19,8 +19,8 @@ def event_loop_policy() -> async_solipsism.EventLoopPolicy: return async_solipsism.EventLoopPolicy() -class FakeService(BackgroundService): - """A background service that does nothing.""" +class FakeService(ServiceBase): + """A service that does nothing.""" def __init__( self, @@ -47,7 +47,7 @@ async def nop() -> None: async def test_construction_defaults() -> None: - """Test the construction of a background service with default arguments.""" + """Test the construction of a service with default arguments.""" fake_service = FakeService() assert fake_service.unique_id == hex(id(fake_service))[2:] assert fake_service.tasks == set() @@ -60,7 +60,7 @@ async def test_construction_defaults() -> None: async def test_construction_custom() -> None: - """Test the construction of a background service with a custom unique ID.""" + """Test the construction of a service with a custom unique ID.""" fake_service = FakeService(unique_id="test") assert fake_service.unique_id == "test" assert fake_service.tasks == set() @@ -68,7 +68,7 @@ async def test_construction_custom() -> None: async def test_start_await() -> None: - """Test a background service starts and can be awaited.""" + """Test a service starts and can be awaited.""" fake_service = FakeService(unique_id="test") assert fake_service.unique_id == "test" assert fake_service.is_running is False @@ -88,7 +88,7 @@ async def test_start_await() -> None: async def test_start_stop() -> None: - """Test a background service starts and stops correctly.""" + """Test a service starts and stops correctly.""" fake_service = FakeService(unique_id="test", sleep=2.0) assert fake_service.unique_id == "test" assert fake_service.is_running is False @@ -110,11 +110,11 @@ async def test_start_stop() -> None: assert fake_service.is_running is False -@pytest.mark.parametrize("method", ["await", "wait", "stop"]) +@pytest.mark.parametrize("method", ["await", "stop"]) async def test_start_and_crash( - method: Literal["await"] | Literal["wait"] | Literal["stop"], + method: Literal["await"] | Literal["stop"], ) -> None: - """Test a background service reports when crashing.""" + """Test a service reports when crashing.""" exc = RuntimeError("error") fake_service = FakeService(unique_id="test", exc=exc) assert fake_service.unique_id == "test" @@ -125,8 +125,6 @@ async def test_start_and_crash( match method: case "await": await fake_service - case "wait": - await fake_service.wait() case "stop": # Give the service some time to run and crash, otherwise stop() will # cancel it before it has a chance to crash @@ -143,7 +141,7 @@ async def test_start_and_crash( async def test_async_context_manager() -> None: - """Test a background service works as an async context manager.""" + """Test a service works as an async context manager.""" async with FakeService(unique_id="test", sleep=1.0) as fake_service: assert fake_service.is_running is True # Is a no-op if the service is running @@ -152,3 +150,18 @@ async def test_async_context_manager() -> None: assert fake_service.is_running is True assert fake_service.is_running is False + + +def test_task_creator_asyncio() -> None: + """Test that the asyncio module is a TaskCreator.""" + assert isinstance(asyncio, TaskCreator) + + +async def test_task_creator_loop() -> None: + """Test that the asyncio event loop is a TaskCreator.""" + assert isinstance(asyncio.get_event_loop(), TaskCreator) + + +def test_task_creator_task_group() -> None: + """Test that the asyncio task group is a TaskCreator.""" + assert isinstance(asyncio.TaskGroup(), TaskCreator)