Skip to content

Commit 0433ebb

Browse files
committed
SyncWorker: Added worker for running coroutine function for cases where loop.is_running() is true
1 parent 5b26ee8 commit 0433ebb

File tree

4 files changed

+250
-0
lines changed

4 files changed

+250
-0
lines changed

ellar/threading/__init__.py

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,6 @@
1+
from .sync_worker import (
2+
execute_async_gen_with_sync_worker,
3+
execute_coroutine_with_sync_worker,
4+
)
5+
6+
__all__ = ["execute_async_gen_with_sync_worker", "execute_coroutine_with_sync_worker"]

ellar/threading/sync_worker.py

Lines changed: 171 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,171 @@
1+
"""
2+
Copied from BackendAI - http:github.com/lablup/backend.ai
3+
https://github.com/lablup/backend.ai/blob/4a19001f9d1ae12be7244e14b304d783da9ea4f9/src/ai/backend/client/session.py#L128
4+
"""
5+
from __future__ import annotations
6+
7+
import asyncio
8+
import enum
9+
import inspect
10+
import logging
11+
import queue
12+
import threading
13+
import typing as t
14+
from contextvars import Context, copy_context
15+
16+
_Item = t.TypeVar("_Item")
17+
18+
logger = logging.getLogger("ellar.sync_worker")
19+
20+
21+
class _Sentinel(enum.Enum):
22+
"""
23+
A special type to represent a special value to indicate closing/shutdown of queues.
24+
"""
25+
26+
TOKEN = 0
27+
28+
def __bool__(self) -> bool: # pragma: no cover
29+
# It should be evaluated as False when used as a boolean expr.
30+
return False
31+
32+
33+
sentinel = _Sentinel.TOKEN
34+
35+
36+
class _SyncWorkerThread(threading.Thread):
37+
work_queue: queue.Queue[
38+
t.Union[
39+
t.Tuple[t.Union[t.AsyncIterator, t.Coroutine], Context],
40+
_Sentinel,
41+
]
42+
]
43+
done_queue: queue.Queue[t.Union[t.Any, Exception]]
44+
stream_queue: queue.Queue[t.Union[t.Any, Exception, _Sentinel]]
45+
stream_block: threading.Event
46+
agen_shutdown: bool
47+
48+
__slots__ = (
49+
"work_queue",
50+
"done_queue",
51+
"stream_queue",
52+
"stream_block",
53+
"agen_shutdown",
54+
)
55+
56+
def __init__(self, *args: t.Any, **kwargs: t.Any) -> None:
57+
super().__init__(*args, **kwargs)
58+
self.work_queue = queue.Queue()
59+
self.done_queue = queue.Queue()
60+
self.stream_queue = queue.Queue()
61+
self.stream_block = threading.Event()
62+
self.agen_shutdown = False
63+
64+
def run(self) -> None:
65+
loop = asyncio.new_event_loop()
66+
asyncio.set_event_loop(loop)
67+
try:
68+
while True:
69+
item = self.work_queue.get()
70+
if item is sentinel:
71+
break
72+
coro, ctx = item
73+
if inspect.isasyncgen(coro):
74+
ctx.run(loop.run_until_complete, self.agen_wrapper(coro)) # type: ignore[arg-type]
75+
else:
76+
try:
77+
# FIXME: Once python/mypy#12756 is resolved, remove the type-ignore tag.
78+
result = ctx.run(loop.run_until_complete, coro) # type: ignore[arg-type]
79+
except Exception as e:
80+
self.done_queue.put_nowait(e)
81+
self.work_queue.task_done()
82+
raise e
83+
else:
84+
self.done_queue.put_nowait(result)
85+
self.work_queue.task_done()
86+
87+
except (SystemExit, KeyboardInterrupt): # pragma: no cover
88+
pass
89+
except Exception as ex:
90+
logger.error(ex)
91+
finally:
92+
loop.run_until_complete(loop.shutdown_asyncgens())
93+
loop.stop()
94+
loop.close()
95+
96+
def execute(self, coro: t.Coroutine) -> t.Any:
97+
ctx = copy_context() # preserve context for the worker thread
98+
try:
99+
self.work_queue.put((coro, ctx))
100+
result = self.done_queue.get()
101+
self.done_queue.task_done()
102+
if isinstance(result, Exception):
103+
raise result
104+
return result
105+
finally:
106+
del ctx
107+
108+
async def agen_wrapper(self, agen: t.Coroutine) -> None:
109+
self.agen_shutdown = False
110+
try:
111+
async for item in agen: # type: ignore[attr-defined]
112+
self.stream_block.clear()
113+
self.stream_queue.put(item)
114+
# flow-control the generator.
115+
self.stream_block.wait()
116+
if self.agen_shutdown:
117+
break
118+
except Exception as e:
119+
self.stream_queue.put(e)
120+
finally:
121+
self.stream_queue.put(sentinel)
122+
await agen.aclose() # type: ignore[attr-defined]
123+
124+
def execute_generator(self, async_gen: t.AsyncIterator[_Item]) -> t.Iterator[_Item]:
125+
ctx = copy_context() # preserve context for the worker thread
126+
try:
127+
self.work_queue.put((async_gen, ctx))
128+
while True:
129+
item = self.stream_queue.get()
130+
try:
131+
if item is sentinel:
132+
break
133+
if isinstance(item, Exception):
134+
self.work_queue.put(sentinel) # initial loop closing
135+
raise item
136+
yield item
137+
finally:
138+
self.stream_block.set()
139+
self.stream_queue.task_done()
140+
finally:
141+
del ctx
142+
143+
def interrupt_generator(self) -> None:
144+
self.agen_shutdown = True
145+
self.stream_block.set()
146+
self.stream_queue.put(sentinel)
147+
148+
149+
def execute_coroutine_with_sync_worker(coro: t.Coroutine) -> t.Any:
150+
_worker_thread = _SyncWorkerThread()
151+
_worker_thread.start()
152+
153+
res = _worker_thread.execute(coro)
154+
155+
_worker_thread.work_queue.put(sentinel)
156+
_worker_thread.join()
157+
158+
return res
159+
160+
161+
def execute_async_gen_with_sync_worker(
162+
async_gen: t.AsyncIterator[_Item],
163+
) -> t.Iterator[_Item]:
164+
_worker_thread = _SyncWorkerThread()
165+
_worker_thread.start()
166+
167+
for item in _worker_thread.execute_generator(async_gen):
168+
yield item
169+
170+
_worker_thread.work_queue.put(sentinel)
171+
_worker_thread.join()

tests/test_thread_worker/__init__.py

Whitespace-only changes.
Lines changed: 73 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,73 @@
1+
import pytest
2+
from ellar.threading.sync_worker import (
3+
_SyncWorkerThread,
4+
execute_async_gen_with_sync_worker,
5+
execute_coroutine_with_sync_worker,
6+
sentinel,
7+
)
8+
9+
10+
async def coroutine_function():
11+
return "Coroutine Function"
12+
13+
14+
async def coroutine_function_2():
15+
raise RuntimeError()
16+
17+
18+
async def async_gen(after=None):
19+
for i in range(0, 10):
20+
if after and i > after:
21+
raise Exception("Exceeded")
22+
yield i
23+
24+
25+
async def test_run_with_sync_worker_runs_async_function_synchronously(anyio_backend):
26+
res = execute_coroutine_with_sync_worker(coroutine_function())
27+
assert res == "Coroutine Function"
28+
29+
30+
async def test_run_with_sync_worker_will_raise_an_exception(anyio_backend):
31+
with pytest.raises(RuntimeError):
32+
execute_coroutine_with_sync_worker(coroutine_function_2())
33+
34+
35+
async def test_sync_worker_exists_wait_for_work_task(anyio_backend):
36+
worker = _SyncWorkerThread()
37+
worker.start()
38+
# exist waiting for a work task
39+
worker.work_queue.put(sentinel)
40+
worker.join()
41+
42+
43+
async def test_sync_worker_execute_async_generator(anyio_backend):
44+
res = list(execute_async_gen_with_sync_worker(async_gen()))
45+
assert res == [0, 1, 2, 3, 4, 5, 6, 7, 8, 9]
46+
47+
48+
async def test_sync_worker_execute_async_generator_raises_exception(anyio_backend):
49+
worker = _SyncWorkerThread()
50+
worker.start()
51+
52+
res = []
53+
with pytest.raises(Exception, match="Exceeded"):
54+
for item in worker.execute_generator(async_gen(6)):
55+
res.append(item)
56+
57+
assert res == [0, 1, 2, 3, 4, 5, 6]
58+
59+
60+
async def test_sync_worker_interrupt_function_works(anyio_backend):
61+
worker = _SyncWorkerThread()
62+
worker.start()
63+
64+
res = []
65+
for item in worker.execute_generator(async_gen()):
66+
if len(res) == 7:
67+
worker.interrupt_generator()
68+
continue
69+
res.append(item)
70+
71+
assert res == [0, 1, 2, 3, 4, 5, 6]
72+
worker.work_queue.put(sentinel)
73+
worker.join()

0 commit comments

Comments
 (0)