Skip to content

Commit d45b3d8

Browse files
committed
Add upload and download backup APIs
1 parent 4263b5e commit d45b3d8

File tree

10 files changed

+191
-39
lines changed

10 files changed

+191
-39
lines changed

aiohasupervisor/backups.py

Lines changed: 33 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,10 @@
11
"""Backups client for supervisor."""
22

3+
from collections.abc import AsyncIterator
4+
5+
from aiohttp import MultipartWriter
6+
from multidict import MultiDict
7+
38
from .client import _SupervisorComponentClient
49
from .const import ResponseType
510
from .models.backups import (
@@ -15,6 +20,8 @@
1520
NewBackup,
1621
PartialBackupOptions,
1722
PartialRestoreOptions,
23+
UploadBackupOptions,
24+
UploadedBackup,
1825
)
1926

2027

@@ -102,4 +109,29 @@ async def partial_restore(
102109
)
103110
return BackupJob.from_dict(result.data)
104111

105-
# Omitted for now - Upload and download backup
112+
async def upload_backup(
113+
self, stream: AsyncIterator[bytes], options: UploadBackupOptions | None = None
114+
) -> str:
115+
"""Upload backup by stream and return slug."""
116+
params = MultiDict()
117+
if options and options.location:
118+
for location in options.location:
119+
params.add("location", location or "")
120+
121+
with MultipartWriter("form-data") as mp:
122+
mp.append(stream)
123+
result = await self._client.post(
124+
"backups/new/upload",
125+
params=params,
126+
data=mp,
127+
response_type=ResponseType.JSON,
128+
)
129+
130+
return UploadedBackup.from_dict(result.data).slug
131+
132+
async def download_backup(self, backup: str) -> AsyncIterator[bytes]:
133+
"""Download backup and return stream."""
134+
result = await self._client.get(
135+
f"backups/{backup}/download", response_type=ResponseType.STREAM
136+
)
137+
return result.data

aiohasupervisor/client.py

Lines changed: 43 additions & 34 deletions
Original file line numberDiff line numberDiff line change
@@ -12,6 +12,7 @@
1212
ClientSession,
1313
ClientTimeout,
1414
)
15+
from multidict import MultiDict
1516
from yarl import URL
1617

1718
from .const import DEFAULT_TIMEOUT, ResponseType
@@ -27,6 +28,7 @@
2728
SupervisorTimeoutError,
2829
)
2930
from .models.base import Response, ResultType
31+
from .utils.aiohttp import ChunkAsyncStreamIterator
3032

3133
VERSION = metadata.version(__package__)
3234

@@ -53,12 +55,33 @@ class _SupervisorClient:
5355
session: ClientSession | None = None
5456
_close_session: bool = field(default=False, init=False)
5557

58+
async def _raise_on_status(self, response: ClientResponse) -> None:
59+
"""Raise appropriate exception on status."""
60+
if response.status >= HTTPStatus.BAD_REQUEST.value:
61+
exc_type: type[SupervisorError] = SupervisorError
62+
match response.status:
63+
case HTTPStatus.BAD_REQUEST:
64+
exc_type = SupervisorBadRequestError
65+
case HTTPStatus.UNAUTHORIZED:
66+
exc_type = SupervisorAuthenticationError
67+
case HTTPStatus.FORBIDDEN:
68+
exc_type = SupervisorForbiddenError
69+
case HTTPStatus.NOT_FOUND:
70+
exc_type = SupervisorNotFoundError
71+
case HTTPStatus.SERVICE_UNAVAILABLE:
72+
exc_type = SupervisorServiceUnavailableError
73+
74+
if is_json(response):
75+
result = Response.from_json(await response.text())
76+
raise exc_type(result.message, result.job_id)
77+
raise exc_type()
78+
5679
async def _request(
5780
self,
5881
method: HTTPMethod,
5982
uri: str,
6083
*,
61-
params: dict[str, str] | None,
84+
params: dict[str, str] | MultiDict[str, str] | None,
6285
response_type: ResponseType,
6386
json: dict[str, Any] | None = None,
6487
data: Any = None,
@@ -94,42 +117,28 @@ async def _request(
94117
self._close_session = True
95118

96119
try:
97-
async with self.session.request(
120+
response = await self.session.request(
98121
method.value,
99122
url,
100123
timeout=timeout,
101124
headers=headers,
102125
params=params,
103126
json=json,
104127
data=data,
105-
) as response:
106-
if response.status >= HTTPStatus.BAD_REQUEST.value:
107-
exc_type: type[SupervisorError] = SupervisorError
108-
match response.status:
109-
case HTTPStatus.BAD_REQUEST:
110-
exc_type = SupervisorBadRequestError
111-
case HTTPStatus.UNAUTHORIZED:
112-
exc_type = SupervisorAuthenticationError
113-
case HTTPStatus.FORBIDDEN:
114-
exc_type = SupervisorForbiddenError
115-
case HTTPStatus.NOT_FOUND:
116-
exc_type = SupervisorNotFoundError
117-
case HTTPStatus.SERVICE_UNAVAILABLE:
118-
exc_type = SupervisorServiceUnavailableError
119-
120-
if is_json(response):
121-
result = Response.from_json(await response.text())
122-
raise exc_type(result.message, result.job_id)
123-
raise exc_type()
124-
125-
match response_type:
126-
case ResponseType.JSON:
127-
is_json(response, raise_on_fail=True)
128-
return Response.from_json(await response.text())
129-
case ResponseType.TEXT:
130-
return Response(ResultType.OK, await response.text())
131-
case _:
132-
return Response(ResultType.OK)
128+
)
129+
await self._raise_on_status(response)
130+
match response_type:
131+
case ResponseType.JSON:
132+
is_json(response, raise_on_fail=True)
133+
return Response.from_json(await response.text())
134+
case ResponseType.TEXT:
135+
return Response(ResultType.OK, await response.text())
136+
case ResponseType.STREAM:
137+
return Response(
138+
ResultType.OK, ChunkAsyncStreamIterator(response.content)
139+
)
140+
case _:
141+
return Response(ResultType.OK)
133142

134143
except (UnicodeDecodeError, ClientResponseError) as err:
135144
raise SupervisorResponseError(
@@ -146,7 +155,7 @@ async def get(
146155
self,
147156
uri: str,
148157
*,
149-
params: dict[str, str] | None = None,
158+
params: dict[str, str] | MultiDict[str, str] | None = None,
150159
response_type: ResponseType = ResponseType.JSON,
151160
timeout: ClientTimeout | None = DEFAULT_TIMEOUT,
152161
) -> Response:
@@ -163,7 +172,7 @@ async def post(
163172
self,
164173
uri: str,
165174
*,
166-
params: dict[str, str] | None = None,
175+
params: dict[str, str] | MultiDict[str, str] | None = None,
167176
response_type: ResponseType = ResponseType.NONE,
168177
json: dict[str, Any] | None = None,
169178
data: Any = None,
@@ -184,7 +193,7 @@ async def put(
184193
self,
185194
uri: str,
186195
*,
187-
params: dict[str, str] | None = None,
196+
params: dict[str, str] | MultiDict[str, str] | None = None,
188197
json: dict[str, Any] | None = None,
189198
timeout: ClientTimeout | None = DEFAULT_TIMEOUT,
190199
) -> Response:
@@ -202,7 +211,7 @@ async def delete(
202211
self,
203212
uri: str,
204213
*,
205-
params: dict[str, str] | None = None,
214+
params: dict[str, str] | MultiDict[str, str] | None = None,
206215
timeout: ClientTimeout | None = DEFAULT_TIMEOUT,
207216
) -> Response:
208217
"""Handle a DELETE request to Supervisor."""

aiohasupervisor/const.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -13,4 +13,5 @@ class ResponseType(StrEnum):
1313

1414
NONE = "none"
1515
JSON = "json"
16+
STREAM = "stream"
1617
TEXT = "text"

aiohasupervisor/models/__init__.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -40,6 +40,7 @@
4040
NewBackup,
4141
PartialBackupOptions,
4242
PartialRestoreOptions,
43+
UploadBackupOptions,
4344
)
4445
from aiohasupervisor.models.discovery import (
4546
Discovery,
@@ -215,6 +216,7 @@
215216
"NewBackup",
216217
"PartialBackupOptions",
217218
"PartialRestoreOptions",
219+
"UploadBackupOptions",
218220
"Discovery",
219221
"DiscoveryConfig",
220222
"AccessPoint",

aiohasupervisor/models/backups.py

Lines changed: 18 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -73,8 +73,8 @@ class BackupAddon(ResponseData):
7373
class BackupComplete(BackupBaseFields, ResponseData):
7474
"""BackupComplete model."""
7575

76-
supervisor_version: str | None
77-
homeassistant: str
76+
supervisor_version: str
77+
homeassistant: str | None
7878
addons: list[BackupAddon]
7979
repositories: list[str]
8080
folders: list[Folder]
@@ -132,9 +132,10 @@ class FullBackupOptions(Request):
132132
name: str | None = None
133133
password: str | None = None
134134
compressed: bool | None = None
135-
location: str | None = None
135+
location: list[str | None] | str | None = None
136136
homeassistant_exclude_database: bool | None = None
137137
background: bool | None = None
138+
extra: dict | None = None
138139

139140

140141
@dataclass(frozen=True, slots=True)
@@ -167,3 +168,17 @@ class FullRestoreOptions(Request):
167168
@dataclass(frozen=True, slots=True)
168169
class PartialRestoreOptions(FullRestoreOptions, PartialBackupRestoreOptions):
169170
"""PartialRestoreOptions model."""
171+
172+
173+
@dataclass(frozen=True, slots=True)
174+
class UploadBackupOptions(Request):
175+
"""UploadBackupOptions model."""
176+
177+
location: set[str | None] = None
178+
179+
180+
@dataclass(frozen=True, slots=True)
181+
class UploadedBackup(ResponseData):
182+
"""UploadedBackup model."""
183+
184+
slug: str

aiohasupervisor/utils/__init__.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1 @@
1+
"""Utilities used internally in library."""

aiohasupervisor/utils/aiohttp.py

Lines changed: 31 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,31 @@
1+
"""Utilities for interacting with aiohttp."""
2+
3+
from typing import Self
4+
5+
from aiohttp import StreamReader
6+
7+
8+
class ChunkAsyncStreamIterator:
9+
"""Async iterator for chunked streams.
10+
11+
Based on aiohttp.streams.ChunkTupleAsyncStreamIterator, but yields
12+
bytes instead of tuple[bytes, bool].
13+
Borrowed from home-assistant/core.
14+
"""
15+
16+
__slots__ = ("_stream",)
17+
18+
def __init__(self, stream: StreamReader) -> None:
19+
"""Initialize."""
20+
self._stream = stream
21+
22+
def __aiter__(self) -> Self:
23+
"""Iterate."""
24+
return self
25+
26+
async def __anext__(self) -> bytes:
27+
"""Yield next chunk."""
28+
rv = await self._stream.readchunk()
29+
if rv == (b"", False):
30+
raise StopAsyncIteration
31+
return rv[0]

tests/__init__.py

Lines changed: 6 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -3,7 +3,12 @@
33
from pathlib import Path
44

55

6+
def get_fixture_path(filename: str) -> Path:
7+
"""Get fixture path."""
8+
return Path(__package__) / "fixtures" / filename
9+
10+
611
def load_fixture(filename: str) -> str:
712
"""Load a fixture."""
8-
fixture = Path(__package__) / "fixtures" / filename
13+
fixture = get_fixture_path(filename)
914
return fixture.read_text(encoding="utf-8")
Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,4 @@
1+
{
2+
"result": "ok",
3+
"data": { "slug": "7fed74c8" }
4+
}

tests/test_backups.py

Lines changed: 52 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,7 @@
11
"""Test backups supervisor client."""
22

3+
import asyncio
4+
from collections.abc import AsyncIterator
35
from datetime import UTC, datetime
46
from typing import Any
57

@@ -15,6 +17,7 @@
1517
FullBackupOptions,
1618
PartialBackupOptions,
1719
PartialRestoreOptions,
20+
UploadBackupOptions,
1821
)
1922

2023
from . import load_fixture
@@ -265,3 +268,52 @@ async def test_partial_restore(
265268
"abc123", PartialRestoreOptions(addons={"core_ssh"})
266269
)
267270
assert result.job_id == "dc9dbc16f6ad4de592ffa72c807ca2bf"
271+
272+
273+
async def test_upload_backup(
274+
responses: aioresponses, supervisor_client: SupervisorClient
275+
) -> None:
276+
"""Test upload backup API."""
277+
responses.post(
278+
f"{SUPERVISOR_URL}/backups/new/upload",
279+
status=200,
280+
body=load_fixture("backup_uploaded.json"),
281+
)
282+
data = asyncio.StreamReader(loop=asyncio.get_running_loop())
283+
data.feed_data(b"backup test")
284+
data.feed_eof()
285+
286+
result = await supervisor_client.backups.upload_backup(data)
287+
assert result == "7fed74c8"
288+
289+
290+
async def test_upload_backup_to_locations(
291+
responses: aioresponses, supervisor_client: SupervisorClient
292+
) -> None:
293+
"""Test upload backup API with multiple locations."""
294+
responses.post(
295+
f"{SUPERVISOR_URL}/backups/new/upload?location=&location=test",
296+
status=200,
297+
body=load_fixture("backup_uploaded.json"),
298+
)
299+
data = asyncio.StreamReader(loop=asyncio.get_running_loop())
300+
data.feed_data(b"backup test")
301+
data.feed_eof()
302+
303+
result = await supervisor_client.backups.upload_backup(
304+
data, UploadBackupOptions(location={None, "test"})
305+
)
306+
assert result == "7fed74c8"
307+
308+
309+
async def test_download_backup(
310+
responses: aioresponses, supervisor_client: SupervisorClient
311+
) -> None:
312+
"""Test download backup API."""
313+
responses.get(
314+
f"{SUPERVISOR_URL}/backups/7fed74c8/download", status=200, body=b"backup test"
315+
)
316+
result = await supervisor_client.backups.download_backup("7fed74c8")
317+
assert isinstance(result, AsyncIterator)
318+
async for chunk in result:
319+
assert chunk == b"backup test"

0 commit comments

Comments
 (0)