Skip to content

Commit 8101a4e

Browse files
committed
Improve laziness to not create slightly eagerly (by 1)
1 parent e315d24 commit 8101a4e

File tree

1 file changed

+49
-16
lines changed

1 file changed

+49
-16
lines changed

async_utils/gen_transform.py

Lines changed: 49 additions & 16 deletions
Original file line numberDiff line numberDiff line change
@@ -15,6 +15,7 @@
1515
from __future__ import annotations
1616

1717
import asyncio
18+
from collections import deque
1819
from collections.abc import AsyncGenerator, Callable, Generator
1920
from typing import ParamSpec, TypeVar
2021

@@ -24,9 +25,36 @@
2425
YieldType = TypeVar("YieldType")
2526

2627

28+
class _PeekableQueue[T](asyncio.Queue[T]):
29+
"""This is for internal use only, tested on both 3.12 and 3.13
30+
This will be tested for 3.14 prior to 3.14's release."""
31+
32+
_get_loop: Callable[[], asyncio.AbstractEventLoop] # pyright: ignore[reportUninitializedInstanceVariable]
33+
_getters: deque[asyncio.Future[None]] # pyright: ignore[reportUninitializedInstanceVariable]
34+
_wakeup_next: Callable[[deque[asyncio.Future[None]]], None] # pyright: ignore[reportUninitializedInstanceVariable]
35+
_queue: deque[T] # pyright: ignore[reportUninitializedInstanceVariable]
36+
37+
async def peek(self) -> T:
38+
while self.empty():
39+
getter = self._get_loop().create_future()
40+
self._getters.append(getter) # type:
41+
try:
42+
await getter
43+
except:
44+
getter.cancel()
45+
try:
46+
self._getters.remove(getter)
47+
except ValueError:
48+
pass
49+
if not self.empty() and not getter.cancelled():
50+
self._wakeup_next(self._getters)
51+
raise
52+
return self._queue[0]
53+
54+
2755
def _consumer(
2856
loop: asyncio.AbstractEventLoop,
29-
queue: asyncio.Queue[YieldType],
57+
queue: _PeekableQueue[YieldType],
3058
f: Callable[P, Generator[YieldType]],
3159
*args: P.args,
3260
**kwargs: P.kwargs,
@@ -43,12 +71,12 @@ def sync_to_async_gen(
4371
*args: P.args,
4472
**kwargs: P.kwargs,
4573
) -> AsyncGenerator[YieldType]:
46-
"""async iterate over synchronous generator ran in backgroun thread.
47-
48-
Generator function and it's arguments must be threadsafe.
74+
"""Asynchronously iterate over a synchronous generator run in
75+
background thread.
4976
50-
Generators which perform cpu intensive work while holding the GIL will
51-
likely not see a benefit.
77+
The generator function and it's arguments must be threadsafe and will be
78+
iterated lazily. Generators which perform cpu intensive work while holding
79+
the GIL will likely not see a benefit.
5280
5381
Generators which rely on two-way communication (generators as coroutines)
5482
are not appropriate for this function. similarly, generator return values
@@ -57,21 +85,26 @@ def sync_to_async_gen(
5785
If your generator is actually a synchronous coroutine, that's super cool,
5886
but rewrite is as a native coroutine or use it directly then, you don't need
5987
what this function does."""
60-
# Provides backpressure, ensuring the underlying sync generator in a thread is lazy
61-
# If the user doesn't want laziness, then using this method makes little sense, they could
62-
# trivially exhaust the generator in a thread with asyncio.to_thread(lambda g: list(g()), g)
63-
# to then use the values
64-
q: asyncio.Queue[YieldType] = asyncio.Queue(maxsize=1)
88+
# Provides backpressure, ensuring the underlying sync generator in a thread
89+
# is lazy If the user doesn't want laziness, then using this method makes
90+
# little sense, they could trivially exhaust the generator in a thread with
91+
# asyncio.to_thread(lambda g: list(g()), g) to then use the values
92+
q: _PeekableQueue[YieldType] = _PeekableQueue(maxsize=1)
6593

66-
background_coro = asyncio.to_thread(_consumer, asyncio.get_running_loop(), q, f, *args, **kwargs)
94+
background_coro = asyncio.to_thread(
95+
_consumer, asyncio.get_running_loop(), q, f, *args, **kwargs
96+
)
6797
background_task = asyncio.create_task(background_coro)
6898

6999
async def gen() -> AsyncGenerator[YieldType]:
70100
while not background_task.done():
71-
q_get = asyncio.ensure_future(q.get())
72-
done, _pending = await asyncio.wait((background_task, q_get), return_when=asyncio.FIRST_COMPLETED)
73-
if q_get in done:
74-
yield (await q_get)
101+
q_peek = asyncio.ensure_future(q.peek())
102+
done, _pending = await asyncio.wait(
103+
(background_task, q_peek), return_when=asyncio.FIRST_COMPLETED
104+
)
105+
if q_peek in done:
106+
yield (await q_peek)
107+
q.get_nowait()
75108
while not q.empty():
76109
yield q.get_nowait()
77110
# ensure errors in the generator propogate *after* the last values yielded

0 commit comments

Comments
 (0)