88import asyncio
99import collections.abc
1010import contextvars
11- import logging
1211from types import TracebackType
1312from typing import Any, Self
1413
1514from typing_extensions import override
1615
16+ from ._task_group import PersistentTaskGroup
1717from ._util import TaskCreator, TaskReturnT
1818
19- _logger = logging.getLogger(__name__)
20-
2119
2220class Service(abc.ABC):
2321 """A service running in the background.
@@ -62,14 +60,11 @@ def unique_id(self) -> str:
6260 @property
6361 @abc.abstractmethod
6462 def is_running(self) -> bool:
65- """Whether this service is running.
66-
67- A service is considered running when at least one task is running.
68- """
63+ """Whether this service is running."""
6964
7065 @abc.abstractmethod
7166 def cancel(self, msg: str | None = None) -> None:
72- """Cancel all running tasks spawned by this service.
67+ """Cancel this service.
7368
7469 Args:
7570 msg: The message to be passed to the tasks being cancelled.
@@ -79,8 +74,7 @@ def cancel(self, msg: str | None = None) -> None:
7974 async def stop(self, msg: str | None = None) -> None: # noqa: DOC502
8075 """Stop this service.
8176
82- This method cancels all running tasks spawned by this service and waits for them
83- to finish.
77+ This method cancels the service and waits for it to finish.
8478
8579 Args:
8680 msg: The message to be passed to the tasks being cancelled.
@@ -149,22 +143,19 @@ class ServiceBase(Service, abc.ABC):
149143 [`stop()`][frequenz.core.asyncio.ServiceBase.stop] method, as the base
150144 implementation does not collect any results and re-raises all exceptions.
151145
152- Example:
146+ Example: Simple single-task example
153147 ```python
154148 import datetime
155149 import asyncio
150+ from typing_extensions import override
156151
157152 class Clock(ServiceBase):
158153 def __init__(self, resolution_s: float, *, unique_id: str | None = None) -> None:
159154 super().__init__(unique_id=unique_id)
160155 self._resolution_s = resolution_s
161156
162- def start(self) -> None:
163- # Managed tasks are automatically saved, so there is no need to hold a
164- # reference to them if you don't need to further interact with them.
165- self.create_task(self._tick())
166-
167- async def _tick(self) -> None:
157+ @override
158+ async def main(self) -> None:
168159 while True:
169160 await asyncio.sleep(self._resolution_s)
170161 print(datetime.datetime.now())
@@ -182,6 +173,49 @@ async def main() -> None:
182173
183174 asyncio.run(main())
184175 ```
176+
177+ Example: Multi-tasks example
178+ ```python
179+ import asyncio
180+ import datetime
181+ from typing_extensions import override
182+
183+ class MultiTaskService(ServiceBase):
184+
185+ async def _print_every(self, *, seconds: float) -> None:
186+ while True:
187+ await asyncio.sleep(seconds)
188+ print(datetime.datetime.now())
189+
190+ async def _fail_after(self, *, seconds: float) -> None:
191+ await asyncio.sleep(seconds)
192+ raise ValueError("I failed")
193+
194+ @override
195+ async def main(self) -> None:
196+ self.create_task(self._print_every(seconds=1), name="print_1")
197+ self.create_task(self._print_every(seconds=11), name="print_11")
198+ failing = self.create_task(self._fail_after(seconds=5), name=f"fail_5")
199+
200+ async for task in self.task_group.as_completed():
201+ assert task.done() # For demonstration purposes only
202+ try:
203+ task.result()
204+ except ValueError as error:
205+ if failing == task:
206+ failing = self.create_task(
207+ self._fail_after(seconds=5), name=f"fail_5"
208+ )
209+ else:
210+ raise
211+
212+ async def main() -> None:
213+ async with MultiTaskService():
214+ await asyncio.sleep(11)
215+
216+ asyncio.run(main())
217+ ```
218+
185219 """
186220
187221 def __init__(
@@ -201,13 +235,10 @@ def __init__(
201235 # [2:] is used to remove the '0x' prefix from the hex representation of the id,
202236 # as it doesn't add any uniqueness to the string.
203237 self._unique_id: str = hex(id(self))[2:] if unique_id is None else unique_id
204- self._tasks: set[asyncio.Task[Any]] = set()
205- self._task_creator: TaskCreator = task_creator
206-
207- @override
208- @abc.abstractmethod
209- def start(self) -> None:
210- """Start this service."""
238+ self._main_task: asyncio.Task[None] | None = None
239+ self._task_group: PersistentTaskGroup = PersistentTaskGroup(
240+ unique_id=self._unique_id, task_creator=task_creator
241+ )
211242
212243 @property
213244 @override
@@ -216,9 +247,22 @@ def unique_id(self) -> str:
216247 return self._unique_id
217248
218249 @property
219- def tasks(self) -> collections.abc.Set[asyncio.Task[Any]]:
220- """The set of running tasks spawned by this service."""
221- return self._tasks
250+ def task_group(self) -> PersistentTaskGroup:
251+ """The task group managing the tasks of this service."""
252+ return self._task_group
253+
254+ @abc.abstractmethod
255+ async def main(self) -> None:
256+ """Execute the service logic."""
257+
258+ @override
259+ def start(self) -> None:
260+ """Start this service."""
261+ if self.is_running:
262+ return
263+ self._main_task = self._task_group.task_creator.create_task(
264+ self.main(), name=str(self)
265+ )
222266
223267 @property
224268 @override
@@ -227,7 +271,7 @@ def is_running(self) -> bool:
227271
228272 A service is considered running when at least one task is running.
229273 """
230- return any(not task.done() for task in self._tasks )
274+ return self._main_task is not None and not self._main_task.done( )
231275
232276 def create_task(
233277 self,
@@ -242,8 +286,8 @@ def create_task(
242286 A reference to the task will be held by the service, so there is no need to save
243287 the task object.
244288
245- Tasks can be retrieved via the
246- [`tasks `][frequenz.core.asyncio.ServiceBase.tasks] property .
289+ Tasks are created using the
290+ [`task_group `][frequenz.core.asyncio.ServiceBase.task_group] .
247291
248292 Managed tasks always have a `name` including information about the service
249293 itself. If you need to retrieve the final name of the task you can always do so
@@ -268,24 +312,9 @@ def create_task(
268312 """
269313 if not name:
270314 name = hex(id(coro))[2:]
271- task = self._task_creator .create_task(
272- coro, name=f"{self}:{name}", context=context
315+ return self._task_group .create_task(
316+ coro, name=f"{self}:{name}", context=context, log_exception=log_exception
273317 )
274- self._tasks.add(task)
275- task.add_done_callback(self._tasks.discard)
276-
277- if log_exception:
278-
279- def _log_exception(task: asyncio.Task[TaskReturnT]) -> None:
280- try:
281- task.result()
282- except asyncio.CancelledError:
283- pass
284- except BaseException: # pylint: disable=broad-except
285- _logger.exception("%s: Task %r raised an exception", self, task)
286-
287- task.add_done_callback(_log_exception)
288- return task
289318
290319 @override
291320 def cancel(self, msg: str | None = None) -> None:
@@ -294,8 +323,9 @@ def cancel(self, msg: str | None = None) -> None:
294323 Args:
295324 msg: The message to be passed to the tasks being cancelled.
296325 """
297- for task in self._tasks:
298- task.cancel(msg)
326+ if self._main_task is not None:
327+ self._main_task.cancel(msg)
328+ self._task_group.cancel(msg)
299329
300330 @override
301331 async def stop(self, msg: str | None = None) -> None:
@@ -311,8 +341,6 @@ async def stop(self, msg: str | None = None) -> None:
311341 BaseExceptionGroup: If any of the tasks spawned by this service raised an
312342 exception.
313343 """
314- if not self._tasks:
315- return
316344 self.cancel(msg)
317345 try:
318346 await self
@@ -369,28 +397,21 @@ async def _wait(self) -> None:
369397 exception (`CancelError` is not considered an error and not returned in
370398 the exception group).
371399 """
372- # We need to account for tasks that were created between when we started
373- # awaiting and we finished awaiting.
374- while self._tasks:
375- done, pending = await asyncio.wait(self._tasks)
376- assert not pending
377-
378- # We remove the done tasks, but there might be new ones created after we
379- # started waiting.
380- self._tasks = self._tasks - done
381-
382- exceptions: list[BaseException] = []
383- for task in done:
384- try:
385- # This will raise a CancelledError if the task was cancelled or any
386- # other exception if the task raised one.
387- _ = task.result()
388- except BaseException as error: # pylint: disable=broad-except
389- exceptions.append(error)
390- if exceptions:
391- raise BaseExceptionGroup(
392- f"Error while stopping service {self}", exceptions
393- )
400+ exceptions: list[BaseException] = []
401+
402+ if self._main_task is not None:
403+ try:
404+ await self._main_task
405+ except BaseException as error: # pylint: disable=broad-except
406+ exceptions.append(error)
407+
408+ try:
409+ await self._task_group
410+ except BaseExceptionGroup as exc_group:
411+ exceptions.append(exc_group)
412+
413+ if exceptions:
414+ raise BaseExceptionGroup(f"Error while stopping {self}", exceptions)
394415
395416 @override
396417 def __await__(self) -> collections.abc.Generator[None, None, None]:
@@ -416,7 +437,13 @@ def __repr__(self) -> str:
416437 Returns:
417438 A string representation of this instance.
418439 """
419- return f"{type(self).__name__}<{self._unique_id} tasks={self._tasks!r}>"
440+ details = "main"
441+ if not self.is_running:
442+ details += " not"
443+ details += " running"
444+ if self._task_group.is_running:
445+ details += f", {len(self._task_group.tasks)} extra tasks"
446+ return f"{type(self).__name__}<{self._unique_id} {details}>"
420447
421448 def __str__(self) -> str:
422449 """Return a string representation of this instance.
0 commit comments