Skip to content

Commit 2b4b46e

Browse files
authored
Add async_iterator util (home-assistant#153194)
1 parent 40b9dae commit 2b4b46e

File tree

5 files changed

+265
-116
lines changed

5 files changed

+265
-116
lines changed

homeassistant/components/backup/http.py

Lines changed: 3 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -17,6 +17,7 @@
1717
from homeassistant.exceptions import HomeAssistantError
1818
from homeassistant.helpers import frame
1919
from homeassistant.util import slugify
20+
from homeassistant.util.async_iterator import AsyncIteratorReader, AsyncIteratorWriter
2021

2122
from . import util
2223
from .agent import BackupAgent
@@ -144,15 +145,15 @@ async def _send_backup_with_password(
144145
return Response(status=HTTPStatus.NOT_FOUND)
145146
else:
146147
stream = await agent.async_download_backup(backup_id)
147-
reader = cast(IO[bytes], util.AsyncIteratorReader(hass, stream))
148+
reader = cast(IO[bytes], AsyncIteratorReader(hass.loop, stream))
148149

149150
worker_done_event = asyncio.Event()
150151

151152
def on_done(error: Exception | None) -> None:
152153
"""Call by the worker thread when it's done."""
153154
hass.loop.call_soon_threadsafe(worker_done_event.set)
154155

155-
stream = util.AsyncIteratorWriter(hass)
156+
stream = AsyncIteratorWriter(hass.loop)
156157
worker = threading.Thread(
157158
target=util.decrypt_backup,
158159
args=[backup, reader, stream, password, on_done, 0, []],

homeassistant/components/backup/manager.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -38,6 +38,7 @@
3838
)
3939
from homeassistant.helpers.json import json_bytes
4040
from homeassistant.util import dt as dt_util, json as json_util
41+
from homeassistant.util.async_iterator import AsyncIteratorReader
4142

4243
from . import util as backup_util
4344
from .agent import (
@@ -72,7 +73,6 @@
7273
)
7374
from .store import BackupStore
7475
from .util import (
75-
AsyncIteratorReader,
7676
DecryptedBackupStreamer,
7777
EncryptedBackupStreamer,
7878
make_backup_dir,
@@ -1525,7 +1525,7 @@ async def async_can_decrypt_on_download(
15251525
reader = await self.hass.async_add_executor_job(open, path.as_posix(), "rb")
15261526
else:
15271527
backup_stream = await agent.async_download_backup(backup_id)
1528-
reader = cast(IO[bytes], AsyncIteratorReader(self.hass, backup_stream))
1528+
reader = cast(IO[bytes], AsyncIteratorReader(self.hass.loop, backup_stream))
15291529
try:
15301530
await self.hass.async_add_executor_job(
15311531
validate_password_stream, reader, password

homeassistant/components/backup/util.py

Lines changed: 10 additions & 112 deletions
Original file line numberDiff line numberDiff line change
@@ -4,7 +4,6 @@
44

55
import asyncio
66
from collections.abc import AsyncIterator, Callable, Coroutine
7-
from concurrent.futures import CancelledError, Future
87
import copy
98
from dataclasses import dataclass, replace
109
from io import BytesIO
@@ -14,7 +13,7 @@
1413
from queue import SimpleQueue
1514
import tarfile
1615
import threading
17-
from typing import IO, Any, Self, cast
16+
from typing import IO, Any, cast
1817

1918
import aiohttp
2019
from securetar import SecureTarError, SecureTarFile, SecureTarReadError
@@ -23,6 +22,11 @@
2322
from homeassistant.core import HomeAssistant
2423
from homeassistant.exceptions import HomeAssistantError
2524
from homeassistant.util import dt as dt_util
25+
from homeassistant.util.async_iterator import (
26+
Abort,
27+
AsyncIteratorReader,
28+
AsyncIteratorWriter,
29+
)
2630
from homeassistant.util.json import JsonObjectType, json_loads_object
2731

2832
from .const import BUF_SIZE, LOGGER
@@ -59,12 +63,6 @@ class BackupEmpty(DecryptError):
5963
_message = "No tar files found in the backup."
6064

6165

62-
class AbortCipher(HomeAssistantError):
63-
"""Abort the cipher operation."""
64-
65-
_message = "Abort cipher operation."
66-
67-
6866
def make_backup_dir(path: Path) -> None:
6967
"""Create a backup directory if it does not exist."""
7068
path.mkdir(exist_ok=True)
@@ -166,106 +164,6 @@ def validate_password(path: Path, password: str | None) -> bool:
166164
return False
167165

168166

169-
class AsyncIteratorReader:
170-
"""Wrap an AsyncIterator."""
171-
172-
def __init__(self, hass: HomeAssistant, stream: AsyncIterator[bytes]) -> None:
173-
"""Initialize the wrapper."""
174-
self._aborted = False
175-
self._hass = hass
176-
self._stream = stream
177-
self._buffer: bytes | None = None
178-
self._next_future: Future[bytes | None] | None = None
179-
self._pos: int = 0
180-
181-
async def _next(self) -> bytes | None:
182-
"""Get the next chunk from the iterator."""
183-
return await anext(self._stream, None)
184-
185-
def abort(self) -> None:
186-
"""Abort the reader."""
187-
self._aborted = True
188-
if self._next_future is not None:
189-
self._next_future.cancel()
190-
191-
def read(self, n: int = -1, /) -> bytes:
192-
"""Read data from the iterator."""
193-
result = bytearray()
194-
while n < 0 or len(result) < n:
195-
if not self._buffer:
196-
self._next_future = asyncio.run_coroutine_threadsafe(
197-
self._next(), self._hass.loop
198-
)
199-
if self._aborted:
200-
self._next_future.cancel()
201-
raise AbortCipher
202-
try:
203-
self._buffer = self._next_future.result()
204-
except CancelledError as err:
205-
raise AbortCipher from err
206-
self._pos = 0
207-
if not self._buffer:
208-
# The stream is exhausted
209-
break
210-
chunk = self._buffer[self._pos : self._pos + n]
211-
result.extend(chunk)
212-
n -= len(chunk)
213-
self._pos += len(chunk)
214-
if self._pos == len(self._buffer):
215-
self._buffer = None
216-
return bytes(result)
217-
218-
def close(self) -> None:
219-
"""Close the iterator."""
220-
221-
222-
class AsyncIteratorWriter:
223-
"""Wrap an AsyncIterator."""
224-
225-
def __init__(self, hass: HomeAssistant) -> None:
226-
"""Initialize the wrapper."""
227-
self._aborted = False
228-
self._hass = hass
229-
self._pos: int = 0
230-
self._queue: asyncio.Queue[bytes | None] = asyncio.Queue(maxsize=1)
231-
self._write_future: Future[bytes | None] | None = None
232-
233-
def __aiter__(self) -> Self:
234-
"""Return the iterator."""
235-
return self
236-
237-
async def __anext__(self) -> bytes:
238-
"""Get the next chunk from the iterator."""
239-
if data := await self._queue.get():
240-
return data
241-
raise StopAsyncIteration
242-
243-
def abort(self) -> None:
244-
"""Abort the writer."""
245-
self._aborted = True
246-
if self._write_future is not None:
247-
self._write_future.cancel()
248-
249-
def tell(self) -> int:
250-
"""Return the current position in the iterator."""
251-
return self._pos
252-
253-
def write(self, s: bytes, /) -> int:
254-
"""Write data to the iterator."""
255-
self._write_future = asyncio.run_coroutine_threadsafe(
256-
self._queue.put(s), self._hass.loop
257-
)
258-
if self._aborted:
259-
self._write_future.cancel()
260-
raise AbortCipher
261-
try:
262-
self._write_future.result()
263-
except CancelledError as err:
264-
raise AbortCipher from err
265-
self._pos += len(s)
266-
return len(s)
267-
268-
269167
def validate_password_stream(
270168
input_stream: IO[bytes],
271169
password: str | None,
@@ -342,7 +240,7 @@ def decrypt_backup(
342240
finally:
343241
# Write an empty chunk to signal the end of the stream
344242
output_stream.write(b"")
345-
except AbortCipher:
243+
except Abort:
346244
LOGGER.debug("Cipher operation aborted")
347245
finally:
348246
on_done(error)
@@ -430,7 +328,7 @@ def encrypt_backup(
430328
finally:
431329
# Write an empty chunk to signal the end of the stream
432330
output_stream.write(b"")
433-
except AbortCipher:
331+
except Abort:
434332
LOGGER.debug("Cipher operation aborted")
435333
finally:
436334
on_done(error)
@@ -557,8 +455,8 @@ def on_done(error: Exception | None) -> None:
557455
self._hass.loop.call_soon_threadsafe(worker_status.done.set)
558456

559457
stream = await self._open_stream()
560-
reader = AsyncIteratorReader(self._hass, stream)
561-
writer = AsyncIteratorWriter(self._hass)
458+
reader = AsyncIteratorReader(self._hass.loop, stream)
459+
writer = AsyncIteratorWriter(self._hass.loop)
562460
worker = threading.Thread(
563461
target=self._cipher_func,
564462
args=[
Lines changed: 134 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,134 @@
1+
"""Async iterator utilities."""
2+
3+
from __future__ import annotations
4+
5+
import asyncio
6+
from collections.abc import AsyncIterator
7+
from concurrent.futures import CancelledError, Future
8+
from typing import Self
9+
10+
11+
class Abort(Exception):
12+
"""Raised when abort is requested."""
13+
14+
15+
class AsyncIteratorReader:
16+
"""Allow reading from an AsyncIterator using blocking I/O.
17+
18+
The class implements a blocking read method reading from the async iterator,
19+
and a close method.
20+
21+
In addition, the abort method can be used to abort any ongoing read operation.
22+
"""
23+
24+
def __init__(
25+
self,
26+
loop: asyncio.AbstractEventLoop,
27+
stream: AsyncIterator[bytes],
28+
) -> None:
29+
"""Initialize the wrapper."""
30+
self._aborted = False
31+
self._loop = loop
32+
self._stream = stream
33+
self._buffer: bytes | None = None
34+
self._next_future: Future[bytes | None] | None = None
35+
self._pos: int = 0
36+
37+
async def _next(self) -> bytes | None:
38+
"""Get the next chunk from the iterator."""
39+
return await anext(self._stream, None)
40+
41+
def abort(self) -> None:
42+
"""Abort the reader."""
43+
self._aborted = True
44+
if self._next_future is not None:
45+
self._next_future.cancel()
46+
47+
def read(self, n: int = -1, /) -> bytes:
48+
"""Read up to n bytes of data from the iterator.
49+
50+
The read method returns 0 bytes when the iterator is exhausted.
51+
"""
52+
result = bytearray()
53+
while n < 0 or len(result) < n:
54+
if not self._buffer:
55+
self._next_future = asyncio.run_coroutine_threadsafe(
56+
self._next(), self._loop
57+
)
58+
if self._aborted:
59+
self._next_future.cancel()
60+
raise Abort
61+
try:
62+
self._buffer = self._next_future.result()
63+
except CancelledError as err:
64+
raise Abort from err
65+
self._pos = 0
66+
if not self._buffer:
67+
# The stream is exhausted
68+
break
69+
chunk = self._buffer[self._pos : self._pos + n]
70+
result.extend(chunk)
71+
n -= len(chunk)
72+
self._pos += len(chunk)
73+
if self._pos == len(self._buffer):
74+
self._buffer = None
75+
return bytes(result)
76+
77+
def close(self) -> None:
78+
"""Close the iterator."""
79+
80+
81+
class AsyncIteratorWriter:
82+
"""Allow writing to an AsyncIterator using blocking I/O.
83+
84+
The class implements a blocking write method writing to the async iterator,
85+
as well as a close and tell methods.
86+
87+
In addition, the abort method can be used to abort any ongoing write operation.
88+
"""
89+
90+
def __init__(self, loop: asyncio.AbstractEventLoop) -> None:
91+
"""Initialize the wrapper."""
92+
self._aborted = False
93+
self._loop = loop
94+
self._pos: int = 0
95+
self._queue: asyncio.Queue[bytes | None] = asyncio.Queue(maxsize=1)
96+
self._write_future: Future[bytes | None] | None = None
97+
98+
def __aiter__(self) -> Self:
99+
"""Return the iterator."""
100+
return self
101+
102+
async def __anext__(self) -> bytes:
103+
"""Get the next chunk from the iterator."""
104+
if data := await self._queue.get():
105+
return data
106+
raise StopAsyncIteration
107+
108+
def abort(self) -> None:
109+
"""Abort the writer."""
110+
self._aborted = True
111+
if self._write_future is not None:
112+
self._write_future.cancel()
113+
114+
def tell(self) -> int:
115+
"""Return the current position in the iterator."""
116+
return self._pos
117+
118+
def write(self, s: bytes, /) -> int:
119+
"""Write data to the iterator.
120+
121+
To signal the end of the stream, write a zero-length bytes object.
122+
"""
123+
self._write_future = asyncio.run_coroutine_threadsafe(
124+
self._queue.put(s), self._loop
125+
)
126+
if self._aborted:
127+
self._write_future.cancel()
128+
raise Abort
129+
try:
130+
self._write_future.result()
131+
except CancelledError as err:
132+
raise Abort from err
133+
self._pos += len(s)
134+
return len(s)

0 commit comments

Comments
 (0)