diff --git a/appdaemon/plugins/hass/hassapi.py b/appdaemon/plugins/hass/hassapi.py index 4fe558d67..6445bb997 100644 --- a/appdaemon/plugins/hass/hassapi.py +++ b/appdaemon/plugins/hass/hassapi.py @@ -28,7 +28,7 @@ # having configured the error logger to use a different name than 'Error' Logging().get_error().warning( "Importing 'hassapi' directly is deprecated and will be removed in a future version. " - "To use the Hass plugin use 'from appdaemon.plugins import hass' instead.", + "To use the Hass plugin use 'from appdaemon.plugins.hass import Hass' instead.", ) @@ -69,8 +69,7 @@ async def ping(self) -> float | None: @utils.sync_decorator async def check_for_entity(self, entity_id: str, namespace: str | None = None) -> bool: - """Uses the REST API to check if an entity exists instead of checking - AD's internal state. + """Uses the REST API to check if an entity exists instead of checking AppDaemon's internal state. Args: entity_id (str): Fully qualified id. @@ -81,10 +80,14 @@ async def check_for_entity(self, entity_id: str, namespace: str | None = None) - Returns: Bool of whether the entity exists. """ - plugin: "HassPlugin" = self.AD.plugins.get_plugin_object( - namespace or self.namespace - ) - return await plugin.check_for_entity(entity_id) + namespace = namespace if namespace is not None else self.namespace + match self.AD.plugins.get_plugin_object(namespace): + case HassPlugin() as plugin: + match await plugin.check_for_entity(entity_id): + case dict(): + return True + return False + # # Internal Helpers @@ -457,7 +460,7 @@ async def call_service( ) -> Any: ... @utils.sync_decorator - async def call_service(self, *args, **kwargs) -> Any: + async def call_service(self, *args, timeout: str | int | float | None = None, **kwargs) -> Any: """Calls a Service within AppDaemon. Services represent specific actions, and are generally registered by plugins or provided by AppDaemon itself. @@ -540,7 +543,8 @@ async def call_service(self, *args, **kwargs) -> Any: """ # We just wrap the ADAPI.call_service method here to add some additional arguments and docstrings - return await super().call_service(*args, **kwargs) + kwargs = utils.remove_literals(kwargs, (None,)) + return await super().call_service(*args, timeout=timeout, **kwargs) def get_service_info(self, service: str) -> dict | None: """Get some information about what kind of data the service expects to receive, which is helpful for debugging. @@ -762,16 +766,13 @@ async def get_history( >>> data = self.get_history(end_time = end_time, days = 5) """ - - namespace = namespace or self._namespace - if days is not None: - end_time = end_time or await self.get_now() + end_time = self.parse_datetime(end_time) if end_time is not None else await self.get_now() start_time = end_time - timedelta(days=days) - plugin = self.AD.plugins.get_plugin_object(namespace or self.namespace) - match plugin: - case HassPlugin(): + namespace = namespace if namespace is not None else self.namespace + match self.AD.plugins.get_plugin_object(namespace): + case HassPlugin() as plugin: coro = plugin.get_history( filter_entity_id=entity_id, timestamp=start_time, @@ -786,21 +787,18 @@ async def get_history( else: return await coro case _: - self.logger.warning( - "Wrong Namespace selected, as %s has no database plugin attached to it", - namespace, - ) + self.logger.warning("HASS plugin not found in namespace '%s'", namespace) @utils.sync_decorator async def get_logbook( self, entity: str | None = None, - start_time: datetime | None = None, - end_time: datetime | None = None, + start_time: datetime | str | None = None, + end_time: datetime | str | None = None, days: int | None = None, callback: Callable | None = None, namespace: str | None = None, - ) -> list[dict[str, str | datetime]]: + ) -> list[dict[str, str | datetime]] | None: """Gets access to the HA Database. This is a convenience function that allows accessing the HA Database. Caution must be taken when using this, as depending on the size of the @@ -838,23 +836,24 @@ async def get_logbook( """ if days is not None: - end_time = end_time or await self.get_now() + end_time = self.parse_datetime(end_time) if end_time is not None else await self.get_now() start_time = end_time - timedelta(days=days) - plugin: "HassPlugin" = self.AD.plugins.get_plugin_object( - namespace or self.namespace - ) - if plugin is not None: - coro = plugin.get_logbook( - entity=entity, - timestamp=start_time, - end_time=end_time, - ) + namespace = namespace if namespace is not None else self.namespace + match self.AD.plugins.get_plugin_object(namespace): + case HassPlugin() as plugin: + coro = plugin.get_logbook( + entity=entity, + timestamp=start_time, + end_time=end_time, + ) - if callback is not None and callable(callback): - self.create_task(coro, callback) - else: - return await coro + if callback is not None and callable(callback): + self.create_task(coro, callback) + else: + return await coro + case _: + self.logger.warning("HASS plugin not found in namespace '%s'", namespace) # Input Helpers @@ -1005,7 +1004,7 @@ async def press_button(self, button_id: str, namespace: str | None = None) -> di namespace=namespace, ) - def last_pressed(self, button_id: str, namespace: str | None = None) -> datetime: + def last_pressed(self, button_id: str, namespace: str | None = None) -> datetime | None: """Only works on entities in the input_button domain""" assert button_id.split('.')[0] == 'input_button' state = self.get_state(button_id, namespace=namespace) @@ -1017,14 +1016,17 @@ def last_pressed(self, button_id: str, namespace: str | None = None) -> datetime case _: self.logger.warning(f'Unknown time: {state}') - def time_since_last_press(self, button_id: str, namespace: str | None = None) -> timedelta: + def time_since_last_press(self, button_id: str, namespace: str | None = None) -> timedelta | None: """Only works on entities in the input_button domain""" - return self.get_now() - self.last_pressed(button_id, namespace) + match self.last_pressed(button_id, namespace): + case datetime() as dt: + return self.get_now() - dt + case _: + self.logger.warning("Unknown last pressed time for %s", button_id) # # Notifications # - @utils.sync_decorator async def notify( self, @@ -1068,27 +1070,29 @@ async def notify( ) @utils.sync_decorator - async def persistent_notification(self, message: str, title=None, id=None) -> None: - kwargs = {"message": message} + async def persistent_notification(self, message: str, title: str | None = None, id: int | None = None) -> None: + kwargs: dict[str, Any] = {"message": message} if title is not None: kwargs["title"] = title if id is not None: kwargs["notification_id"] = id await self.call_service("persistent_notification/create", **kwargs) - @overload def notify_android( self, device: str, - tag: str, - title: str, - message: str, - target: str, - **data - ) -> dict: ... - - def notify_android(self, device: str, tag: str = 'appdaemon', **kwargs) -> dict: + tag: str = 'appdaemon', + title: str | None = None, + message: str | None = None, + target: str | None = None, + **kwargs: Any, + ) -> dict: """Convenience method for quickly creating mobile Android notifications""" + kwargs.update({ + 'title': title, + 'message': message, + 'target': target, + }) return self._notify_mobile_app(device, AndroidData, tag, **kwargs) def notify_ios(self, device: str, tag: str = 'appdaemon', **kwargs) -> dict: @@ -1098,20 +1102,24 @@ def notify_ios(self, device: str, tag: str = 'appdaemon', **kwargs) -> dict: def _notify_mobile_app( self, device: str, - model: str | Type[NotificationData], + type_: str | Type[NotificationData], tag: str = 'appdaemon', **kwargs ) -> dict: - match model: + match type_: case NotificationData(): pass case 'android': model = AndroidData case 'iOS' | 'ios': model = iOSData + case _: + raise ValueError(f'Unknown model type: {type_}') model = model.model_validate(kwargs) - model.data.tag = model.data.tag or tag # Fills in the tag if it's blank + if model.data is not None: + # Fills in the tag if it's blank + model.data.tag = model.data.tag or tag return self.call_service( service=f'notify/mobile_app_{device}', **model.model_dump(mode='json', exclude_none=True, by_alias=True) @@ -1134,7 +1142,8 @@ def android_tts( tts_text (str): String of text to translate into speech media_stream (optional): Defaults to ``music_stream``. critical (bool, optional): Defaults to False. If set to ``True``, the notification will use the correct - settings to have the TTS at the maximum possible volume. For more information see `Critical Notifications `_ + settings to have the TTS at the maximum possible volume. For more information see + `Critical Notifications `_ """ return self.call_service( **AndroidNotification.tts(device, tts_text, media_stream, critical).to_service_call() @@ -1145,65 +1154,159 @@ def listen_notification_action(self, callback: Callable, action: str) -> str: # Backup/Restore - @overload - def backup_full( + @utils.sync_decorator + async def backup_full( self, name: str | None = None, password: str | None = None, - compressed: bool = True, + compressed: bool | None = None, location: str | None = None, - homeassistant_exclude_database: bool = False, - timeout: int | float = 30 # Used by sync_decorator - ): ... + homeassistant_exclude_database: bool | None = None, + timeout: str | int | float = 30, # Used by sync_decorator + hass_timeout: str | int | float = 10, + ) -> dict: + """Create a full backup. - @utils.sync_decorator - async def backup_full(self, name=None, timeout: int | float = 30, **kwargs) -> dict: - # https://www.home-assistant.io/integrations/hassio/#action-hassiobackup_full - return await self.call_service("hassio/backup_full", name=name, **kwargs) + Action `hassio.backup_full `_ - @overload + Args: + name (str, optional): By default, the current date and time are used in your local time, which you have set in your general settings. + password (str, optional): Optional password for backup. + compressed (bool, optional): False to create uncompressed backups. + location (str, optional): Alternate backup location instead of using the default location for backups. + homeassistant_exclude_database (bool, optional): Exclude the Home Assistant database file from backup. + timeout (str | int | float, optional): Timeout for the app thread to wait for a response from the main + thread. + hass_timeout (str | int | float, optional): Timeout for AppDaemon waiting on a response from Home Assistant + to respond to the backup request. Cannot be set lower than the timeout value. + + Returns: + dict: Response from the backup service. + """ + return await self.call_service( + "hassio/backup_full", + name=name, + password=password, + compressed=compressed, + location=location, + homeassistant_exclude_database=homeassistant_exclude_database, + hass_timeout=max(timeout, hass_timeout), + ) + + @utils.sync_decorator async def backup_partial( self, - addons: Iterable[str] = None, - folders: Iterable[str] = None, - name: str = None, - password: str = None, - compressed: bool = True, - location: str = None, - homeassistant: bool = False, - homeassistant_exclude_database: bool = False, - timeout: int | float = 30 # Used by sync_decorator - ): ... + addons: Iterable[str] | None = None, + folders: Iterable[str] | None = None, + name: str | None = None, + password: str | None = None, + compressed: bool | None = None, + location: str | None = None, + homeassistant: bool | None = None, + homeassistant_exclude_database: bool | None = None, + timeout: str | int | float = 30, # Used by sync_decorator + hass_timeout: str | int | float = 10, + ) -> dict: + """Create a partial backup. - @utils.sync_decorator - async def backup_partial(self, name=None, timeout: int | float = 30, **kwargs) -> dict: - # https://www.home-assistant.io/integrations/hassio/#action-hassiobackup_partial - return await self.call_service("hassio/backup_partial", name=name, **kwargs) + Action `hassio.backup_partial `_ + + Args: + addons (Iterable[str], optional): List of add-on slugs to backup. + folders (Iterable[str], optional): List of directories to backup. + name (str, optional): Name of the backup file. Default is the current date and time in the user's local time. + password (str, optional): Optional password for backup. + compressed (bool, optional): False to create uncompressed backups. Defaults to True. + location (str, optional): Alternate backup location instead of using the default location for backups. + homeassistant (bool, optional): Include Home Assistant and associated config in backup. Defaults to False. + homeassistant_exclude_database (bool, optional): Exclude the Home Assistant database file from backup. + Defaults to False. + timeout (str | int | float, optional): Timeout for the app thread to wait for a response from the main + thread. + hass_timeout (str | int | float, optional): Timeout for AppDaemon waiting on a response from Home Assistant + to respond to the backup request. Cannot be set lower than the timeout value. + + Returns: + dict: Response from the backup service. + """ + return await self.call_service( + "hassio/backup_partial", + name=name, + addons=addons, + folders=folders, + password=password, + compressed=compressed, + location=location, + homeassistant=homeassistant, + homeassistant_exclude_database=homeassistant_exclude_database, + hass_timeout=max(timeout, hass_timeout), + ) @utils.sync_decorator async def restore_full( self, slug: str, password: str | None = None, - timeout: int | float = 30 # Used by sync_decorator + timeout: str | int | float = 30, # Used by sync_decorator + hass_timeout: str | int | float = 10, ) -> dict: - # https://www.home-assistant.io/integrations/hassio/#action-hassiorestore_full - return await self.call_service("hassio/restore_full", slug=slug, password=password) + """Restore from full backup. - @overload - async def restore_parial( + Action `hassio.restore_full `_ + + Args: + slug (str): Slug of backup to restore from. + password (str, optional): Optional password for backup. + timeout (str | int | float, optional): Timeout for the app thread to wait for a response from the main + thread. + hass_timeout (str | int | float, optional): Timeout for AppDaemon waiting on a response from Home Assistant + to respond to the backup request. Cannot be set lower than the timeout value. + """ + return await self.call_service( + "hassio/restore_full", + slug=slug, + password=password, + hass_timeout=max(timeout, hass_timeout), + ) + + @utils.sync_decorator + async def restore_partial( self, slug: str, - homeassistant: bool = False, - addons: Iterable[str] = None, - folders: Iterable[str] = None, - password: str = None, - timeout: int | float = 30 # Used by sync_decorator - ): ... + homeassistant: bool | None = None, + addons: Iterable[str] | None = None, + folders: Iterable[str] | None = None, + password: str | None = None, + timeout: str | int | float = 30, # Used by sync_decorator + hass_timeout: str | int | float = 10, + ) -> dict: + """Restore from partial backup. + + Action `hassio.restore_partial `_ - async def restore_parial(self, slug: str, timeout: int | float = 30, **kwargs) -> dict: - # https://www.home-assistant.io/integrations/hassio/#action-hassiorestore_partial - return await self.call_service("hassio/restore_parial", slug=slug, **kwargs) + Args: + slug (str): Slug of backup to restore from. + homeassistant (bool, optional): Whether to restore Home Assistant, true or false. Defaults to False. + addons (Iterable[str], optional): List of add-on slugs to restore. + folders (Iterable[str], optional): List of directories to restore. + password (str, optional): Optional password for backup. + timeout (str | int | float, optional): Timeout for the app thread to wait for a response from the main + thread. + hass_timeout (str | int | float, optional): Timeout for AppDaemon waiting on a response from Home Assistant + to respond to the backup request. Cannot be set lower than the timeout value. + + Returns: + dict: Response from the restore service. + """ + return await self.call_service( + "hassio/restore_partial", + slug=slug, + homeassistant=homeassistant, + addons=addons, + folders=folders, + password=password, + hass_timeout=max(timeout, hass_timeout), + ) # Media @@ -1251,10 +1354,10 @@ def get_calendar_events( self, entity_id: str = "calendar.localcalendar", days: int = 1, - hours: int = None, - minutes: int = None, + hours: int | None = None, + minutes: int | None = None, namespace: str | None = None - ) -> list[dict[str, str | datetime]]: + ) -> list[dict[str, str | datetime]] | None: """ Retrieve calendar events for a specified entity within a given number of days. @@ -1292,14 +1395,17 @@ def get_calendar_events( entity_id=entity_id, duration=duration, ) - if isinstance(res, dict) and res['success']: - return [ - { - k: datetime.fromisoformat(v) if k in ('start', 'end') else v - for k, v in event.items() - } - for event in res['result']['response'][entity_id]['events'] - ] + match res: + case {"success": True, "result": {"response": resp}}: + return [ + { + k: datetime.fromisoformat(v) if k in ('start', 'end') else v + for k, v in event.items() + } + for event in resp[entity_id]['events'] + ] + case _: + self.logger.error("Failed to get calendar events for '%s'", entity_id) # Scripts @@ -1340,8 +1446,8 @@ def run_script( service = f'{domain}/{script_name}' service_data = kwargs + namespace = namespace if namespace is not None else self.namespace try: - namespace = namespace or self.namespace return self.call_service( service, namespace, entity_id=entity_id, @@ -1385,20 +1491,22 @@ async def render_template(self, template: str, namespace: str | None = None, **k hello bob """ - plugin: "HassPlugin" = self.AD.plugins.get_plugin_object( - namespace or self.namespace - ) - result = await plugin.render_template(self.namespace, template, **kwargs) - try: - return literal_eval(result) - except (SyntaxError, ValueError): - return result - - def _template_command(self, command: str, *args: str) -> str | list[str]: + namespace = namespace if namespace is not None else self.namespace + match self.AD.plugins.get_plugin_object(namespace): + case HassPlugin() as plugin: + result = await plugin.render_template(self.namespace, template, **kwargs) + if result is not None: + try: + return literal_eval(result) + except (SyntaxError, ValueError): + return result + + def _template_command(self, command: str, *args: str) -> Any: """Internal AppDaemon function to format calling a single template command correctly.""" if len(args) == 0: return self.render_template(f'{{{{ {command}() }}}}') else: + args = tuple(a for a in args if a is not None) assert all(isinstance(i, str) for i in args), f"All inputs must be strings, got {args}" arg_str = ', '.join(f"'{i}'" for i in args) cmd_str = f'{{{{ {command}({arg_str}) }}}}' @@ -1439,7 +1547,7 @@ def is_device_attr(self, device_or_entity_id: str, attr_name: str, attr_value: s See `device functions `_ for more information. """ - return self._template_command('is_device_attr', device_or_entity_id, attr_name, attr_value) + return self._template_command('is_device_attr', device_or_entity_id, attr_name, str(attr_value)) def device_id(self, entity_id: str) -> str: """Get the device ID for a given entity ID or device name. @@ -1505,13 +1613,13 @@ def integration_entities(self, integration: str) -> list[str]: # Labels # https://www.home-assistant.io/docs/configuration/templating/#labels - def labels(self, input: str = None) -> list[str]: + def labels(self) -> list[str]: """Get the full list of label IDs, or those for a given area ID, device ID, or entity ID. See `label functions `_ for more information. """ - return self._template_command('labels', input) + return self._template_command('labels') def label_id(self, lookup_value: str) -> str: """Get the label ID for a given label name. diff --git a/appdaemon/plugins/hass/hassplugin.py b/appdaemon/plugins/hass/hassplugin.py index 8e6aeadd4..ec51e3a94 100644 --- a/appdaemon/plugins/hass/hassplugin.py +++ b/appdaemon/plugins/hass/hassplugin.py @@ -3,17 +3,18 @@ """ import asyncio -import datetime import functools import json import ssl +from collections.abc import AsyncGenerator, Iterable from copy import deepcopy from dataclasses import dataclass, field +from datetime import datetime, timedelta from time import perf_counter from typing import Any, Literal, Optional import aiohttp -from aiohttp import ClientResponse, WSMsgType +from aiohttp import ClientResponse, ClientResponseError, RequestInfo, WSMsgType from pydantic import BaseModel import appdaemon.utils as utils @@ -107,55 +108,72 @@ async def stop(self): self.logger.debug("aiohttp session closed for '%s'", self.name) def create_session(self) -> aiohttp.ClientSession: - """Handles creating an ``aiohttp.ClientSession`` with the cert information from the plugin config - and the authorization headers for the REST API. + """Handles creating an :py:class:`~aiohttp.ClientSession` with the cert information from the plugin config + and the authorization headers for the `REST API `_. """ if self.config.cert_path is not None: ssl_context = ssl.create_default_context(capath=self.config.cert_path) conn = aiohttp.TCPConnector(ssl_context=ssl_context, verify_ssl=self.config.cert_verify) else: conn = aiohttp.TCPConnector(ssl=False) + + connect_timeout_secs = self.config.connect_timeout.total_seconds() return aiohttp.ClientSession( connector=conn, headers=self.config.auth_headers, json_serialize=utils.convert_json, - conn_timeout=self.config.connect_timeout.total_seconds(), + timeout=aiohttp.ClientTimeout( + connect=connect_timeout_secs, + sock_connect=connect_timeout_secs, + ) ) - async def websocket_msg_factory(self): + async def websocket_msg_factory(self) -> AsyncGenerator[aiohttp.WSMessage]: """Async generator that yields websocket messages. - Handles creating the connection based on the HASSConfig and updates the performance counters + Uses :py:meth:`~HassPlugin.create_session` and :py:meth:`~aiohttp.ClientSession.ws_connect` to connect to Home + Assistant. + + See the :py:ref:`aiohttp websockets documentation ` for more information. + + Yields: + aiohttp.WSMessage: Incoming messages on the websocket connection """ self.start = perf_counter() async with self.create_session() as self.session: try: async with self.session.ws_connect(self.config.websocket_url) as self.ws: - self.id = 0 async for msg in self.ws: - self.update_perf(bytes_recv=len(msg.data), updates_recv=1) + self.updates_recv += 1 + self.bytes_recv += len(msg.data) yield msg finally: self.connect_event.clear() async def match_ws_msg(self, msg: aiohttp.WSMessage) -> dict: - """Wraps a match/case statement for the ``msg.type``""" - msg_json = msg.json() - match msg.type: - case WSMsgType.TEXT: + """Uses a :py:ref:`match ` statement on :py:class:`~aiohttp.WSMessage`. + + Uses :py:meth:`~HassPlugin.process_websocket_json` on :py:attr:`~aiohttp.WSMsgType.TEXT` messages. + """ + match msg: + case aiohttp.WSMessage(type=WSMsgType.TEXT): # create a separate task for processing messages to keep the message reading task unblocked - self.AD.loop.create_task(self.process_websocket_json(msg_json)) - case WSMsgType.ERROR: - self.logger.error("Error from aiohttp websocket: %s", msg_json) - case WSMsgType.CLOSE: + self.AD.loop.create_task(self.process_websocket_json(msg.json())) + case aiohttp.WSMessage(type=WSMsgType.ERROR): + self.logger.error("Error from aiohttp websocket: %s", msg.json()) + case aiohttp.WSMessage(type=WSMsgType.CLOSE): self.logger.debug("Received %s message", msg.type) case _: - self.logger.error("Unhandled websocket message type: %s", msg.type) - return msg_json + self.logger.warning("Unhandled websocket message type: %s", msg.type) + return msg.json() @utils.warning_decorator(error_text="Error during processing jSON", reraise=True) async def process_websocket_json(self, resp: dict[str, Any]) -> None: - """Wraps a match/case statement around the JSON received from the websocket""" + """Uses a :py:ref:`match ` statement around the JSON received from the websocket. + + It handles both authorization and routing the responses to :py:meth:`~HassPlugin.receive_event` and + :py:meth:`~HassPlugin.receive_result`. + """ match resp: case {"type": "auth_required", "ha_version": ha_version}: self.logger.info("Connected to Home Assistant %s with aiohttp websocket", ha_version) @@ -183,6 +201,7 @@ async def process_websocket_json(self, resp: dict[str, Any]) -> None: async def __post_conn__(self) -> None: """Initialization to do after getting connected to the Home Assistant websocket""" self.connect_event.set() + self.id = 0 await self.websocket_send_json(**self.config.auth_json) async def __post_auth__(self) -> None: @@ -320,7 +339,7 @@ async def receive_event(self, event: dict[str, Any]) -> None: @utils.warning_decorator(error_text="Unexpected error during websocket send") async def websocket_send_json( self, - timeout: str | int | float | datetime.timedelta | None = None, + timeout: str | int | float | timedelta | None = None, *, # Arguments after this are keyword-only silent: bool = False, **request: Any, @@ -331,7 +350,7 @@ async def websocket_send_json( The `id` parameter is handled automatically and is used to match the response to the request. Args: - timeout (str | int | float | datetime.timedelta, optional): Length of time to wait for a response from Home + timeout (str | int | float | timedelta, optional): Length of time to wait for a response from Home Assistant with a matching `id`. Defaults to the value of the `ws_timeout` setting in the plugin config. silent (bool, optional): If set to `True`, the method will not log the request or response. Defaults to `False`. @@ -340,7 +359,8 @@ async def websocket_send_json( Returns: A dict containing the response from Home Assistant. """ - request = dict(utils.clean_kwargs(**request)) + request = utils.clean_kwargs(request) + request = utils.remove_literals(request, (None,)) if not self.connect_event.is_set(): self.logger.debug("Not connected to websocket, skipping JSON send.") @@ -393,7 +413,7 @@ async def websocket_send_json( ad_status = ServiceCallStatus.TERMINATING result = {"success": False} if not silent: - self.logger.warning(f"AppDaemon started shut down while waiting for the response from the request: {request}") + self.logger.warning(f"AppDaemon cancelled waiting for the response from the request: {request}") else: ad_status = ServiceCallStatus.OK @@ -406,27 +426,20 @@ async def http_method( self, method: Literal["get", "post", "delete"], endpoint: str, - timeout: str | int | float | datetime.timedelta | None = 10, + timeout: str | int | float | timedelta | None = 10, **kwargs: Any, - ) -> str | dict[str, Any] | list[Any] | ClientResponse | None: - """ - - https://developers.home-assistant.io/docs/api/rest + ) -> str | dict[str, Any] | list[Any] | aiohttp.ClientResponseError | None: + """Wrapper for making HTTP requests to Home Assistant's + `REST API `_. Args: - typ (Literal['get', 'post', 'delete']): Type of HTTP method to use + method (Literal['get', 'post', 'delete']): HTTP method to use. endpoint (str): Home Assistant REST endpoint to use. For example '/api/states' timeout (float, optional): Timeout for the method in seconds. Defaults to 10s. - **kwargs (optional): Zero or more keyword arguments. These get used as the data - for the method, as appropriate. - - Raises: - NotImplementedError: _description_ - - Returns: - dict | None: _description_ + **kwargs (optional): Zero or more keyword arguments. These get used as the data for the method, as + appropriate. """ - kwargs = dict(utils.clean_http_kwargs(**kwargs)) + kwargs = utils.clean_http_kwargs(kwargs) url = utils.make_endpoint(self.config.ha_url, endpoint) try: @@ -451,24 +464,25 @@ async def http_method( async with http_method(url=url, timeout=client_timeout) as resp: self.logger.debug(f"HTTP {method.upper()} {resp.url}") self.update_perf(bytes_recv=resp.content_length, updates_recv=1) - match resp.status: - case 200 | 201: - if endpoint.endswith("template"): - return await resp.text() - else: - return await resp.json() - case 400 | 401 | 403 | 404 | 405: - try: - msg = (await resp.json())["message"] - except Exception: - msg = await resp.text() - self.logger.error(f"Bad response from {url}: {msg}") - case 500 | 502: - text = await resp.text() - self.logger.error("Internal server error %s: %s", url, text) - case _: - raise NotImplementedError("Unhandled error: HTTP %s", resp.status) - return resp + try: + resp.raise_for_status() + except aiohttp.ClientResponseError as cre: + self.logger.error("[%d] HTTP %s: %s %s", cre.status, method.upper(), cre.message, kwargs) + return cre + else: + match resp: + case ClientResponse( + content_type=content_type, + request_info=RequestInfo(url=url, method=str(meth)) + ): + self.logger.debug("%s success from %s", meth, url) + match content_type: + case "application/json": + return await resp.json() + case "text/plain": + return await resp.text() + case _: + self.logger.warning("Unhandled content type: %s", content_type) except asyncio.TimeoutError: self.logger.error("Timed out waiting for %s", url) except asyncio.CancelledError: @@ -528,6 +542,14 @@ async def wait_for_conditions(self, conditions: StartupConditions | None) -> Non await asyncio.wait(tasks) async def get_updates(self): + """Main function for running the HASS plugin. + + Combines :py:meth:`~HassPlugin.websocket_msg_factory` with :py:meth:`~HassPlugin.match_ws_msg` to process + websocket messages as they come in. This happens in a while loop that breaks on AppDaemon's internal stop event. + + This uses the :py:meth:`~appdaemon.utility_loop.Utility.sleep` utility method between retries if the connection + fails. + """ while not self.AD.stopping: try: async for msg in self.websocket_msg_factory(): @@ -535,8 +557,8 @@ async def get_updates(self): continue raise ValueError except Exception as exc: - self.error.error(exc) if not self.AD.stopping: + self.error.error(exc) self.logger.info("Attempting reconnection in %s", utils.format_timedelta(self.config.retry_secs)) if self.is_ready: # Will only run the first time through the loop after a failure @@ -750,7 +772,7 @@ async def fire_plugin_event( self, event: str, namespace: str, - timeout: str | int | float | datetime.timedelta | None = None, + timeout: str | int | float | timedelta | None = None, **kwargs: Any, ) -> dict[str, Any] | None: # fmt: skip # if we get a request for not our namespace something has gone very wrong @@ -786,10 +808,19 @@ async def safe_delete(self: "HassPlugin"): @utils.warning_decorator(error_text="Unexpected error while getting hass state") async def get_complete_state(self) -> dict[str, dict[str, Any]] | None: - """This method is needed for all AppDaemon plugins""" + """Required method for all AppDaemon plugins. + + Uses the ``/api/states`` endpoint of the `REST API `_ to + get an array of state objects. Each state has the following attributes: `entity_id`, `state`, `last_changed` and + `attributes`. + + The API natively returns the result as a list of dicts, but this turns the result into a single dict based on + `entity_id` to match what AppDaemon needs from this method. + """ resp = await self.websocket_send_json(type="get_states") match resp: - case {"result": hass_state, "success": True}: + case {"success": True, "result": hass_state}: + self.logger.debug(f"Received {len(hass_state):,} states") return {s["entity_id"]: s for s in hass_state} case _: return # websocket_send_json will log warnings if something happens on the AD side @@ -818,27 +849,35 @@ async def safe_set_state(self: "HassPlugin"): async def get_plugin_state( self, entity_id: str, - timeout: str | int | float | datetime.timedelta | None = 5, + timeout: str | int | float | timedelta | None = 5, ) -> dict | None: resp = await self.http_method("get", f"/api/states/{entity_id}", timeout) match resp: - case ClientResponse(): - self.logger.error("Error getting state") - case dict() | None: + case ClientResponseError(message=str(msg)): + self.logger.error("Error getting state: %s", msg) + case (dict() | None): return resp case _: raise ValueError(f"Unexpected result from get_plugin_state: {resp}") + @utils.warning_decorator(error_text="Unexpected error checking for entity") async def check_for_entity( self, entity_id: str, - timeout: str | int | float | datetime.timedelta | None = 5, + timeout: str | int | float | timedelta | None = 5, *, # Arguments after this are keyword-only local: bool = False, ) -> dict | Literal[False]: - """Try to get the state of an entity ID to see if it exists. + """Checks for the state of an entity to see if it exists. - Returns a dict of the state if the entity exists. Otherwise returns False""" + Args: + entity_id: Entity ID of the entity to check for + timeout: Timeout for the request to the REST API if local is `False` + local: If `True`, this will check for the entity in the local state instead of using the REST API. Defaults + to `False`. + + Returns: + dict | Literal[False]: dict of the state if the entity exists, otherwise `False`""" if local: resp = self.AD.state.state.get(self.namespace, {}).get(entity_id, False) else: @@ -853,14 +892,30 @@ async def check_for_entity( @utils.warning_decorator(error_text="Unexpected error getting history") async def get_history( self, - filter_entity_id: str | list[str], - timestamp: datetime.datetime | None = None, - end_time: datetime.datetime | None = None, + filter_entity_id: str | Iterable[str], + timestamp: datetime | None = None, + end_time: datetime | None = None, minimal_response: bool | None = None, no_attributes: bool | None = None, significant_changes_only: bool | None = None, ) -> list[list[dict[str, Any]]] | None: - """Used to get HA's History""" + """Returns an array of state changes using the ``/api/history/period`` endpoint of the + `REST API `_. Each object contains further details for the + entities. + + Args: + filter_entity_id (str, Iterable[str]): Filter on one or more entities. + timestamp (datetime, optional): Determines the beginning of the period. Defaults to 1 day before the time of + the request. + end_time (datetime, optional): + minimal_response (bool, optional): Only return last_changed and state for states other than the first and + last state (much faster). Defaults to `False` + no_attributes (bool, optional): Skip returning attributes from the database (much faster). + significant_changes_only (bool, optional): Only return significant state changes. + + Returns: + list[list[dict[str, Any]]]: List of history lists for each entity. + """ if isinstance(filter_entity_id, str): filter_entity_id = [filter_entity_id] filter_entity_id = ",".join(filter_entity_id) @@ -880,12 +935,8 @@ async def get_history( ) match result: - case ClientResponse(): - # This means that HA rejected the request - error_text = (await result.json()).get("message", "Unknown") - if error_text == "Invalid filter_entity_id": - error_text += f" '{filter_entity_id}'" - self.logger.error("Error getting history: %s", error_text) + case ClientResponseError(message=str(msg)): + self.logger.error("Error getting history: %s", msg) case list(): # nested comprehension to convert the datetimes for convenience return [ @@ -893,7 +944,6 @@ async def get_history( { k: ( datetime - .datetime .fromisoformat(v) .astimezone(self.AD.tz) ) if k.startswith("last_") else v @@ -910,42 +960,53 @@ async def get_history( async def get_logbook( self, entity: str | None = None, - timestamp: datetime.datetime | None = None, - end_time: datetime.datetime | None = None, - ) -> list[dict[str, str | datetime.datetime]] | None: - """Used to get HA's logbook""" + timestamp: datetime | None = None, + end_time: datetime | None = None, + ) -> list[dict[str, str | datetime]] | None: + """Returns an array of logbook entries using the ``/api/logbook/`` endpoint of the + `REST API `_ + + Args: + timestamp (datetime, optional): Determines the beginning of the period. Defaults to 1 day before the time of + the request. + entity (str, optional): Filter on one entity. + end_time (datetime, optional): Choose the end of period starting from the `timestamp` + """ endpoint = "/api/logbook" if timestamp is not None: endpoint += f"/{timestamp.isoformat()}" - if entity is not None: - assert await self.check_for_entity(entity_id=entity), f"'{entity}' does not exist" - - result = await self.http_method("get", endpoint, entity=entity, end_time=end_time) - - match result: - case ClientResponse(): - # This means that HA rejected the request - error_text = (await result.json()).get("message", "Unknown") - if error_text == "Invalid filter_entity_id": - error_text += f" '{entity}'" - self.logger.error("Error getting history: %s", error_text) + resp = await self.http_method("get", endpoint, entity=entity, end_time=end_time) + match resp: case list(): return [ { k: v if k != "when" else ( datetime - .datetime .fromisoformat(v) .astimezone(self.AD.tz) ) for k, v in entry.items() } - for entry in result + for entry in resp ] # fmt: skip + case ClientResponseError(status=500): + self.logger.error("Error getting logbook for '%s', it might not exist.", entity) + case ClientResponseError(message=str(msg)): + self.logger.error("Error getting logbook for '%s': %s", entity, msg) + case _: + self.logger.error("Unexpected error getting logbook: %s", resp) @utils.warning_decorator(error_text="Unexpected error rendering template") - async def render_template(self, namespace: str, template: str, **kwargs): + async def render_template(self, namespace: str, template: str, **kwargs) -> str | None: + """Render the template using the ``/api/template`` endpoint of the + `REST API `_. + + See the `template docs `_ for more information. + + If successful, this returns a str of the raw response. It should still be processed downstream with + :py:func:`~ast.literal_eval`, which will turn the result into its real type. + """ self.logger.debug( "render_template() namespace=%s data=%s", namespace, @@ -954,4 +1015,11 @@ async def render_template(self, namespace: str, template: str, **kwargs): # if we get a request for not our namespace something has gone very wrong assert namespace == self.namespace - return await self.http_method("post", "/api/template", template=template, **kwargs) + resp = await self.http_method("post", "/api/template", template=template, **kwargs) + match resp: + case str(): + return resp + case ClientResponseError(message=str(msg)): + self.logger.error("Error rendering template: %s", msg) + case _: + raise ValueError(f"Unexpected result from render_template: {resp}") diff --git a/appdaemon/utils.py b/appdaemon/utils.py index e6211fda6..720f440ff 100644 --- a/appdaemon/utils.py +++ b/appdaemon/utils.py @@ -16,7 +16,7 @@ import sys import threading import traceback -from collections.abc import Awaitable, Generator, Iterable, Mapping +from collections.abc import Awaitable, Generator, Iterable, Mapping, Sequence from datetime import datetime, time, timedelta, tzinfo from functools import wraps from logging import Logger @@ -58,6 +58,8 @@ OFFSET_SPLIT_REGEX = re.compile(r"\s*?[+-]\s*?") +T = TypeVar("T") + def has_offset(time_str: str) -> bool: """Check if a time string has an offset. @@ -1072,58 +1074,49 @@ def time_str(start: float, now: float | None = None) -> str: return format_timedelta((now or perf_counter()) - start) -def clean_kwargs(**kwargs: Any) -> Generator[tuple[str, Any]]: +def clean_kwargs(val: Any, *, http: bool = False) -> Any: """Recursively clean a dict of kwargs. Conversions: - - None values are removed - datetime values are converted to ISO format strings - - bool values are converted to lowercase strings - - int, float, and str values are converted to strings - - Iterable values (like lists and tuples) are converted to lists of cleaned values - Mapping values (like dicts) are converted to dicts of cleaned key-value pairs + - Iterable values (like lists and tuples) are converted to lists of cleaned values + - Other values are converted to strings """ - def _clean_value(val: bool | datetime | Any) -> str | int | float | bool: - match val: - case bool(): - return val - case str() | int() | float() | bool(): - return val - case datetime(): - return val.isoformat() - case _: - return str(val) - - for key, val in kwargs.items(): - match val: - case None: - continue - case str(): - # This case needs to be before the Iterable case because strings are iterable - yield key, val - case Mapping(): - # This case needs to be before the Iterable case because Mappings like dicts are iterable - yield key, dict(clean_kwargs(**val)) - case Iterable(): - yield key, list(map(_clean_value, val)) - case _: - yield key, _clean_value(val) + match val: + case True if http: + return "true" + case str() | int() | float() | bool() | None: + return val + case datetime(): + return val.isoformat() + case Mapping(): + return {k: clean_kwargs(v, http=http) for k, v in val.items()} + case Iterable(): + return [clean_kwargs(v, http=http) for v in val] + case _: + return str(val) + + +def remove_literals(val: Any, literal: Sequence[Any]) -> Any: + """Remove instances of literals from a nested data structure.""" + match val: + case str(): + return val + case Mapping(): + return {k: remove_literals(v, literal) for k, v in val.items() if v not in literal} + case Iterable(): + return [remove_literals(v, literal) for v in val if v not in literal] + case _: + return val -def clean_http_kwargs(**kwargs: Any) -> Generator[tuple[str, Any]]: +def clean_http_kwargs(val: Any) -> Any: """Recursively cleans the kwarg dict to prepare it for use in HTTP requests.""" - for key, val in clean_kwargs(**kwargs): - match val: - case None: - continue # filter None values - case False | "false": - # Filter out values that are False - continue - case True: - yield key, "true" - case _: - yield key, val + cleaned = clean_kwargs(val, http=True) + pruned = remove_literals(cleaned, (None, False)) + return pruned def make_endpoint(base: str, endpoint: str) -> str: diff --git a/docs/HISTORY.md b/docs/HISTORY.md index 1949214a4..c09326d7c 100644 --- a/docs/HISTORY.md +++ b/docs/HISTORY.md @@ -15,10 +15,13 @@ - Bumped versions in CI pipeline - uv version - Docker build/push version -- Improved error messages for failed connections to Home Assistant +- Improved error messages + - for failed connections to Home Assistant + - for failed HTTP requests to Home Assistant - Improved error messages for custom plugins - Parsing various timedeltas in config with `utils.parse_timedelta` - Add callback argument to Dashboard's call_service - contributed by [psolyca](https://github.com/psolyca) +- Added docstrings to `HassPlugin` methods and added it to the reference in the docs. **Fixes** diff --git a/tests/unit/test_kwarg_clean.py b/tests/unit/test_kwarg_clean.py index dc28b5e5a..7fbadeaa0 100644 --- a/tests/unit/test_kwarg_clean.py +++ b/tests/unit/test_kwarg_clean.py @@ -3,7 +3,7 @@ import pytest import pytz -from appdaemon.utils import clean_http_kwargs, clean_kwargs +from appdaemon.utils import clean_http_kwargs, clean_kwargs, remove_literals pytestmark = [ pytest.mark.ci, @@ -11,27 +11,83 @@ ] -BASE = {"a": 1, "b": 2.0, "c": "three", "d": True, "e": False, "f": datetime(2025, 9, 22, 12, 0, 0, tzinfo=pytz.utc), "g": None} +BASE = { + "a": 1, + "b": 2.0, + "c": "three", + "d": True, + "e": False, + "f": datetime(2025, 9, 22, 12, 0, 0, tzinfo=pytz.utc), + "g": None +} def test_clean_kwargs(): - cleaned = dict(clean_kwargs(**BASE)) + cleaned = clean_kwargs(BASE) + pruned = remove_literals(BASE, (None,)) assert isinstance(cleaned["f"], str) + assert cleaned["a"] == 1 + assert cleaned["b"] == 2.0 + assert cleaned["c"] == "three" assert cleaned["d"] is True assert cleaned["e"] is False - assert "g" not in cleaned + assert "g" not in pruned kwargs = deepcopy(BASE) kwargs["nested"] = deepcopy(BASE) kwargs["nested"]["extra"] = deepcopy(BASE) - cleaned = dict(clean_kwargs(**kwargs)) + cleaned = clean_kwargs(kwargs) assert isinstance(cleaned["nested"]["extra"]["f"], str) def test_clean_http_kwargs(): - cleaned = dict(clean_http_kwargs(**BASE)) + cleaned = clean_http_kwargs(BASE) assert isinstance(cleaned["f"], str) + assert cleaned["d"] == "true" assert "e" not in cleaned assert "g" not in cleaned + + +SERVICE_CALL = { + 'type': 'call_service', + 'domain': 'notify', + 'service': 'mobile_app_pixel_9a', + 'service_data': { + 'message': 'Phobos Initialized', + 'data': { + 'push': {'sound': {'name': 'Alert_Health_Haptic.caf', 'volume': 0.6, 'critical': 1}}, + 'tag': 'phobos-alert', + 'actions': [ + {'action': 'stop_alarms', 'title': 'Stop alarms'}, + {'action': 'silence', 'title': 'Silence'}, + ] + }, + } +} + + +def test_websocket_service_call_kwargs(): + cleaned = clean_kwargs(SERVICE_CALL) + match cleaned: + case { + "service_data": + {"data": + {"actions": list(actions), + "push": {"sound": {"volume": float(vol)}}, + }, + }, + }: + assert vol == 0.6 + for action in actions: + match action: + case {"action": str(), "title": str()}: + pass + case _: + assert False, "Action format incorrect" + case _: + assert False, "Action format incorrect" + + pruned = remove_literals(SERVICE_CALL, (None,)) + assert "timeout" not in pruned["service_data"]