diff --git a/src/frequenz/dispatch/_managing_actor.py b/src/frequenz/dispatch/_managing_actor.py index 6ed4e57..e14c85f 100644 --- a/src/frequenz/dispatch/_managing_actor.py +++ b/src/frequenz/dispatch/_managing_actor.py @@ -4,12 +4,15 @@ """Helper class to manage actors based on dispatches.""" import logging +from abc import abstractmethod +from collections.abc import Callable from dataclasses import dataclass -from typing import Any, Set +from typing import Any from frequenz.channels import Receiver, Sender from frequenz.client.dispatch.types import TargetComponents from frequenz.sdk.actor import Actor +from typing_extensions import override from ._dispatch import Dispatch @@ -38,29 +41,62 @@ class DispatchManagingActor(Actor): ```python import os import asyncio - from frequenz.dispatch import Dispatcher, DispatchManagingActor, DispatchUpdate + from typing import override + from frequenz.dispatch import Dispatcher, DispatchManagingActor, DispatchUpdate, DispatchableActor from frequenz.client.dispatch.types import TargetComponents from frequenz.client.common.microgrid.components import ComponentCategory - - from frequenz.channels import Receiver, Broadcast + from frequenz.channels import Receiver, Broadcast, select, selected_from + from frequenz.sdk.actor import Actor, run class MyActor(Actor): - def __init__(self, updates_channel: Receiver[DispatchUpdate]): - super().__init__() - self._updates_channel = updates_channel - self._dry_run: bool - self._options : dict[str, Any] - + def __init__( + self, + *, + name: str | None = None, + ) -> None: + super().__init__(name=name) + self._dispatch_updates_receiver: Receiver[DispatchUpdate] | None = None + self._dry_run: bool = False + self._options: dict[str, Any] = {} + + @classmethod + def new_with_dispatch( + cls, + initial_dispatch: DispatchUpdate, + dispatch_updates_receiver: Receiver[DispatchUpdate], + *, + name: str | None = None, + ) -> Self: + self = cls(name=name) + self._dispatch_updates_receiver = dispatch_updates_receiver + self._update_dispatch_information(initial_dispatch) + return self + + @override async def _run(self) -> None: - while True: - update = await self._updates_channel.receive() - print("Received update:", update) + other_recv: Receiver[Any] = ... - self.set_components(update.components) - self._dry_run = update.dry_run - self._options = update.options - - def set_components(self, components: TargetComponents) -> None: + if self._dispatch_updates_receiver is None: + async for msg in other: + # do stuff + ... + else: + await self._run_with_dispatch(other_recv) + + async def _run_with_dispatch(self, other_recv: Receiver[Any]) -> None: + async for selected in select(self._dispatch_updates_receiver, other_recv): + if selected_from(selected, self._dispatch_updates_receiver): + self._update_dispatch_information(selected.message) + elif selected_from(selected, other_recv): + # do stuff + ... + else: + assert False, f"Unexpected selected receiver: {selected}" + + def _update_dispatch_information(self, dispatch_update: DispatchUpdate) -> None: + print("Received update:", dispatch_update) + self._dry_run = dispatch_update.dry_run + self._options = dispatch_update.options match components: case [int(), *_] as component_ids: print("Dispatch: Setting components to %s", components) @@ -84,6 +120,7 @@ async def run(): server_url=url, key=key ) + dispatcher.start() # Create update channel to receive dispatch update events pre-start and mid-run dispatch_updates_channel = Broadcast[DispatchUpdate](name="dispatch_updates_channel") @@ -94,19 +131,21 @@ async def run(): status_receiver = dispatcher.running_status_change.new_receiver() managing_actor = DispatchManagingActor( - actor=my_actor, + actor_factory=labda initial_dispatch: MyActor.new_with_dispatch( + initial_dispatch, dispatch_updates_channel.new_receiver(), + ), dispatch_type="EXAMPLE", running_status_receiver=status_receiver, updates_sender=dispatch_updates_channel.new_sender(), ) - await asyncio.gather(dispatcher.start(), managing_actor.start()) + await run(managing_actor) ``` """ def __init__( self, - actor: Actor | Set[Actor], + actor_factory: Callable[[DispatchUpdate], Actor], dispatch_type: str, running_status_receiver: Receiver[Dispatch], updates_sender: Sender[DispatchUpdate] | None = None, @@ -114,38 +153,48 @@ def __init__( """Initialize the dispatch handler. Args: - actor: A set of actors or a single actor to manage. + actor_factory: A callable that creates an actor with some initial dispatch + information. dispatch_type: The type of dispatches to handle. running_status_receiver: The receiver for dispatch running status changes. updates_sender: The sender for dispatch events """ super().__init__() self._dispatch_rx = running_status_receiver - self._actors: frozenset[Actor] = frozenset( - [actor] if isinstance(actor, Actor) else actor - ) + self._actor_factory = actor_factory + self._actor: Actor | None = None self._dispatch_type = dispatch_type self._updates_sender = updates_sender - def _start_actors(self) -> None: + async def _start_actor(self, dispatch_update: DispatchUpdate) -> None: """Start all actors.""" - for actor in self._actors: - if actor.is_running: - _logger.warning("Actor %s is already running", actor.name) - else: - actor.start() + if self._actor is None: + sent_str = "" + if self._updates_sender is not None: + sent_str = ", sent a dispatch update instead of creating a new actor" + await self._updates_sender.send(dispatch_update) + _logger.warning( + "Actor for dispatch type %r is already running%s", + self._dispatch_type, + sent_str, + ) + else: + self._actor = self._actor_factory(dispatch_update) + self._actor.start() - async def _stop_actors(self, msg: str) -> None: + async def _stop_actor(self, msg: str) -> None: """Stop all actors. Args: msg: The message to be passed to the actors being stopped. """ - for actor in self._actors: - if actor.is_running: - await actor.stop(msg) - else: - _logger.warning("Actor %s is not running", actor.name) + if self._actor is None: + _logger.warning( + "Actor for dispatch type %r is not running", self._dispatch_type + ) + else: + await self._actor.stop(msg) + self._actor = None async def _run(self) -> None: """Wait for dispatches and handle them.""" @@ -159,22 +208,40 @@ async def _handle_dispatch(self, dispatch: Dispatch) -> None: dispatch: The dispatch to handle. """ if dispatch.type != self._dispatch_type: - _logger.debug("Ignoring dispatch %s", dispatch.id) + _logger.debug( + "Ignoring dispatch %s, handled type is %r but received %r", + dispatch.id, + self._dispatch_type, + dispatch.type, + ) return if dispatch.started: - if self._updates_sender is not None: - _logger.info("Updated by dispatch %s", dispatch.id) - await self._updates_sender.send( - DispatchUpdate( - components=dispatch.target, - dry_run=dispatch.dry_run, - options=dispatch.payload, - ) + dispatch_update = DispatchUpdate( + components=dispatch.target, + dry_run=dispatch.dry_run, + options=dispatch.payload, + ) + if self._actor is None: + _logger.info( + "A new dispatch with ID %s became active for type %r and the " + "actor was not running, starting...", + dispatch.id, + self._dispatch_type, ) - - _logger.info("Started by dispatch %s", dispatch.id) - self._start_actors() + self._actor = self._actor_factory(dispatch_update) + elif self._updates_sender is not None: + _logger.info( + "A new dispatch with ID %s became active for type %r and the " + "actor was running, sending update...", + dispatch.id, + self._dispatch_type, + ) + await self._updates_sender.send(dispatch_update) else: - _logger.info("Stopped by dispatch %s", dispatch.id) - await self._stop_actors("Dispatch stopped") + _logger.info( + "Actor for dispatch type %r stopped by dispatch ID %s", + self._dispatch_type, + dispatch.id, + ) + await self._stop_actor("Dispatch stopped")