|
2 | 2 |
|
3 | 3 | import logging |
4 | 4 | import time |
| 5 | +import warnings |
5 | 6 | from abc import ABCMeta, abstractmethod |
6 | 7 | from functools import wraps |
7 | 8 | from types import TracebackType |
8 | | -from typing import ( |
9 | | - TYPE_CHECKING, |
10 | | - Any, |
11 | | - Dict, |
12 | | - Iterator, |
13 | | - List, |
14 | | - Optional, |
15 | | - Sequence, |
16 | | - Union, |
17 | | -) |
| 9 | +from typing import TYPE_CHECKING, Any, Dict, List, Optional, Sequence, Union |
18 | 10 | from urllib.parse import urljoin |
19 | 11 |
|
20 | 12 | from httpx import ( |
|
28 | 20 |
|
29 | 21 | from firebolt.client.client import AsyncClient, AsyncClientV1, AsyncClientV2 |
30 | 22 | from firebolt.common._types import ColType, ParameterType, SetParameter |
31 | | -from firebolt.common.base_cursor import ( |
| 23 | +from firebolt.common.constants import ( |
32 | 24 | JSON_OUTPUT_FORMAT, |
33 | 25 | RESET_SESSION_HEADER, |
34 | 26 | UPDATE_ENDPOINT_HEADER, |
35 | 27 | UPDATE_PARAMETERS_HEADER, |
36 | | - BaseCursor, |
37 | 28 | CursorState, |
| 29 | +) |
| 30 | +from firebolt.common.cursor.base_cursor import ( |
| 31 | + BaseCursor, |
38 | 32 | _parse_update_endpoint, |
39 | 33 | _parse_update_parameters, |
40 | 34 | _raise_if_internal_set_parameter, |
| 35 | +) |
| 36 | +from firebolt.common.cursor.decorators import ( |
41 | 37 | async_not_allowed, |
42 | 38 | check_not_closed, |
43 | 39 | check_query_executed, |
44 | 40 | ) |
| 41 | +from firebolt.common.row_set.asynchronous.base import BaseAsyncRowSet |
| 42 | +from firebolt.common.row_set.asynchronous.in_memory import InMemoryAsyncRowSet |
45 | 43 | from firebolt.common.statement_formatter import create_statement_formatter |
46 | 44 | from firebolt.utils.exception import ( |
47 | 45 | EngineNotRunningError, |
|
58 | 56 | if TYPE_CHECKING: |
59 | 57 | from firebolt.async_db.connection import Connection |
60 | 58 |
|
61 | | -from firebolt.utils.util import _print_error_body, raise_errors_from_body |
| 59 | +from firebolt.utils.async_util import async_islice |
| 60 | +from firebolt.utils.util import ( |
| 61 | + Timer, |
| 62 | + _print_error_body, |
| 63 | + raise_errors_from_body, |
| 64 | +) |
62 | 65 |
|
63 | 66 | logger = logging.getLogger(__name__) |
64 | 67 |
|
@@ -88,6 +91,7 @@ def __init__( |
88 | 91 | self._client = client |
89 | 92 | self.connection = connection |
90 | 93 | self.engine_url = connection.engine_url |
| 94 | + self._row_set: Optional[BaseAsyncRowSet] = None |
91 | 95 | if connection.init_parameters: |
92 | 96 | self._update_set_parameters(connection.init_parameters) |
93 | 97 |
|
@@ -121,13 +125,14 @@ async def _api_request( |
121 | 125 | if self.parameters: |
122 | 126 | parameters = {**self.parameters, **parameters} |
123 | 127 | try: |
124 | | - return await self._client.request( |
| 128 | + req = self._client.build_request( |
125 | 129 | url=urljoin(self.engine_url.rstrip("/") + "/", path or ""), |
126 | 130 | method="POST", |
127 | 131 | params=parameters, |
128 | 132 | content=query, |
129 | 133 | timeout=timeout if timeout is not None else USE_CLIENT_DEFAULT, |
130 | 134 | ) |
| 135 | + return await self._client.send(req, stream=True) |
131 | 136 | except TimeoutException: |
132 | 137 | raise QueryTimeoutError() |
133 | 138 |
|
@@ -170,6 +175,9 @@ async def _validate_set_parameter( |
170 | 175 | # set parameter passed validation |
171 | 176 | self._set_parameters[parameter.name] = parameter.value |
172 | 177 |
|
| 178 | + # append empty result set |
| 179 | + await self._append_row_set_from_response(None) |
| 180 | + |
173 | 181 | async def _parse_response_headers(self, headers: Headers) -> None: |
174 | 182 | if headers.get(UPDATE_ENDPOINT_HEADER): |
175 | 183 | endpoint, params = _parse_update_endpoint( |
@@ -271,8 +279,7 @@ async def _handle_query_execution( |
271 | 279 | self._parse_async_response(resp) |
272 | 280 | else: |
273 | 281 | await self._parse_response_headers(resp.headers) |
274 | | - row_set = self._row_set_from_response(resp) |
275 | | - self._append_row_set(row_set) |
| 282 | + await self._append_row_set_from_response(resp) |
276 | 283 |
|
277 | 284 | @check_not_closed |
278 | 285 | async def execute( |
@@ -353,75 +360,113 @@ async def executemany( |
353 | 360 | await self._do_execute(query, parameters_seq, timeout=timeout_seconds) |
354 | 361 | return self.rowcount |
355 | 362 |
|
356 | | - @abstractmethod |
357 | | - async def is_db_available(self, database: str) -> bool: |
358 | | - """Verify that the database exists.""" |
359 | | - ... |
| 363 | + async def _append_row_set_from_response( |
| 364 | + self, |
| 365 | + response: Optional[Response], |
| 366 | + ) -> None: |
| 367 | + """Store information about executed query.""" |
| 368 | + if self._row_set is None: |
| 369 | + self._row_set = InMemoryAsyncRowSet() |
| 370 | + if response is None: |
| 371 | + self._row_set.append_empty_response() |
| 372 | + else: |
| 373 | + await self._row_set.append_response(response) |
360 | 374 |
|
361 | | - @abstractmethod |
362 | | - async def is_engine_running(self, engine_url: str) -> bool: |
363 | | - """Verify that the engine is running.""" |
364 | | - ... |
| 375 | + _performance_log_message = ( |
| 376 | + "[PERFORMANCE] Parsing query output into native Python types" |
| 377 | + ) |
365 | 378 |
|
366 | | - @wraps(BaseCursor.fetchone) |
| 379 | + @check_not_closed |
| 380 | + @async_not_allowed |
| 381 | + @check_query_executed |
367 | 382 | async def fetchone(self) -> Optional[List[ColType]]: |
368 | 383 | """Fetch the next row of a query result set.""" |
369 | | - return super().fetchone() |
| 384 | + assert self._row_set is not None |
| 385 | + with Timer(self._performance_log_message): |
| 386 | + # anext() is only supported in Python 3.10+ |
| 387 | + # this means we cannot just do return anext(self._row_set), |
| 388 | + # we need to handle iteration manually |
| 389 | + try: |
| 390 | + return await self._row_set.__anext__() |
| 391 | + except StopAsyncIteration: |
| 392 | + return None |
370 | 393 |
|
371 | | - @wraps(BaseCursor.fetchmany) |
| 394 | + @check_not_closed |
| 395 | + @async_not_allowed |
| 396 | + @check_query_executed |
372 | 397 | async def fetchmany(self, size: Optional[int] = None) -> List[List[ColType]]: |
373 | 398 | """ |
374 | 399 | Fetch the next set of rows of a query result; |
375 | | - size is cursor.arraysize by default. |
| 400 | + cursor.arraysize is default size. |
376 | 401 | """ |
377 | | - return super().fetchmany(size) |
| 402 | + assert self._row_set is not None |
| 403 | + size = size if size is not None else self.arraysize |
| 404 | + with Timer(self._performance_log_message): |
| 405 | + return await async_islice(self._row_set, size) |
378 | 406 |
|
379 | | - @wraps(BaseCursor.fetchall) |
| 407 | + @check_not_closed |
| 408 | + @async_not_allowed |
| 409 | + @check_query_executed |
380 | 410 | async def fetchall(self) -> List[List[ColType]]: |
381 | 411 | """Fetch all remaining rows of a query result.""" |
382 | | - return super().fetchall() |
| 412 | + assert self._row_set is not None |
| 413 | + with Timer(self._performance_log_message): |
| 414 | + return [it async for it in self._row_set] |
383 | 415 |
|
384 | 416 | @wraps(BaseCursor.nextset) |
385 | 417 | async def nextset(self) -> None: |
386 | 418 | return super().nextset() |
387 | 419 |
|
| 420 | + async def aclose(self) -> None: |
| 421 | + super().close() |
| 422 | + if self._row_set is not None: |
| 423 | + await self._row_set.aclose() |
| 424 | + |
| 425 | + @abstractmethod |
| 426 | + async def is_db_available(self, database: str) -> bool: |
| 427 | + """Verify that the database exists.""" |
| 428 | + ... |
| 429 | + |
| 430 | + @abstractmethod |
| 431 | + async def is_engine_running(self, engine_url: str) -> bool: |
| 432 | + """Verify that the engine is running.""" |
| 433 | + ... |
| 434 | + |
388 | 435 | # Iteration support |
389 | 436 | @check_not_closed |
390 | 437 | @async_not_allowed |
391 | 438 | @check_query_executed |
392 | 439 | def __aiter__(self) -> Cursor: |
393 | 440 | return self |
394 | 441 |
|
395 | | - # TODO: figure out how to implement __aenter__ and __await__ |
396 | 442 | @check_not_closed |
397 | | - def __aenter__(self) -> Cursor: |
398 | | - return self |
399 | | - |
400 | | - @check_not_closed |
401 | | - def __enter__(self) -> Cursor: |
| 443 | + async def __aenter__(self) -> Cursor: |
402 | 444 | return self |
403 | 445 |
|
404 | | - def __exit__( |
405 | | - self, exc_type: type, exc_val: Exception, exc_tb: TracebackType |
406 | | - ) -> None: |
407 | | - self.close() |
408 | | - |
409 | | - def __await__(self) -> Iterator: |
410 | | - yield None |
411 | | - |
412 | 446 | async def __aexit__( |
413 | 447 | self, exc_type: type, exc_val: Exception, exc_tb: TracebackType |
414 | 448 | ) -> None: |
415 | | - self.close() |
| 449 | + await self.aclose() |
416 | 450 |
|
417 | 451 | @check_not_closed |
418 | 452 | @async_not_allowed |
419 | 453 | @check_query_executed |
420 | 454 | async def __anext__(self) -> List[ColType]: |
421 | | - row = await self.fetchone() |
422 | | - if row is None: |
423 | | - raise StopAsyncIteration |
424 | | - return row |
| 455 | + assert self._row_set is not None |
| 456 | + return await self._row_set.__anext__() |
| 457 | + |
| 458 | + @check_not_closed |
| 459 | + def __enter__(self) -> Cursor: |
| 460 | + warnings.warn( |
| 461 | + "Using __enter__ is deprecated, use 'async with' instead", |
| 462 | + DeprecationWarning, |
| 463 | + ) |
| 464 | + return self |
| 465 | + |
| 466 | + def __exit__( |
| 467 | + self, exc_type: type, exc_val: Exception, exc_tb: TracebackType |
| 468 | + ) -> None: |
| 469 | + return None |
425 | 470 |
|
426 | 471 |
|
427 | 472 | class CursorV2(Cursor): |
|
0 commit comments