|
8 | 8 | from collections.abc import Callable |
9 | 9 | from dataclasses import dataclass |
10 | 10 | from datetime import timedelta |
11 | | -from typing import Any, Awaitable |
| 11 | +from typing import Any, Awaitable, cast |
12 | 12 |
|
13 | 13 | from frequenz.channels import Broadcast, Receiver, Sender, select |
14 | | -from frequenz.client.dispatch.types import TargetComponents |
| 14 | +from frequenz.client.common.microgrid.components import ComponentCategory |
| 15 | +from frequenz.client.microgrid import ComponentId |
15 | 16 | from frequenz.sdk.actor import Actor, BackgroundService |
16 | 17 |
|
17 | 18 | from ._dispatch import Dispatch |
18 | 19 |
|
19 | 20 | _logger = logging.getLogger(__name__) |
20 | 21 |
|
| 22 | +TargetComponents = list[ComponentId] | list[ComponentCategory] |
| 23 | +"""One or more target components specifying which components a dispatch targets. |
| 24 | +
|
| 25 | +It can be a list of component IDs or a list of categories. |
| 26 | +""" |
| 27 | + |
21 | 28 |
|
22 | 29 | @dataclass(frozen=True, kw_only=True) |
23 | 30 | class DispatchInfo: |
@@ -46,7 +53,6 @@ class ActorDispatcher(BackgroundService): |
46 | 53 | import asyncio |
47 | 54 | from typing import override |
48 | 55 | from frequenz.dispatch import Dispatcher, ActorDispatcher, DispatchInfo |
49 | | - from frequenz.client.dispatch.types import TargetComponents |
50 | 56 | from frequenz.client.common.microgrid.components import ComponentCategory |
51 | 57 | from frequenz.channels import Receiver, Broadcast, select, selected_from |
52 | 58 | from frequenz.sdk.actor import Actor, run |
@@ -236,10 +242,21 @@ def start(self) -> None: |
236 | 242 | """Start the background service.""" |
237 | 243 | self._tasks.add(asyncio.create_task(self._run())) |
238 | 244 |
|
| 245 | + def _get_target_components_from_dispatch( |
| 246 | + self, dispatch: Dispatch |
| 247 | + ) -> TargetComponents: |
| 248 | + if all(isinstance(comp, int) for comp in dispatch.target): |
| 249 | + # We've confirmed all elements are integers, so we can cast. |
| 250 | + int_components = cast(list[int], dispatch.target) |
| 251 | + return [ComponentId(cid) for cid in int_components] |
| 252 | + # If not all are ints, then it must be a list of ComponentCategory |
| 253 | + # based on the definition of ClientTargetComponents. |
| 254 | + return cast(list[ComponentCategory], dispatch.target) |
| 255 | + |
239 | 256 | async def _start_actor(self, dispatch: Dispatch) -> None: |
240 | 257 | """Start the actor the given dispatch refers to.""" |
241 | 258 | dispatch_update = DispatchInfo( |
242 | | - components=dispatch.target, |
| 259 | + components=self._get_target_components_from_dispatch(dispatch), |
243 | 260 | dry_run=dispatch.dry_run, |
244 | 261 | options=dispatch.payload, |
245 | 262 | _src=dispatch, |
|
0 commit comments