Skip to content

Commit 1ade712

Browse files
fix: Jupyter runtime (#120)
* fix jupyter notebook * extend dbapi examples * address comments * clean outputs * extend comments * extend comments 2
1 parent 6d635dd commit 1ade712

File tree

6 files changed

+224
-22
lines changed

6 files changed

+224
-22
lines changed

examples/dbapi.ipynb

Lines changed: 121 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -16,7 +16,8 @@
1616
"outputs": [],
1717
"source": [
1818
"from firebolt.db import connect\n",
19-
"from firebolt.client import DEFAULT_API_URL"
19+
"from firebolt.client import DEFAULT_API_URL\n",
20+
"from datetime import datetime"
2021
]
2122
},
2223
{
@@ -36,7 +37,6 @@
3637
"source": [
3738
"# Only one of these two parameters should be specified\n",
3839
"engine_url = \"\"\n",
39-
"engine_name = \"\"\n",
4040
"assert bool(engine_url) != bool(\n",
4141
" engine_name\n",
4242
"), \"Specify only one of engine_name and engine_url\"\n",
@@ -98,8 +98,32 @@
9898
" \"insert into test_table values (1, 'hello', '2021-01-01 01:01:01'),\"\n",
9999
" \"(2, 'world', '2022-02-02 02:02:02'),\"\n",
100100
" \"(3, '!', '2023-03-03 03:03:03')\"\n",
101+
")"
102+
]
103+
},
104+
{
105+
"cell_type": "markdown",
106+
"id": "b356295a",
107+
"metadata": {},
108+
"source": [
109+
"### Parameterized query"
110+
]
111+
},
112+
{
113+
"cell_type": "code",
114+
"execution_count": null,
115+
"id": "929f5221",
116+
"metadata": {},
117+
"outputs": [],
118+
"source": [
119+
"cursor.execute(\n",
120+
" \"insert into test_table values (?, ?, ?)\",\n",
121+
" (3, \"single parameter set\", datetime.now()),\n",
101122
")\n",
102-
"cursor.execute(\"select * from test_table\")"
123+
"cursor.executemany(\n",
124+
" \"insert into test_table values (?, ?, ?)\",\n",
125+
" ((4, \"multiple\", datetime.now()), (5, \"parameter sets\", datetime.fromtimestamp(0))),\n",
126+
")"
103127
]
104128
},
105129
{
@@ -117,6 +141,7 @@
117141
"metadata": {},
118142
"outputs": [],
119143
"source": [
144+
"cursor.execute(\"select * from test_table\")\n",
120145
"print(\"Description: \", cursor.description)\n",
121146
"print(\"Rowcount: \", cursor.rowcount)"
122147
]
@@ -141,6 +166,67 @@
141166
"print(cursor.fetchall())"
142167
]
143168
},
169+
{
170+
"cell_type": "markdown",
171+
"id": "efc4ff0a",
172+
"metadata": {},
173+
"source": [
174+
"## Multi-statement queries"
175+
]
176+
},
177+
{
178+
"cell_type": "code",
179+
"execution_count": null,
180+
"id": "744817b1",
181+
"metadata": {},
182+
"outputs": [],
183+
"source": [
184+
"cursor.execute(\n",
185+
" \"\"\"\n",
186+
" select * from test_table where id < 4;\n",
187+
" select * from test_table where id > 2;\n",
188+
"\"\"\"\n",
189+
")\n",
190+
"print(cursor._row_sets[0][2])\n",
191+
"print(cursor._row_sets[1][2])\n",
192+
"print(cursor._rows)\n",
193+
"# print(\"First query: \", cursor.fetchall())\n",
194+
"assert cursor.nextset()\n",
195+
"print(cursor._rows)\n",
196+
"# print(\"Secont query: \", cursor.fetchall())\n",
197+
"assert cursor.nextset() is None"
198+
]
199+
},
200+
{
201+
"cell_type": "markdown",
202+
"id": "02e5db2f",
203+
"metadata": {},
204+
"source": [
205+
"### Error handling\n",
206+
"If one query fails during the execution, all remaining queries are canceled.\n",
207+
"However, you still can fetch results for successful queries"
208+
]
209+
},
210+
{
211+
"cell_type": "code",
212+
"execution_count": null,
213+
"id": "888500a9",
214+
"metadata": {},
215+
"outputs": [],
216+
"source": [
217+
"try:\n",
218+
" cursor.execute(\n",
219+
" \"\"\"\n",
220+
" select * from test_table where id < 4;\n",
221+
" select * from test_table where wrong_field > 2;\n",
222+
" select * from test_table\n",
223+
" \"\"\"\n",
224+
" )\n",
225+
"except:\n",
226+
" pass\n",
227+
"cursor.fetchall()"
228+
]
229+
},
144230
{
145231
"cell_type": "markdown",
146232
"id": "b1cd4ff2",
@@ -286,6 +372,38 @@
286372
"source": [
287373
"await print_results(async_cursor)"
288374
]
375+
},
376+
{
377+
"cell_type": "markdown",
378+
"id": "da36dd3f",
379+
"metadata": {},
380+
"source": [
381+
"### Closing connection"
382+
]
383+
},
384+
{
385+
"cell_type": "code",
386+
"execution_count": null,
387+
"id": "83fc1686",
388+
"metadata": {},
389+
"outputs": [],
390+
"source": [
391+
"# manually\n",
392+
"connection.close()\n",
393+
"\n",
394+
"# using context manager\n",
395+
"with connect(\n",
396+
" engine_url=engine_url,\n",
397+
" engine_name=engine_name,\n",
398+
" database=database_name,\n",
399+
" username=username,\n",
400+
" password=password,\n",
401+
" api_endpoint=api_endpoint,\n",
402+
") as conn:\n",
403+
" # create cursors, perform database queries\n",
404+
" pass\n",
405+
"conn.closed"
406+
]
289407
}
290408
],
291409
"metadata": {

src/firebolt/async_db/connection.py

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -3,7 +3,7 @@
33
import socket
44
from json import JSONDecodeError
55
from types import TracebackType
6-
from typing import Callable, List, Optional, Type
6+
from typing import Any, Callable, List, Optional, Type
77

88
from httpcore.backends.auto import AutoBackend
99
from httpcore.backends.base import AsyncNetworkStream
@@ -207,15 +207,15 @@ def __init__(
207207
self._cursors: List[BaseCursor] = []
208208
self._is_closed = False
209209

210-
def cursor(self) -> BaseCursor:
210+
def _cursor(self, **kwargs: Any) -> BaseCursor:
211211
"""
212212
Create new cursor object.
213213
"""
214214

215215
if self.closed:
216216
raise ConnectionClosedError("Unable to create cursor: connection closed")
217217

218-
c = self.cursor_class(self._client, self)
218+
c = self.cursor_class(self._client, self, **kwargs)
219219
self._cursors.append(c)
220220
return c
221221

@@ -279,7 +279,7 @@ class Connection(BaseConnection):
279279
aclose = BaseConnection._aclose
280280

281281
def cursor(self) -> Cursor:
282-
c = super().cursor()
282+
c = super()._cursor()
283283
assert isinstance(c, Cursor) # typecheck
284284
return c
285285

src/firebolt/common/util.py

Lines changed: 63 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,20 @@
1-
from asyncio import get_event_loop, new_event_loop, set_event_loop
1+
from asyncio import (
2+
AbstractEventLoop,
3+
get_event_loop,
4+
new_event_loop,
5+
set_event_loop,
6+
)
27
from functools import lru_cache, wraps
3-
from typing import TYPE_CHECKING, Any, Callable, Type, TypeVar
8+
from threading import Thread
9+
from typing import (
10+
TYPE_CHECKING,
11+
Any,
12+
Callable,
13+
Coroutine,
14+
Optional,
15+
Type,
16+
TypeVar,
17+
)
418

519
T = TypeVar("T")
620

@@ -37,15 +51,59 @@ def fix_url_schema(url: str) -> str:
3751
return url if url.startswith("http") else f"https://{url}"
3852

3953

40-
def async_to_sync(f: Callable) -> Callable:
54+
class AsyncJobThread:
55+
"""
56+
Thread runner that allows running async tasks syncronously in a separate thread.
57+
Caches loop to be reused in all threads
58+
It allows running async functions syncronously inside a running event loop.
59+
Since nesting loops is not allowed, we create a separate thread for a new event loop
60+
"""
61+
62+
def __init__(self) -> None:
63+
self.loop: Optional[AbstractEventLoop] = None
64+
self.result: Optional[Any] = None
65+
self.exception: Optional[BaseException] = None
66+
67+
def _initialize_loop(self) -> None:
68+
if not self.loop:
69+
try:
70+
# despite the docs, this function fails if no loop is set
71+
self.loop = get_event_loop()
72+
except RuntimeError:
73+
self.loop = new_event_loop()
74+
set_event_loop(self.loop)
75+
76+
def run(self, coro: Coroutine) -> None:
77+
try:
78+
self._initialize_loop()
79+
assert self.loop is not None
80+
self.result = self.loop.run_until_complete(coro)
81+
except BaseException as e:
82+
self.exception = e
83+
84+
def execute(self, coro: Coroutine) -> Any:
85+
thread = Thread(target=self.run, args=[coro])
86+
thread.start()
87+
thread.join()
88+
if self.exception:
89+
raise self.exception
90+
return self.result
91+
92+
93+
def async_to_sync(f: Callable, async_job_thread: AsyncJobThread = None) -> Callable:
4194
@wraps(f)
4295
def sync(*args: Any, **kwargs: Any) -> Any:
4396
try:
4497
loop = get_event_loop()
4598
except RuntimeError:
4699
loop = new_event_loop()
47100
set_event_loop(loop)
48-
res = loop.run_until_complete(f(*args, **kwargs))
49-
return res
101+
# We are inside a running loop
102+
if loop.is_running():
103+
nonlocal async_job_thread
104+
if not async_job_thread:
105+
async_job_thread = AsyncJobThread()
106+
return async_job_thread.execute(f(*args, **kwargs))
107+
return loop.run_until_complete(f(*args, **kwargs))
50108

51109
return sync

src/firebolt/db/connection.py

Lines changed: 5 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -10,7 +10,7 @@
1010
from firebolt.async_db.connection import BaseConnection as AsyncBaseConnection
1111
from firebolt.async_db.connection import async_connect_factory
1212
from firebolt.common.exception import ConnectionClosedError
13-
from firebolt.common.util import async_to_sync
13+
from firebolt.common.util import AsyncJobThread, async_to_sync
1414
from firebolt.db.cursor import Cursor
1515

1616
DEFAULT_TIMEOUT_SECONDS: int = 5
@@ -33,7 +33,7 @@ class Connection(AsyncBaseConnection):
3333
are not implemented.
3434
"""
3535

36-
__slots__ = AsyncBaseConnection.__slots__ + ("_closing_lock",)
36+
__slots__ = AsyncBaseConnection.__slots__ + ("_closing_lock", "_async_job_thread")
3737

3838
cursor_class = Cursor
3939

@@ -42,18 +42,18 @@ def __init__(self, *args: Any, **kwargs: Any) -> None:
4242
# Holding this lock for write means that connection is closing itself.
4343
# cursor() should hold this lock for read to read/write state
4444
self._closing_lock = RWLockWrite()
45+
self._async_job_thread = AsyncJobThread()
4546

46-
@wraps(AsyncBaseConnection.cursor)
4747
def cursor(self) -> Cursor:
4848
with self._closing_lock.gen_rlock():
49-
c = super().cursor()
49+
c = super()._cursor(async_job_thread=self._async_job_thread)
5050
assert isinstance(c, Cursor) # typecheck
5151
return c
5252

5353
@wraps(AsyncBaseConnection._aclose)
5454
def close(self) -> None:
5555
with self._closing_lock.gen_wlock():
56-
async_to_sync(self._aclose)()
56+
async_to_sync(self._aclose, self._async_job_thread)()
5757

5858
# Context manager support
5959
def __enter__(self) -> Connection:

src/firebolt/db/cursor.py

Lines changed: 13 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -13,7 +13,7 @@
1313
check_not_closed,
1414
check_query_executed,
1515
)
16-
from firebolt.common.util import async_to_sync
16+
from firebolt.common.util import AsyncJobThread, async_to_sync
1717

1818

1919
class Cursor(AsyncBaseCursor):
@@ -31,11 +31,16 @@ class Cursor(AsyncBaseCursor):
3131
with :py:func:`fetchmany` method
3232
"""
3333

34-
__slots__ = AsyncBaseCursor.__slots__ + ("_query_lock", "_idx_lock")
34+
__slots__ = AsyncBaseCursor.__slots__ + (
35+
"_query_lock",
36+
"_idx_lock",
37+
"_async_job_thread",
38+
)
3539

3640
def __init__(self, *args: Any, **kwargs: Any) -> None:
3741
self._query_lock = RWLockWrite()
3842
self._idx_lock = Lock()
43+
self._async_job_thread: AsyncJobThread = kwargs.pop("async_job_thread")
3944
super().__init__(*args, **kwargs)
4045

4146
@wraps(AsyncBaseCursor.execute)
@@ -46,14 +51,18 @@ def execute(
4651
set_parameters: Optional[Dict] = None,
4752
) -> int:
4853
with self._query_lock.gen_wlock():
49-
return async_to_sync(super().execute)(query, parameters, set_parameters)
54+
return async_to_sync(super().execute, self._async_job_thread)(
55+
query, parameters, set_parameters
56+
)
5057

5158
@wraps(AsyncBaseCursor.executemany)
5259
def executemany(
5360
self, query: str, parameters_seq: Sequence[Sequence[ParameterType]]
5461
) -> int:
5562
with self._query_lock.gen_wlock():
56-
return async_to_sync(super().executemany)(query, parameters_seq)
63+
return async_to_sync(super().executemany, self._async_job_thread)(
64+
query, parameters_seq
65+
)
5766

5867
@wraps(AsyncBaseCursor._get_next_range)
5968
def _get_next_range(self, size: int) -> Tuple[int, int]:

0 commit comments

Comments
 (0)