|
3 | 3 | import threading |
4 | 4 | import traceback |
5 | 5 | import uuid |
6 | | -from collections.abc import Mapping |
| 6 | +from collections.abc import Awaitable, Callable, Mapping |
7 | 7 | from copy import copy, deepcopy |
8 | 8 | from datetime import timedelta |
9 | 9 | from enum import Enum |
10 | 10 | from logging import Logger |
11 | 11 | from pathlib import Path |
12 | | -from typing import TYPE_CHECKING, Any, Awaitable, List, Literal, Optional, Protocol, Set, overload |
| 12 | +from typing import TYPE_CHECKING, Any, Literal, Optional, Protocol, overload |
13 | 13 |
|
14 | 14 | from . import exceptions as ade |
15 | 15 | from . import utils |
@@ -51,7 +51,7 @@ class State: |
51 | 51 | name: str = "_state" |
52 | 52 | state: dict[str, dict[str, Any] | utils.PersistentDict] |
53 | 53 |
|
54 | | - app_added_namespaces: Set[str] |
| 54 | + app_added_namespaces: set[str] |
55 | 55 |
|
56 | 56 | def __init__(self, ad: "AppDaemon"): |
57 | 57 | self.AD = ad |
@@ -234,7 +234,7 @@ async def remove_persistent_namespace(self, namespace: str, state: utils.Persist |
234 | 234 | self.logger.error('Error removing namespace file %s: %s', ns_file.name, e) |
235 | 235 | continue |
236 | 236 |
|
237 | | - def list_namespaces(self) -> List[str]: |
| 237 | + def list_namespaces(self) -> list[str]: |
238 | 238 | return list(self.state.keys()) |
239 | 239 |
|
240 | 240 | def list_namespace_entities(self, namespace: str) -> list[str]: |
@@ -836,21 +836,25 @@ async def set_state( |
836 | 836 |
|
837 | 837 | plugin = self.AD.plugins.get_plugin_object(namespace) |
838 | 838 |
|
839 | | - if set_plugin_state := getattr(plugin, "set_plugin_state", False): |
| 839 | + plugin_handled = False |
| 840 | + set_plugin_state: Callable[..., Awaitable[dict[str, Any] | None]] | None |
| 841 | + if (set_plugin_state := getattr(plugin, "set_plugin_state", None)) is not None: |
840 | 842 | # We assume that the state change will come back to us via the plugin |
841 | 843 | self.logger.debug("sending event to plugin") |
842 | 844 |
|
843 | | - result = await set_plugin_state( # pyright: ignore[reportCallIssue] |
| 845 | + result = await set_plugin_state( |
844 | 846 | namespace, |
845 | 847 | entity, |
846 | 848 | state=new_state["state"], |
847 | | - attributes=new_state["attributes"] |
848 | | - ) # fmt: skip |
| 849 | + attributes=new_state["attributes"], |
| 850 | + ) |
849 | 851 | if result is not None: |
850 | 852 | if "entity_id" in result: |
851 | 853 | result.pop("entity_id") |
852 | 854 | self.state[namespace][entity] = self.parse_state(namespace, entity, **result) |
853 | | - else: |
| 855 | + plugin_handled = True |
| 856 | + |
| 857 | + if not plugin_handled: |
854 | 858 | # Set the state locally |
855 | 859 | self.state[namespace][entity] = new_state |
856 | 860 | # Fire the event locally |
|
0 commit comments