Skip to content

Commit 29e8939

Browse files
feat: Fir 8425 implement multi statement funct (#110)
* implement split_format_sql * add split_format_sql unit tests * add cursor error state * enable multi-statement queries * fix unit tests * add nextset unit tests * add integration tests * fix mypy issues, improve nextset * extended tests for nextset * address comments * resolve new comments * remove code smell
1 parent 21d5542 commit 29e8939

File tree

7 files changed

+388
-107
lines changed

7 files changed

+388
-107
lines changed

src/firebolt/async_db/_types.py

Lines changed: 28 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -3,10 +3,10 @@
33
from collections import namedtuple
44
from datetime import date, datetime, timezone
55
from enum import Enum
6-
from typing import Sequence, Union
6+
from typing import List, Sequence, Union
77

88
from sqlparse import parse as parse_sql # type: ignore
9-
from sqlparse.sql import Token, TokenList # type: ignore
9+
from sqlparse.sql import Statement, Token, TokenList # type: ignore
1010
from sqlparse.tokens import Token as TokenType # type: ignore
1111

1212
try:
@@ -224,10 +224,9 @@ def format_value(value: ParameterType) -> str:
224224
raise DataError(f"unsupported parameter type {type(value)}")
225225

226226

227-
def format_sql(query: str, parameters: Sequence[ParameterType]) -> str:
227+
def format_statement(statement: Statement, parameters: Sequence[ParameterType]) -> str:
228228
"""
229-
Substitute placeholders in queries with provided values.
230-
'?' symbol is used as a placeholder. Using '\\?' would result in a plain '?'
229+
Substitute placeholders in a sqlparse statement with provided values.
231230
"""
232231
idx = 0
233232

@@ -245,16 +244,11 @@ def process_token(token: Token) -> Token:
245244
return Token(TokenType.Text, formatted)
246245
if isinstance(token, TokenList):
247246
# Process all children tokens
248-
token.tokens = [process_token(t) for t in token.tokens]
249-
return token
250247

251-
parsed = parse_sql(query)
252-
if not parsed:
253-
return query
254-
if len(parsed) > 1:
255-
raise NotSupportedError("Multi-statement queries are not supported")
248+
return TokenList([process_token(t) for t in token.tokens])
249+
return token
256250

257-
formatted_sql = str(process_token(parsed[0]))
251+
formatted_sql = str(process_token(statement)).rstrip(";")
258252

259253
if idx < len(parameters):
260254
raise DataError(
@@ -263,3 +257,24 @@ def process_token(token: Token) -> Token:
263257
)
264258

265259
return formatted_sql
260+
261+
262+
def split_format_sql(
263+
query: str, parameters: Sequence[Sequence[ParameterType]]
264+
) -> List[str]:
265+
"""
266+
Split a query into separate statement, and format it with parameters
267+
if it's a single statement
268+
Trying to format a multi-statement query would result in NotSupportedError
269+
"""
270+
statements = parse_sql(query)
271+
if not statements:
272+
return [query]
273+
274+
if parameters:
275+
if len(statements) > 1:
276+
raise NotSupportedError(
277+
"formatting multistatement queries is not supported"
278+
)
279+
return [format_statement(statements[0], paramset) for paramset in parameters]
280+
return [str(st).strip().rstrip(";") for st in statements]

src/firebolt/async_db/cursor.py

Lines changed: 103 additions & 62 deletions
Original file line numberDiff line numberDiff line change
@@ -6,7 +6,6 @@
66
from enum import Enum
77
from functools import wraps
88
from inspect import cleandoc
9-
from json import JSONDecodeError
109
from types import TracebackType
1110
from typing import (
1211
TYPE_CHECKING,
@@ -27,9 +26,9 @@
2726
Column,
2827
ParameterType,
2928
RawColType,
30-
format_sql,
3129
parse_type,
3230
parse_value,
31+
split_format_sql,
3332
)
3433
from firebolt.async_db.util import is_db_available, is_engine_running
3534
from firebolt.client import AsyncClient
@@ -38,7 +37,6 @@
3837
DataError,
3938
EngineNotRunningError,
4039
FireboltDatabaseError,
41-
NotSupportedError,
4240
OperationalError,
4341
ProgrammingError,
4442
QueryNotRunError,
@@ -55,6 +53,7 @@
5553

5654
class CursorState(Enum):
5755
NONE = 1
56+
ERROR = 2
5857
DONE = 3
5958
CLOSED = 4
6059

@@ -99,6 +98,8 @@ class BaseCursor:
9998
"_rows",
10099
"_idx",
101100
"_idx_lock",
101+
"_row_sets",
102+
"_next_set_idx",
102103
)
103104

104105
default_arraysize = 1
@@ -107,8 +108,15 @@ def __init__(self, client: AsyncClient, connection: Connection):
107108
self.connection = connection
108109
self._client = client
109110
self._arraysize = self.default_arraysize
111+
# These fields initialized here for type annotations purpose
110112
self._rows: Optional[List[List[RawColType]]] = None
111113
self._descriptions: Optional[List[Column]] = None
114+
self._row_sets: List[
115+
Tuple[int, Optional[List[Column]], Optional[List[List[RawColType]]]]
116+
] = []
117+
self._rowcount = -1
118+
self._idx = 0
119+
self._next_set_idx = 0
112120
self._reset()
113121

114122
def __del__(self) -> None:
@@ -164,24 +172,58 @@ def close(self) -> None:
164172
# remove typecheck skip after connection is implemented
165173
self.connection._remove_cursor(self) # type: ignore
166174

167-
def _store_query_data(self, response: Response) -> None:
175+
def _append_query_data(self, response: Response) -> None:
168176
"""Store information about executed query from httpx response."""
169177

178+
row_set: Tuple[
179+
int, Optional[List[Column]], Optional[List[List[RawColType]]]
180+
] = (-1, None, None)
181+
170182
# Empty response is returned for insert query
171-
if response.headers.get("content-length", "") == "0":
172-
return
173-
try:
174-
query_data = response.json()
175-
self._rowcount = int(query_data["rows"])
176-
self._descriptions = [
177-
Column(d["name"], parse_type(d["type"]), None, None, None, None, None)
178-
for d in query_data["meta"]
179-
]
180-
181-
# Parse data during fetch
182-
self._rows = query_data["data"]
183-
except (KeyError, JSONDecodeError) as err:
184-
raise DataError(f"Invalid query data format: {str(err)}")
183+
if response.headers.get("content-length", "") != "0":
184+
try:
185+
query_data = response.json()
186+
rowcount = int(query_data["rows"])
187+
descriptions = [
188+
Column(
189+
d["name"], parse_type(d["type"]), None, None, None, None, None
190+
)
191+
for d in query_data["meta"]
192+
]
193+
194+
# Parse data during fetch
195+
rows = query_data["data"]
196+
row_set = (rowcount, descriptions, rows)
197+
except (KeyError, ValueError) as err:
198+
raise DataError(f"Invalid query data format: {str(err)}")
199+
200+
self._row_sets.append(row_set)
201+
if self._next_set_idx == 0:
202+
# Populate values for first set
203+
self._pop_next_set()
204+
205+
@check_not_closed
206+
@check_query_executed
207+
def nextset(self) -> Optional[bool]:
208+
"""
209+
Skip to the next available set, discarding any remaining rows
210+
from the current set.
211+
Returns True if operation was successful,
212+
None if there are no more sets to retrive
213+
"""
214+
return self._pop_next_set()
215+
216+
def _pop_next_set(self) -> Optional[bool]:
217+
"""
218+
Same functionality as .nextset, but doesn't check that query has been executed.
219+
"""
220+
if self._next_set_idx >= len(self._row_sets):
221+
return None
222+
self._rowcount, self._descriptions, self._rows = self._row_sets[
223+
self._next_set_idx
224+
]
225+
self._next_set_idx += 1
226+
return True
185227

186228
async def _raise_if_error(self, resp: Response) -> None:
187229
"""Raise a proper error if any"""
@@ -213,29 +255,52 @@ def _reset(self) -> None:
213255
self._descriptions = None
214256
self._rowcount = -1
215257
self._idx = 0
258+
self._row_sets = []
259+
self._next_set_idx = 0
216260

217261
async def _do_execute_request(
218262
self,
219263
query: str,
220-
parameters: Optional[Sequence[ParameterType]] = None,
264+
parameters: Sequence[Sequence[ParameterType]],
221265
set_parameters: Optional[Dict] = None,
222-
) -> Response:
223-
if parameters:
224-
query = format_sql(query, parameters)
225-
226-
resp = await self._client.request(
227-
url="/",
228-
method="POST",
229-
params={
230-
"database": self.connection.database,
231-
"output_format": JSON_OUTPUT_FORMAT,
232-
**(set_parameters or dict()),
233-
},
234-
content=query,
235-
)
266+
) -> None:
267+
self._reset()
268+
try:
269+
270+
queries = split_format_sql(query, parameters)
271+
272+
for query in queries:
273+
274+
start_time = time.time()
275+
# our CREATE EXTERNAL TABLE queries currently require credentials,
276+
# so we will skip logging those queries.
277+
# https://docs.firebolt.io/sql-reference/commands/ddl-commands#create-external-table
278+
if not re.search("aws_key_id|credentials", query, flags=re.IGNORECASE):
279+
logger.debug(f"Running query: {query}")
280+
281+
resp = await self._client.request(
282+
url="/",
283+
method="POST",
284+
params={
285+
"database": self.connection.database,
286+
"output_format": JSON_OUTPUT_FORMAT,
287+
**(set_parameters or dict()),
288+
},
289+
content=query,
290+
)
291+
292+
await self._raise_if_error(resp)
293+
self._append_query_data(resp)
294+
logger.info(
295+
f"Query fetched {self.rowcount} rows in"
296+
f" {time.time() - start_time} seconds"
297+
)
298+
299+
self._state = CursorState.DONE
236300

237-
await self._raise_if_error(resp)
238-
return resp
301+
except Exception:
302+
self._state = CursorState.ERROR
303+
raise
239304

240305
@check_not_closed
241306
async def execute(
@@ -245,21 +310,9 @@ async def execute(
245310
set_parameters: Optional[Dict] = None,
246311
) -> int:
247312
"""Prepare and execute a database query. Return row count."""
248-
start_time = time.time()
249313

250-
# our CREATE EXTERNAL TABLE queries currently require credentials,
251-
# so we will skip logging those queries.
252-
# https://docs.firebolt.io/sql-reference/commands/ddl-commands#create-external-table
253-
if not re.search("aws_key_id|credentials", query, flags=re.IGNORECASE):
254-
logger.debug(f"Running query: {query}")
255-
256-
self._reset()
257-
resp = await self._do_execute_request(query, parameters, set_parameters)
258-
self._store_query_data(resp)
259-
self._state = CursorState.DONE
260-
logger.info(
261-
f"Query fetched {self.rowcount} rows in {time.time() - start_time} seconds"
262-
)
314+
params_list = [parameters] if parameters else []
315+
await self._do_execute_request(query, params_list, set_parameters)
263316
return self.rowcount
264317

265318
@check_not_closed
@@ -270,19 +323,7 @@ async def executemany(
270323
Prepare and execute a database query against all parameter
271324
sequences provided. Return last query row count.
272325
"""
273-
274-
if len(parameters_seq) > 1:
275-
raise NotSupportedError(
276-
"Parameterized multi-statement queries are not supported"
277-
)
278-
279-
self._reset()
280-
resp = None
281-
for parameters in parameters_seq:
282-
resp = await self._do_execute_request(query, parameters)
283-
if resp is not None:
284-
self._store_query_data(resp)
285-
self._state = CursorState.DONE
326+
await self._do_execute_request(query, parameters_seq)
286327
return self.rowcount
287328

288329
def _parse_row(self, row: List[RawColType]) -> List[ColType]:

tests/integration/dbapi/async/test_queries_async.py

Lines changed: 41 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -237,3 +237,44 @@ async def test_empty_query(c: Cursor, query: str, params: tuple) -> None:
237237
[params + ["?"]],
238238
"Invalid data in table after parameterized insert",
239239
)
240+
241+
242+
@mark.asyncio
243+
async def test_multi_statement_query(connection: Connection) -> None:
244+
"""Query parameters are handled properly"""
245+
246+
with connection.cursor() as c:
247+
await c.execute("DROP TABLE IF EXISTS test_tb_multi_statement")
248+
await c.execute(
249+
"CREATE FACT TABLE test_tb_multi_statement(i int, s string) primary index i"
250+
)
251+
252+
assert (
253+
await c.execute(
254+
"INSERT INTO test_tb_multi_statement values (1, 'a'), (2, 'b');"
255+
"SELECT * FROM test_tb_multi_statement"
256+
)
257+
== -1
258+
), "Invalid row count returned for insert"
259+
assert c.rowcount == -1, "Invalid row count"
260+
assert c.description is None, "Invalid description"
261+
262+
assert c.nextset()
263+
264+
assert c.rowcount == 2, "Invalid select row count"
265+
assert_deep_eq(
266+
c.description,
267+
[
268+
Column("i", int, None, None, None, None, None),
269+
Column("s", str, None, None, None, None, None),
270+
],
271+
"Invalid select query description",
272+
)
273+
274+
assert_deep_eq(
275+
await c.fetchall(),
276+
[[1, "a"], [2, "b"]],
277+
"Invalid data in table after parameterized insert",
278+
)
279+
280+
assert c.nextset() is None

0 commit comments

Comments
 (0)