Skip to content

Commit 2d53aee

Browse files
feat: Fir 11986 full support for set statement (#149)
* implement set statements parsing * set statements support for cursor * add deprecation warning * add integration tests * remove redundant type ignore * extend unit tests * Set -> SetParameter * use token caching for integration tests * fix long query tests * fix long test Co-authored-by: Stepan Burlakov <[email protected]>
1 parent 9a3e76f commit 2d53aee

File tree

12 files changed

+686
-524
lines changed

12 files changed

+686
-524
lines changed

src/firebolt/async_db/_types.py

Lines changed: 53 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -4,10 +4,15 @@
44
from datetime import date, datetime, timezone
55
from decimal import Decimal
66
from enum import Enum
7-
from typing import List, Sequence, Union
7+
from typing import List, Optional, Sequence, Union
88

99
from sqlparse import parse as parse_sql # type: ignore
10-
from sqlparse.sql import Statement, Token, TokenList # type: ignore
10+
from sqlparse.sql import ( # type: ignore
11+
Comparison,
12+
Statement,
13+
Token,
14+
TokenList,
15+
)
1116
from sqlparse.tokens import Token as TokenType # type: ignore
1217

1318
try:
@@ -22,7 +27,11 @@ def parse_datetime(date_string: str) -> datetime: # type: ignore
2227
return datetime.fromisoformat(date_string)
2328

2429

25-
from firebolt.common.exception import DataError, NotSupportedError
30+
from firebolt.common.exception import (
31+
DataError,
32+
InterfaceError,
33+
NotSupportedError,
34+
)
2635
from firebolt.common.util import cached_property
2736

2837
_NoneType = type(None)
@@ -312,7 +321,7 @@ def process_token(token: Token) -> Token:
312321
return TokenList([process_token(t) for t in token.tokens])
313322
return token
314323

315-
formatted_sql = str(process_token(statement)).rstrip(";")
324+
formatted_sql = statement_to_sql(process_token(statement))
316325

317326
if idx < len(parameters):
318327
raise DataError(
@@ -323,9 +332,43 @@ def process_token(token: Token) -> Token:
323332
return formatted_sql
324333

325334

335+
SetParameter = namedtuple("SetParameter", ["name", "value"])
336+
337+
338+
def statement_to_set(statement: Statement) -> Optional[SetParameter]:
339+
"""Try to parse statement as a SET command. Return None if it's not a SET command"""
340+
# Filter out meaningless tokens like Punctuation and Whitespaces
341+
tokens = [
342+
token
343+
for token in statement.tokens
344+
if token.ttype == TokenType.Keyword or isinstance(token, Comparison)
345+
]
346+
347+
# Check if it's a SET statement by checking if it starts with set
348+
if (
349+
len(tokens) > 0
350+
and tokens[0].ttype == TokenType.Keyword
351+
and tokens[0].value.lower() == "set"
352+
):
353+
# Check if set statement has a valid format
354+
if len(tokens) != 2 or not isinstance(tokens[1], Comparison):
355+
raise InterfaceError(
356+
f"Invalid set statement format: {statement_to_sql(statement)},"
357+
" expected SET <param> = <value>"
358+
)
359+
return SetParameter(
360+
statement_to_sql(tokens[1].left), statement_to_sql(tokens[1].right)
361+
)
362+
return None
363+
364+
365+
def statement_to_sql(statement: Statement) -> str:
366+
return str(statement).strip().rstrip(";")
367+
368+
326369
def split_format_sql(
327370
query: str, parameters: Sequence[Sequence[ParameterType]]
328-
) -> List[str]:
371+
) -> List[Union[str, SetParameter]]:
329372
"""
330373
Split a query into separate statement, and format it with parameters
331374
if it's a single statement
@@ -340,5 +383,9 @@ def split_format_sql(
340383
raise NotSupportedError(
341384
"formatting multistatement queries is not supported"
342385
)
386+
if statement_to_set(statements[0]):
387+
raise NotSupportedError("formatting set statements is not supported")
343388
return [format_statement(statements[0], paramset) for paramset in parameters]
344-
return [str(st).strip().rstrip(";") for st in statements]
389+
390+
# Try parsing each statement as a SET, otherwise return as a plain sql string
391+
return [statement_to_set(st) or statement_to_sql(st) for st in statements]

src/firebolt/async_db/cursor.py

Lines changed: 90 additions & 49 deletions
Original file line numberDiff line numberDiff line change
@@ -26,6 +26,7 @@
2626
Column,
2727
ParameterType,
2828
RawColType,
29+
SetParameter,
2930
parse_type,
3031
parse_value,
3132
split_format_sql,
@@ -100,6 +101,7 @@ class BaseCursor:
100101
"_idx_lock",
101102
"_row_sets",
102103
"_next_set_idx",
104+
"_set_parameters",
103105
)
104106

105107
default_arraysize = 1
@@ -114,6 +116,7 @@ def __init__(self, client: AsyncClient, connection: Connection):
114116
self._row_sets: List[
115117
Tuple[int, Optional[List[Column]], Optional[List[List[RawColType]]]]
116118
] = []
119+
self._set_parameters: Dict[str, Any] = dict()
117120
self._rowcount = -1
118121
self._idx = 0
119122
self._next_set_idx = 0
@@ -172,37 +175,6 @@ def close(self) -> None:
172175
# remove typecheck skip after connection is implemented
173176
self.connection._remove_cursor(self) # type: ignore
174177

175-
def _append_query_data(self, response: Response) -> None:
176-
"""Store information about executed query from httpx response."""
177-
178-
row_set: Tuple[
179-
int, Optional[List[Column]], Optional[List[List[RawColType]]]
180-
] = (-1, None, None)
181-
182-
# Empty response is returned for insert query
183-
if response.headers.get("content-length", "") != "0":
184-
try:
185-
# Skip parsing floats to properly parse them later
186-
query_data = response.json(parse_float=str)
187-
rowcount = int(query_data["rows"])
188-
descriptions = [
189-
Column(
190-
d["name"], parse_type(d["type"]), None, None, None, None, None
191-
)
192-
for d in query_data["meta"]
193-
]
194-
195-
# Parse data during fetch
196-
rows = query_data["data"]
197-
row_set = (rowcount, descriptions, rows)
198-
except (KeyError, ValueError) as err:
199-
raise DataError(f"Invalid query data format: {str(err)}")
200-
201-
self._row_sets.append(row_set)
202-
if self._next_set_idx == 0:
203-
# Populate values for first set
204-
self._pop_next_set()
205-
206178
@check_not_closed
207179
@check_query_executed
208180
def nextset(self) -> Optional[bool]:
@@ -227,6 +199,9 @@ def _pop_next_set(self) -> Optional[bool]:
227199
self._next_set_idx += 1
228200
return True
229201

202+
def flush_parameters(self) -> None:
203+
self._set_parameters = dict()
204+
230205
async def _raise_if_error(self, resp: Response) -> None:
231206
"""Raise a proper error if any"""
232207
if resp.status_code == codes.INTERNAL_SERVER_ERROR:
@@ -260,39 +235,105 @@ def _reset(self) -> None:
260235
self._row_sets = []
261236
self._next_set_idx = 0
262237

263-
async def _do_execute_request(
238+
def _row_set_from_response(
239+
self, response: Response
240+
) -> Tuple[int, Optional[List[Column]], Optional[List[List[RawColType]]]]:
241+
"""Fetch information about executed query from http response"""
242+
243+
# Empty response is returned for insert query
244+
if response.headers.get("content-length", "") == "0":
245+
return (-1, None, None)
246+
247+
try:
248+
# Skip parsing floats to properly parse them later
249+
query_data = response.json(parse_float=str)
250+
rowcount = int(query_data["rows"])
251+
descriptions = [
252+
Column(d["name"], parse_type(d["type"]), None, None, None, None, None)
253+
for d in query_data["meta"]
254+
]
255+
256+
# Parse data during fetch
257+
rows = query_data["data"]
258+
return (rowcount, descriptions, rows)
259+
except (KeyError, ValueError) as err:
260+
raise DataError(f"Invalid query data format: {str(err)}")
261+
262+
def _append_row_set(
264263
self,
265-
query: str,
264+
row_set: Tuple[int, Optional[List[Column]], Optional[List[List[RawColType]]]],
265+
) -> None:
266+
"""Store information about executed query."""
267+
self._row_sets.append(row_set)
268+
if self._next_set_idx == 0:
269+
# Populate values for first set
270+
self._pop_next_set()
271+
272+
async def _api_request(
273+
self, query: str, set_parameters: Optional[dict]
274+
) -> Response:
275+
return await self._client.request(
276+
url="/",
277+
method="POST",
278+
params={
279+
"database": self.connection.database,
280+
"output_format": JSON_OUTPUT_FORMAT,
281+
**self._set_parameters,
282+
**(set_parameters or dict()),
283+
},
284+
content=query,
285+
)
286+
287+
async def _do_execute(
288+
self,
289+
raw_query: str,
266290
parameters: Sequence[Sequence[ParameterType]],
267291
set_parameters: Optional[Dict] = None,
268292
) -> None:
269293
self._reset()
294+
if set_parameters is not None:
295+
logger.warning(
296+
"Passing set parameters as an argument is deprecated. Please run "
297+
"a query 'SET <param> = <value>'"
298+
)
270299
try:
271300

272-
queries = split_format_sql(query, parameters)
301+
queries = split_format_sql(raw_query, parameters)
273302

274303
for query in queries:
275304

276305
start_time = time.time()
277306
# our CREATE EXTERNAL TABLE queries currently require credentials,
278307
# so we will skip logging those queries.
279308
# https://docs.firebolt.io/sql-reference/commands/ddl-commands#create-external-table
280-
if not re.search("aws_key_id|credentials", query, flags=re.IGNORECASE):
309+
if isinstance(query, SetParameter) or not re.search(
310+
"aws_key_id|credentials", query, flags=re.IGNORECASE
311+
):
281312
logger.debug(f"Running query: {query}")
282313

283-
resp = await self._client.request(
284-
url="/",
285-
method="POST",
286-
params={
287-
"database": self.connection.database,
288-
"output_format": JSON_OUTPUT_FORMAT,
289-
**(set_parameters or dict()),
290-
},
291-
content=query,
292-
)
314+
# Define type for mypy
315+
row_set: Tuple[
316+
int, Optional[List[Column]], Optional[List[List[RawColType]]]
317+
] = (-1, None, None)
318+
if isinstance(query, SetParameter):
319+
# Validate parameter by executing simple query with it
320+
resp = await self._api_request(
321+
"select 1", {query.name: query.value}
322+
)
323+
# Handle invalid set parameter
324+
if resp.status_code == codes.BAD_REQUEST:
325+
raise OperationalError(resp.text)
326+
await self._raise_if_error(resp)
327+
328+
# set parameter passed validation
329+
self._set_parameters[query.name] = query.value
330+
else:
331+
resp = await self._api_request(query, set_parameters)
332+
await self._raise_if_error(resp)
333+
row_set = self._row_set_from_response(resp)
334+
335+
self._append_row_set(row_set)
293336

294-
await self._raise_if_error(resp)
295-
self._append_query_data(resp)
296337
logger.info(
297338
f"Query fetched {self.rowcount} rows in"
298339
f" {time.time() - start_time} seconds"
@@ -314,7 +355,7 @@ async def execute(
314355
"""Prepare and execute a database query. Return row count."""
315356

316357
params_list = [parameters] if parameters else []
317-
await self._do_execute_request(query, params_list, set_parameters)
358+
await self._do_execute(query, params_list, set_parameters)
318359
return self.rowcount
319360

320361
@check_not_closed
@@ -325,7 +366,7 @@ async def executemany(
325366
Prepare and execute a database query against all parameter
326367
sequences provided. Return last query row count.
327368
"""
328-
await self._do_execute_request(query, parameters_seq)
369+
await self._do_execute(query, parameters_seq)
329370
return self.rowcount
330371

331372
def _parse_row(self, row: List[RawColType]) -> List[ColType]:

tests/conftest.py

Lines changed: 0 additions & 8 deletions
This file was deleted.

0 commit comments

Comments
 (0)