|
4 | 4 |
|
5 | 5 | import asyncio |
6 | 6 | from collections.abc import AsyncIterator, Callable, Coroutine |
7 | | -from concurrent.futures import CancelledError, Future |
8 | 7 | import copy |
9 | 8 | from dataclasses import dataclass, replace |
10 | 9 | from io import BytesIO |
|
14 | 13 | from queue import SimpleQueue |
15 | 14 | import tarfile |
16 | 15 | import threading |
17 | | -from typing import IO, Any, Self, cast |
| 16 | +from typing import IO, Any, cast |
18 | 17 |
|
19 | 18 | import aiohttp |
20 | 19 | from securetar import SecureTarError, SecureTarFile, SecureTarReadError |
|
23 | 22 | from homeassistant.core import HomeAssistant |
24 | 23 | from homeassistant.exceptions import HomeAssistantError |
25 | 24 | from homeassistant.util import dt as dt_util |
| 25 | +from homeassistant.util.async_iterator import ( |
| 26 | + Abort, |
| 27 | + AsyncIteratorReader, |
| 28 | + AsyncIteratorWriter, |
| 29 | +) |
26 | 30 | from homeassistant.util.json import JsonObjectType, json_loads_object |
27 | 31 |
|
28 | 32 | from .const import BUF_SIZE, LOGGER |
@@ -59,12 +63,6 @@ class BackupEmpty(DecryptError): |
59 | 63 | _message = "No tar files found in the backup." |
60 | 64 |
|
61 | 65 |
|
62 | | -class AbortCipher(HomeAssistantError): |
63 | | - """Abort the cipher operation.""" |
64 | | - |
65 | | - _message = "Abort cipher operation." |
66 | | - |
67 | | - |
68 | 66 | def make_backup_dir(path: Path) -> None: |
69 | 67 | """Create a backup directory if it does not exist.""" |
70 | 68 | path.mkdir(exist_ok=True) |
@@ -166,106 +164,6 @@ def validate_password(path: Path, password: str | None) -> bool: |
166 | 164 | return False |
167 | 165 |
|
168 | 166 |
|
169 | | -class AsyncIteratorReader: |
170 | | - """Wrap an AsyncIterator.""" |
171 | | - |
172 | | - def __init__(self, hass: HomeAssistant, stream: AsyncIterator[bytes]) -> None: |
173 | | - """Initialize the wrapper.""" |
174 | | - self._aborted = False |
175 | | - self._hass = hass |
176 | | - self._stream = stream |
177 | | - self._buffer: bytes | None = None |
178 | | - self._next_future: Future[bytes | None] | None = None |
179 | | - self._pos: int = 0 |
180 | | - |
181 | | - async def _next(self) -> bytes | None: |
182 | | - """Get the next chunk from the iterator.""" |
183 | | - return await anext(self._stream, None) |
184 | | - |
185 | | - def abort(self) -> None: |
186 | | - """Abort the reader.""" |
187 | | - self._aborted = True |
188 | | - if self._next_future is not None: |
189 | | - self._next_future.cancel() |
190 | | - |
191 | | - def read(self, n: int = -1, /) -> bytes: |
192 | | - """Read data from the iterator.""" |
193 | | - result = bytearray() |
194 | | - while n < 0 or len(result) < n: |
195 | | - if not self._buffer: |
196 | | - self._next_future = asyncio.run_coroutine_threadsafe( |
197 | | - self._next(), self._hass.loop |
198 | | - ) |
199 | | - if self._aborted: |
200 | | - self._next_future.cancel() |
201 | | - raise AbortCipher |
202 | | - try: |
203 | | - self._buffer = self._next_future.result() |
204 | | - except CancelledError as err: |
205 | | - raise AbortCipher from err |
206 | | - self._pos = 0 |
207 | | - if not self._buffer: |
208 | | - # The stream is exhausted |
209 | | - break |
210 | | - chunk = self._buffer[self._pos : self._pos + n] |
211 | | - result.extend(chunk) |
212 | | - n -= len(chunk) |
213 | | - self._pos += len(chunk) |
214 | | - if self._pos == len(self._buffer): |
215 | | - self._buffer = None |
216 | | - return bytes(result) |
217 | | - |
218 | | - def close(self) -> None: |
219 | | - """Close the iterator.""" |
220 | | - |
221 | | - |
222 | | -class AsyncIteratorWriter: |
223 | | - """Wrap an AsyncIterator.""" |
224 | | - |
225 | | - def __init__(self, hass: HomeAssistant) -> None: |
226 | | - """Initialize the wrapper.""" |
227 | | - self._aborted = False |
228 | | - self._hass = hass |
229 | | - self._pos: int = 0 |
230 | | - self._queue: asyncio.Queue[bytes | None] = asyncio.Queue(maxsize=1) |
231 | | - self._write_future: Future[bytes | None] | None = None |
232 | | - |
233 | | - def __aiter__(self) -> Self: |
234 | | - """Return the iterator.""" |
235 | | - return self |
236 | | - |
237 | | - async def __anext__(self) -> bytes: |
238 | | - """Get the next chunk from the iterator.""" |
239 | | - if data := await self._queue.get(): |
240 | | - return data |
241 | | - raise StopAsyncIteration |
242 | | - |
243 | | - def abort(self) -> None: |
244 | | - """Abort the writer.""" |
245 | | - self._aborted = True |
246 | | - if self._write_future is not None: |
247 | | - self._write_future.cancel() |
248 | | - |
249 | | - def tell(self) -> int: |
250 | | - """Return the current position in the iterator.""" |
251 | | - return self._pos |
252 | | - |
253 | | - def write(self, s: bytes, /) -> int: |
254 | | - """Write data to the iterator.""" |
255 | | - self._write_future = asyncio.run_coroutine_threadsafe( |
256 | | - self._queue.put(s), self._hass.loop |
257 | | - ) |
258 | | - if self._aborted: |
259 | | - self._write_future.cancel() |
260 | | - raise AbortCipher |
261 | | - try: |
262 | | - self._write_future.result() |
263 | | - except CancelledError as err: |
264 | | - raise AbortCipher from err |
265 | | - self._pos += len(s) |
266 | | - return len(s) |
267 | | - |
268 | | - |
269 | 167 | def validate_password_stream( |
270 | 168 | input_stream: IO[bytes], |
271 | 169 | password: str | None, |
@@ -342,7 +240,7 @@ def decrypt_backup( |
342 | 240 | finally: |
343 | 241 | # Write an empty chunk to signal the end of the stream |
344 | 242 | output_stream.write(b"") |
345 | | - except AbortCipher: |
| 243 | + except Abort: |
346 | 244 | LOGGER.debug("Cipher operation aborted") |
347 | 245 | finally: |
348 | 246 | on_done(error) |
@@ -430,7 +328,7 @@ def encrypt_backup( |
430 | 328 | finally: |
431 | 329 | # Write an empty chunk to signal the end of the stream |
432 | 330 | output_stream.write(b"") |
433 | | - except AbortCipher: |
| 331 | + except Abort: |
434 | 332 | LOGGER.debug("Cipher operation aborted") |
435 | 333 | finally: |
436 | 334 | on_done(error) |
@@ -557,8 +455,8 @@ def on_done(error: Exception | None) -> None: |
557 | 455 | self._hass.loop.call_soon_threadsafe(worker_status.done.set) |
558 | 456 |
|
559 | 457 | stream = await self._open_stream() |
560 | | - reader = AsyncIteratorReader(self._hass, stream) |
561 | | - writer = AsyncIteratorWriter(self._hass) |
| 458 | + reader = AsyncIteratorReader(self._hass.loop, stream) |
| 459 | + writer = AsyncIteratorWriter(self._hass.loop) |
562 | 460 | worker = threading.Thread( |
563 | 461 | target=self._cipher_func, |
564 | 462 | args=[ |
|
0 commit comments