66import functools
77import json
88import ssl
9- from collections .abc import AsyncGenerator , Iterable
9+ from collections .abc import AsyncGenerator , Callable , Coroutine , Iterable
1010from copy import deepcopy
1111from dataclasses import dataclass , field
1212from datetime import datetime , timedelta
1313from time import perf_counter
1414from typing import Any , Literal , Optional
1515
1616import aiohttp
17- from aiohttp import ClientResponse , ClientResponseError , RequestInfo , WSMsgType
17+ from aiohttp import ClientResponse , ClientResponseError , RequestInfo , WSMsgType , WebSocketError
1818from pydantic import BaseModel
1919
2020import appdaemon .utils as utils
2121from appdaemon .appdaemon import AppDaemon
2222from appdaemon .models .config .plugin import HASSConfig , StartupConditions
2323from appdaemon .plugin_management import PluginBase
2424
25- from .exceptions import HAEventsSubError
26- from .utils import ServiceCallStatus , hass_check , looped_coro
25+ from .exceptions import HAEventsSubError , HassConnectionError
26+ from .utils import ServiceCallStatus , hass_check
2727
2828
2929class HASSWebsocketResponse (BaseModel ):
@@ -81,6 +81,9 @@ class HassPlugin(PluginBase):
8181 _result_futures : dict [int , asyncio .Future ]
8282 _silent_results : dict [int , bool ]
8383 startup_conditions : list [StartupWaitCondition ]
84+ maintenance_tasks : list [asyncio .Task ]
85+ """List of tasks that run in the background as part of the plugin operation. These are tracked because they might
86+ need to get cancelled during shutdown."""
8487
8588 start : float
8689
@@ -96,6 +99,7 @@ def __init__(self, ad: "AppDaemon", name: str, config: HASSConfig):
9699 self ._result_futures = {}
97100 self ._silent_results = {}
98101 self .startup_conditions = []
102+ self .maintenance_tasks = []
99103
100104 self .service_logger = self .diag .getChild ("services" )
101105 self .logger .info ("HASS Plugin initialization complete" )
@@ -107,15 +111,18 @@ async def stop(self):
107111 await self .session .close ()
108112 self .logger .debug ("aiohttp session closed for '%s'" , self .name )
109113
114+ def _create_maintenance_task (self , coro : Coroutine , name : str ) -> asyncio .Task :
115+ task = self .AD .loop .create_task (coro , name = name )
116+ self .maintenance_tasks .append (task )
117+ task .add_done_callback (lambda t : self .maintenance_tasks .remove (t ))
118+ return task
119+
110120 def create_session (self ) -> aiohttp .ClientSession :
111121 """Handles creating an :py:class:`~aiohttp.ClientSession` with the cert information from the plugin config
112122 and the authorization headers for the `REST API <https://developers.home-assistant.io/docs/api/rest>`_.
113123 """
114- if self .config .cert_path is not None :
115- ssl_context = ssl .create_default_context (capath = self .config .cert_path )
116- conn = aiohttp .TCPConnector (ssl_context = ssl_context , verify_ssl = self .config .cert_verify )
117- else :
118- conn = aiohttp .TCPConnector (ssl = False )
124+ ssl_context = ssl .create_default_context (capath = self .config .cert_path )
125+ conn = aiohttp .TCPConnector (ssl_context = ssl_context )
119126
120127 connect_timeout_secs = self .config .connect_timeout .total_seconds ()
121128 return aiohttp .ClientSession (
@@ -142,30 +149,36 @@ async def websocket_msg_factory(self) -> AsyncGenerator[aiohttp.WSMessage]:
142149 self .start = perf_counter ()
143150 async with self .create_session () as self .session :
144151 try :
145- async with self .session .ws_connect (self .config .websocket_url ) as self .ws :
152+ async with self .session .ws_connect (
153+ url = self .config .websocket_url ,
154+ max_msg_size = self .config .ws_max_msg_size ,
155+ ) as self .ws :
156+ if (exc := self .ws .exception ()) is not None :
157+ raise HassConnectionError ("Failed to connect to Home Assistant websocket" ) from exc
158+
146159 async for msg in self .ws :
147- self .updates_recv += 1
148- self .bytes_recv += len (msg .data )
149160 yield msg
150161 finally :
151162 self .connect_event .clear ()
152163
153- async def match_ws_msg (self , msg : aiohttp .WSMessage ) -> dict :
164+ async def match_ws_msg (self , msg : aiohttp .WSMessage ) -> None :
154165 """Uses a :py:ref:`match <class-patterns>` statement on :py:class:`~aiohttp.WSMessage`.
155166
156167 Uses :py:meth:`~HassPlugin.process_websocket_json` on :py:attr:`~aiohttp.WSMsgType.TEXT` messages.
157168 """
158169 match msg :
159- case aiohttp .WSMessage (type = WSMsgType .TEXT ):
170+ case aiohttp .WSMessage (type = WSMsgType .TEXT , data = str ( data ) ):
160171 # create a separate task for processing messages to keep the message reading task unblocked
161- self .AD .loop .create_task (self .process_websocket_json (msg .json ()))
162- case aiohttp .WSMessage (type = WSMsgType .ERROR ):
163- self .logger .error ("Error from aiohttp websocket: %s" , msg .json ())
172+ self .updates_recv += 1
173+ self .bytes_recv += len (data )
174+ # Intentionally not using self._create_maintenance_task here
175+ self .AD .loop .create_task (self .process_websocket_json (msg .json ()), name = "process_ws_msg" )
176+ case aiohttp .WSMessage (type = WSMsgType .ERROR , data = WebSocketError () as err ):
177+ self .logger .error ("Error from aiohttp websocket: %s" , err )
164178 case aiohttp .WSMessage (type = WSMsgType .CLOSE ):
165179 self .logger .debug ("Received %s message" , msg .type )
166180 case _:
167181 self .logger .warning ("Unhandled websocket message type: %s" , msg .type )
168- return msg .json ()
169182
170183 @utils .warning_decorator (error_text = "Error during processing jSON" , reraise = True )
171184 async def process_websocket_json (self , resp : dict [str , Any ]) -> None :
@@ -182,7 +195,7 @@ async def process_websocket_json(self, resp: dict[str, Any]) -> None:
182195 case {"type" : "auth_ok" , "ha_version" : ha_version }:
183196 self .logger .info ("Authenticated to Home Assistant %s" , ha_version )
184197 # Creating a task here allows the plugin to still receive events as it waits for the startup conditions
185- self .AD . loop . create_task (self .__post_auth__ ())
198+ self ._create_maintenance_task (self .__post_auth__ (), name = "post_auth" )
186199 case {"type" : "auth_invalid" , "message" : message }:
187200 self .logger .error ("Failed to authenticate to Home Assistant: %s" , message )
188201 await self .ws .close ()
@@ -218,11 +231,14 @@ async def __post_auth__(self) -> None:
218231 case _:
219232 raise HAEventsSubError (- 1 , f"Unknown response from subscribe_events: { res } " )
220233
221- config_coro = looped_coro (self .get_hass_config , self .config .config_sleep_time .total_seconds ())
222- self .AD .loop .create_task (config_coro (self ))
223-
224- service_coro = looped_coro (self .get_hass_services , self .config .services_sleep_time .total_seconds ())
225- self .AD .loop .create_task (service_coro (self ))
234+ self ._create_maintenance_task (
235+ self .looped_coro (self .get_hass_config , self .config .config_sleep_time .total_seconds ()),
236+ name = "get_hass_config loop"
237+ )
238+ self ._create_maintenance_task (
239+ self .looped_coro (self .get_hass_services , self .config .services_sleep_time .total_seconds ()),
240+ name = "get_hass_services loop"
241+ )
226242
227243 if self .first_time :
228244 conditions = self .config .appdaemon_startup_conditions
@@ -413,7 +429,7 @@ async def websocket_send_json(
413429 ad_status = ServiceCallStatus .TERMINATING
414430 result = {"success" : False }
415431 if not silent :
416- self .logger .warning (f"AppDaemon cancelled waiting for the response from the request: { request } " )
432+ self .logger .debug (f"AppDaemon cancelled waiting for the response from the request: { request } " )
417433 else :
418434 ad_status = ServiceCallStatus .OK
419435
@@ -527,14 +543,16 @@ async def wait_for_conditions(self, conditions: StartupConditions | None) -> Non
527543 )
528544
529545 tasks : list [asyncio .Task [Literal [True ] | None ]] = [
530- self .AD . loop . create_task (cond .event .wait ())
546+ self ._create_maintenance_task (cond .event .wait (), name = f"startup condition: { cond } " )
531547 for cond in self .startup_conditions
532548 ] # fmt: skip
533549
534550 if delay := conditions .delay :
535551 self .logger .info (f"Adding a { delay :.0f} s delay to the { self .name } startup" )
536- sleep = self .AD .utility .sleep (delay , timeout_ok = True )
537- task = self .AD .loop .create_task (sleep )
552+ task = self ._create_maintenance_task (
553+ self .AD .utility .sleep (delay , timeout_ok = True ),
554+ name = "startup delay"
555+ )
538556 tasks .append (task )
539557
540558 self .logger .info (f"Waiting for { len (tasks )} startup condition tasks after { self .time_str ()} " )
@@ -555,7 +573,7 @@ async def get_updates(self):
555573 async for msg in self .websocket_msg_factory ():
556574 await self .match_ws_msg (msg )
557575 continue
558- raise ValueError
576+ raise HassConnectionError ( "Websocket connection lost" )
559577 except Exception as exc :
560578 if not self .AD .stopping :
561579 self .error .error (exc )
@@ -568,7 +586,17 @@ async def get_updates(self):
568586
569587 # always do this block, no matter what
570588 finally :
589+ for task in self .maintenance_tasks :
590+ if not task .done ():
591+ task .cancel ()
592+
571593 if not self .AD .stopping :
594+ for fut in self ._result_futures .values ():
595+ if not fut .done ():
596+ fut .cancel ()
597+ self ._result_futures .clear ()
598+ self ._silent_results .clear ()
599+
572600 # remove callback from getting local events
573601 await self .AD .callbacks .clear_callbacks (self .name )
574602
@@ -605,6 +633,22 @@ async def check_register_service(
605633 # self.logger.debug("Utility (currently unused)")
606634 # return None
607635
636+ async def looped_coro (self , coro : Callable [..., Coroutine ], sleep : float ):
637+ """Run a coroutine in a loop with a sleep interval.
638+
639+ This is a utility function that can be used to run a coroutine in a loop with a sleep interval. It is used
640+ internally to run the `get_hass_config` and
641+ """
642+ while not self .AD .stopping :
643+ try :
644+ await coro ()
645+ except asyncio .CancelledError :
646+ pass
647+ except Exception as e :
648+ self .logger .error ("Error in looped coroutine: %s" , e )
649+ finally :
650+ await self .AD .utility .sleep (sleep , timeout_ok = True )
651+
608652 @utils .warning_decorator (error_text = "Unexpected error while getting hass config" )
609653 async def get_hass_config (self ) -> dict [str , Any ] | None :
610654 resp = await self .websocket_send_json (type = "get_config" )
0 commit comments