Skip to content

Commit 9392d10

Browse files
authored
Add background option to update/install APIs (#6134)
* Add background option to update/install APIs * Refactor to use common background_task utility in backups too * Use a validation_complete event rather then looking for bus events
1 parent 5ce62f3 commit 9392d10

File tree

9 files changed

+333
-68
lines changed

9 files changed

+333
-68
lines changed

supervisor/addons/manager.py

Lines changed: 16 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -184,7 +184,9 @@ async def shutdown(self, stage: AddonStartup) -> None:
184184
on_condition=AddonsJobError,
185185
concurrency=JobConcurrency.QUEUE,
186186
)
187-
async def install(self, slug: str) -> None:
187+
async def install(
188+
self, slug: str, *, validation_complete: asyncio.Event | None = None
189+
) -> None:
188190
"""Install an add-on."""
189191
self.sys_jobs.current.reference = slug
190192

@@ -197,6 +199,10 @@ async def install(self, slug: str) -> None:
197199

198200
store.validate_availability()
199201

202+
# If being run in the background, notify caller that validation has completed
203+
if validation_complete:
204+
validation_complete.set()
205+
200206
await Addon(self.coresys, slug).install()
201207

202208
_LOGGER.info("Add-on '%s' successfully installed", slug)
@@ -226,7 +232,11 @@ async def uninstall(self, slug: str, *, remove_config: bool = False) -> None:
226232
on_condition=AddonsJobError,
227233
)
228234
async def update(
229-
self, slug: str, backup: bool | None = False
235+
self,
236+
slug: str,
237+
backup: bool | None = False,
238+
*,
239+
validation_complete: asyncio.Event | None = None,
230240
) -> asyncio.Task | None:
231241
"""Update add-on.
232242
@@ -251,6 +261,10 @@ async def update(
251261
# Check if available, Maybe something have changed
252262
store.validate_availability()
253263

264+
# If being run in the background, notify caller that validation has completed
265+
if validation_complete:
266+
validation_complete.set()
267+
254268
if backup:
255269
await self.sys_backups.do_backup_partial(
256270
name=f"addon_{addon.slug}_{addon.version}",

supervisor/api/backups.py

Lines changed: 9 additions & 48 deletions
Original file line numberDiff line numberDiff line change
@@ -3,7 +3,6 @@
33
from __future__ import annotations
44

55
import asyncio
6-
from collections.abc import Callable
76
import errno
87
from io import IOBase
98
import logging
@@ -46,12 +45,9 @@
4645
ATTR_TYPE,
4746
ATTR_VERSION,
4847
REQUEST_FROM,
49-
BusEvent,
50-
CoreState,
5148
)
5249
from ..coresys import CoreSysAttributes
5350
from ..exceptions import APIError, APIForbidden, APINotFound
54-
from ..jobs import JobSchedulerOptions, SupervisorJob
5551
from ..mounts.const import MountUsage
5652
from ..resolution.const import UnhealthyReason
5753
from .const import (
@@ -61,7 +57,7 @@
6157
ATTR_LOCATIONS,
6258
CONTENT_TYPE_TAR,
6359
)
64-
from .utils import api_process, api_validate
60+
from .utils import api_process, api_validate, background_task
6561

6662
_LOGGER: logging.Logger = logging.getLogger(__name__)
6763

@@ -289,41 +285,6 @@ def _validate_cloud_backup_location(
289285
f"Location {LOCATION_CLOUD_BACKUP} is only available for Home Assistant"
290286
)
291287

292-
async def _background_backup_task(
293-
self, backup_method: Callable, *args, **kwargs
294-
) -> tuple[asyncio.Task, str]:
295-
"""Start backup task in background and return task and job ID."""
296-
event = asyncio.Event()
297-
job, backup_task = cast(
298-
tuple[SupervisorJob, asyncio.Task],
299-
self.sys_jobs.schedule_job(
300-
backup_method, JobSchedulerOptions(), *args, **kwargs
301-
),
302-
)
303-
304-
async def release_on_freeze(new_state: CoreState):
305-
if new_state == CoreState.FREEZE:
306-
event.set()
307-
308-
# Wait for system to get into freeze state before returning
309-
# If the backup fails validation it will raise before getting there
310-
listener = self.sys_bus.register_event(
311-
BusEvent.SUPERVISOR_STATE_CHANGE, release_on_freeze
312-
)
313-
try:
314-
event_task = self.sys_create_task(event.wait())
315-
_, pending = await asyncio.wait(
316-
(backup_task, event_task),
317-
return_when=asyncio.FIRST_COMPLETED,
318-
)
319-
# It seems backup returned early (error or something), make sure to cancel
320-
# the event task to avoid "Task was destroyed but it is pending!" errors.
321-
if event_task in pending:
322-
event_task.cancel()
323-
return (backup_task, job.uuid)
324-
finally:
325-
self.sys_bus.remove_listener(listener)
326-
327288
@api_process
328289
async def backup_full(self, request: web.Request):
329290
"""Create full backup."""
@@ -342,8 +303,8 @@ async def backup_full(self, request: web.Request):
342303
body[ATTR_ADDITIONAL_LOCATIONS] = locations
343304

344305
background = body.pop(ATTR_BACKGROUND)
345-
backup_task, job_id = await self._background_backup_task(
346-
self.sys_backups.do_backup_full, **body
306+
backup_task, job_id = await background_task(
307+
self, self.sys_backups.do_backup_full, **body
347308
)
348309

349310
if background and not backup_task.done():
@@ -378,8 +339,8 @@ async def backup_partial(self, request: web.Request):
378339
body[ATTR_ADDONS] = list(self.sys_addons.local)
379340

380341
background = body.pop(ATTR_BACKGROUND)
381-
backup_task, job_id = await self._background_backup_task(
382-
self.sys_backups.do_backup_partial, **body
342+
backup_task, job_id = await background_task(
343+
self, self.sys_backups.do_backup_partial, **body
383344
)
384345

385346
if background and not backup_task.done():
@@ -402,8 +363,8 @@ async def restore_full(self, request: web.Request):
402363
request, body.get(ATTR_LOCATION, backup.location)
403364
)
404365
background = body.pop(ATTR_BACKGROUND)
405-
restore_task, job_id = await self._background_backup_task(
406-
self.sys_backups.do_restore_full, backup, **body
366+
restore_task, job_id = await background_task(
367+
self, self.sys_backups.do_restore_full, backup, **body
407368
)
408369

409370
if background and not restore_task.done() or await restore_task:
@@ -422,8 +383,8 @@ async def restore_partial(self, request: web.Request):
422383
request, body.get(ATTR_LOCATION, backup.location)
423384
)
424385
background = body.pop(ATTR_BACKGROUND)
425-
restore_task, job_id = await self._background_backup_task(
426-
self.sys_backups.do_restore_partial, backup, **body
386+
restore_task, job_id = await background_task(
387+
self, self.sys_backups.do_restore_partial, backup, **body
427388
)
428389

429390
if background and not restore_task.done() or await restore_task:

supervisor/api/homeassistant.py

Lines changed: 16 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -20,6 +20,7 @@
2020
ATTR_CPU_PERCENT,
2121
ATTR_IMAGE,
2222
ATTR_IP_ADDRESS,
23+
ATTR_JOB_ID,
2324
ATTR_MACHINE,
2425
ATTR_MEMORY_LIMIT,
2526
ATTR_MEMORY_PERCENT,
@@ -37,8 +38,8 @@
3738
from ..coresys import CoreSysAttributes
3839
from ..exceptions import APIDBMigrationInProgress, APIError
3940
from ..validate import docker_image, network_port, version_tag
40-
from .const import ATTR_FORCE, ATTR_SAFE_MODE
41-
from .utils import api_process, api_validate
41+
from .const import ATTR_BACKGROUND, ATTR_FORCE, ATTR_SAFE_MODE
42+
from .utils import api_process, api_validate, background_task
4243

4344
_LOGGER: logging.Logger = logging.getLogger(__name__)
4445

@@ -61,6 +62,7 @@
6162
{
6263
vol.Optional(ATTR_VERSION): version_tag,
6364
vol.Optional(ATTR_BACKUP): bool,
65+
vol.Optional(ATTR_BACKGROUND, default=False): bool,
6466
}
6567
)
6668

@@ -170,18 +172,24 @@ async def stats(self, request: web.Request) -> dict[Any, str]:
170172
}
171173

172174
@api_process
173-
async def update(self, request: web.Request) -> None:
175+
async def update(self, request: web.Request) -> dict[str, str] | None:
174176
"""Update Home Assistant."""
175177
body = await api_validate(SCHEMA_UPDATE, request)
176178
await self._check_offline_migration()
177179

178-
await asyncio.shield(
179-
self.sys_homeassistant.core.update(
180-
version=body.get(ATTR_VERSION, self.sys_homeassistant.latest_version),
181-
backup=body.get(ATTR_BACKUP),
182-
)
180+
background = body[ATTR_BACKGROUND]
181+
update_task, job_id = await background_task(
182+
self,
183+
self.sys_homeassistant.core.update,
184+
version=body.get(ATTR_VERSION, self.sys_homeassistant.latest_version),
185+
backup=body.get(ATTR_BACKUP),
183186
)
184187

188+
if background and not update_task.done():
189+
return {ATTR_JOB_ID: job_id}
190+
191+
return await update_task
192+
185193
@api_process
186194
async def stop(self, request: web.Request) -> Awaitable[None]:
187195
"""Stop Home Assistant."""

supervisor/api/store.py

Lines changed: 37 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,6 @@
11
"""Init file for Supervisor Home Assistant RESTful API."""
22

33
import asyncio
4-
from collections.abc import Awaitable
54
from pathlib import Path
65
from typing import Any, cast
76

@@ -36,6 +35,7 @@
3635
ATTR_ICON,
3736
ATTR_INGRESS,
3837
ATTR_INSTALLED,
38+
ATTR_JOB_ID,
3939
ATTR_LOGO,
4040
ATTR_LONG_DESCRIPTION,
4141
ATTR_MAINTAINER,
@@ -57,18 +57,26 @@
5757
from ..store.addon import AddonStore
5858
from ..store.repository import Repository
5959
from ..store.validate import validate_repository
60-
from .const import CONTENT_TYPE_PNG, CONTENT_TYPE_TEXT
60+
from .const import ATTR_BACKGROUND, CONTENT_TYPE_PNG, CONTENT_TYPE_TEXT
61+
from .utils import background_task
6162

6263
SCHEMA_UPDATE = vol.Schema(
6364
{
6465
vol.Optional(ATTR_BACKUP): bool,
66+
vol.Optional(ATTR_BACKGROUND, default=False): bool,
6567
}
6668
)
6769

6870
SCHEMA_ADD_REPOSITORY = vol.Schema(
6971
{vol.Required(ATTR_REPOSITORY): vol.All(str, validate_repository)}
7072
)
7173

74+
SCHEMA_INSTALL = vol.Schema(
75+
{
76+
vol.Optional(ATTR_BACKGROUND, default=False): bool,
77+
}
78+
)
79+
7280

7381
def _read_static_text_file(path: Path) -> Any:
7482
"""Read in a static text file asset for API output.
@@ -217,24 +225,45 @@ async def addons_list(self, request: web.Request) -> dict[str, Any]:
217225
}
218226

219227
@api_process
220-
def addons_addon_install(self, request: web.Request) -> Awaitable[None]:
228+
async def addons_addon_install(self, request: web.Request) -> dict[str, str] | None:
221229
"""Install add-on."""
222230
addon = self._extract_addon(request)
223-
return asyncio.shield(self.sys_addons.install(addon.slug))
231+
body = await api_validate(SCHEMA_INSTALL, request)
232+
233+
background = body[ATTR_BACKGROUND]
234+
235+
install_task, job_id = await background_task(
236+
self, self.sys_addons.install, addon.slug
237+
)
238+
239+
if background and not install_task.done():
240+
return {ATTR_JOB_ID: job_id}
241+
242+
return await install_task
224243

225244
@api_process
226-
async def addons_addon_update(self, request: web.Request) -> None:
245+
async def addons_addon_update(self, request: web.Request) -> dict[str, str] | None:
227246
"""Update add-on."""
228247
addon = self._extract_addon(request, installed=True)
229248
if addon == request.get(REQUEST_FROM):
230249
raise APIForbidden(f"Add-on {addon.slug} can't update itself!")
231250

232251
body = await api_validate(SCHEMA_UPDATE, request)
252+
background = body[ATTR_BACKGROUND]
253+
254+
update_task, job_id = await background_task(
255+
self,
256+
self.sys_addons.update,
257+
addon.slug,
258+
backup=body.get(ATTR_BACKUP),
259+
)
260+
261+
if background and not update_task.done():
262+
return {ATTR_JOB_ID: job_id}
233263

234-
if start_task := await asyncio.shield(
235-
self.sys_addons.update(addon.slug, backup=body.get(ATTR_BACKUP))
236-
):
264+
if start_task := await update_task:
237265
await start_task
266+
return None
238267

239268
@api_process
240269
async def addons_addon_info(self, request: web.Request) -> dict[str, Any]:

supervisor/api/utils.py

Lines changed: 48 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,9 @@
11
"""Init file for Supervisor util for RESTful API."""
22

3+
import asyncio
4+
from collections.abc import Callable
35
import json
4-
from typing import Any
6+
from typing import Any, cast
57

68
from aiohttp import web
79
from aiohttp.hdrs import AUTHORIZATION
@@ -23,6 +25,7 @@
2325
)
2426
from ..coresys import CoreSys, CoreSysAttributes
2527
from ..exceptions import APIError, BackupFileNotFoundError, DockerAPIError, HassioError
28+
from ..jobs import JobSchedulerOptions, SupervisorJob
2629
from ..utils import check_exception_chain, get_message_from_exception_chain
2730
from ..utils.json import json_dumps, json_loads as json_loads_util
2831
from ..utils.log_format import format_message
@@ -198,3 +201,47 @@ async def api_validate(
198201
data_validated[origin_value] = data[origin_value]
199202

200203
return data_validated
204+
205+
206+
async def background_task(
207+
coresys_obj: CoreSysAttributes,
208+
task_method: Callable,
209+
*args,
210+
**kwargs,
211+
) -> tuple[asyncio.Task, str]:
212+
"""Start task in background and return task and job ID.
213+
214+
Args:
215+
coresys_obj: Instance that accesses coresys data using CoreSysAttributes
216+
task_method: The method to execute in the background. Must include a keyword arg 'validation_complete' of type asyncio.Event. Should set it after any initial validation has completed
217+
*args: Arguments to pass to task_method
218+
**kwargs: Keyword arguments to pass to task_method
219+
220+
Returns:
221+
Tuple of (task, job_id)
222+
223+
"""
224+
event = asyncio.Event()
225+
job, task = cast(
226+
tuple[SupervisorJob, asyncio.Task],
227+
coresys_obj.sys_jobs.schedule_job(
228+
task_method,
229+
JobSchedulerOptions(),
230+
*args,
231+
validation_complete=event,
232+
**kwargs,
233+
),
234+
)
235+
236+
# Wait for provided event before returning
237+
# If the task fails validation it should raise before getting there
238+
event_task = coresys_obj.sys_create_task(event.wait())
239+
_, pending = await asyncio.wait(
240+
(task, event_task),
241+
return_when=asyncio.FIRST_COMPLETED,
242+
)
243+
# It seems task returned early (error or something), make sure to cancel
244+
# the event task to avoid "Task was destroyed but it is pending!" errors.
245+
if event_task in pending:
246+
event_task.cancel()
247+
return (task, job.uuid)

0 commit comments

Comments
 (0)