Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
7 changes: 6 additions & 1 deletion appdaemon/plugins/hass/exceptions.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@
from dataclasses import dataclass, field


from appdaemon import exceptions as ade


Expand All @@ -16,7 +17,11 @@ class HAAuthenticationError(ade.AppDaemonException):

@dataclass
class HAEventsSubError(ade.AppDaemonException):
pass
code: int
msg: str

def __str__(self) -> str:
return f"{self.code}: {self.msg}"


@dataclass
Expand Down
6 changes: 3 additions & 3 deletions appdaemon/plugins/hass/hassapi.py
Original file line number Diff line number Diff line change
Expand Up @@ -677,9 +677,9 @@ async def get_history(
days: int | None = None,
start_time: datetime | str | None = None,
end_time: datetime | str | None = None,
minimal_response: bool | None = None,
no_attributes: bool | None = None,
significant_changes_only: bool | None = None,
minimal_response: bool = False,
no_attributes: bool = False,
significant_changes_only: bool = False,
callback: Callable | None = None,
namespace: str | None = None,
) -> list[list[dict[str, Any]]] | None:
Expand Down
59 changes: 31 additions & 28 deletions appdaemon/plugins/hass/hassplugin.py
Original file line number Diff line number Diff line change
Expand Up @@ -195,9 +195,9 @@ async def __post_auth__(self) -> None:
utils.format_timedelta(ad_duration),
)
case {"success": False, "error": {"code": code, "message": msg}}:
raise HAEventsSubError(f"{code}: {msg}")
raise HAEventsSubError(code, msg)
case _:
raise HAEventsSubError(f"Unknown response from subscribe_events: {res}")
raise HAEventsSubError(-1, f"Unknown response from subscribe_events: {res}")

config_coro = looped_coro(self.get_hass_config, self.config.config_sleep_time)
self.AD.loop.create_task(config_coro(self))
Expand Down Expand Up @@ -340,7 +340,7 @@ async def websocket_send_json(
Returns:
A dict containing the response from Home Assistant.
"""
request = utils.clean_kwargs(**request)
request = dict(utils.clean_kwargs(**request))

if not self.connect_event.is_set():
self.logger.debug("Not connected to websocket, skipping JSON send.")
Expand Down Expand Up @@ -426,52 +426,55 @@ async def http_method(
Returns:
dict | None: _description_
"""
kwargs = utils.clean_kwargs(**kwargs)
kwargs = dict(utils.clean_http_kwargs(**kwargs))
url = utils.make_endpoint(self.config.ha_url, endpoint)

try:
self.update_perf(
bytes_sent=len(url) + len(json.dumps(kwargs).encode("utf-8")),
requests_sent=1,
)

self.logger.debug(f"Hass {method.upper()} {endpoint}: {kwargs}")
match method.lower():
case "get":
coro = self.session.get(url=url, params=kwargs)
http_method = functools.partial(self.session.get, params=kwargs)
case "post":
coro = self.session.post(url=url, json=kwargs)
http_method = functools.partial(self.session.post, json=kwargs)
case "delete":
coro = self.session.delete(url=url, json=kwargs)
http_method = functools.partial(self.session.delete, json=kwargs)
case _:
raise ValueError(f"Invalid method: {method}")

timeout = utils.parse_timedelta(timeout)
resp = await asyncio.wait_for(coro, timeout=timeout.total_seconds())
client_timeout = aiohttp.ClientTimeout(total=timeout.total_seconds())
async with http_method(url=url, timeout=client_timeout) as resp:
self.logger.debug(f"HTTP {method.upper()} {resp.url}")
self.update_perf(bytes_recv=resp.content_length, updates_recv=1)
match resp.status:
case 200 | 201:
if endpoint.endswith("template"):
return await resp.text()
else:
return await resp.json()
case 400 | 401 | 403 | 404 | 405:
try:
msg = (await resp.json())["message"]
except Exception:
msg = await resp.text()
self.logger.error(f"Bad response from {url}: {msg}")
case 500 | 502:
text = await resp.text()
self.logger.error("Internal server error %s: %s", url, text)
case _:
raise NotImplementedError("Unhandled error: HTTP %s", resp.status)
return resp
except asyncio.TimeoutError:
self.logger.error("Timed out waiting for %s", url)
except asyncio.CancelledError:
self.logger.debug("Task cancelled during %s", method.upper())
except aiohttp.ServerDisconnectedError:
self.logger.error("HASS disconnected unexpectedly during %s to %s", method.upper(), url)
else:
self.update_perf(bytes_recv=resp.content_length, updates_recv=1)
match resp.status:
case 200 | 201:
if endpoint.endswith("template"):
return await resp.text()
else:
return await resp.json()
case 400 | 401 | 403 | 404 | 405:
try:
msg = (await resp.json())["message"]
except Exception:
msg = await resp.text()
self.logger.error(f"Bad response from {url}: {msg}")
case 500 | 502:
text = await resp.text()
self.logger.error("Internal server error %s: %s", url, text)
case _:
raise NotImplementedError("Unhandled error: HTTP %s", resp.status)
return resp

async def wait_for_conditions(self, conditions: StartupConditions | None) -> None:
if conditions is None:
Expand Down
59 changes: 41 additions & 18 deletions appdaemon/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,14 +16,13 @@
import sys
import threading
import traceback
from collections.abc import Awaitable, Generator, Iterable
from collections.abc import Awaitable, Generator, Iterable, Mapping
from datetime import datetime, time, timedelta, tzinfo
from functools import wraps
from logging import Logger
from pathlib import Path
from time import perf_counter
from typing import (TYPE_CHECKING, Any, Callable, Coroutine, Literal,
ParamSpec, Protocol, TypeVar)
from typing import TYPE_CHECKING, Any, Callable, Coroutine, Literal, ParamSpec, Protocol, TypeVar

import dateutil.parser
import tomli
Expand Down Expand Up @@ -1071,28 +1070,52 @@ def time_str(start: float, now: float | None = None) -> str:
return format_timedelta((now or perf_counter()) - start)


def clean_kwargs(**kwargs):
"""Converts everything to strings and removes null values"""
def clean_kwargs(**kwargs: Any) -> Generator[tuple[str, Any]]:
"""Recursively clean a dict of kwargs.

def clean_value(val: Any) -> str:
Conversions:
- None values are removed
- datetime values are converted to ISO format strings
- bool values are converted to lowercase strings
- int, float, and str values are converted to strings
- Iterable values (like lists and tuples) are converted to lists of cleaned values
- Mapping values (like dicts) are converted to dicts of cleaned key-value pairs
"""

def _clean_value(val: bool | datetime | Any) -> str:
match val:
case int() | float() | str():
return val
case datetime():
return val.isoformat()
case dict():
return clean_kwargs(**val)
case Iterable():
return [clean_value(v) for v in val]
case bool():
return str(val).lower()
case _:
return str(val)

kwargs = {
k: clean_value(v)
for k, v in kwargs.items()
if v is not None
} # fmt: skip
return kwargs
for key, val in kwargs.items():
match val:
case None:
continue
case str():
# This case needs to be before the Iterable case because strings are iterable
yield key, _clean_value(val)
case Mapping():
# This case needs to be before the Iterable case because Mappings like dicts are iterable
yield key, dict(clean_kwargs(**val))
case Iterable():
yield key, list(map(_clean_value, val))
case _:
yield key, _clean_value(val)


def clean_http_kwargs(**kwargs: Any) -> Generator[tuple[str, Any]]:
"""Recursively cleans the kwarg dict to prepare it for use in HTTP requests."""
for key, val in clean_kwargs(**kwargs):
match val:
case "false" | None:
# Filter out values that are False or None
continue
case _:
yield key, val


def make_endpoint(base: str, endpoint: str) -> str:
Expand Down
40 changes: 40 additions & 0 deletions tests/unit/test_kwarg_clean.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,40 @@
from copy import deepcopy
from datetime import datetime

import pytest
import pytz
from appdaemon.utils import clean_http_kwargs, clean_kwargs

pytestmark = [
pytest.mark.ci,
pytest.mark.unit,
]


BASE = {"a": 1, "b": 2.0, "c": "three", "d": True, "e": False, "f": datetime(2025, 9, 22, 12, 0, 0, tzinfo=pytz.utc), "g": None}


def test_clean_kwargs():
cleaned = dict(clean_kwargs(**BASE))
for v in cleaned.values():
assert isinstance(v, str), f"Value {v} should be a string after cleaning"

assert cleaned["e"] == "false"
assert "g" not in cleaned

kwargs = deepcopy(BASE)

kwargs["nested"] = deepcopy(BASE)
kwargs["nested"]["extra"] = deepcopy(BASE)
cleaned = dict(clean_kwargs(**kwargs))
for v in cleaned["nested"]["extra"].values():
assert isinstance(v, str), f"Value {v} should be a string after cleaning"


def test_clean_http_kwargs():
cleaned = dict(clean_http_kwargs(**BASE))
for v in cleaned.values():
assert isinstance(v, str), f"Value {v} should be a string after cleaning"

assert "e" not in cleaned
assert "g" not in cleaned
Loading