Skip to content

Commit bd468e7

Browse files
committed
Use a PersistentTaskGroup to implement ServiceBase
The service base class now delegates all sub-task management to a `PersistentTaskGroup`, and has a special task called `main()` that drives the service. Writing single-task services should be much easier now, as only a `main()` method needs to be implemented. For more complex, multi-task, services, the internal task group can be used to monitor sub-tasks, for example by using `task_group.as_completed()`. Signed-off-by: Leandro Lucarella <[email protected]>
1 parent fbd6735 commit bd468e7

File tree

2 files changed

+114
-87
lines changed

2 files changed

+114
-87
lines changed

src/frequenz/core/asyncio/_service.py

Lines changed: 101 additions & 74 deletions
Original file line numberDiff line numberDiff line change
@@ -8,16 +8,14 @@
88
import asyncio
99
import collections.abc
1010
import contextvars
11-
import logging
1211
from types import TracebackType
1312
from typing import Any, Self
1413

1514
from typing_extensions import override
1615

16+
from ._task_group import PersistentTaskGroup
1717
from ._util import TaskCreator, TaskReturnT
1818

19-
_logger = logging.getLogger(__name__)
20-
2119

2220
class Service(abc.ABC):
2321
"""A service running in the background.
@@ -62,14 +60,11 @@ def unique_id(self) -> str:
6260
@property
6361
@abc.abstractmethod
6462
def is_running(self) -> bool:
65-
"""Whether this service is running.
66-
67-
A service is considered running when at least one task is running.
68-
"""
63+
"""Whether this service is running."""
6964

7065
@abc.abstractmethod
7166
def cancel(self, msg: str | None = None) -> None:
72-
"""Cancel all running tasks spawned by this service.
67+
"""Cancel this service.
7368

7469
Args:
7570
msg: The message to be passed to the tasks being cancelled.
@@ -79,8 +74,7 @@ def cancel(self, msg: str | None = None) -> None:
7974
async def stop(self, msg: str | None = None) -> None: # noqa: DOC502
8075
"""Stop this service.
8176

82-
This method cancels all running tasks spawned by this service and waits for them
83-
to finish.
77+
This method cancels the service and waits for it to finish.
8478

8579
Args:
8680
msg: The message to be passed to the tasks being cancelled.
@@ -149,22 +143,19 @@ class ServiceBase(Service, abc.ABC):
149143
[`stop()`][frequenz.core.asyncio.ServiceBase.stop] method, as the base
150144
implementation does not collect any results and re-raises all exceptions.
151145

152-
Example:
146+
Example: Simple single-task example
153147
```python
154148
import datetime
155149
import asyncio
150+
from typing_extensions import override
156151

157152
class Clock(ServiceBase):
158153
def __init__(self, resolution_s: float, *, unique_id: str | None = None) -> None:
159154
super().__init__(unique_id=unique_id)
160155
self._resolution_s = resolution_s
161156

162-
def start(self) -> None:
163-
# Managed tasks are automatically saved, so there is no need to hold a
164-
# reference to them if you don't need to further interact with them.
165-
self.create_task(self._tick())
166-
167-
async def _tick(self) -> None:
157+
@override
158+
async def main(self) -> None:
168159
while True:
169160
await asyncio.sleep(self._resolution_s)
170161
print(datetime.datetime.now())
@@ -182,6 +173,49 @@ async def main() -> None:
182173

183174
asyncio.run(main())
184175
```
176+
177+
Example: Multi-tasks example
178+
```python
179+
import asyncio
180+
import datetime
181+
from typing_extensions import override
182+
183+
class MultiTaskService(ServiceBase):
184+
185+
async def _print_every(self, *, seconds: float) -> None:
186+
while True:
187+
await asyncio.sleep(seconds)
188+
print(datetime.datetime.now())
189+
190+
async def _fail_after(self, *, seconds: float) -> None:
191+
await asyncio.sleep(seconds)
192+
raise ValueError("I failed")
193+
194+
@override
195+
async def main(self) -> None:
196+
self.create_task(self._print_every(seconds=1), name="print_1")
197+
self.create_task(self._print_every(seconds=11), name="print_11")
198+
failing = self.create_task(self._fail_after(seconds=5), name=f"fail_5")
199+
200+
async for task in self.task_group.as_completed():
201+
assert task.done() # For demonstration purposes only
202+
try:
203+
task.result()
204+
except ValueError as error:
205+
if failing == task:
206+
failing = self.create_task(
207+
self._fail_after(seconds=5), name=f"fail_5"
208+
)
209+
else:
210+
raise
211+
212+
async def main() -> None:
213+
async with MultiTaskService():
214+
await asyncio.sleep(11)
215+
216+
asyncio.run(main())
217+
```
218+
185219
"""
186220

187221
def __init__(
@@ -201,13 +235,10 @@ def __init__(
201235
# [2:] is used to remove the '0x' prefix from the hex representation of the id,
202236
# as it doesn't add any uniqueness to the string.
203237
self._unique_id: str = hex(id(self))[2:] if unique_id is None else unique_id
204-
self._tasks: set[asyncio.Task[Any]] = set()
205-
self._task_creator: TaskCreator = task_creator
206-
207-
@override
208-
@abc.abstractmethod
209-
def start(self) -> None:
210-
"""Start this service."""
238+
self._main_task: asyncio.Task[None] | None = None
239+
self._task_group: PersistentTaskGroup = PersistentTaskGroup(
240+
unique_id=self._unique_id, task_creator=task_creator
241+
)
211242

212243
@property
213244
@override
@@ -216,9 +247,22 @@ def unique_id(self) -> str:
216247
return self._unique_id
217248

218249
@property
219-
def tasks(self) -> collections.abc.Set[asyncio.Task[Any]]:
220-
"""The set of running tasks spawned by this service."""
221-
return self._tasks
250+
def task_group(self) -> PersistentTaskGroup:
251+
"""The task group managing the tasks of this service."""
252+
return self._task_group
253+
254+
@abc.abstractmethod
255+
async def main(self) -> None:
256+
"""Execute the service logic."""
257+
258+
@override
259+
def start(self) -> None:
260+
"""Start this service."""
261+
if self.is_running:
262+
return
263+
self._main_task = self._task_group.task_creator.create_task(
264+
self.main(), name=str(self)
265+
)
222266

223267
@property
224268
@override
@@ -227,7 +271,7 @@ def is_running(self) -> bool:
227271

228272
A service is considered running when at least one task is running.
229273
"""
230-
return any(not task.done() for task in self._tasks)
274+
return self._main_task is not None and not self._main_task.done()
231275

232276
def create_task(
233277
self,
@@ -242,8 +286,8 @@ def create_task(
242286
A reference to the task will be held by the service, so there is no need to save
243287
the task object.
244288

245-
Tasks can be retrieved via the
246-
[`tasks`][frequenz.core.asyncio.ServiceBase.tasks] property.
289+
Tasks are created using the
290+
[`task_group`][frequenz.core.asyncio.ServiceBase.task_group].
247291

248292
Managed tasks always have a `name` including information about the service
249293
itself. If you need to retrieve the final name of the task you can always do so
@@ -268,24 +312,9 @@ def create_task(
268312
"""
269313
if not name:
270314
name = hex(id(coro))[2:]
271-
task = self._task_creator.create_task(
272-
coro, name=f"{self}:{name}", context=context
315+
return self._task_group.create_task(
316+
coro, name=f"{self}:{name}", context=context, log_exception=log_exception
273317
)
274-
self._tasks.add(task)
275-
task.add_done_callback(self._tasks.discard)
276-
277-
if log_exception:
278-
279-
def _log_exception(task: asyncio.Task[TaskReturnT]) -> None:
280-
try:
281-
task.result()
282-
except asyncio.CancelledError:
283-
pass
284-
except BaseException: # pylint: disable=broad-except
285-
_logger.exception("%s: Task %r raised an exception", self, task)
286-
287-
task.add_done_callback(_log_exception)
288-
return task
289318

290319
@override
291320
def cancel(self, msg: str | None = None) -> None:
@@ -294,8 +323,9 @@ def cancel(self, msg: str | None = None) -> None:
294323
Args:
295324
msg: The message to be passed to the tasks being cancelled.
296325
"""
297-
for task in self._tasks:
298-
task.cancel(msg)
326+
if self._main_task is not None:
327+
self._main_task.cancel(msg)
328+
self._task_group.cancel(msg)
299329

300330
@override
301331
async def stop(self, msg: str | None = None) -> None:
@@ -311,8 +341,6 @@ async def stop(self, msg: str | None = None) -> None:
311341
BaseExceptionGroup: If any of the tasks spawned by this service raised an
312342
exception.
313343
"""
314-
if not self._tasks:
315-
return
316344
self.cancel(msg)
317345
try:
318346
await self
@@ -369,28 +397,21 @@ async def _wait(self) -> None:
369397
exception (`CancelError` is not considered an error and not returned in
370398
the exception group).
371399
"""
372-
# We need to account for tasks that were created between when we started
373-
# awaiting and we finished awaiting.
374-
while self._tasks:
375-
done, pending = await asyncio.wait(self._tasks)
376-
assert not pending
377-
378-
# We remove the done tasks, but there might be new ones created after we
379-
# started waiting.
380-
self._tasks = self._tasks - done
381-
382-
exceptions: list[BaseException] = []
383-
for task in done:
384-
try:
385-
# This will raise a CancelledError if the task was cancelled or any
386-
# other exception if the task raised one.
387-
_ = task.result()
388-
except BaseException as error: # pylint: disable=broad-except
389-
exceptions.append(error)
390-
if exceptions:
391-
raise BaseExceptionGroup(
392-
f"Error while stopping service {self}", exceptions
393-
)
400+
exceptions: list[BaseException] = []
401+
402+
if self._main_task is not None:
403+
try:
404+
await self._main_task
405+
except BaseException as error: # pylint: disable=broad-except
406+
exceptions.append(error)
407+
408+
try:
409+
await self._task_group
410+
except BaseExceptionGroup as exc_group:
411+
exceptions.append(exc_group)
412+
413+
if exceptions:
414+
raise BaseExceptionGroup(f"Error while stopping {self}", exceptions)
394415

395416
@override
396417
def __await__(self) -> collections.abc.Generator[None, None, None]:
@@ -416,7 +437,13 @@ def __repr__(self) -> str:
416437
Returns:
417438
A string representation of this instance.
418439
"""
419-
return f"{type(self).__name__}<{self._unique_id} tasks={self._tasks!r}>"
440+
details = "main"
441+
if not self.is_running:
442+
details += " not"
443+
details += " running"
444+
if self._task_group.is_running:
445+
details += f", {len(self._task_group.tasks)} extra tasks"
446+
return f"{type(self).__name__}<{self._unique_id} {details}>"
420447

421448
def __str__(self) -> str:
422449
"""Return a string representation of this instance.

tests/asyncio/test_service.py

Lines changed: 13 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -8,6 +8,7 @@
88

99
import async_solipsism
1010
import pytest
11+
from typing_extensions import override
1112

1213
from frequenz.core.asyncio import ServiceBase
1314

@@ -34,33 +35,32 @@ def __init__(
3435
self._sleep = sleep
3536
self._exc = exc
3637

37-
def start(self) -> None:
38-
"""Start this service."""
39-
40-
async def nop() -> None:
41-
if self._sleep is not None:
42-
await asyncio.sleep(self._sleep)
43-
if self._exc is not None:
44-
raise self._exc
45-
46-
self._tasks.add(asyncio.create_task(nop(), name="nop"))
38+
@override
39+
async def main(self) -> None:
40+
"""Run this service."""
41+
if self._sleep is not None:
42+
await asyncio.sleep(self._sleep)
43+
if self._exc is not None:
44+
raise self._exc
4745

4846

4947
async def test_construction_defaults() -> None:
5048
"""Test the construction of a service with default arguments."""
5149
fake_service = FakeService()
5250
assert fake_service.unique_id == hex(id(fake_service))[2:]
53-
assert fake_service.tasks == set()
51+
assert fake_service.task_group.tasks == set()
5452
assert fake_service.is_running is False
5553
assert str(fake_service) == f"FakeService:{fake_service.unique_id}"
56-
assert repr(fake_service) == f"FakeService<{fake_service.unique_id} tasks=set()>"
54+
assert (
55+
repr(fake_service) == f"FakeService<{fake_service.unique_id} main not running>"
56+
)
5757

5858

5959
async def test_construction_custom() -> None:
6060
"""Test the construction of a service with a custom unique ID."""
6161
fake_service = FakeService(unique_id="test")
6262
assert fake_service.unique_id == "test"
63-
assert fake_service.tasks == set()
63+
assert fake_service.task_group.tasks == set()
6464
assert fake_service.is_running is False
6565

6666

0 commit comments

Comments
 (0)