Skip to content

Commit 880f6f1

Browse files
committed
type hints and docstrings
1 parent 15d3ca7 commit 880f6f1

File tree

1 file changed

+77
-62
lines changed

1 file changed

+77
-62
lines changed

appdaemon/entity.py

Lines changed: 77 additions & 62 deletions
Original file line numberDiff line numberDiff line change
@@ -1,34 +1,36 @@
11
import asyncio
2+
from collections import defaultdict
23
import uuid
34
from collections.abc import Iterable
5+
from dataclasses import dataclass, field
46
from datetime import datetime, timedelta
57
from logging import Logger
6-
from typing import TYPE_CHECKING, Any, overload
8+
from typing import TYPE_CHECKING, Any
79
from collections.abc import Callable
810

11+
from .state import StateCallback
912
import appdaemon.utils as utils
1013
from appdaemon.exceptions import TimeOutException
1114

1215
if TYPE_CHECKING:
1316
from appdaemon.appdaemon import AppDaemon
1417

1518

19+
@dataclass
1620
class Entity:
21+
"""Dataclass to wrap the logic for interacting with a certain entity.
22+
23+
Primarily stores the namespace, app name, and entity id in order to pre-fill calls to the AppDaemon internals.
24+
"""
25+
26+
logger: Logger
1727
AD: "AppDaemon"
1828
name: str
19-
logger: Logger
20-
entity_id: str
2129
namespace: str
30+
entity_id: str | None
31+
_async_events: dict[str, asyncio.Event] = field(default_factory=lambda: defaultdict(asyncio.Event))
2232
# states_attrs = EntityAttrs()
2333

24-
def __init__(self, logger: Logger, ad: "AppDaemon", name: str, namespace: str, entity_id: str):
25-
self.AD = ad
26-
self.name = name
27-
self.logger = logger
28-
self.entity_id = entity_id
29-
self.namespace = namespace
30-
self._async_events = {}
31-
3234
def set_namespace(self, namespace: str) -> None:
3335
"""Sets a new namespace for the Entity to use from that point forward.
3436
It should be noted that when this function is used, a different entity will be referenced.
@@ -51,17 +53,14 @@ def set_namespace(self, namespace: str) -> None:
5153
"""
5254
self.namespace = namespace
5355

54-
@overload
56+
@utils.sync_decorator
5557
async def set_state(
5658
self,
5759
state: Any | None,
58-
attributes: dict,
59-
replace: bool,
60+
attributes: dict | None = None,
61+
replace: bool = False,
6062
**kwargs
61-
) -> dict: ...
62-
63-
@utils.sync_decorator
64-
async def set_state(self, state: Any | None = None, **kwargs) -> dict:
63+
) -> dict:
6564
"""Updates the state of the specified entity.
6665
6766
Args:
@@ -96,14 +95,16 @@ async def set_state(self, state: Any | None = None, **kwargs) -> dict:
9695
namespace=self.namespace,
9796
entity=self.entity_id,
9897
state=state,
98+
attributes=attributes,
99+
replace=replace,
99100
**kwargs
100101
)
101102

102103
@utils.sync_decorator
103104
async def get_state(
104105
self,
105-
attribute: str = None,
106-
default: Any | None= None,
106+
attribute: str | None = None,
107+
default: Any | None = None,
107108
copy: bool = True
108109
) -> Any:
109110
"""Gets the state of any entity within AD.
@@ -145,29 +146,29 @@ async def get_state(
145146
146147
"""
147148
self.logger.debug("get state: %s, %s from %s", self.entity_id, self.namespace, self.name)
148-
return await self.AD.state.get_state(self.name, self.namespace, self.entity_id, attribute, default, copy)
149-
150-
@overload
151-
async def listen_state(
152-
self,
153-
callback: Callable,
154-
new: str | Callable | None = None,
155-
old: str | Callable | None = None,
156-
duration: int | None = None,
157-
attribute: str | None = None,
158-
timeout: int | None = None,
159-
immediate: bool | None = None,
160-
oneshot: bool | None = None,
161-
pin: bool | None = None,
162-
pin_thread: int | None = None,
163-
**kwargs: Any | None
164-
) -> str | list[str]: ...
149+
return await self.AD.state.get_state(
150+
name=self.name,
151+
namespace=self.namespace,
152+
entity_id=self.entity_id,
153+
attribute=attribute,
154+
default=default,
155+
copy=copy
156+
)
165157

166158
@utils.sync_decorator
167159
async def listen_state(
168160
self,
169-
callback: Callable,
170-
**kwargs: Any | None
161+
callback: StateCallback,
162+
new: str | Callable[[Any], bool] | None = None,
163+
old: str | Callable[[Any], bool] | None = None,
164+
duration: str | int | float | timedelta | None = None,
165+
attribute: str| None = None,
166+
timeout: str | int | float | timedelta | None = None,
167+
immediate: bool = False,
168+
oneshot: bool = False,
169+
pin: bool = False,
170+
pin_thread: int | None = None,
171+
**kwargs: Any
171172
) -> str:
172173
"""Registers a callback to react to state changes.
173174
@@ -271,10 +272,20 @@ async def listen_state(
271272
272273
>>> self.handle = self.my_entity.listen_state(self.my_callback, new = "on", duration = 60, immediate = True)
273274
"""
274-
kwargs.pop("namespace", None)
275-
275+
kwargs = dict(
276+
new=new,
277+
old=old,
278+
duration=duration,
279+
attribute=attribute,
280+
timeout=timeout,
281+
immediate=immediate,
282+
oneshot=oneshot,
283+
pin=pin,
284+
pin_thread=pin_thread,
285+
**kwargs
286+
)
287+
kwargs = {k: v for k, v in kwargs.items() if v is not None}
276288
self.logger.debug("Calling listen_state for %s, %s from %s", self.entity_id, kwargs, self.name)
277-
278289
return await self.AD.state.add_state_callback(
279290
name=self.name,
280291
namespace=self.namespace,
@@ -319,18 +330,14 @@ def exists(self) -> bool:
319330
"""Checks the existence of the entity in AD."""
320331
return self.AD.state.entity_exists(self.namespace, self.entity_id)
321332

322-
@overload
333+
@utils.sync_decorator
323334
async def call_service(
324335
self,
325336
service: str,
326-
namespace: str,
327-
return_result: bool,
328-
callback: Callable,
329-
**kwargs
330-
) -> Any: ...
331-
332-
@utils.sync_decorator
333-
async def call_service(self, service: str, namespace: str | None = None, **kwargs: Any | None) -> Any:
337+
timeout: str | int | float | None = None, # Used by utils.sync_decorator
338+
callback: Callable[[Any], Any] | None = None,
339+
**data: Any,
340+
) -> Any:
334341
"""Calls an entity supported Service within AppDaemon.
335342
336343
This function can call only services that are tied to the entity, and provide any required parameters.
@@ -355,15 +362,23 @@ async def call_service(self, service: str, namespace: str | None = None, **kwarg
355362
>>> self.my_entity.call_service("turn_on", color_name="red")
356363
357364
"""
358-
namespace = namespace or self.namespace
359-
kwargs["entity_id"] = self.entity_id
365+
kwargs = dict(
366+
entity_id=self.entity_id,
367+
**data
368+
)
369+
kwargs = {k: v for k, v in kwargs.items() if v is not None}
360370
self.logger.debug("call_service: %s/%s, %s", self.domain, service, kwargs)
361-
return await self.AD.services.call_service(
362-
namespace=namespace,
371+
coro = self.AD.services.call_service(
372+
namespace=self.namespace,
363373
domain=self.domain,
364374
service=service,
365-
data=kwargs
375+
data=data
366376
) # fmt: skip
377+
if callback is None:
378+
return await coro
379+
else:
380+
task = self.AD.loop.create_task(coro)
381+
task.add_done_callback(lambda f: callback(f.result()))
367382

368383
async def wait_state(
369384
self,
@@ -379,9 +394,9 @@ async def wait_state(
379394
380395
Args:
381396
state (Any): The state to wait for, for the entity to be in before continuing
382-
attribute (str): The entity's attribute to use, if not using the entity's state
383-
duration (int|float): How long the state is to hold, before continuing
384-
timeout (int|float): How long to wait for the state to be achieved, before timing out.
397+
attribute (str, optional): The entity's attribute to use, if not using the entity's state
398+
duration (int, float): How long the state is to hold, before continuing
399+
timeout (int, float): How long to wait for the state to be achieved, before timing out.
385400
When it times out, a appdaemon.exceptions.TimeOutException is raised
386401
387402
Returns:
@@ -401,8 +416,7 @@ async def wait_state(
401416
"""
402417

403418
wait_id = uuid.uuid4().hex
404-
async_event = asyncio.Event()
405-
self._async_events[wait_id] = async_event
419+
async_event = self._async_events[wait_id]
406420

407421
try:
408422
handle = await self.listen_state(
@@ -416,11 +430,12 @@ async def wait_state(
416430
wait_id=wait_id,
417431
)
418432
await asyncio.wait_for(async_event.wait(), timeout=timeout)
419-
420433
except asyncio.TimeoutError as e:
421434
await self.AD.state.cancel_state_callback(handle, self.name)
422435
self.logger.warning(f"State Wait for {self.entity_id} Timed Out")
423436
raise TimeOutException("The entity timed out") from e
437+
finally:
438+
self._async_events.pop(wait_id)
424439

425440
async def entity_state_changed(self, *args, wait_id: str, **kwargs) -> None:
426441
"""The entity state changed"""

0 commit comments

Comments
 (0)