|
5 | 5 | from __future__ import annotations |
6 | 6 |
|
7 | 7 | import asyncio |
| 8 | +import contextlib |
8 | 9 | import enum |
9 | 10 | import inspect |
10 | 11 | import logging |
@@ -36,7 +37,9 @@ def __bool__(self) -> bool: # pragma: no cover |
36 | 37 | class _SyncWorkerThread(threading.Thread): |
37 | 38 | work_queue: queue.Queue[ |
38 | 39 | t.Union[ |
39 | | - t.Tuple[t.Union[t.AsyncIterator, t.Coroutine], Context], |
| 40 | + t.Tuple[ |
| 41 | + t.Union[t.AsyncIterator, t.Coroutine, t.AsyncContextManager], Context |
| 42 | + ], |
40 | 43 | _Sentinel, |
41 | 44 | ] |
42 | 45 | ] |
@@ -72,6 +75,11 @@ def run(self) -> None: |
72 | 75 | coro, ctx = item |
73 | 76 | if inspect.isasyncgen(coro): |
74 | 77 | ctx.run(loop.run_until_complete, self.agen_wrapper(coro)) # type: ignore[arg-type] |
| 78 | + elif isinstance(coro, t.AsyncContextManager): |
| 79 | + ctx.run( |
| 80 | + loop.run_until_complete, |
| 81 | + self.async_context_manager_wrapper(coro), |
| 82 | + ) |
75 | 83 | else: |
76 | 84 | try: |
77 | 85 | # FIXME: Once python/mypy#12756 is resolved, remove the type-ignore tag. |
@@ -131,14 +139,57 @@ def execute_generator(self, async_gen: t.AsyncIterator[_Item]) -> t.Iterator[_It |
131 | 139 | if item is sentinel: |
132 | 140 | break |
133 | 141 | if isinstance(item, Exception): |
134 | | - self.work_queue.put(sentinel) # initial loop closing |
135 | 142 | raise item |
136 | 143 | yield item |
137 | 144 | finally: |
138 | 145 | self.stream_block.set() |
139 | 146 | self.stream_queue.task_done() |
140 | 147 | finally: |
141 | 148 | del ctx |
| 149 | + self.work_queue.put(sentinel) # initial loop closing |
| 150 | + |
| 151 | + def _update_context(self, context: Context) -> None: |
| 152 | + for var, value in context.items(): |
| 153 | + var.set(value) |
| 154 | + |
| 155 | + @contextlib.contextmanager |
| 156 | + def execute_async_context_generator( |
| 157 | + self, async_context_manager: t.AsyncContextManager, context_update: bool = True |
| 158 | + ) -> t.Generator: |
| 159 | + ctx = copy_context() # preserve context for the worker thread |
| 160 | + |
| 161 | + try: |
| 162 | + self.work_queue.put((async_context_manager, ctx)) |
| 163 | + item, updated_ctx = self.stream_queue.get() # type:ignore[misc] |
| 164 | + |
| 165 | + try: |
| 166 | + if isinstance(item, Exception): |
| 167 | + raise item |
| 168 | + |
| 169 | + if context_update: |
| 170 | + self._update_context(updated_ctx) |
| 171 | + |
| 172 | + yield item |
| 173 | + finally: |
| 174 | + if updated_ctx: |
| 175 | + del updated_ctx |
| 176 | + self._update_context(ctx) |
| 177 | + |
| 178 | + self.stream_block.set() |
| 179 | + self.stream_queue.task_done() |
| 180 | + finally: |
| 181 | + del ctx |
| 182 | + self.work_queue.put(sentinel) # initial loop closing |
| 183 | + |
| 184 | + async def async_context_manager_wrapper(self, agen: t.AsyncContextManager) -> None: |
| 185 | + try: |
| 186 | + async with agen as s: |
| 187 | + self.stream_block.clear() |
| 188 | + self.stream_queue.put((s, copy_context())) |
| 189 | + # flow-control the generator. |
| 190 | + self.stream_block.wait() |
| 191 | + except Exception as e: |
| 192 | + self.stream_queue.put((e, None)) |
142 | 193 |
|
143 | 194 | def interrupt_generator(self) -> None: |
144 | 195 | self.agen_shutdown = True |
@@ -169,3 +220,19 @@ def execute_async_gen_with_sync_worker( |
169 | 220 |
|
170 | 221 | _worker_thread.work_queue.put(sentinel) |
171 | 222 | _worker_thread.join() |
| 223 | + |
| 224 | + |
| 225 | +@contextlib.contextmanager # type:ignore[arg-type] |
| 226 | +def execute_async_context_manager_with_sync_worker( # type:ignore[misc] |
| 227 | + async_gen: t.AsyncContextManager, context_update: bool = True |
| 228 | +) -> t.ContextManager: |
| 229 | + _worker_thread = _SyncWorkerThread() |
| 230 | + _worker_thread.start() |
| 231 | + |
| 232 | + with _worker_thread.execute_async_context_generator( |
| 233 | + async_gen, context_update |
| 234 | + ) as item: |
| 235 | + yield item |
| 236 | + |
| 237 | + _worker_thread.work_queue.put(sentinel) |
| 238 | + _worker_thread.join() |
0 commit comments