Skip to content

Commit d1e11cf

Browse files
committed
Merge branch 'python313' into uv
2 parents d6005be + 229d077 commit d1e11cf

File tree

25 files changed

+1218
-378
lines changed

25 files changed

+1218
-378
lines changed

.github/workflows/python-tests.yml

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -24,7 +24,7 @@ jobs:
2424
uses: actions/checkout@v5
2525
- name: Install uv and set the python version
2626
id: setup-uv
27-
uses: astral-sh/setup-uv@v6
27+
uses: astral-sh/setup-uv@v7
2828
with:
2929
# https://docs.astral.sh/uv/guides/integration/github/#using-uv-in-github-actions
3030
# It is considered best practice to pin to a specific uv version
@@ -68,7 +68,7 @@ jobs:
6868
uses: actions/checkout@v5
6969
- name: Install uv and set the python version
7070
id: setup-uv
71-
uses: astral-sh/setup-uv@v6
71+
uses: astral-sh/setup-uv@v7
7272
with:
7373
python-version: ${{ matrix.python-version }}
7474
# https://docs.astral.sh/uv/guides/integration/github/#using-uv-in-github-actions

.gitignore

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -46,6 +46,7 @@ nosetests.xml
4646
coverage.xml
4747
*,cover
4848
.hypothesis/
49+
tests/conf/namespaces
4950

5051
# Translations
5152
*.mo

.python-version

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1 +1 @@
1-
3.12
1+
3.13

appdaemon/adapi.py

Lines changed: 94 additions & 79 deletions
Large diffs are not rendered by default.

appdaemon/appdaemon.py

Lines changed: 19 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -119,6 +119,7 @@ def __init__(
119119
self.booted = "booting"
120120
self.logger = logging.get_logger()
121121
self.logging.register_ad(self) # needs to go last to reference the config object
122+
self._shutdown_logger = self.logging.get_child("_shutdown")
122123
self.stop_event = asyncio.Event()
123124

124125
self.global_vars: Any = {}
@@ -375,6 +376,7 @@ def start(self) -> None:
375376
self.thread_async.start()
376377
self.sched.start()
377378
self.utility.start()
379+
self.state.start()
378380

379381
if self.apps_enabled:
380382
self.app_management.start()
@@ -390,22 +392,22 @@ async def stop(self) -> None:
390392
- :meth:`Scheduler <appdaemon.scheduler.Scheduler.stop>`
391393
- :meth:`State <appdaemon.state.State.stop>`
392394
"""
393-
self.logger.info("Stopping AppDaemon")
395+
self._shutdown_logger.info("Stopping AppDaemon")
394396
self.stopping = True
395397

396398
# Subsystems are able to create tasks during their stop methods
397399
if self.apps_enabled:
398400
try:
399401
await asyncio.wait_for(self.app_management.stop(), timeout=3)
400402
except asyncio.TimeoutError:
401-
self.logger.warning("AppManagement stop timed out, continuing shutdown")
403+
self._shutdown_logger.warning("AppManagement stop timed out, continuing shutdown")
402404
if self.thread_async is not None:
403405
self.thread_async.stop()
404406
if self.plugins is not None:
405407
try:
406408
await asyncio.wait_for(self.plugins.stop(), timeout=1)
407409
except asyncio.TimeoutError:
408-
self.logger.warning("Timed out stopping plugins, continuing shutdown")
410+
self._shutdown_logger.warning("Timed out stopping plugins, continuing shutdown")
409411
self.sched.stop()
410412
self.state.stop()
411413
self.threading.stop()
@@ -420,7 +422,20 @@ async def stop(self) -> None:
420422
all_coro = asyncio.wait(running_tasks, return_when=asyncio.ALL_COMPLETED, timeout=3)
421423
gather_task = asyncio.create_task(all_coro, name="appdaemon_stop_tasks")
422424
gather_task.add_done_callback(lambda _: self.logger.debug("All tasks finished"))
423-
self.logger.debug("Waiting for tasks to finish...")
425+
self._shutdown_logger.debug("Waiting for tasks %s to finish...", len(running_tasks))
426+
427+
# These is left here for future debugging purposes
428+
# await asyncio.sleep(2.0)
429+
# still_running = [
430+
# task
431+
# for task in asyncio.all_tasks()
432+
# if task is not current_task and task is not gather_task and not task.done()
433+
# ]
434+
# self._shutdown_logger.debug("%s tasks still running after 2 seconds", len(still_running))
435+
# if still_running:
436+
# for task in still_running:
437+
# self._shutdown_logger.debug("Still running: %s", task.get_name())
438+
424439
await gather_task
425440

426441
#

appdaemon/models/config/misc.py

Lines changed: 10 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -1,10 +1,14 @@
11
import json
2-
from datetime import datetime
2+
from datetime import datetime, timedelta
33
from pathlib import Path
44
from typing import Any, Literal
55

66
from pydantic import BaseModel, Field, model_validator
77

8+
from appdaemon.utils import ADWritebackType
9+
10+
from .common import ParsedTimedelta
11+
812
LEVELS = Literal["DEBUG", "INFO", "WARNING", "ERROR", "CRITICAL"]
913

1014

@@ -31,23 +35,24 @@ class FilterConfig(BaseModel):
3135

3236

3337
class NamespaceConfig(BaseModel):
34-
writeback: Literal["safe", "hybrid"] | None = None
38+
writeback: ADWritebackType | None = None
3539
persist: bool = Field(default=False, alias="persistent")
40+
save_interval: ParsedTimedelta = Field(default=timedelta(seconds=1))
3641

3742
@model_validator(mode="before")
3843
@classmethod
3944
def validate_persistence(cls, values: Any):
4045
"""Sets persistence to True if writeback is set to safe or hybrid."""
4146
match values:
4247
case {"writeback": wb} if wb is not None:
43-
values["persistent"] = True
48+
values["persist"] = True
4449
case _ if getattr(values, "writeback", None) is not None:
45-
values.persistent = True
50+
values.persist = True
4651
return values
4752

4853
@model_validator(mode="after")
4954
def validate_writeback(self):
5055
"""Makes the writeback safe by default if persist is set to True."""
5156
if self.persist and self.writeback is None:
52-
self.writeback = "safe"
57+
self.writeback = ADWritebackType.safe
5358
return self

appdaemon/models/config/plugin.py

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -85,7 +85,7 @@ class StartupConditions(BaseModel):
8585
event: EventStartupCondition | None = None
8686

8787

88-
class HASSConfig(PluginConfig):
88+
class HASSConfig(PluginConfig, extra="forbid"):
8989
ha_url: str = "http://supervisor/core"
9090
token: SecretStr
9191
ha_key: Annotated[SecretStr, deprecated("'ha_key' is deprecated. Please use long lived tokens instead")] | None = None
@@ -101,6 +101,7 @@ class HASSConfig(PluginConfig):
101101
commtype: Annotated[str, deprecated("'commtype' is deprecated")] | None = None
102102
ws_timeout: ParsedTimedelta = timedelta(seconds=10)
103103
"""Default timeout for waiting for responses from the websocket connection"""
104+
ws_max_msg_size: int = 4 * 1024 * 1024
104105
suppress_log_messages: bool = False
105106
services_sleep_time: ParsedTimedelta = timedelta(seconds=60)
106107
"""The sleep time in the background task that updates the internal list of available services every once in a while"""

appdaemon/plugins/hass/exceptions.py

Lines changed: 7 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -41,3 +41,10 @@ def __str__(self):
4141
if self.namespace != "default":
4242
res += f" with namespace '{self.namespace}'"
4343
return res
44+
45+
@dataclass
46+
class HassConnectionError(ade.AppDaemonException):
47+
msg: str
48+
49+
def __str__(self) -> str:
50+
return self.msg

appdaemon/plugins/hass/hassplugin.py

Lines changed: 73 additions & 29 deletions
Original file line numberDiff line numberDiff line change
@@ -6,24 +6,24 @@
66
import functools
77
import json
88
import ssl
9-
from collections.abc import AsyncGenerator, Iterable
9+
from collections.abc import AsyncGenerator, Callable, Coroutine, Iterable
1010
from copy import deepcopy
1111
from dataclasses import dataclass, field
1212
from datetime import datetime, timedelta
1313
from time import perf_counter
1414
from typing import Any, Literal, Optional
1515

1616
import aiohttp
17-
from aiohttp import ClientResponse, ClientResponseError, RequestInfo, WSMsgType
17+
from aiohttp import ClientResponse, ClientResponseError, RequestInfo, WSMsgType, WebSocketError
1818
from pydantic import BaseModel
1919

2020
import appdaemon.utils as utils
2121
from appdaemon.appdaemon import AppDaemon
2222
from appdaemon.models.config.plugin import HASSConfig, StartupConditions
2323
from appdaemon.plugin_management import PluginBase
2424

25-
from .exceptions import HAEventsSubError
26-
from .utils import ServiceCallStatus, hass_check, looped_coro
25+
from .exceptions import HAEventsSubError, HassConnectionError
26+
from .utils import ServiceCallStatus, hass_check
2727

2828

2929
class HASSWebsocketResponse(BaseModel):
@@ -81,6 +81,9 @@ class HassPlugin(PluginBase):
8181
_result_futures: dict[int, asyncio.Future]
8282
_silent_results: dict[int, bool]
8383
startup_conditions: list[StartupWaitCondition]
84+
maintenance_tasks: list[asyncio.Task]
85+
"""List of tasks that run in the background as part of the plugin operation. These are tracked because they might
86+
need to get cancelled during shutdown."""
8487

8588
start: float
8689

@@ -96,6 +99,7 @@ def __init__(self, ad: "AppDaemon", name: str, config: HASSConfig):
9699
self._result_futures = {}
97100
self._silent_results = {}
98101
self.startup_conditions = []
102+
self.maintenance_tasks = []
99103

100104
self.service_logger = self.diag.getChild("services")
101105
self.logger.info("HASS Plugin initialization complete")
@@ -107,15 +111,18 @@ async def stop(self):
107111
await self.session.close()
108112
self.logger.debug("aiohttp session closed for '%s'", self.name)
109113

114+
def _create_maintenance_task(self, coro: Coroutine, name: str) -> asyncio.Task:
115+
task = self.AD.loop.create_task(coro, name=name)
116+
self.maintenance_tasks.append(task)
117+
task.add_done_callback(lambda t: self.maintenance_tasks.remove(t))
118+
return task
119+
110120
def create_session(self) -> aiohttp.ClientSession:
111121
"""Handles creating an :py:class:`~aiohttp.ClientSession` with the cert information from the plugin config
112122
and the authorization headers for the `REST API <https://developers.home-assistant.io/docs/api/rest>`_.
113123
"""
114-
if self.config.cert_path is not None:
115-
ssl_context = ssl.create_default_context(capath=self.config.cert_path)
116-
conn = aiohttp.TCPConnector(ssl_context=ssl_context, verify_ssl=self.config.cert_verify)
117-
else:
118-
conn = aiohttp.TCPConnector(ssl=False)
124+
ssl_context = ssl.create_default_context(capath=self.config.cert_path)
125+
conn = aiohttp.TCPConnector(ssl_context=ssl_context)
119126

120127
connect_timeout_secs = self.config.connect_timeout.total_seconds()
121128
return aiohttp.ClientSession(
@@ -142,30 +149,36 @@ async def websocket_msg_factory(self) -> AsyncGenerator[aiohttp.WSMessage]:
142149
self.start = perf_counter()
143150
async with self.create_session() as self.session:
144151
try:
145-
async with self.session.ws_connect(self.config.websocket_url) as self.ws:
152+
async with self.session.ws_connect(
153+
url=self.config.websocket_url,
154+
max_msg_size=self.config.ws_max_msg_size,
155+
) as self.ws:
156+
if (exc := self.ws.exception()) is not None:
157+
raise HassConnectionError("Failed to connect to Home Assistant websocket") from exc
158+
146159
async for msg in self.ws:
147-
self.updates_recv += 1
148-
self.bytes_recv += len(msg.data)
149160
yield msg
150161
finally:
151162
self.connect_event.clear()
152163

153-
async def match_ws_msg(self, msg: aiohttp.WSMessage) -> dict:
164+
async def match_ws_msg(self, msg: aiohttp.WSMessage) -> None:
154165
"""Uses a :py:ref:`match <class-patterns>` statement on :py:class:`~aiohttp.WSMessage`.
155166
156167
Uses :py:meth:`~HassPlugin.process_websocket_json` on :py:attr:`~aiohttp.WSMsgType.TEXT` messages.
157168
"""
158169
match msg:
159-
case aiohttp.WSMessage(type=WSMsgType.TEXT):
170+
case aiohttp.WSMessage(type=WSMsgType.TEXT, data=str(data)):
160171
# create a separate task for processing messages to keep the message reading task unblocked
161-
self.AD.loop.create_task(self.process_websocket_json(msg.json()))
162-
case aiohttp.WSMessage(type=WSMsgType.ERROR):
163-
self.logger.error("Error from aiohttp websocket: %s", msg.json())
172+
self.updates_recv += 1
173+
self.bytes_recv += len(data)
174+
# Intentionally not using self._create_maintenance_task here
175+
self.AD.loop.create_task(self.process_websocket_json(msg.json()), name="process_ws_msg")
176+
case aiohttp.WSMessage(type=WSMsgType.ERROR, data=WebSocketError() as err):
177+
self.logger.error("Error from aiohttp websocket: %s", err)
164178
case aiohttp.WSMessage(type=WSMsgType.CLOSE):
165179
self.logger.debug("Received %s message", msg.type)
166180
case _:
167181
self.logger.warning("Unhandled websocket message type: %s", msg.type)
168-
return msg.json()
169182

170183
@utils.warning_decorator(error_text="Error during processing jSON", reraise=True)
171184
async def process_websocket_json(self, resp: dict[str, Any]) -> None:
@@ -182,7 +195,7 @@ async def process_websocket_json(self, resp: dict[str, Any]) -> None:
182195
case {"type": "auth_ok", "ha_version": ha_version}:
183196
self.logger.info("Authenticated to Home Assistant %s", ha_version)
184197
# Creating a task here allows the plugin to still receive events as it waits for the startup conditions
185-
self.AD.loop.create_task(self.__post_auth__())
198+
self._create_maintenance_task(self.__post_auth__(), name="post_auth")
186199
case {"type": "auth_invalid", "message": message}:
187200
self.logger.error("Failed to authenticate to Home Assistant: %s", message)
188201
await self.ws.close()
@@ -218,11 +231,14 @@ async def __post_auth__(self) -> None:
218231
case _:
219232
raise HAEventsSubError(-1, f"Unknown response from subscribe_events: {res}")
220233

221-
config_coro = looped_coro(self.get_hass_config, self.config.config_sleep_time.total_seconds())
222-
self.AD.loop.create_task(config_coro(self))
223-
224-
service_coro = looped_coro(self.get_hass_services, self.config.services_sleep_time.total_seconds())
225-
self.AD.loop.create_task(service_coro(self))
234+
self._create_maintenance_task(
235+
self.looped_coro(self.get_hass_config, self.config.config_sleep_time.total_seconds()),
236+
name="get_hass_config loop"
237+
)
238+
self._create_maintenance_task(
239+
self.looped_coro(self.get_hass_services, self.config.services_sleep_time.total_seconds()),
240+
name="get_hass_services loop"
241+
)
226242

227243
if self.first_time:
228244
conditions = self.config.appdaemon_startup_conditions
@@ -413,7 +429,7 @@ async def websocket_send_json(
413429
ad_status = ServiceCallStatus.TERMINATING
414430
result = {"success": False}
415431
if not silent:
416-
self.logger.warning(f"AppDaemon cancelled waiting for the response from the request: {request}")
432+
self.logger.debug(f"AppDaemon cancelled waiting for the response from the request: {request}")
417433
else:
418434
ad_status = ServiceCallStatus.OK
419435

@@ -527,14 +543,16 @@ async def wait_for_conditions(self, conditions: StartupConditions | None) -> Non
527543
)
528544

529545
tasks: list[asyncio.Task[Literal[True] | None]] = [
530-
self.AD.loop.create_task(cond.event.wait())
546+
self._create_maintenance_task(cond.event.wait(), name=f"startup condition: {cond}")
531547
for cond in self.startup_conditions
532548
] # fmt: skip
533549

534550
if delay := conditions.delay:
535551
self.logger.info(f"Adding a {delay:.0f}s delay to the {self.name} startup")
536-
sleep = self.AD.utility.sleep(delay, timeout_ok=True)
537-
task = self.AD.loop.create_task(sleep)
552+
task = self._create_maintenance_task(
553+
self.AD.utility.sleep(delay, timeout_ok=True),
554+
name="startup delay"
555+
)
538556
tasks.append(task)
539557

540558
self.logger.info(f"Waiting for {len(tasks)} startup condition tasks after {self.time_str()}")
@@ -555,7 +573,7 @@ async def get_updates(self):
555573
async for msg in self.websocket_msg_factory():
556574
await self.match_ws_msg(msg)
557575
continue
558-
raise ValueError
576+
raise HassConnectionError("Websocket connection lost")
559577
except Exception as exc:
560578
if not self.AD.stopping:
561579
self.error.error(exc)
@@ -568,7 +586,17 @@ async def get_updates(self):
568586

569587
# always do this block, no matter what
570588
finally:
589+
for task in self.maintenance_tasks:
590+
if not task.done():
591+
task.cancel()
592+
571593
if not self.AD.stopping:
594+
for fut in self._result_futures.values():
595+
if not fut.done():
596+
fut.cancel()
597+
self._result_futures.clear()
598+
self._silent_results.clear()
599+
572600
# remove callback from getting local events
573601
await self.AD.callbacks.clear_callbacks(self.name)
574602

@@ -605,6 +633,22 @@ async def check_register_service(
605633
# self.logger.debug("Utility (currently unused)")
606634
# return None
607635

636+
async def looped_coro(self, coro: Callable[..., Coroutine], sleep: float):
637+
"""Run a coroutine in a loop with a sleep interval.
638+
639+
This is a utility function that can be used to run a coroutine in a loop with a sleep interval. It is used
640+
internally to run the `get_hass_config` and
641+
"""
642+
while not self.AD.stopping:
643+
try:
644+
await coro()
645+
except asyncio.CancelledError:
646+
pass
647+
except Exception as e:
648+
self.logger.error("Error in looped coroutine: %s", e)
649+
finally:
650+
await self.AD.utility.sleep(sleep, timeout_ok=True)
651+
608652
@utils.warning_decorator(error_text="Unexpected error while getting hass config")
609653
async def get_hass_config(self) -> dict[str, Any] | None:
610654
resp = await self.websocket_send_json(type="get_config")

0 commit comments

Comments
 (0)