diff --git a/appdaemon/models/config/plugin.py b/appdaemon/models/config/plugin.py index b689741f4..340fec91e 100644 --- a/appdaemon/models/config/plugin.py +++ b/appdaemon/models/config/plugin.py @@ -92,7 +92,7 @@ class HASSConfig(PluginConfig): enable_started_event: bool = True """If true, the plugin will wait for the 'homeassistant_started' event before starting the plugin.""" cert_path: CoercedPath | None = None - cert_verify: bool | None = None + cert_verify: bool = True commtype: str = "WS" q_timeout: int = 30 ws_timeout: Annotated[ diff --git a/appdaemon/plugins/hass/hassapi.py b/appdaemon/plugins/hass/hassapi.py index e624ec3e3..74bc1d142 100644 --- a/appdaemon/plugins/hass/hassapi.py +++ b/appdaemon/plugins/hass/hassapi.py @@ -58,10 +58,14 @@ def __init__(self, ad: AppDaemon, config_model: "AppConfig"): self.register_constraint("constrain_input_select") @utils.sync_decorator - async def ping(self) -> float: + async def ping(self) -> float | None: """Gets the number of seconds """ if (plugin := self._plugin) is not None: - return (await plugin.ping())['ad_duration'] + match await plugin.ping(): + case {"ad_status": "OK", "ad_duration": ad_duration}: + return ad_duration + case _: + return None @utils.sync_decorator async def check_for_entity(self, entity_id: str, namespace: str | None = None) -> bool: @@ -539,16 +543,14 @@ def get_service_info(self, service: str) -> dict | None: Returns: Information about the service in a dict with the following keys: ``name``, ``description``, ``target``, and ``fields``. - """ - if (plugin := self._plugin) is not None: - domain, service_name = service.split("/", 2) - for service_def in plugin.services: - if service_def.get("domain") == domain: - if (services := service_def.get("services")) is not None: - return deepcopy(services.get(service_name)) - else: - self.logger.warning("Service info not found for domain '%s", domain) + match self._plugin: + case HassPlugin() as plugin: + domain, service_name = service.split("/", 2) + if info := plugin.services.get(domain, {}).get(service_name): + # Return a copy of the info dict to prevent accidental modification + return deepcopy(info) + self.logger.warning("Service info not found for domain '%s", domain) # Methods that use self.call_service @@ -680,7 +682,7 @@ async def get_history( significant_changes_only: bool | None = None, callback: Callable | None = None, namespace: str | None = None, - ) -> list[list[dict[str, Any]]]: + ) -> list[list[dict[str, Any]]] | None: """Gets access to the HA Database. This is a convenience function that allows accessing the HA Database, so the history state of a device can be retrieved. It allows for a level of flexibility @@ -755,31 +757,27 @@ async def get_history( end_time = end_time or await self.get_now() start_time = end_time - timedelta(days=days) - - plugin: "HassPlugin" = self.AD.plugins.get_plugin_object( - namespace or self.namespace - ) - - if plugin is not None: - coro = plugin.get_history( - filter_entity_id=entity_id, - timestamp=start_time, - end_time=end_time, - minimal_response=minimal_response, - no_attributes=no_attributes, - significant_changes_only=significant_changes_only, - ) - - if callback is not None and callable(callback): - self.create_task(coro, callback) - else: - return await coro - - else: - self.logger.warning( - "Wrong Namespace selected, as %s has no database plugin attached to it", - namespace, - ) + plugin = self.AD.plugins.get_plugin_object(namespace or self.namespace) + match plugin: + case HassPlugin(): + coro = plugin.get_history( + filter_entity_id=entity_id, + timestamp=start_time, + end_time=end_time, + minimal_response=minimal_response, + no_attributes=no_attributes, + significant_changes_only=significant_changes_only, + ) + + if callback is not None and callable(callback): + self.create_task(coro, callback) + else: + return await coro + case _: + self.logger.warning( + "Wrong Namespace selected, as %s has no database plugin attached to it", + namespace, + ) @utils.sync_decorator async def get_logbook( diff --git a/appdaemon/plugins/hass/hassplugin.py b/appdaemon/plugins/hass/hassplugin.py index 24ed932b0..0d92f8f0c 100644 --- a/appdaemon/plugins/hass/hassplugin.py +++ b/appdaemon/plugins/hass/hassplugin.py @@ -4,9 +4,9 @@ import asyncio import datetime +import functools import json import ssl -from collections.abc import Iterable from copy import deepcopy from dataclasses import dataclass, field from time import perf_counter @@ -69,7 +69,14 @@ class HassPlugin(PluginBase): ws: aiohttp.ClientWebSocketResponse """websocket dedicated for event loop""" metadata: dict[str, Any] - services: list[dict[str, Any]] + services: dict[ + str, # Domain + dict[ + str, # Service name + dict[ + str, # Field name + Any # Field information + ]]] _result_futures: dict[int, asyncio.Future] _silent_results: dict[int, bool] @@ -85,7 +92,7 @@ def __init__(self, ad: "AppDaemon", name: str, config: HASSConfig): self.id = 0 self.metadata = {} - self.services = [] + self.services = {} self._result_futures = {} self._silent_results = {} self.startup_conditions = [] @@ -107,10 +114,11 @@ def create_session(self) -> aiohttp.ClientSession: """Handles creating an ``aiohttp.ClientSession`` with the cert information from the plugin config and the authorization headers for the REST API. """ - ssl_context = None if self.config.cert_verify else False - if self.config.cert_verify and self.config.cert_path: + if self.config.cert_path is not None: ssl_context = ssl.create_default_context(capath=self.config.cert_path) - conn = aiohttp.TCPConnector(ssl=ssl_context) + conn = aiohttp.TCPConnector(ssl_context=ssl_context, verify_ssl=self.config.cert_verify) + else: + conn = aiohttp.TCPConnector(ssl=False) return aiohttp.ClientSession( connector=conn, headers=self.config.auth_headers, @@ -147,49 +155,50 @@ async def match_ws_msg(self, msg: aiohttp.WSMessage) -> dict: return msg_json @utils.warning_decorator(error_text="Error during processing jSON", reraise=True) - async def process_websocket_json(self, resp: dict): - """Wraps a match/case statement for the ``type`` key of the JSON received from the websocket""" - match resp["type"]: - case "auth_required": - self.logger.info("Connected to Home Assistant %s with aiohttp websocket", resp["ha_version"]) + async def process_websocket_json(self, resp: dict[str, Any]) -> None: + """Wraps a match/case statement around the JSON received from the websocket""" + match resp: + case {"type": "auth_required", "ha_version": ha_version}: + self.logger.info("Connected to Home Assistant %s with aiohttp websocket", ha_version) + # Use await here so that nothing else can happen until the post connection stuff is done await self.__post_conn__() - case "auth_ok": - self.logger.info("Authenticated to Home Assistant %s", resp["ha_version"]) + case {"type": "auth_ok", "ha_version": ha_version}: + self.logger.info("Authenticated to Home Assistant %s", ha_version) # Creating a task here allows the plugin to still receive events as it waits for the startup conditions self.AD.loop.create_task(self.__post_auth__()) - case "auth_invalid": - self.logger.error(f'Failed to authenticate to Home Assistant: {resp["message"]}') + case {"type": "auth_invalid", "message": message}: + self.logger.error('Failed to authenticate to Home Assistant: %s', message) await self.ws.close() - case "ping": + case {"type": "ping"}: await self.ping() - case "pong": - if future := self._result_futures.get(resp["id"]): + case {"type": "pong", "id": resp_id}: + if future := self._result_futures.get(resp_id): future.set_result(resp) - case "result": + case {"type": "result"}: await self.receive_result(resp) - case "event": - await self.receive_event(event=resp["event"]) - case _: - raise NotImplementedError(resp["type"]) + case {"type": "event", "event": event}: + await self.receive_event(event) + case {"type": type_}: + raise NotImplementedError(type_) - async def __post_conn__(self): + async def __post_conn__(self) -> None: """Initialization to do after getting connected to the Home Assistant websocket""" self.connect_event.set() - return await self.websocket_send_json(**self.config.auth_json) + await self.websocket_send_json(**self.config.auth_json) - async def __post_auth__(self): + async def __post_auth__(self) -> None: """Initialization to do after getting authenticated on the websocket""" res = await self.websocket_send_json(type="subscribe_events") match res: - case None: - raise HAEventsSubError("Unknown error in subscribe") - case dict(): - match res["success"]: - case False: - res = res["error"] - raise HAEventsSubError(f'{res["code"]}: {res["message"]}') - case "timeout": - raise HAEventsSubError("Timed out waiting for subscription acknowledgement") + case {"success": True, "ad_duration": ad_duration}: + self.logger.debug( + "Subscribed to Home Assistant events from the websocket in %s", + utils.format_timedelta(ad_duration) + ) + case {"success": False, "error": {"code": code, "message": msg}}: + raise HAEventsSubError(f'{code}: {msg}') + case _: + raise HAEventsSubError(f'Unknown response from subscribe_events: {res}') config_coro = looped_coro(self.get_hass_config, self.config.config_sleep_time) self.AD.loop.create_task(config_coro(self)) @@ -223,7 +232,7 @@ async def __post_auth__(self): self.logger.info(f"Completed initialization in {self.time_str()}") @hass_check - async def ping(self, timeout: float = 1.0): + async def ping(self, timeout: float = 1.0) -> dict[str, Any ] | None: """Method for testing response times over the websocket.""" # https://developers.home-assistant.io/docs/api/websocket/#pings-and-pongs return await self.websocket_send_json(timeout=timeout, type="ping") @@ -253,7 +262,7 @@ async def receive_result(self, resp: dict): self.logger.error(f"Invalid response success value: {resp['success']}") @utils.warning_decorator(error_text="Unexpected error during receive_event") - async def receive_event(self, event: dict): + async def receive_event(self, event: dict[str, Any]) -> None: self.logger.debug(f"Received event type: {event['event_type']}") meta_attrs = {"origin", "time_fired", "context"} @@ -269,49 +278,71 @@ async def receive_event(self, event: dict): if condition.conditions_met: self.logger.info(f'HASS startup condition met {condition}') - match typ := event["event_type"]: - case "homeassistant_started": + match event: + case {"event_type": "homeassistant_started"}: self.logger.info(f"Home Assistant fully started after {utils.time_str(self.start)}") self.ready_event.set() - # https://data.home-assistant.io/docs/events/#service_registered - case "service_registered": - data = event["data"] - await self.check_register_service(data["domain"], data["service"], silent=True) - case "call_service": - service_name = f'{event["data"]["domain"]}.{event["data"]["service"]}' - entity_id = event["data"]["service_data"].get('entity_id') - self.logger.debug(f'{service_name}, {entity_id}') - case 'entity_registry_updated': + case {"event_type": "service_registered", "data": {"domain": domain, "service": service}}: + # https://data.home-assistant.io/docs/events/#service_registered + await self.check_register_service(domain, service, silent=True) + # Everything below here is just for information/debug purposes + case { # + "event_type": "call_service", + "data": { + "domain": domain, + "service": service, + "service_data": { + "entity_id": entity_id, + } + } + }: + self.logger.debug(f'Service {domain}.{service} called with {entity_id}') + case {"event_type": "entity_registry_updated"}: pass - # https://data.home-assistant.io/docs/events/#state_changed - case "state_changed": - ... - case "mobile_app_notification_action": + case { # https://data.home-assistant.io/docs/events/#state_changed + "event_type": "state_changed", + "data": { + "entity_id": entity_id, + "new_state": {"state": new_state}, + "old_state": {"state": old_state}, + }, + }: + self.logger.debug(f'{entity_id} state changed from {old_state} to {new_state}') + case {"event_type": "mobile_app_notification_action", "data": {"action": action}}: + self.logger.debug('Mobile action: %s', action) + case {"event_type": "mobile_app_notification_cleared"}: ... - # action = event['data']['action'] - case "mobile_app_notification_cleared": + case {"event_type": "android.zone_entered"}: ... - case "android.zone_entered": - ... - case "component_loaded": - self.logger.debug(f'Loaded component: {event["data"]["component"]}') - case _: - if typ.startswith('recorder'): + case {"event_type": "component_loaded", "data": {"component": component}}: + self.logger.debug('Loaded component: %s', component) + case {"event_type": other_event}: + if other_event.startswith('recorder'): return - # ? 'entity_registry_updated' - self.logger.debug('Unrecognized event %s', typ) + self.logger.debug('Unrecognized event %s', other_event) @utils.warning_decorator(error_text="Unexpected error during websocket send") async def websocket_send_json( self, - timeout: str | int | float | None = None, + timeout: str | int | float | datetime.timedelta | None = None, + *, # Arguments after this are keyword-only silent: bool = False, - **request - ) -> dict: + **request: Any + ) -> dict[str, Any] | None: """ - Sends a json request over the websocket and gets the response. + Send a JSON request over the websocket and await the response. + + The `id` parameter is handled automatically and is used to match the response to the request. + + Args: + timeout (str | int | float | datetime.timedelta, optional): Length of time to wait for a response from Home + Assistant with a matching `id`. Defaults to the value of the `ws_timeout` setting in the plugin config. + silent (bool, optional): If set to `True`, the method will not log the request or response. Defaults to + `False`. + **request (Any): Zero or more keyword arguments that will make up JSON request. - Handles incrementing the `id` parameter and appends + Returns: + A dict containing the response from Home Assistant. """ request = utils.clean_kwargs(**request) @@ -339,21 +370,23 @@ async def websocket_send_json( # happens when the connection closes in the middle, which could be during shutdown except ConnectionResetError: if self.stopping: + self.logger.debug("Not connected to websocket, skipping JSON send.") return else: - raise + raise # Something bad actually happened, so raise the exception self.update_perf(bytes_sent=len(json.dumps(request)), requests_sent=1) - if request.get("type") == "auth": - return + match request: + case {"type": "auth"}: + return future = self.AD.loop.create_future() self._result_futures[self.id] = future self._silent_results[self.id] = silent try: - timeout = utils.parse_timedelta(timeout) if timeout is not None else self.config.ws_timeout + timeout = utils.parse_timedelta(self.config.ws_timeout if timeout is None else timeout) result: dict = await asyncio.wait_for(future, timeout=timeout.total_seconds()) except asyncio.TimeoutError: ad_status = ServiceCallStatus.TIMEOUT @@ -380,9 +413,9 @@ async def http_method( self, method: Literal['get', 'post', 'delete'], endpoint: str, - timeout: float = 10.0, + timeout: str | int | float | datetime.timedelta | None = 10, **kwargs - ) -> dict | ClientResponse: + ) -> str | dict[str, Any] | list[Any] | ClientResponse | None: """ https://developers.home-assistant.io/docs/api/rest @@ -418,7 +451,8 @@ async def http_method( coro = self.session.delete(url=url, json=kwargs) case _: raise ValueError(f'Invalid method: {method}') - resp = await asyncio.wait_for(coro, timeout=timeout) + timeout = utils.parse_timedelta(timeout) + resp = await asyncio.wait_for(coro, timeout=timeout.total_seconds()) except asyncio.TimeoutError: self.logger.error("Timed out waiting for %s", url) except asyncio.CancelledError: @@ -446,7 +480,7 @@ async def http_method( raise NotImplementedError('Unhandled error: HTTP %s', resp.status) return resp - async def wait_for_conditions(self, conditions: StartupConditions | None): + async def wait_for_conditions(self, conditions: StartupConditions | None) -> None: if conditions is None: return @@ -458,20 +492,21 @@ async def wait_for_conditions(self, conditions: StartupConditions | None): self.startup_conditions.append(StartupWaitCondition(event_cond_data)) if cond := conditions.state: - current_state = await self.check_for_entity(cond.entity) + current_state = await self.check_for_entity(cond.entity, local=False) if cond.value is None: - if current_state is False: - # Wait for entity to exist - self.startup_conditions.append( - StartupWaitCondition({ - 'event_type': 'state_changed', - 'data': {'entity_id': cond.entity} - })) - else: - self.logger.info(f'Startup state condition already met: {cond.entity} exists') + match current_state: + case dict(): + self.logger.info(f'Startup state condition already met: {cond.entity} exists') + case False: + # Wait for entity to exist + self.startup_conditions.append( + StartupWaitCondition({ + 'event_type': 'state_changed', + 'data': {'entity_id': cond.entity} + })) else: data = cond.model_dump(exclude_unset=True) - if utils.deep_compare(data['value'], current_state): + if isinstance(current_state, dict) and utils.deep_compare(data['value'], current_state): self.logger.info(f'Startup state condition already met: {data}') else: self.logger.info(f'Adding startup state condition: {data}') @@ -483,18 +518,16 @@ async def wait_for_conditions(self, conditions: StartupConditions | None): } })) - tasks = [ + tasks: list[asyncio.Task[Literal[True] | None]] = [ self.AD.loop.create_task(cond.event.wait()) for cond in self.startup_conditions ] if delay := conditions.delay: self.logger.info(f'Adding a {delay:.0f}s delay to the {self.name} startup') - tasks.append( - self.AD.loop.create_task( - asyncio.sleep(delay) - ) - ) + sleep = asyncio.sleep(delay) + task = self.AD.loop.create_task(sleep) + tasks.append(task) self.logger.info(f'Waiting for {len(tasks)} startup condition tasks after {self.time_str()}') if tasks: @@ -531,32 +564,21 @@ async def get_updates(self): self.logger.info("Disconnecting from Home Assistant") - async def check_register_service(self, domain: str, services: str | dict, silent: bool = False) -> bool: - """Used to check and register a service with AppDaemon if need be""" - - existing_domains = set(s["domain"] for s in self.services) - new_services = set() - match services: - case str(): - service = services # rename for clarity - if domain not in existing_domains: - self.services.append({"domain": domain, "services": {service: {}}}) - new_services = {service} - case dict(): - if domain in existing_domains: - for i, s in enumerate(self.services): - if domain == s["domain"]: - self.services[i]["services"].update(services) - new_services = set(s for s in services if s not in self.services[i]["services"]) - else: - self.services.append({"domain": domain, "services": services}) - new_services = services - pass + def _check_for_service(self, domain: str, service: str) -> bool: + return service in self.AD.services.services.get(self.namespace, {}).get(domain, {}) - for service in new_services: + async def check_register_service( + self, + domain: str, + service: str, + *, + force: bool = False, + silent: bool = False + ) -> None: + """Register a service with AppDaemon if it doesn't already exist.""" + if (not self._check_for_service(domain, service)) or force: if not silent: self.logger.debug("Registering new service %s/%s", domain, service) - self.AD.services.register_service( self.namespace, domain, @@ -564,6 +586,8 @@ async def check_register_service(self, domain: str, services: str | dict, silent self.call_plugin_service, silent=True, ) + elif not silent: + self.logger.debug("Service %s/%s already registered", domain, service) # # Utility functions @@ -574,54 +598,45 @@ async def check_register_service(self, domain: str, services: str | dict, silent # return None @utils.warning_decorator(error_text="Unexpected error while getting hass config") - async def get_hass_config(self) -> dict | None: - if meta := (await self.websocket_send_json(type="get_config")).get("result"): - HASSMetaData.model_validate(meta) - self.metadata = meta - if self.metadata.get('state') == "RUNNING": - self.ready_event.set() - return self.metadata + async def get_hass_config(self) -> dict[str, Any] | None: + resp = await self.websocket_send_json(type="get_config") + match resp: + case {"success": True, "result": meta}: + HASSMetaData.model_validate(meta) + if meta.get('state') == "RUNNING": + self.ready_event.set() + self.metadata = meta + return self.metadata + case _: + return # websocket_send_json will log warnings if something happens on the AD side @utils.warning_decorator(error_text="Unexpected error while getting hass services") - async def get_hass_services(self): - """ "Gets a fresh list of services from the websocket and updates the various internal AppDaemon entries.""" - # raise ValueError - try: - services: dict[str, dict[str, dict]] = (await self.websocket_send_json(type="get_services"))["result"] - services = [{"domain": domain, "services": services} for domain, services in services.items()] - - # manually added HASS services - new_services = {} - new_services["database"] = {"history": {}} - new_services["template"] = {"render": {}} - - # now add the services - for i, service in enumerate(deepcopy(services)): - domain = service["domain"] - if domain in new_services: - # the domain already exists - services[i]["services"].update(new_services[domain]) - - # remove from the list - del new_services[domain] - - if len(new_services) > 0: # some have not been processed - for domain, service in new_services.items(): - services.append({"domain": domain, "services": {}}) - services[-1]["services"].update(service) - - for s in services: - await self.check_register_service(s["domain"], s["services"], silent=True) - else: - self.logger.debug("Updated internal service registry") - self._dump_services("ha") - - self.services = services - return services + async def get_hass_services(self) -> dict[str, Any] | None: + """Use the `get_services` feature of the Home Assistant websocket API. - except Exception: - self.logger.warning("Error getting services - retrying") - raise + This registers a service in AppDaemon for each service that is returned by Home Assistant, and sets the + `services` attribute the a deepcopy of the services dict as returned by Home Assistant. + """ + resp = await self.websocket_send_json(type="get_services") + match resp: + case {"result": full_services, "success": True}: + with self.AD.services.services_lock: + await self._register_http_services() + self.services = deepcopy(full_services) + self._dump_services("ha") + to_register = [ + functools.partial(self.check_register_service, domain, service, silent=True) + for domain, services in full_services.items() + for service in services + if not self._check_for_service(domain, service) + ] + self.logger.debug(f'Registering {len(to_register)} new services') + for registration in to_register: + await registration() + self.logger.debug("Updated internal service registry") + return self.services + case _: + return # websocket_send_json method will log warnings if something happens on the AD side def _compare_services(self, typ: Literal["ha", "ad"]) -> dict[str, set[str]]: match typ: @@ -629,8 +644,8 @@ def _compare_services(self, typ: Literal["ha", "ad"]) -> dict[str, set[str]]: # This gets the names of all the services as they come back from the get_hass_services method that gets # called when the plugin starts and at the interval defined by services_sleep_time in the plugin config. services = { - info["domain"]: set(info["services"].keys()) - for info in self.services + domain: set(services.keys()) + for domain, services in self.services.items() } case "ad": # This gets the names of all the services as they're stored in the services subsystem @@ -647,6 +662,11 @@ def _dump_services(self, typ: Literal["ha", "ad"]) -> None: service_str = json.dumps(services, indent=4, sort_keys=True, default=str) self.service_logger.debug(f"Services ({typ}):\n{service_str}") + async def _register_http_services(self): + """Register the services that are special cases because they use the REST API instead of the websocket API.""" + await self.check_register_service(domain="database", service="history", silent=True) + await self.check_register_service(domain="render", service="template", silent=True) + def time_str(self, now: float | None = None) -> str: return utils.time_str(self.start, now) @@ -697,17 +717,17 @@ async def call_plugin_service( # if we get a request for not our namespace something has gone very wrong assert namespace == self.namespace - if domain == "database": - assert service == "history", "Use the 'history' service with 'database'" - return await self.get_history(**data) - - # Keep this just in case anyone is still using call_service() for templates - if domain == "template" and service == "render": - return await self.render_template(namespace, data) + # This match block handles the special cases for services that use the legacy service calls. These are still + # relevant because they provide services for features that can only be used through the REST API. Otherwise, + # service calls use the websocket API to call service actions in Home Assistant. + match (domain, service): + case ("database", "history"): + return await self.get_history(**data) + case ("render", "template"): + return await self.render_template(namespace, data) # https://developers.home-assistant.io/docs/api/websocket#calling-a-service-action - - req = {"type": "call_service", "domain": domain, "service": service} + req: dict[str, Any] = {"type": "call_service", "domain": domain, "service": service} service_data = data.pop('service_data', {}) service_data.update(data) @@ -716,17 +736,16 @@ async def call_plugin_service( service_properties = { prop: val - for entry in self.services # For each service entry, - if domain == entry["domain"] # if the domain matches - for name, info in entry["services"].items() - if name == service # and the service name matches, - for prop, val in info.items() # get each of the properties + for domain_, service_ in self.services.items() # For each service entry, + if domain == domain_ # if the domain matches, + for name, info in service_.items() + if name == service # and the service name matches, + for prop, val in info.items() # get each of the properties } - # if it has a response section - if resp := service_properties.get("response"): - # if the response section says it's not optional - if not resp.get("optional"): + # Set the return_response flag if doing so is not optional + match service_properties: + case {"response": {"optional": False}}: req["return_response"] = True if target is None and entity_id is not None: @@ -782,11 +801,15 @@ async def safe_delete(self: 'HassPlugin'): # @utils.warning_decorator(error_text="Unexpected error while getting hass state") - async def get_complete_state(self) -> dict[str, dict[str, Any]]: + async def get_complete_state(self) -> dict[str, dict[str, Any]] | None: """This method is needed for all AppDaemon plugins""" - hass_state = (await self.websocket_send_json(type="get_states"))["result"] - states = {s["entity_id"]: s for s in hass_state} - return states + resp = await self.websocket_send_json(type="get_states") + match resp: + case {"result": hass_state, "success": True}: + return {s["entity_id"]: s for s in hass_state} + case _: + return # websocket_send_json will log warnings if something happens on the AD side + @utils.warning_decorator(error_text='Unexpected error setting state') async def set_plugin_state( @@ -809,18 +832,40 @@ async def safe_set_state(self: 'HassPlugin'): return await safe_set_state(self) @utils.warning_decorator(error_text='Unexpected error getting state') - async def get_plugin_state(self, entity_id: str, timeout: float | None = None): - return await self.http_method('get', f'/api/states/{entity_id}', timeout) + async def get_plugin_state( + self, + entity_id: str, + timeout: str | int | float | datetime.timedelta | None = 5 + ) -> dict | None: + resp = await self.http_method('get', f'/api/states/{entity_id}', timeout) + match resp: + case ClientResponse(): + self.logger.error('Error getting state') + case dict() | None: + return resp + case _: + raise ValueError(f"Unexpected result from get_plugin_state: {resp}") - async def check_for_entity(self, entity_id: str, timeout: float | None = None) -> dict | Literal[False]: - """Tries to get the state of an entity ID to see if it exists. + async def check_for_entity( + self, + entity_id: str, + timeout: str | int | float | datetime.timedelta | None = 5, + *, # Arguments after this are keyword-only + local: bool = False, + ) -> dict | Literal[False]: + """Try to get the state of an entity ID to see if it exists. Returns a dict of the state if the entity exists. Otherwise returns False""" - resp = await self.get_plugin_state(entity_id, timeout) - if isinstance(resp, dict): - return resp - elif isinstance(resp, ClientResponse) and resp.status == 404: - return False + if local: + resp = self.AD.state.state.get(self.namespace, {}).get(entity_id, False) + else: + resp = await self.get_plugin_state(entity_id, timeout) + + match resp: + case dict(): + return resp + case _: + return False @utils.warning_decorator(error_text='Unexpected error getting history') async def get_history( @@ -831,7 +876,7 @@ async def get_history( minimal_response: bool | None = None, no_attributes: bool | None = None, significant_changes_only: bool | None = None, - ) -> list[list[dict[str, Any]]]: + ) -> list[list[dict[str, Any]]] | None: """Used to get HA's History""" if isinstance(filter_entity_id, str): filter_entity_id = [filter_entity_id] @@ -857,9 +902,9 @@ async def get_history( if error_text == 'Invalid filter_entity_id': error_text += f" '{filter_entity_id}'" self.logger.error('Error getting history: %s', error_text) - case Iterable(): + case list(): # nested comprehension to convert the datetimes for convenience - result = [ + return [ [ { k: ( @@ -874,7 +919,6 @@ async def get_history( ] for entity_res in result ] - return result case _: raise ValueError(f"Unexpected result from history: {result}") @@ -884,34 +928,42 @@ async def get_logbook( entity: str | None = None, timestamp: datetime.datetime | None = None, end_time: datetime.datetime | None = None, - ) -> list[dict[str, str | datetime.datetime]]: + ) -> list[dict[str, str | datetime.datetime]] | None: """Used to get HA's logbook""" endpoint = "/api/logbook" if timestamp is not None: endpoint += f"/{timestamp.isoformat()}" - assert await self.check_for_entity(entity_id=entity), f"'{entity}' does not exist" + if entity is not None: + assert await self.check_for_entity(entity_id=entity), f"'{entity}' does not exist" - result: list[dict[str, str]] = await self.http_method( + result = await self.http_method( "get", endpoint, entity=entity, end_time=end_time ) - result = [ - { - k: v if k != "when" else ( - datetime - .datetime - .fromisoformat(v) - .astimezone(self.AD.tz) - ) - for k, v in entry.items() - } - for entry in result - ] - return result + match result: + case ClientResponse(): + # This means that HA rejected the request + error_text = (await result.json()).get('message', 'Unknown') + if error_text == 'Invalid filter_entity_id': + error_text += f" '{entity}'" + self.logger.error('Error getting history: %s', error_text) + case list(): + return [ + { + k: v if k != "when" else ( + datetime + .datetime + .fromisoformat(v) + .astimezone(self.AD.tz) + ) + for k, v in entry.items() + } + for entry in result + ] @utils.warning_decorator(error_text='Unexpected error rendering template') async def render_template(self, namespace: str, template: str, **kwargs): diff --git a/appdaemon/utils.py b/appdaemon/utils.py index a48bb2404..1e3ef16e2 100644 --- a/appdaemon/utils.py +++ b/appdaemon/utils.py @@ -334,6 +334,11 @@ def parse_timedelta(s: str | int | float | timedelta | None) -> timedelta: case 4: day, hour, min, sec = parts return timedelta(days=day, hours=hour, minutes=min, seconds=sec) + case _: + raise ValueError( + f"Invalid string format for timedelta: {s}." + "Must be in the format 'HH:MM:SS', 'MM:SS', or 'SS'." + ) case None: return timedelta() case _: