11import asyncio
2+ from collections import defaultdict
23import uuid
34from collections .abc import Iterable
5+ from dataclasses import dataclass , field
46from datetime import datetime , timedelta
57from logging import Logger
6- from typing import TYPE_CHECKING , Any , overload
8+ from typing import TYPE_CHECKING , Any
79from collections .abc import Callable
810
11+ from .state import StateCallback
912import appdaemon .utils as utils
1013from appdaemon .exceptions import TimeOutException
1114
1215if TYPE_CHECKING :
1316 from appdaemon .appdaemon import AppDaemon
1417
1518
19+ @dataclass
1620class Entity :
21+ """Dataclass to wrap the logic for interacting with a certain entity.
22+
23+ Primarily stores the namespace, app name, and entity id in order to pre-fill calls to the AppDaemon internals.
24+ """
25+
26+ logger : Logger
1727 AD : "AppDaemon"
1828 name : str
19- logger : Logger
20- entity_id : str
2129 namespace : str
30+ entity_id : str | None
31+ _async_events : dict [str , asyncio .Event ] = field (default_factory = lambda : defaultdict (asyncio .Event ))
2232 # states_attrs = EntityAttrs()
2333
24- def __init__ (self , logger : Logger , ad : "AppDaemon" , name : str , namespace : str , entity_id : str ):
25- self .AD = ad
26- self .name = name
27- self .logger = logger
28- self .entity_id = entity_id
29- self .namespace = namespace
30- self ._async_events = {}
31-
3234 def set_namespace (self , namespace : str ) -> None :
3335 """Sets a new namespace for the Entity to use from that point forward.
3436 It should be noted that when this function is used, a different entity will be referenced.
@@ -51,17 +53,14 @@ def set_namespace(self, namespace: str) -> None:
5153 """
5254 self .namespace = namespace
5355
54- @overload
56+ @utils . sync_decorator
5557 async def set_state (
5658 self ,
5759 state : Any | None ,
58- attributes : dict ,
59- replace : bool ,
60+ attributes : dict | None = None ,
61+ replace : bool = False ,
6062 ** kwargs
61- ) -> dict : ...
62-
63- @utils .sync_decorator
64- async def set_state (self , state : Any | None = None , ** kwargs ) -> dict :
63+ ) -> dict :
6564 """Updates the state of the specified entity.
6665
6766 Args:
@@ -96,14 +95,16 @@ async def set_state(self, state: Any | None = None, **kwargs) -> dict:
9695 namespace = self .namespace ,
9796 entity = self .entity_id ,
9897 state = state ,
98+ attributes = attributes ,
99+ replace = replace ,
99100 ** kwargs
100101 )
101102
102103 @utils .sync_decorator
103104 async def get_state (
104105 self ,
105- attribute : str = None ,
106- default : Any | None = None ,
106+ attribute : str | None = None ,
107+ default : Any | None = None ,
107108 copy : bool = True
108109 ) -> Any :
109110 """Gets the state of any entity within AD.
@@ -145,29 +146,29 @@ async def get_state(
145146
146147 """
147148 self .logger .debug ("get state: %s, %s from %s" , self .entity_id , self .namespace , self .name )
148- return await self .AD .state .get_state (self .name , self .namespace , self .entity_id , attribute , default , copy )
149-
150- @overload
151- async def listen_state (
152- self ,
153- callback : Callable ,
154- new : str | Callable | None = None ,
155- old : str | Callable | None = None ,
156- duration : int | None = None ,
157- attribute : str | None = None ,
158- timeout : int | None = None ,
159- immediate : bool | None = None ,
160- oneshot : bool | None = None ,
161- pin : bool | None = None ,
162- pin_thread : int | None = None ,
163- ** kwargs : Any | None
164- ) -> str | list [str ]: ...
149+ return await self .AD .state .get_state (
150+ name = self .name ,
151+ namespace = self .namespace ,
152+ entity_id = self .entity_id ,
153+ attribute = attribute ,
154+ default = default ,
155+ copy = copy
156+ )
165157
166158 @utils .sync_decorator
167159 async def listen_state (
168160 self ,
169- callback : Callable ,
170- ** kwargs : Any | None
161+ callback : StateCallback ,
162+ new : str | Callable [[Any ], bool ] | None = None ,
163+ old : str | Callable [[Any ], bool ] | None = None ,
164+ duration : str | int | float | timedelta | None = None ,
165+ attribute : str | None = None ,
166+ timeout : str | int | float | timedelta | None = None ,
167+ immediate : bool = False ,
168+ oneshot : bool = False ,
169+ pin : bool = False ,
170+ pin_thread : int | None = None ,
171+ ** kwargs : Any
171172 ) -> str :
172173 """Registers a callback to react to state changes.
173174
@@ -271,10 +272,20 @@ async def listen_state(
271272
272273 >>> self.handle = self.my_entity.listen_state(self.my_callback, new = "on", duration = 60, immediate = True)
273274 """
274- kwargs .pop ("namespace" , None )
275-
275+ kwargs = dict (
276+ new = new ,
277+ old = old ,
278+ duration = duration ,
279+ attribute = attribute ,
280+ timeout = timeout ,
281+ immediate = immediate ,
282+ oneshot = oneshot ,
283+ pin = pin ,
284+ pin_thread = pin_thread ,
285+ ** kwargs
286+ )
287+ kwargs = {k : v for k , v in kwargs .items () if v is not None }
276288 self .logger .debug ("Calling listen_state for %s, %s from %s" , self .entity_id , kwargs , self .name )
277-
278289 return await self .AD .state .add_state_callback (
279290 name = self .name ,
280291 namespace = self .namespace ,
@@ -319,18 +330,14 @@ def exists(self) -> bool:
319330 """Checks the existence of the entity in AD."""
320331 return self .AD .state .entity_exists (self .namespace , self .entity_id )
321332
322- @overload
333+ @utils . sync_decorator
323334 async def call_service (
324335 self ,
325336 service : str ,
326- namespace : str ,
327- return_result : bool ,
328- callback : Callable ,
329- ** kwargs
330- ) -> Any : ...
331-
332- @utils .sync_decorator
333- async def call_service (self , service : str , namespace : str | None = None , ** kwargs : Any | None ) -> Any :
337+ timeout : str | int | float | None = None , # Used by utils.sync_decorator
338+ callback : Callable [[Any ], Any ] | None = None ,
339+ ** data : Any ,
340+ ) -> Any :
334341 """Calls an entity supported Service within AppDaemon.
335342
336343 This function can call only services that are tied to the entity, and provide any required parameters.
@@ -355,15 +362,23 @@ async def call_service(self, service: str, namespace: str | None = None, **kwarg
355362 >>> self.my_entity.call_service("turn_on", color_name="red")
356363
357364 """
358- namespace = namespace or self .namespace
359- kwargs ["entity_id" ] = self .entity_id
365+ kwargs = dict (
366+ entity_id = self .entity_id ,
367+ ** data
368+ )
369+ kwargs = {k : v for k , v in kwargs .items () if v is not None }
360370 self .logger .debug ("call_service: %s/%s, %s" , self .domain , service , kwargs )
361- return await self .AD .services .call_service (
362- namespace = namespace ,
371+ coro = self .AD .services .call_service (
372+ namespace = self . namespace ,
363373 domain = self .domain ,
364374 service = service ,
365- data = kwargs
375+ data = data
366376 ) # fmt: skip
377+ if callback is None :
378+ return await coro
379+ else :
380+ task = self .AD .loop .create_task (coro )
381+ task .add_done_callback (lambda f : callback (f .result ()))
367382
368383 async def wait_state (
369384 self ,
@@ -379,9 +394,9 @@ async def wait_state(
379394
380395 Args:
381396 state (Any): The state to wait for, for the entity to be in before continuing
382- attribute (str): The entity's attribute to use, if not using the entity's state
383- duration (int| float): How long the state is to hold, before continuing
384- timeout (int| float): How long to wait for the state to be achieved, before timing out.
397+ attribute (str, optional ): The entity's attribute to use, if not using the entity's state
398+ duration (int, float): How long the state is to hold, before continuing
399+ timeout (int, float): How long to wait for the state to be achieved, before timing out.
385400 When it times out, a appdaemon.exceptions.TimeOutException is raised
386401
387402 Returns:
@@ -401,8 +416,7 @@ async def wait_state(
401416 """
402417
403418 wait_id = uuid .uuid4 ().hex
404- async_event = asyncio .Event ()
405- self ._async_events [wait_id ] = async_event
419+ async_event = self ._async_events [wait_id ]
406420
407421 try :
408422 handle = await self .listen_state (
@@ -416,11 +430,12 @@ async def wait_state(
416430 wait_id = wait_id ,
417431 )
418432 await asyncio .wait_for (async_event .wait (), timeout = timeout )
419-
420433 except asyncio .TimeoutError as e :
421434 await self .AD .state .cancel_state_callback (handle , self .name )
422435 self .logger .warning (f"State Wait for { self .entity_id } Timed Out" )
423436 raise TimeOutException ("The entity timed out" ) from e
437+ finally :
438+ self ._async_events .pop (wait_id )
424439
425440 async def entity_state_changed (self , * args , wait_id : str , ** kwargs ) -> None :
426441 """The entity state changed"""
0 commit comments