diff --git a/appdaemon/appdaemon.py b/appdaemon/appdaemon.py index 7d74b86c2..9fa340550 100755 --- a/appdaemon/appdaemon.py +++ b/appdaemon/appdaemon.py @@ -119,6 +119,7 @@ def __init__( self.booted = "booting" self.logger = logging.get_logger() self.logging.register_ad(self) # needs to go last to reference the config object + self._shutdown_logger = self.logging.get_child("_shutdown") self.stop_event = asyncio.Event() self.global_vars: Any = {} @@ -390,7 +391,7 @@ async def stop(self) -> None: - :meth:`Scheduler ` - :meth:`State ` """ - self.logger.info("Stopping AppDaemon") + self._shutdown_logger.info("Stopping AppDaemon") self.stopping = True # Subsystems are able to create tasks during their stop methods @@ -398,14 +399,14 @@ async def stop(self) -> None: try: await asyncio.wait_for(self.app_management.stop(), timeout=3) except asyncio.TimeoutError: - self.logger.warning("AppManagement stop timed out, continuing shutdown") + self._shutdown_logger.warning("AppManagement stop timed out, continuing shutdown") if self.thread_async is not None: self.thread_async.stop() if self.plugins is not None: try: await asyncio.wait_for(self.plugins.stop(), timeout=1) except asyncio.TimeoutError: - self.logger.warning("Timed out stopping plugins, continuing shutdown") + self._shutdown_logger.warning("Timed out stopping plugins, continuing shutdown") self.sched.stop() self.state.stop() self.threading.stop() @@ -420,7 +421,20 @@ async def stop(self) -> None: all_coro = asyncio.wait(running_tasks, return_when=asyncio.ALL_COMPLETED, timeout=3) gather_task = asyncio.create_task(all_coro, name="appdaemon_stop_tasks") gather_task.add_done_callback(lambda _: self.logger.debug("All tasks finished")) - self.logger.debug("Waiting for tasks to finish...") + self._shutdown_logger.debug("Waiting for tasks %s to finish...", len(running_tasks)) + + # These is left here for future debugging purposes + # await asyncio.sleep(2.0) + # still_running = [ + # task + # for task in asyncio.all_tasks() + # if task is not current_task and task is not gather_task and not task.done() + # ] + # self._shutdown_logger.debug("%s tasks still running after 2 seconds", len(still_running)) + # if still_running: + # for task in still_running: + # self._shutdown_logger.debug("Still running: %s", task.get_name()) + await gather_task # diff --git a/appdaemon/models/config/plugin.py b/appdaemon/models/config/plugin.py index 6e348a730..bf1580e09 100644 --- a/appdaemon/models/config/plugin.py +++ b/appdaemon/models/config/plugin.py @@ -85,7 +85,7 @@ class StartupConditions(BaseModel): event: EventStartupCondition | None = None -class HASSConfig(PluginConfig): +class HASSConfig(PluginConfig, extra="forbid"): ha_url: str = "http://supervisor/core" token: SecretStr ha_key: Annotated[SecretStr, deprecated("'ha_key' is deprecated. Please use long lived tokens instead")] | None = None @@ -101,6 +101,7 @@ class HASSConfig(PluginConfig): commtype: Annotated[str, deprecated("'commtype' is deprecated")] | None = None ws_timeout: ParsedTimedelta = timedelta(seconds=10) """Default timeout for waiting for responses from the websocket connection""" + ws_max_msg_size: int = 4 * 1024 * 1024 suppress_log_messages: bool = False services_sleep_time: ParsedTimedelta = timedelta(seconds=60) """The sleep time in the background task that updates the internal list of available services every once in a while""" diff --git a/appdaemon/plugins/hass/exceptions.py b/appdaemon/plugins/hass/exceptions.py index cd865f5ca..3da340dc5 100644 --- a/appdaemon/plugins/hass/exceptions.py +++ b/appdaemon/plugins/hass/exceptions.py @@ -41,3 +41,10 @@ def __str__(self): if self.namespace != "default": res += f" with namespace '{self.namespace}'" return res + +@dataclass +class HassConnectionError(ade.AppDaemonException): + msg: str + + def __str__(self) -> str: + return self.msg diff --git a/appdaemon/plugins/hass/hassplugin.py b/appdaemon/plugins/hass/hassplugin.py index ec51e3a94..119a4593b 100644 --- a/appdaemon/plugins/hass/hassplugin.py +++ b/appdaemon/plugins/hass/hassplugin.py @@ -6,7 +6,7 @@ import functools import json import ssl -from collections.abc import AsyncGenerator, Iterable +from collections.abc import AsyncGenerator, Callable, Coroutine, Iterable from copy import deepcopy from dataclasses import dataclass, field from datetime import datetime, timedelta @@ -14,7 +14,7 @@ from typing import Any, Literal, Optional import aiohttp -from aiohttp import ClientResponse, ClientResponseError, RequestInfo, WSMsgType +from aiohttp import ClientResponse, ClientResponseError, RequestInfo, WSMsgType, WebSocketError from pydantic import BaseModel import appdaemon.utils as utils @@ -22,8 +22,8 @@ from appdaemon.models.config.plugin import HASSConfig, StartupConditions from appdaemon.plugin_management import PluginBase -from .exceptions import HAEventsSubError -from .utils import ServiceCallStatus, hass_check, looped_coro +from .exceptions import HAEventsSubError, HassConnectionError +from .utils import ServiceCallStatus, hass_check class HASSWebsocketResponse(BaseModel): @@ -81,6 +81,9 @@ class HassPlugin(PluginBase): _result_futures: dict[int, asyncio.Future] _silent_results: dict[int, bool] startup_conditions: list[StartupWaitCondition] + maintenance_tasks: list[asyncio.Task] + """List of tasks that run in the background as part of the plugin operation. These are tracked because they might + need to get cancelled during shutdown.""" start: float @@ -96,6 +99,7 @@ def __init__(self, ad: "AppDaemon", name: str, config: HASSConfig): self._result_futures = {} self._silent_results = {} self.startup_conditions = [] + self.maintenance_tasks = [] self.service_logger = self.diag.getChild("services") self.logger.info("HASS Plugin initialization complete") @@ -107,6 +111,12 @@ async def stop(self): await self.session.close() self.logger.debug("aiohttp session closed for '%s'", self.name) + def _create_maintenance_task(self, coro: Coroutine, name: str) -> asyncio.Task: + task = self.AD.loop.create_task(coro, name=name) + self.maintenance_tasks.append(task) + task.add_done_callback(lambda t: self.maintenance_tasks.remove(t)) + return task + def create_session(self) -> aiohttp.ClientSession: """Handles creating an :py:class:`~aiohttp.ClientSession` with the cert information from the plugin config and the authorization headers for the `REST API `_. @@ -142,30 +152,36 @@ async def websocket_msg_factory(self) -> AsyncGenerator[aiohttp.WSMessage]: self.start = perf_counter() async with self.create_session() as self.session: try: - async with self.session.ws_connect(self.config.websocket_url) as self.ws: + async with self.session.ws_connect( + url=self.config.websocket_url, + max_msg_size=self.config.ws_max_msg_size, + ) as self.ws: + if (exc := self.ws.exception()) is not None: + raise HassConnectionError("Failed to connect to Home Assistant websocket") from exc + async for msg in self.ws: - self.updates_recv += 1 - self.bytes_recv += len(msg.data) yield msg finally: self.connect_event.clear() - async def match_ws_msg(self, msg: aiohttp.WSMessage) -> dict: + async def match_ws_msg(self, msg: aiohttp.WSMessage) -> None: """Uses a :py:ref:`match ` statement on :py:class:`~aiohttp.WSMessage`. Uses :py:meth:`~HassPlugin.process_websocket_json` on :py:attr:`~aiohttp.WSMsgType.TEXT` messages. """ match msg: - case aiohttp.WSMessage(type=WSMsgType.TEXT): + case aiohttp.WSMessage(type=WSMsgType.TEXT, data=str(data)): # create a separate task for processing messages to keep the message reading task unblocked - self.AD.loop.create_task(self.process_websocket_json(msg.json())) - case aiohttp.WSMessage(type=WSMsgType.ERROR): - self.logger.error("Error from aiohttp websocket: %s", msg.json()) + self.updates_recv += 1 + self.bytes_recv += len(data) + # Intentionally not using self._create_maintenance_task here + self.AD.loop.create_task(self.process_websocket_json(msg.json()), name="process_ws_msg") + case aiohttp.WSMessage(type=WSMsgType.ERROR, data=WebSocketError() as err): + self.logger.error("Error from aiohttp websocket: %s", err) case aiohttp.WSMessage(type=WSMsgType.CLOSE): self.logger.debug("Received %s message", msg.type) case _: self.logger.warning("Unhandled websocket message type: %s", msg.type) - return msg.json() @utils.warning_decorator(error_text="Error during processing jSON", reraise=True) async def process_websocket_json(self, resp: dict[str, Any]) -> None: @@ -182,7 +198,7 @@ async def process_websocket_json(self, resp: dict[str, Any]) -> None: case {"type": "auth_ok", "ha_version": ha_version}: self.logger.info("Authenticated to Home Assistant %s", ha_version) # Creating a task here allows the plugin to still receive events as it waits for the startup conditions - self.AD.loop.create_task(self.__post_auth__()) + self._create_maintenance_task(self.__post_auth__(), name="post_auth") case {"type": "auth_invalid", "message": message}: self.logger.error("Failed to authenticate to Home Assistant: %s", message) await self.ws.close() @@ -218,11 +234,14 @@ async def __post_auth__(self) -> None: case _: raise HAEventsSubError(-1, f"Unknown response from subscribe_events: {res}") - config_coro = looped_coro(self.get_hass_config, self.config.config_sleep_time.total_seconds()) - self.AD.loop.create_task(config_coro(self)) - - service_coro = looped_coro(self.get_hass_services, self.config.services_sleep_time.total_seconds()) - self.AD.loop.create_task(service_coro(self)) + self._create_maintenance_task( + self.looped_coro(self.get_hass_config, self.config.config_sleep_time.total_seconds()), + name="get_hass_config loop" + ) + self._create_maintenance_task( + self.looped_coro(self.get_hass_services, self.config.services_sleep_time.total_seconds()), + name="get_hass_services loop" + ) if self.first_time: conditions = self.config.appdaemon_startup_conditions @@ -413,7 +432,7 @@ async def websocket_send_json( ad_status = ServiceCallStatus.TERMINATING result = {"success": False} if not silent: - self.logger.warning(f"AppDaemon cancelled waiting for the response from the request: {request}") + self.logger.debug(f"AppDaemon cancelled waiting for the response from the request: {request}") else: ad_status = ServiceCallStatus.OK @@ -527,14 +546,16 @@ async def wait_for_conditions(self, conditions: StartupConditions | None) -> Non ) tasks: list[asyncio.Task[Literal[True] | None]] = [ - self.AD.loop.create_task(cond.event.wait()) + self._create_maintenance_task(cond.event.wait(), name=f"startup condition: {cond}") for cond in self.startup_conditions ] # fmt: skip if delay := conditions.delay: self.logger.info(f"Adding a {delay:.0f}s delay to the {self.name} startup") - sleep = self.AD.utility.sleep(delay, timeout_ok=True) - task = self.AD.loop.create_task(sleep) + task = self._create_maintenance_task( + self.AD.utility.sleep(delay, timeout_ok=True), + name="startup delay" + ) tasks.append(task) self.logger.info(f"Waiting for {len(tasks)} startup condition tasks after {self.time_str()}") @@ -555,7 +576,7 @@ async def get_updates(self): async for msg in self.websocket_msg_factory(): await self.match_ws_msg(msg) continue - raise ValueError + raise HassConnectionError("Websocket connection lost") except Exception as exc: if not self.AD.stopping: self.error.error(exc) @@ -568,7 +589,17 @@ async def get_updates(self): # always do this block, no matter what finally: + for task in self.maintenance_tasks: + if not task.done(): + task.cancel() + if not self.AD.stopping: + for fut in self._result_futures.values(): + if not fut.done(): + fut.cancel() + self._result_futures.clear() + self._silent_results.clear() + # remove callback from getting local events await self.AD.callbacks.clear_callbacks(self.name) @@ -605,6 +636,22 @@ async def check_register_service( # self.logger.debug("Utility (currently unused)") # return None + async def looped_coro(self, coro: Callable[..., Coroutine], sleep: float): + """Run a coroutine in a loop with a sleep interval. + + This is a utility function that can be used to run a coroutine in a loop with a sleep interval. It is used + internally to run the `get_hass_config` and + """ + while not self.AD.stopping: + try: + await coro() + except asyncio.CancelledError: + pass + except Exception as e: + self.logger.error("Error in looped coroutine: %s", e) + finally: + await self.AD.utility.sleep(sleep, timeout_ok=True) + @utils.warning_decorator(error_text="Unexpected error while getting hass config") async def get_hass_config(self) -> dict[str, Any] | None: resp = await self.websocket_send_json(type="get_config") diff --git a/docs/HASS_API_REFERENCE.rst b/docs/HASS_API_REFERENCE.rst index 13d3d4a4f..ac34fa63c 100644 --- a/docs/HASS_API_REFERENCE.rst +++ b/docs/HASS_API_REFERENCE.rst @@ -82,6 +82,11 @@ This is the full list of configuration options available for the `Hass` plugin. - Timeout for waiting for Home Assistant response from the websocket API. This is the time between when a websocket message is first sent and when Home Assistant responds with some kind of acknowledgement/result. Config values are parsed with :py:func:`parse_timedelta `. Defaults to 10 seconds. + * - ``ws_max_msg_size`` + - optional + - Maximum size in bytes for incoming websocket messages. Defaults to 4MB. Increase this if you have very large + entities (e.g., many attributes) that cause messages to exceed this size. You can also allow any message size by + setting this to 0, but that may cause other unforeseen issues. * - ``suppress_log_messages`` - optional - If ``true``, suppress log messages related to :py:meth:`call_service `. diff --git a/docs/HISTORY.md b/docs/HISTORY.md index c71e7c0ba..1baa4b591 100644 --- a/docs/HISTORY.md +++ b/docs/HISTORY.md @@ -4,12 +4,13 @@ **Features** -None +- Added the ``ws_max_msg_size`` config option to the Hass plugin **Fixes** -Fix for sunrise and sunset with offsets - contributed by [ekutner](https://github.com/ekutner) -Fix for random MQTT disconnects - contributed by [Xsandor](https://github.com/Xsandor) +- Better error handling for receiving huge websocket messages in the Hass plugin +- Fix for sunrise and sunset with offsets - contributed by [ekutner](https://github.com/ekutner) +- Fix for random MQTT disconnects - contributed by [Xsandor](https://github.com/Xsandor) **Breaking Changes**