Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
18 changes: 16 additions & 2 deletions supervisor/addons/manager.py
Original file line number Diff line number Diff line change
Expand Up @@ -184,7 +184,9 @@ async def shutdown(self, stage: AddonStartup) -> None:
on_condition=AddonsJobError,
concurrency=JobConcurrency.QUEUE,
)
async def install(self, slug: str) -> None:
async def install(
self, slug: str, *, validation_complete: asyncio.Event | None = None
) -> None:
"""Install an add-on."""
self.sys_jobs.current.reference = slug

Expand All @@ -197,6 +199,10 @@ async def install(self, slug: str) -> None:

store.validate_availability()

# If being run in the background, notify caller that validation has completed
if validation_complete:
validation_complete.set()

await Addon(self.coresys, slug).install()

_LOGGER.info("Add-on '%s' successfully installed", slug)
Expand Down Expand Up @@ -226,7 +232,11 @@ async def uninstall(self, slug: str, *, remove_config: bool = False) -> None:
on_condition=AddonsJobError,
)
async def update(
self, slug: str, backup: bool | None = False
self,
slug: str,
backup: bool | None = False,
*,
validation_complete: asyncio.Event | None = None,
) -> asyncio.Task | None:
"""Update add-on.

Expand All @@ -251,6 +261,10 @@ async def update(
# Check if available, Maybe something have changed
store.validate_availability()

# If being run in the background, notify caller that validation has completed
if validation_complete:
validation_complete.set()

if backup:
await self.sys_backups.do_backup_partial(
name=f"addon_{addon.slug}_{addon.version}",
Expand Down
57 changes: 9 additions & 48 deletions supervisor/api/backups.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,6 @@
from __future__ import annotations

import asyncio
from collections.abc import Callable
import errno
from io import IOBase
import logging
Expand Down Expand Up @@ -46,12 +45,9 @@
ATTR_TYPE,
ATTR_VERSION,
REQUEST_FROM,
BusEvent,
CoreState,
)
from ..coresys import CoreSysAttributes
from ..exceptions import APIError, APIForbidden, APINotFound
from ..jobs import JobSchedulerOptions, SupervisorJob
from ..mounts.const import MountUsage
from ..resolution.const import UnhealthyReason
from .const import (
Expand All @@ -61,7 +57,7 @@
ATTR_LOCATIONS,
CONTENT_TYPE_TAR,
)
from .utils import api_process, api_validate
from .utils import api_process, api_validate, background_task

_LOGGER: logging.Logger = logging.getLogger(__name__)

Expand Down Expand Up @@ -289,41 +285,6 @@ def _validate_cloud_backup_location(
f"Location {LOCATION_CLOUD_BACKUP} is only available for Home Assistant"
)

async def _background_backup_task(
self, backup_method: Callable, *args, **kwargs
) -> tuple[asyncio.Task, str]:
"""Start backup task in background and return task and job ID."""
event = asyncio.Event()
job, backup_task = cast(
tuple[SupervisorJob, asyncio.Task],
self.sys_jobs.schedule_job(
backup_method, JobSchedulerOptions(), *args, **kwargs
),
)

async def release_on_freeze(new_state: CoreState):
if new_state == CoreState.FREEZE:
event.set()

# Wait for system to get into freeze state before returning
# If the backup fails validation it will raise before getting there
listener = self.sys_bus.register_event(
BusEvent.SUPERVISOR_STATE_CHANGE, release_on_freeze
)
try:
event_task = self.sys_create_task(event.wait())
_, pending = await asyncio.wait(
(backup_task, event_task),
return_when=asyncio.FIRST_COMPLETED,
)
# It seems backup returned early (error or something), make sure to cancel
# the event task to avoid "Task was destroyed but it is pending!" errors.
if event_task in pending:
event_task.cancel()
return (backup_task, job.uuid)
finally:
self.sys_bus.remove_listener(listener)

@api_process
async def backup_full(self, request: web.Request):
"""Create full backup."""
Expand All @@ -342,8 +303,8 @@ async def backup_full(self, request: web.Request):
body[ATTR_ADDITIONAL_LOCATIONS] = locations

background = body.pop(ATTR_BACKGROUND)
backup_task, job_id = await self._background_backup_task(
self.sys_backups.do_backup_full, **body
backup_task, job_id = await background_task(
self, self.sys_backups.do_backup_full, **body
)

if background and not backup_task.done():
Expand Down Expand Up @@ -378,8 +339,8 @@ async def backup_partial(self, request: web.Request):
body[ATTR_ADDONS] = list(self.sys_addons.local)

background = body.pop(ATTR_BACKGROUND)
backup_task, job_id = await self._background_backup_task(
self.sys_backups.do_backup_partial, **body
backup_task, job_id = await background_task(
self, self.sys_backups.do_backup_partial, **body
)

if background and not backup_task.done():
Expand All @@ -402,8 +363,8 @@ async def restore_full(self, request: web.Request):
request, body.get(ATTR_LOCATION, backup.location)
)
background = body.pop(ATTR_BACKGROUND)
restore_task, job_id = await self._background_backup_task(
self.sys_backups.do_restore_full, backup, **body
restore_task, job_id = await background_task(
self, self.sys_backups.do_restore_full, backup, **body
)

if background and not restore_task.done() or await restore_task:
Expand All @@ -422,8 +383,8 @@ async def restore_partial(self, request: web.Request):
request, body.get(ATTR_LOCATION, backup.location)
)
background = body.pop(ATTR_BACKGROUND)
restore_task, job_id = await self._background_backup_task(
self.sys_backups.do_restore_partial, backup, **body
restore_task, job_id = await background_task(
self, self.sys_backups.do_restore_partial, backup, **body
)

if background and not restore_task.done() or await restore_task:
Expand Down
24 changes: 16 additions & 8 deletions supervisor/api/homeassistant.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,6 +20,7 @@
ATTR_CPU_PERCENT,
ATTR_IMAGE,
ATTR_IP_ADDRESS,
ATTR_JOB_ID,
ATTR_MACHINE,
ATTR_MEMORY_LIMIT,
ATTR_MEMORY_PERCENT,
Expand All @@ -37,8 +38,8 @@
from ..coresys import CoreSysAttributes
from ..exceptions import APIDBMigrationInProgress, APIError
from ..validate import docker_image, network_port, version_tag
from .const import ATTR_FORCE, ATTR_SAFE_MODE
from .utils import api_process, api_validate
from .const import ATTR_BACKGROUND, ATTR_FORCE, ATTR_SAFE_MODE
from .utils import api_process, api_validate, background_task

_LOGGER: logging.Logger = logging.getLogger(__name__)

Expand All @@ -61,6 +62,7 @@
{
vol.Optional(ATTR_VERSION): version_tag,
vol.Optional(ATTR_BACKUP): bool,
vol.Optional(ATTR_BACKGROUND, default=False): bool,
}
)

Expand Down Expand Up @@ -170,18 +172,24 @@ async def stats(self, request: web.Request) -> dict[Any, str]:
}

@api_process
async def update(self, request: web.Request) -> None:
async def update(self, request: web.Request) -> dict[str, str] | None:
"""Update Home Assistant."""
body = await api_validate(SCHEMA_UPDATE, request)
await self._check_offline_migration()

await asyncio.shield(
self.sys_homeassistant.core.update(
version=body.get(ATTR_VERSION, self.sys_homeassistant.latest_version),
backup=body.get(ATTR_BACKUP),
)
background = body[ATTR_BACKGROUND]
update_task, job_id = await background_task(
self,
self.sys_homeassistant.core.update,
version=body.get(ATTR_VERSION, self.sys_homeassistant.latest_version),
backup=body.get(ATTR_BACKUP),
)

if background and not update_task.done():
return {ATTR_JOB_ID: job_id}

return await update_task

@api_process
async def stop(self, request: web.Request) -> Awaitable[None]:
"""Stop Home Assistant."""
Expand Down
45 changes: 37 additions & 8 deletions supervisor/api/store.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,6 @@
"""Init file for Supervisor Home Assistant RESTful API."""

import asyncio
from collections.abc import Awaitable
from pathlib import Path
from typing import Any, cast

Expand Down Expand Up @@ -36,6 +35,7 @@
ATTR_ICON,
ATTR_INGRESS,
ATTR_INSTALLED,
ATTR_JOB_ID,
ATTR_LOGO,
ATTR_LONG_DESCRIPTION,
ATTR_MAINTAINER,
Expand All @@ -57,18 +57,26 @@
from ..store.addon import AddonStore
from ..store.repository import Repository
from ..store.validate import validate_repository
from .const import CONTENT_TYPE_PNG, CONTENT_TYPE_TEXT
from .const import ATTR_BACKGROUND, CONTENT_TYPE_PNG, CONTENT_TYPE_TEXT
from .utils import background_task

SCHEMA_UPDATE = vol.Schema(
{
vol.Optional(ATTR_BACKUP): bool,
vol.Optional(ATTR_BACKGROUND, default=False): bool,
}
)

SCHEMA_ADD_REPOSITORY = vol.Schema(
{vol.Required(ATTR_REPOSITORY): vol.All(str, validate_repository)}
)

SCHEMA_INSTALL = vol.Schema(
{
vol.Optional(ATTR_BACKGROUND, default=False): bool,
}
)


def _read_static_text_file(path: Path) -> Any:
"""Read in a static text file asset for API output.
Expand Down Expand Up @@ -217,24 +225,45 @@ async def addons_list(self, request: web.Request) -> dict[str, Any]:
}

@api_process
def addons_addon_install(self, request: web.Request) -> Awaitable[None]:
async def addons_addon_install(self, request: web.Request) -> dict[str, str] | None:
"""Install add-on."""
addon = self._extract_addon(request)
return asyncio.shield(self.sys_addons.install(addon.slug))
body = await api_validate(SCHEMA_INSTALL, request)

background = body[ATTR_BACKGROUND]

install_task, job_id = await background_task(
self, self.sys_addons.install, addon.slug
)

if background and not install_task.done():
return {ATTR_JOB_ID: job_id}

return await install_task

@api_process
async def addons_addon_update(self, request: web.Request) -> None:
async def addons_addon_update(self, request: web.Request) -> dict[str, str] | None:
"""Update add-on."""
addon = self._extract_addon(request, installed=True)
if addon == request.get(REQUEST_FROM):
raise APIForbidden(f"Add-on {addon.slug} can't update itself!")

body = await api_validate(SCHEMA_UPDATE, request)
background = body[ATTR_BACKGROUND]

update_task, job_id = await background_task(
self,
self.sys_addons.update,
addon.slug,
backup=body.get(ATTR_BACKUP),
)

if background and not update_task.done():
return {ATTR_JOB_ID: job_id}

if start_task := await asyncio.shield(
self.sys_addons.update(addon.slug, backup=body.get(ATTR_BACKUP))
):
if start_task := await update_task:
await start_task
return None

@api_process
async def addons_addon_info(self, request: web.Request) -> dict[str, Any]:
Expand Down
49 changes: 48 additions & 1 deletion supervisor/api/utils.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,9 @@
"""Init file for Supervisor util for RESTful API."""

import asyncio
from collections.abc import Callable
import json
from typing import Any
from typing import Any, cast

from aiohttp import web
from aiohttp.hdrs import AUTHORIZATION
Expand All @@ -23,6 +25,7 @@
)
from ..coresys import CoreSys, CoreSysAttributes
from ..exceptions import APIError, BackupFileNotFoundError, DockerAPIError, HassioError
from ..jobs import JobSchedulerOptions, SupervisorJob
from ..utils import check_exception_chain, get_message_from_exception_chain
from ..utils.json import json_dumps, json_loads as json_loads_util
from ..utils.log_format import format_message
Expand Down Expand Up @@ -198,3 +201,47 @@ async def api_validate(
data_validated[origin_value] = data[origin_value]

return data_validated


async def background_task(
coresys_obj: CoreSysAttributes,
task_method: Callable,
*args,
**kwargs,
) -> tuple[asyncio.Task, str]:
"""Start task in background and return task and job ID.

Args:
coresys_obj: Instance that accesses coresys data using CoreSysAttributes
task_method: The method to execute in the background. Must include a keyword arg 'validation_complete' of type asyncio.Event. Should set it after any initial validation has completed
*args: Arguments to pass to task_method
**kwargs: Keyword arguments to pass to task_method

Returns:
Tuple of (task, job_id)

"""
event = asyncio.Event()
job, task = cast(
tuple[SupervisorJob, asyncio.Task],
coresys_obj.sys_jobs.schedule_job(
task_method,
JobSchedulerOptions(),
*args,
validation_complete=event,
**kwargs,
),
)

# Wait for provided event before returning
# If the task fails validation it should raise before getting there
event_task = coresys_obj.sys_create_task(event.wait())
_, pending = await asyncio.wait(
(task, event_task),
return_when=asyncio.FIRST_COMPLETED,
)
# It seems task returned early (error or something), make sure to cancel
# the event task to avoid "Task was destroyed but it is pending!" errors.
if event_task in pending:
event_task.cancel()
return (task, job.uuid)
Loading