Skip to content

Commit 04ad6e6

Browse files
authored
Properly keep references and await for concurrent tasks (#984)
The data sourcing actor is parallelizing the sending of samples to the stream senders, but it was not properly keeping references to the tasks and awaiting for them to finish. References were probably kept by `gather()`, but `gather()` wasn't awaited, so it can potentially not run at all (although it seems it did, because things were working). This commit uses a `TaskGroup` to keep references to the tasks in `process_msg()` and then makes each `process_msg()` call a task too, but we don't use a task group here because we don't want to await until a batch of messages is sent before receiving the next one. Instead, we want to keep sending messages in the background. But we still need to clean up these tasks as soon as they are done. This replaces the event used to wait until there are no more pending messages to be sent, as it is enough to synchronize awaiting for the (pending) tasks to finish. This was discovered while looking at warnings in the tests (#982).
2 parents 72c25a8 + 0c207a6 commit 04ad6e6

File tree

2 files changed

+26
-22
lines changed

2 files changed

+26
-22
lines changed

RELEASE_NOTES.md

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -18,3 +18,4 @@
1818
## Bug Fixes
1919

2020
- Fixed a typing issue that occurs in some cases when composing formulas with constants.
21+
- Fixed a bug where sending tasks in the data sourcing actor might have not been properly awaited.

src/frequenz/sdk/actor/_data_sourcing/microgrid_api_source.py

Lines changed: 25 additions & 22 deletions
Original file line numberDiff line numberDiff line change
@@ -398,32 +398,35 @@ async def _handle_data_stream(
398398
)
399399
api_data_receiver: Receiver[Any] = self.comp_data_receivers[comp_id]
400400

401-
senders_done: asyncio.Event = asyncio.Event()
402-
pending_messages = 0
403-
404-
def process_msg(data: Any) -> None:
405-
tasks = []
406-
for extractor, senders in stream_senders:
407-
for sender in senders:
408-
tasks.append(
409-
sender.send(
410-
Sample(data.timestamp, Quantity(extractor(data)))
411-
)
401+
async def process_msg(data: Any) -> None:
402+
async with asyncio.TaskGroup() as tg:
403+
for extractor, senders in stream_senders:
404+
for sender in senders:
405+
sample = Sample(data.timestamp, Quantity(extractor(data)))
406+
name = f"send:ts={sample.timestamp}:cid={comp_id}"
407+
tg.create_task(sender.send(sample), name=name)
408+
409+
sending_tasks: set[asyncio.Task[None]] = set()
410+
411+
async def clean_tasks(
412+
sending_tasks: set[asyncio.Task[None]],
413+
) -> set[asyncio.Task[None]]:
414+
done, pending = await asyncio.wait(sending_tasks, timeout=0)
415+
for task in done:
416+
if error := task.exception():
417+
_logger.error(
418+
"Error while processing message in task %s",
419+
task.get_name(),
420+
exc_info=error,
412421
)
413-
asyncio.gather(*tasks)
414-
nonlocal pending_messages
415-
pending_messages -= 1
416-
if pending_messages == 0:
417-
senders_done.set()
422+
return pending
418423

419424
async for data in api_data_receiver:
420-
pending_messages += 1
421-
senders_done.clear()
422-
process_msg(data)
423-
424-
while pending_messages > 0:
425-
await senders_done.wait()
425+
name = f"process_msg:cid={comp_id}"
426+
sending_tasks.add(asyncio.create_task(process_msg(data), name=name))
427+
sending_tasks = await clean_tasks(sending_tasks)
426428

429+
await asyncio.gather(*sending_tasks)
427430
await asyncio.gather(
428431
*[
429432
self._registry.close_and_remove(r.get_channel_name())

0 commit comments

Comments
 (0)