Skip to content

Commit 422a26f

Browse files
authored
feat: use trio for async_to_sync (#220)
1 parent 2b9dcd1 commit 422a26f

File tree

6 files changed

+61
-113
lines changed

6 files changed

+61
-113
lines changed

setup.cfg

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -32,6 +32,7 @@ install_requires =
3232
pydantic[dotenv]>=1.8.2,<1.10
3333
readerwriterlock==1.0.9
3434
sqlparse>=0.4.2
35+
trio
3536
python_requires = >=3.7
3637
include_package_data = True
3738
package_dir =
@@ -55,6 +56,7 @@ dev =
5556
pytest-mock==3.6.1
5657
pytest-timeout==2.1.0
5758
pytest-xdist==2.5.0
59+
trio-typing[mypy]==0.6.*
5860
types-cryptography==3.3.18
5961

6062
[options.package_data]

src/firebolt/async_db/cursor.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -679,7 +679,7 @@ def __exit__(
679679

680680
class Cursor(BaseCursor):
681681
"""
682-
Executes asyncio queries to Firebolt Database.
682+
Executes async queries to Firebolt Database.
683683
Should not be created directly;
684684
use :py:func:`connection.cursor <firebolt.async_db.connection.Connection>`
685685

src/firebolt/db/connection.py

Lines changed: 4 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -11,7 +11,7 @@
1111
from firebolt.async_db.connection import async_connect_factory
1212
from firebolt.db.cursor import Cursor
1313
from firebolt.utils.exception import ConnectionClosedError
14-
from firebolt.utils.util import AsyncJobThread, async_to_sync
14+
from firebolt.utils.util import async_to_sync
1515

1616

1717
class Connection(AsyncBaseConnection):
@@ -31,7 +31,7 @@ class Connection(AsyncBaseConnection):
3131
are not implemented.
3232
"""
3333

34-
__slots__ = AsyncBaseConnection.__slots__ + ("_closing_lock", "_async_job_thread")
34+
__slots__ = AsyncBaseConnection.__slots__ + ("_closing_lock",)
3535

3636
cursor_class = Cursor
3737

@@ -40,18 +40,17 @@ def __init__(self, *args: Any, **kwargs: Any) -> None:
4040
# Holding this lock for write means that connection is closing itself.
4141
# cursor() should hold this lock for read to read/write state
4242
self._closing_lock = RWLockWrite()
43-
self._async_job_thread = AsyncJobThread()
4443

4544
def cursor(self) -> Cursor:
4645
with self._closing_lock.gen_rlock():
47-
c = super()._cursor(async_job_thread=self._async_job_thread)
46+
c = super()._cursor()
4847
assert isinstance(c, Cursor) # typecheck
4948
return c
5049

5150
@wraps(AsyncBaseConnection._aclose)
5251
def close(self) -> None:
5352
with self._closing_lock.gen_wlock():
54-
async_to_sync(self._aclose, self._async_job_thread)()
53+
async_to_sync(self._aclose)()
5554

5655
# Context manager support
5756
def __enter__(self) -> Connection:

src/firebolt/db/cursor.py

Lines changed: 5 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -14,7 +14,7 @@
1414
check_not_closed,
1515
check_query_executed,
1616
)
17-
from firebolt.utils.util import AsyncJobThread, async_to_sync
17+
from firebolt.utils.util import async_to_sync
1818

1919

2020
class Cursor(AsyncBaseCursor):
@@ -34,13 +34,11 @@ class Cursor(AsyncBaseCursor):
3434
__slots__ = AsyncBaseCursor.__slots__ + (
3535
"_query_lock",
3636
"_idx_lock",
37-
"_async_job_thread",
3837
)
3938

4039
def __init__(self, *args: Any, **kwargs: Any) -> None:
4140
self._query_lock = RWLockWrite()
4241
self._idx_lock = Lock()
43-
self._async_job_thread: AsyncJobThread = kwargs.pop("async_job_thread")
4442
super().__init__(*args, **kwargs)
4543

4644
@wraps(AsyncBaseCursor.execute)
@@ -52,7 +50,7 @@ def execute(
5250
async_execution: Optional[bool] = False,
5351
) -> Union[int, str]:
5452
with self._query_lock.gen_wlock():
55-
return async_to_sync(super().execute, self._async_job_thread)(
53+
return async_to_sync(super().execute)(
5654
query, parameters, skip_parsing, async_execution
5755
)
5856

@@ -64,7 +62,7 @@ def executemany(
6462
async_execution: Optional[bool] = False,
6563
) -> Union[int, str]:
6664
with self._query_lock.gen_wlock():
67-
return async_to_sync(super().executemany, self._async_job_thread)(
65+
return async_to_sync(super().executemany)(
6866
query, parameters_seq, async_execution
6967
)
7068

@@ -106,9 +104,9 @@ def __iter__(self) -> Generator[List[ColType], None, None]:
106104
@wraps(AsyncBaseCursor.get_status)
107105
def get_status(self, query_id: str) -> QueryStatus:
108106
with self._query_lock.gen_rlock():
109-
return async_to_sync(super().get_status, self._async_job_thread)(query_id)
107+
return async_to_sync(super().get_status)(query_id)
110108

111109
@wraps(AsyncBaseCursor.cancel)
112110
def cancel(self, query_id: str) -> None:
113111
with self._query_lock.gen_rlock():
114-
return async_to_sync(super().cancel, self._async_job_thread)(query_id)
112+
return async_to_sync(super().cancel)(query_id)

src/firebolt/utils/util.py

Lines changed: 5 additions & 100 deletions
Original file line numberDiff line numberDiff line change
@@ -1,21 +1,7 @@
1-
from asyncio import (
2-
AbstractEventLoop,
3-
get_event_loop,
4-
new_event_loop,
5-
set_event_loop,
6-
)
7-
from functools import lru_cache, wraps
8-
from threading import Thread
9-
from typing import (
10-
TYPE_CHECKING,
11-
Any,
12-
Callable,
13-
Coroutine,
14-
Optional,
15-
Type,
16-
TypeVar,
17-
)
1+
from functools import lru_cache, partial, wraps
2+
from typing import TYPE_CHECKING, Any, Callable, Type, TypeVar
183

4+
import trio
195
from httpx import URL
206

217
T = TypeVar("T")
@@ -84,100 +70,19 @@ def fix_url_schema(url: str) -> str:
8470
return url if url.startswith("http") else f"https://{url}"
8571

8672

87-
class AsyncJobThread:
88-
"""Thread runner that allows running async tasks synchronously in a separate thread.
89-
90-
Caches loop to be reused in all threads.
91-
It allows running async functions synchronously inside a running event loop.
92-
Since nesting loops is not allowed, we create a separate thread for a new event loop
93-
94-
Attributes:
95-
result (Any): Value, returned by coroutine execution
96-
exception (Optional[BaseException]): If any, exception that occurred
97-
during coroutine execution
98-
"""
99-
100-
def __init__(self) -> None:
101-
self._loop: Optional[AbstractEventLoop] = None
102-
self.result: Any = None
103-
self.exception: Optional[BaseException] = None
104-
105-
def _initialize_loop(self) -> None:
106-
"""Initialize a loop once to use for later execution.
107-
108-
Tries to get a running loop.
109-
Creates a new loop if no active one, and sets it as active.
110-
"""
111-
if not self._loop:
112-
try:
113-
# despite the docs, this function fails if no loop is set
114-
self._loop = get_event_loop()
115-
except RuntimeError:
116-
self._loop = new_event_loop()
117-
set_event_loop(self._loop)
118-
119-
def _run(self, coro: Coroutine) -> None:
120-
"""Run coroutine in an event loop.
121-
122-
Execution return value is stored into ``result`` field.
123-
If an exception occurs, it will be caught and stored into ``exception`` field.
124-
125-
Args:
126-
coro (Coroutine): Coroutine to execute
127-
"""
128-
try:
129-
self._initialize_loop()
130-
assert self._loop is not None
131-
self.result = self._loop.run_until_complete(coro)
132-
except BaseException as e:
133-
self.exception = e
134-
135-
def execute(self, coro: Coroutine) -> Any:
136-
"""Execute coroutine in a separate thread.
137-
138-
Args:
139-
coro (Coroutine): Coroutine to execute
140-
141-
Returns:
142-
Any: Coroutine execution return value
143-
144-
Raises:
145-
exception: Exeption, occured within coroutine
146-
"""
147-
thread = Thread(target=self._run, args=[coro])
148-
thread.start()
149-
thread.join()
150-
if self.exception:
151-
raise self.exception
152-
return self.result
153-
154-
155-
def async_to_sync(f: Callable, async_job_thread: AsyncJobThread = None) -> Callable:
73+
def async_to_sync(f: Callable) -> Callable:
15674
"""Convert async function to sync.
15775
15876
Args:
15977
f (Callable): function to convert
160-
async_job_thread (AsyncJobThread): Job thread instance to use for async excution
161-
(Default value = None)
16278
16379
Returns:
16480
Callable: regular function, which can be executed synchronously
16581
"""
16682

16783
@wraps(f)
16884
def sync(*args: Any, **kwargs: Any) -> Any:
169-
try:
170-
loop = get_event_loop()
171-
except RuntimeError:
172-
loop = new_event_loop()
173-
set_event_loop(loop)
174-
# We are inside a running loop
175-
if loop.is_running():
176-
nonlocal async_job_thread
177-
if not async_job_thread:
178-
async_job_thread = AsyncJobThread()
179-
return async_job_thread.execute(f(*args, **kwargs))
180-
return loop.run_until_complete(f(*args, **kwargs))
85+
return trio.run(partial(f, *args, **kwargs))
18186

18287
return sync
18388

tests/integration/dbapi/sync/test_queries.py

Lines changed: 44 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -464,3 +464,47 @@ async def test_server_side_async_execution_get_status(
464464
# assert (
465465
# type(status) is QueryStatus,
466466
# ), "get_status() did not return a QueryStatus object."
467+
468+
469+
def test_multi_thread_connection_sharing(
470+
engine_url: str,
471+
database_name: str,
472+
password_auth: Auth,
473+
account_name: str,
474+
api_endpoint: str,
475+
) -> None:
476+
"""
477+
Test to verify sharing the same connection between different
478+
threads works. With asyncio synching an async function this used
479+
to fail due to a different loop having exclusive rights to the
480+
Httpx client. Trio fixes this issue.
481+
"""
482+
483+
exceptions = []
484+
485+
connection = connect(
486+
auth=password_auth,
487+
database=database_name,
488+
account_name=account_name,
489+
engine_url=engine_url,
490+
api_endpoint=api_endpoint,
491+
)
492+
493+
def run_query():
494+
try:
495+
cursor = connection.cursor()
496+
cursor.execute("select 1")
497+
cursor.fetchall()
498+
except BaseException as e:
499+
exceptions.append(e)
500+
501+
thread_1 = Thread(target=run_query)
502+
thread_2 = Thread(target=run_query)
503+
504+
thread_1.start()
505+
thread_1.join()
506+
thread_2.start()
507+
thread_2.join()
508+
509+
connection.close()
510+
assert not exceptions

0 commit comments

Comments
 (0)