Skip to content

Commit 28a90ac

Browse files
refactor: FIR-43722 refactor our result set logic from cursors (#423)
Co-authored-by: Petro Tiurin <[email protected]>
1 parent ed91701 commit 28a90ac

Some content is hidden

Large Commits have some content hidden by default. Use the searchbox below for content that may be hidden.

41 files changed

+1062
-699
lines changed

src/firebolt/async_db/connection.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -179,7 +179,7 @@ async def aclose(self) -> None:
179179
for c in cursors:
180180
# Here c can already be closed by another thread,
181181
# but it shouldn't raise an error in this case
182-
c.close()
182+
await c.aclose()
183183
await self._client.aclose()
184184
self._is_closed = True
185185

src/firebolt/async_db/cursor.py

Lines changed: 95 additions & 50 deletions
Original file line numberDiff line numberDiff line change
@@ -2,19 +2,11 @@
22

33
import logging
44
import time
5+
import warnings
56
from abc import ABCMeta, abstractmethod
67
from functools import wraps
78
from types import TracebackType
8-
from typing import (
9-
TYPE_CHECKING,
10-
Any,
11-
Dict,
12-
Iterator,
13-
List,
14-
Optional,
15-
Sequence,
16-
Union,
17-
)
9+
from typing import TYPE_CHECKING, Any, Dict, List, Optional, Sequence, Union
1810
from urllib.parse import urljoin
1911

2012
from httpx import (
@@ -28,20 +20,26 @@
2820

2921
from firebolt.client.client import AsyncClient, AsyncClientV1, AsyncClientV2
3022
from firebolt.common._types import ColType, ParameterType, SetParameter
31-
from firebolt.common.base_cursor import (
23+
from firebolt.common.constants import (
3224
JSON_OUTPUT_FORMAT,
3325
RESET_SESSION_HEADER,
3426
UPDATE_ENDPOINT_HEADER,
3527
UPDATE_PARAMETERS_HEADER,
36-
BaseCursor,
3728
CursorState,
29+
)
30+
from firebolt.common.cursor.base_cursor import (
31+
BaseCursor,
3832
_parse_update_endpoint,
3933
_parse_update_parameters,
4034
_raise_if_internal_set_parameter,
35+
)
36+
from firebolt.common.cursor.decorators import (
4137
async_not_allowed,
4238
check_not_closed,
4339
check_query_executed,
4440
)
41+
from firebolt.common.row_set.asynchronous.base import BaseAsyncRowSet
42+
from firebolt.common.row_set.asynchronous.in_memory import InMemoryAsyncRowSet
4543
from firebolt.common.statement_formatter import create_statement_formatter
4644
from firebolt.utils.exception import (
4745
EngineNotRunningError,
@@ -58,7 +56,12 @@
5856
if TYPE_CHECKING:
5957
from firebolt.async_db.connection import Connection
6058

61-
from firebolt.utils.util import _print_error_body, raise_errors_from_body
59+
from firebolt.utils.async_util import async_islice
60+
from firebolt.utils.util import (
61+
Timer,
62+
_print_error_body,
63+
raise_errors_from_body,
64+
)
6265

6366
logger = logging.getLogger(__name__)
6467

@@ -88,6 +91,7 @@ def __init__(
8891
self._client = client
8992
self.connection = connection
9093
self.engine_url = connection.engine_url
94+
self._row_set: Optional[BaseAsyncRowSet] = None
9195
if connection.init_parameters:
9296
self._update_set_parameters(connection.init_parameters)
9397

@@ -121,13 +125,14 @@ async def _api_request(
121125
if self.parameters:
122126
parameters = {**self.parameters, **parameters}
123127
try:
124-
return await self._client.request(
128+
req = self._client.build_request(
125129
url=urljoin(self.engine_url.rstrip("/") + "/", path or ""),
126130
method="POST",
127131
params=parameters,
128132
content=query,
129133
timeout=timeout if timeout is not None else USE_CLIENT_DEFAULT,
130134
)
135+
return await self._client.send(req, stream=True)
131136
except TimeoutException:
132137
raise QueryTimeoutError()
133138

@@ -170,6 +175,9 @@ async def _validate_set_parameter(
170175
# set parameter passed validation
171176
self._set_parameters[parameter.name] = parameter.value
172177

178+
# append empty result set
179+
await self._append_row_set_from_response(None)
180+
173181
async def _parse_response_headers(self, headers: Headers) -> None:
174182
if headers.get(UPDATE_ENDPOINT_HEADER):
175183
endpoint, params = _parse_update_endpoint(
@@ -271,8 +279,7 @@ async def _handle_query_execution(
271279
self._parse_async_response(resp)
272280
else:
273281
await self._parse_response_headers(resp.headers)
274-
row_set = self._row_set_from_response(resp)
275-
self._append_row_set(row_set)
282+
await self._append_row_set_from_response(resp)
276283

277284
@check_not_closed
278285
async def execute(
@@ -353,75 +360,113 @@ async def executemany(
353360
await self._do_execute(query, parameters_seq, timeout=timeout_seconds)
354361
return self.rowcount
355362

356-
@abstractmethod
357-
async def is_db_available(self, database: str) -> bool:
358-
"""Verify that the database exists."""
359-
...
363+
async def _append_row_set_from_response(
364+
self,
365+
response: Optional[Response],
366+
) -> None:
367+
"""Store information about executed query."""
368+
if self._row_set is None:
369+
self._row_set = InMemoryAsyncRowSet()
370+
if response is None:
371+
self._row_set.append_empty_response()
372+
else:
373+
await self._row_set.append_response(response)
360374

361-
@abstractmethod
362-
async def is_engine_running(self, engine_url: str) -> bool:
363-
"""Verify that the engine is running."""
364-
...
375+
_performance_log_message = (
376+
"[PERFORMANCE] Parsing query output into native Python types"
377+
)
365378

366-
@wraps(BaseCursor.fetchone)
379+
@check_not_closed
380+
@async_not_allowed
381+
@check_query_executed
367382
async def fetchone(self) -> Optional[List[ColType]]:
368383
"""Fetch the next row of a query result set."""
369-
return super().fetchone()
384+
assert self._row_set is not None
385+
with Timer(self._performance_log_message):
386+
# anext() is only supported in Python 3.10+
387+
# this means we cannot just do return anext(self._row_set),
388+
# we need to handle iteration manually
389+
try:
390+
return await self._row_set.__anext__()
391+
except StopAsyncIteration:
392+
return None
370393

371-
@wraps(BaseCursor.fetchmany)
394+
@check_not_closed
395+
@async_not_allowed
396+
@check_query_executed
372397
async def fetchmany(self, size: Optional[int] = None) -> List[List[ColType]]:
373398
"""
374399
Fetch the next set of rows of a query result;
375-
size is cursor.arraysize by default.
400+
cursor.arraysize is default size.
376401
"""
377-
return super().fetchmany(size)
402+
assert self._row_set is not None
403+
size = size if size is not None else self.arraysize
404+
with Timer(self._performance_log_message):
405+
return await async_islice(self._row_set, size)
378406

379-
@wraps(BaseCursor.fetchall)
407+
@check_not_closed
408+
@async_not_allowed
409+
@check_query_executed
380410
async def fetchall(self) -> List[List[ColType]]:
381411
"""Fetch all remaining rows of a query result."""
382-
return super().fetchall()
412+
assert self._row_set is not None
413+
with Timer(self._performance_log_message):
414+
return [it async for it in self._row_set]
383415

384416
@wraps(BaseCursor.nextset)
385417
async def nextset(self) -> None:
386418
return super().nextset()
387419

420+
async def aclose(self) -> None:
421+
super().close()
422+
if self._row_set is not None:
423+
await self._row_set.aclose()
424+
425+
@abstractmethod
426+
async def is_db_available(self, database: str) -> bool:
427+
"""Verify that the database exists."""
428+
...
429+
430+
@abstractmethod
431+
async def is_engine_running(self, engine_url: str) -> bool:
432+
"""Verify that the engine is running."""
433+
...
434+
388435
# Iteration support
389436
@check_not_closed
390437
@async_not_allowed
391438
@check_query_executed
392439
def __aiter__(self) -> Cursor:
393440
return self
394441

395-
# TODO: figure out how to implement __aenter__ and __await__
396442
@check_not_closed
397-
def __aenter__(self) -> Cursor:
398-
return self
399-
400-
@check_not_closed
401-
def __enter__(self) -> Cursor:
443+
async def __aenter__(self) -> Cursor:
402444
return self
403445

404-
def __exit__(
405-
self, exc_type: type, exc_val: Exception, exc_tb: TracebackType
406-
) -> None:
407-
self.close()
408-
409-
def __await__(self) -> Iterator:
410-
yield None
411-
412446
async def __aexit__(
413447
self, exc_type: type, exc_val: Exception, exc_tb: TracebackType
414448
) -> None:
415-
self.close()
449+
await self.aclose()
416450

417451
@check_not_closed
418452
@async_not_allowed
419453
@check_query_executed
420454
async def __anext__(self) -> List[ColType]:
421-
row = await self.fetchone()
422-
if row is None:
423-
raise StopAsyncIteration
424-
return row
455+
assert self._row_set is not None
456+
return await self._row_set.__anext__()
457+
458+
@check_not_closed
459+
def __enter__(self) -> Cursor:
460+
warnings.warn(
461+
"Using __enter__ is deprecated, use 'async with' instead",
462+
DeprecationWarning,
463+
)
464+
return self
465+
466+
def __exit__(
467+
self, exc_type: type, exc_val: Exception, exc_tb: TracebackType
468+
) -> None:
469+
return None
425470

426471

427472
class CursorV2(Cursor):

src/firebolt/common/_types.py

Lines changed: 1 addition & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -77,19 +77,6 @@ def Binary(value: str) -> bytes: # NOSONAR
7777
DATETIME = datetime
7878
ROWID = int
7979

80-
Column = namedtuple(
81-
"Column",
82-
(
83-
"name",
84-
"type_code",
85-
"display_size",
86-
"internal_size",
87-
"precision",
88-
"scale",
89-
"null_ok",
90-
),
91-
)
92-
9380

9481
class ExtendedType:
9582
"""Base type for all extended types in Firebolt (array, decimal, struct, etc.)."""
@@ -338,7 +325,7 @@ def parse_value(
338325
raise DataError(f"Invalid bytea value {value}: str expected")
339326
return _parse_bytea(value)
340327
if isinstance(ctype, DECIMAL):
341-
if not isinstance(value, (str, int)):
328+
if not isinstance(value, (str, int, float)):
342329
raise DataError(f"Invalid decimal value {value}: str or int expected")
343330
return Decimal(value)
344331
if isinstance(ctype, ARRAY):

0 commit comments

Comments
 (0)