Skip to content

Commit f715812

Browse files
fix: async_to_sync tests (#72)
* add async_to_sync test * fix async_to_sync test * extend async_to_sync tests
1 parent a0aaafc commit f715812

File tree

3 files changed

+63
-2
lines changed

3 files changed

+63
-2
lines changed

src/firebolt/common/util.py

Lines changed: 11 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,4 @@
1-
from asyncio import get_event_loop
1+
from asyncio import get_event_loop, new_event_loop
22
from functools import lru_cache, wraps
33
from typing import TYPE_CHECKING, Any, Callable, Type, TypeVar
44

@@ -39,6 +39,15 @@ def fix_url_schema(url: str) -> str:
3939
def async_to_sync(f: Callable) -> Callable:
4040
@wraps(f)
4141
def sync(*args: Any, **kwargs: Any) -> Any:
42-
return get_event_loop().run_until_complete(f(*args, **kwargs))
42+
close = False
43+
try:
44+
loop = get_event_loop()
45+
except RuntimeError:
46+
loop = new_event_loop()
47+
close = True
48+
res = loop.run_until_complete(f(*args, **kwargs))
49+
if close:
50+
loop.close()
51+
return res
4352

4453
return sync

tests/unit/common/__init__.py

Whitespace-only changes.

tests/unit/common/test_util.py

Lines changed: 52 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,52 @@
1+
from asyncio import run
2+
from threading import Thread
3+
4+
from pytest import raises
5+
6+
from firebolt.common.util import async_to_sync
7+
8+
9+
def test_async_to_sync_happy_path():
10+
"""async_to_sync properly converts coroutine to sync function"""
11+
12+
class JobMarker(Exception):
13+
pass
14+
15+
async def task():
16+
raise JobMarker()
17+
18+
for i in range(3):
19+
with raises(JobMarker):
20+
async_to_sync(task)()
21+
22+
23+
def test_async_to_sync_thread():
24+
"""async_to_sync properly works in threads"""
25+
26+
marks = [False] * 3
27+
28+
async def task(id: int):
29+
marks[id] = True
30+
31+
ts = [Thread(target=async_to_sync(task), args=[i]) for i in range(3)]
32+
[t.start() for t in ts]
33+
[t.join() for t in ts]
34+
assert all(marks)
35+
36+
37+
def test_async_to_sync_after_run():
38+
"""async_to_sync properly runs after asyncio.run"""
39+
40+
class JobMarker(Exception):
41+
pass
42+
43+
async def task():
44+
raise JobMarker()
45+
46+
with raises(JobMarker):
47+
run(task())
48+
49+
# Here local event loop is closed by run
50+
51+
with raises(JobMarker):
52+
async_to_sync(task)()

0 commit comments

Comments
 (0)