Skip to content

Commit e24944e

Browse files
ptiurinlucianosrp
andauthored
fix(FIR-43230): string escape in query parameters (#418)
Co-authored-by: Luciano Scarpulla <[email protected]>
1 parent 89e650d commit e24944e

File tree

15 files changed

+424
-229
lines changed

15 files changed

+424
-229
lines changed

src/firebolt/async_db/cursor.py

Lines changed: 19 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -27,12 +27,7 @@
2727
)
2828

2929
from firebolt.client.client import AsyncClient, AsyncClientV1, AsyncClientV2
30-
from firebolt.common._types import (
31-
ColType,
32-
ParameterType,
33-
SetParameter,
34-
split_format_sql,
35-
)
30+
from firebolt.common._types import ColType, ParameterType, SetParameter
3631
from firebolt.common.base_cursor import (
3732
JSON_OUTPUT_FORMAT,
3833
RESET_SESSION_HEADER,
@@ -47,6 +42,7 @@
4742
check_not_closed,
4843
check_query_executed,
4944
)
45+
from firebolt.common.statement_formatter import create_statement_formatter
5046
from firebolt.utils.exception import (
5147
EngineNotRunningError,
5248
FireboltDatabaseError,
@@ -209,7 +205,9 @@ async def _do_execute(
209205
) -> None:
210206
self._reset()
211207
queries: List[Union[SetParameter, str]] = (
212-
[raw_query] if skip_parsing else split_format_sql(raw_query, parameters)
208+
[raw_query]
209+
if skip_parsing
210+
else self._formatter.split_format_sql(raw_query, parameters)
213211
)
214212
timeout_controller = TimeoutController(timeout)
215213

@@ -435,7 +433,13 @@ def __init__(
435433
**kwargs: Any,
436434
) -> None:
437435
assert isinstance(client, AsyncClientV2)
438-
super().__init__(*args, client=client, connection=connection, **kwargs)
436+
super().__init__(
437+
*args,
438+
client=client,
439+
connection=connection,
440+
formatter=create_statement_formatter(version=2),
441+
**kwargs,
442+
)
439443

440444
@check_not_closed
441445
async def execute_async(
@@ -512,7 +516,13 @@ def __init__(
512516
**kwargs: Any,
513517
) -> None:
514518
assert isinstance(client, AsyncClientV1)
515-
super().__init__(*args, client=client, connection=connection, **kwargs)
519+
super().__init__(
520+
*args,
521+
client=client,
522+
connection=connection,
523+
formatter=create_statement_formatter(version=1),
524+
**kwargs,
525+
)
516526

517527
async def is_db_available(self, database_name: str) -> bool:
518528
"""

src/firebolt/common/_types.py

Lines changed: 3 additions & 174 deletions
Original file line numberDiff line numberDiff line change
@@ -2,24 +2,11 @@
22

33
import re
44
from collections import namedtuple
5-
from datetime import date, datetime, timezone
5+
from datetime import date, datetime
66
from decimal import Decimal
77
from enum import Enum
88
from io import StringIO
9-
from typing import Any, Dict, List, Optional, Sequence, Tuple, Union
10-
11-
from sqlparse import parse as parse_sql # type: ignore
12-
from sqlparse.sql import ( # type: ignore
13-
Comment,
14-
Comparison,
15-
Statement,
16-
Token,
17-
TokenList,
18-
)
19-
from sqlparse.tokens import Comparison as ComparisonType # type: ignore
20-
from sqlparse.tokens import Newline # type: ignore
21-
from sqlparse.tokens import Whitespace # type: ignore
22-
from sqlparse.tokens import Token as TokenType # type: ignore
9+
from typing import Any, Dict, List, Sequence, Tuple, Union
2310

2411
try:
2512
from ciso8601 import parse_datetime # type: ignore
@@ -46,11 +33,7 @@ def parse_datetime(datetime_string: str) -> datetime:
4633
return datetime.fromisoformat(_fix_timezone(_fix_milliseconds(datetime_string)))
4734

4835

49-
from firebolt.utils.exception import (
50-
DataError,
51-
InterfaceError,
52-
NotSupportedError,
53-
)
36+
from firebolt.utils.exception import DataError, NotSupportedError
5437
from firebolt.utils.util import cached_property
5538

5639
_NoneType = type(None)
@@ -372,158 +355,4 @@ def parse_value(
372355
raise DataError(f"Unsupported data type returned: {ctype.__name__}")
373356

374357

375-
escape_chars = {
376-
"\0": "\\0",
377-
"\\": "\\\\",
378-
"'": "\\'",
379-
}
380-
381-
382-
def format_value(value: ParameterType) -> str:
383-
"""For Python value to be used in a SQL query."""
384-
if isinstance(value, bool):
385-
return "true" if value else "false"
386-
if isinstance(value, (int, float, Decimal)):
387-
return str(value)
388-
elif isinstance(value, str):
389-
return f"'{''.join(escape_chars.get(c, c) for c in value)}'"
390-
elif isinstance(value, datetime):
391-
if value.tzinfo is not None:
392-
value = value.astimezone(timezone.utc)
393-
return f"'{value.strftime('%Y-%m-%d %H:%M:%S')}'"
394-
elif isinstance(value, date):
395-
return f"'{value.isoformat()}'"
396-
elif isinstance(value, bytes):
397-
# Encode each byte into hex
398-
return "E'" + "".join(f"\\x{b:02x}" for b in value) + "'"
399-
if value is None:
400-
return "NULL"
401-
elif isinstance(value, Sequence):
402-
return f"[{', '.join(format_value(it) for it in value)}]"
403-
404-
raise DataError(f"unsupported parameter type {type(value)}")
405-
406-
407-
def format_statement(statement: Statement, parameters: Sequence[ParameterType]) -> str:
408-
"""
409-
Substitute placeholders in a `sqlparse` statement with provided values.
410-
"""
411-
idx = 0
412-
413-
def process_token(token: Token) -> Token:
414-
nonlocal idx
415-
if token.ttype == TokenType.Name.Placeholder:
416-
# Replace placeholder with formatted parameter
417-
if idx >= len(parameters):
418-
raise DataError(
419-
"not enough parameters provided for substitution: given "
420-
f"{len(parameters)}, found one more"
421-
)
422-
formatted = format_value(parameters[idx])
423-
idx += 1
424-
return Token(TokenType.Text, formatted)
425-
if isinstance(token, TokenList):
426-
# Process all children tokens
427-
428-
return TokenList([process_token(t) for t in token.tokens])
429-
return token
430-
431-
formatted_sql = statement_to_sql(process_token(statement))
432-
433-
if idx < len(parameters):
434-
raise DataError(
435-
f"too many parameters provided for substitution: given {len(parameters)}, "
436-
f"used only {idx}"
437-
)
438-
439-
return formatted_sql
440-
441-
442358
SetParameter = namedtuple("SetParameter", ["name", "value"])
443-
444-
445-
def statement_to_set(statement: Statement) -> Optional[SetParameter]:
446-
"""
447-
Try to parse `statement` as a `SET` command.
448-
Return `None` if it's not a `SET` command.
449-
"""
450-
# Filter out meaningless tokens like Punctuation and Whitespaces
451-
skip_types = [Whitespace, Newline]
452-
tokens = [
453-
token
454-
for token in statement.tokens
455-
if token.ttype not in skip_types and not isinstance(token, Comment)
456-
]
457-
# Trim tail punctuation
458-
right_idx = len(tokens) - 1
459-
while str(tokens[right_idx]) == ";":
460-
right_idx -= 1
461-
462-
tokens = tokens[: right_idx + 1]
463-
464-
# Check if it's a SET statement by checking if it starts with set
465-
if (
466-
len(tokens) > 0
467-
and tokens[0].ttype == TokenType.Keyword
468-
and tokens[0].value.lower() == "set"
469-
):
470-
# Check if set statement has a valid format
471-
if len(tokens) == 2 and isinstance(tokens[1], Comparison):
472-
return SetParameter(
473-
statement_to_sql(tokens[1].left),
474-
statement_to_sql(tokens[1].right).strip("'"),
475-
)
476-
# Or if at least there is a comparison
477-
cmp_idx = next(
478-
(
479-
i
480-
for i, token in enumerate(tokens)
481-
if token.ttype == ComparisonType or isinstance(token, Comparison)
482-
),
483-
None,
484-
)
485-
if cmp_idx:
486-
left_tokens, right_tokens = tokens[1:cmp_idx], tokens[cmp_idx + 1 :]
487-
if isinstance(tokens[cmp_idx], Comparison):
488-
left_tokens = left_tokens + [tokens[cmp_idx].left]
489-
right_tokens = [tokens[cmp_idx].right] + right_tokens
490-
491-
if left_tokens and right_tokens:
492-
return SetParameter(
493-
"".join(statement_to_sql(t) for t in left_tokens),
494-
"".join(statement_to_sql(t) for t in right_tokens).strip("'"),
495-
)
496-
497-
raise InterfaceError(
498-
f"Invalid set statement format: {statement_to_sql(statement)},"
499-
" expected SET <param> = <value>"
500-
)
501-
return None
502-
503-
504-
def statement_to_sql(statement: Statement) -> str:
505-
return str(statement).strip().rstrip(";")
506-
507-
508-
def split_format_sql(
509-
query: str, parameters: Sequence[Sequence[ParameterType]]
510-
) -> List[Union[str, SetParameter]]:
511-
"""
512-
Multi-statement query formatting will result in `NotSupportedError`.
513-
Instead, split a query into a separate statement and format with parameters.
514-
"""
515-
statements = parse_sql(query)
516-
if not statements:
517-
return [query]
518-
519-
if parameters:
520-
if len(statements) > 1:
521-
raise NotSupportedError(
522-
"Formatting multi-statement queries is not supported."
523-
)
524-
if statement_to_set(statements[0]):
525-
raise NotSupportedError("Formatting set statements is not supported.")
526-
return [format_statement(statements[0], paramset) for paramset in parameters]
527-
528-
# Try parsing each statement as a SET, otherwise return as a plain sql string
529-
return [statement_to_set(st) or statement_to_sql(st) for st in statements]

src/firebolt/common/base_cursor.py

Lines changed: 6 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -18,6 +18,7 @@
1818
parse_type,
1919
parse_value,
2020
)
21+
from firebolt.common.statement_formatter import StatementFormatter
2122
from firebolt.utils.exception import (
2223
ConfigurationError,
2324
CursorClosedError,
@@ -187,6 +188,7 @@ class BaseCursor:
187188
"_rows",
188189
"_idx",
189190
"_row_sets",
191+
"_formatter",
190192
"_next_set_idx",
191193
"_set_parameters",
192194
"_query_id",
@@ -196,13 +198,16 @@ class BaseCursor:
196198

197199
default_arraysize = 1
198200

199-
def __init__(self, *args: Any, **kwargs: Any) -> None:
201+
def __init__(
202+
self, *args: Any, formatter: StatementFormatter, **kwargs: Any
203+
) -> None:
200204
self._arraysize = self.default_arraysize
201205
# These fields initialized here for type annotations purpose
202206
self._rows: Optional[List[List[RawColType]]] = None
203207
self._descriptions: Optional[List[Column]] = None
204208
self._statistics: Optional[Statistics] = None
205209
self._row_sets: List[RowSet] = []
210+
self._formatter = formatter
206211
# User-defined set parameters
207212
self._set_parameters: Dict[str, Any] = dict()
208213
# Server-side parameters (user can't change them)

0 commit comments

Comments
 (0)