22
33import re
44from collections import namedtuple
5- from datetime import date , datetime , timezone
5+ from datetime import date , datetime
66from decimal import Decimal
77from enum import Enum
88from io import StringIO
9- from typing import Any , Dict , List , Optional , Sequence , Tuple , Union
10-
11- from sqlparse import parse as parse_sql # type: ignore
12- from sqlparse .sql import ( # type: ignore
13- Comment ,
14- Comparison ,
15- Statement ,
16- Token ,
17- TokenList ,
18- )
19- from sqlparse .tokens import Comparison as ComparisonType # type: ignore
20- from sqlparse .tokens import Newline # type: ignore
21- from sqlparse .tokens import Whitespace # type: ignore
22- from sqlparse .tokens import Token as TokenType # type: ignore
9+ from typing import Any , Dict , List , Sequence , Tuple , Union
2310
2411try :
2512 from ciso8601 import parse_datetime # type: ignore
@@ -46,11 +33,7 @@ def parse_datetime(datetime_string: str) -> datetime:
4633 return datetime .fromisoformat (_fix_timezone (_fix_milliseconds (datetime_string )))
4734
4835
49- from firebolt .utils .exception import (
50- DataError ,
51- InterfaceError ,
52- NotSupportedError ,
53- )
36+ from firebolt .utils .exception import DataError , NotSupportedError
5437from firebolt .utils .util import cached_property
5538
5639_NoneType = type (None )
@@ -372,158 +355,4 @@ def parse_value(
372355 raise DataError (f"Unsupported data type returned: { ctype .__name__ } " )
373356
374357
375- escape_chars = {
376- "\0 " : "\\ 0" ,
377- "\\ " : "\\ \\ " ,
378- "'" : "\\ '" ,
379- }
380-
381-
382- def format_value (value : ParameterType ) -> str :
383- """For Python value to be used in a SQL query."""
384- if isinstance (value , bool ):
385- return "true" if value else "false"
386- if isinstance (value , (int , float , Decimal )):
387- return str (value )
388- elif isinstance (value , str ):
389- return f"'{ '' .join (escape_chars .get (c , c ) for c in value )} '"
390- elif isinstance (value , datetime ):
391- if value .tzinfo is not None :
392- value = value .astimezone (timezone .utc )
393- return f"'{ value .strftime ('%Y-%m-%d %H:%M:%S' )} '"
394- elif isinstance (value , date ):
395- return f"'{ value .isoformat ()} '"
396- elif isinstance (value , bytes ):
397- # Encode each byte into hex
398- return "E'" + "" .join (f"\\ x{ b :02x} " for b in value ) + "'"
399- if value is None :
400- return "NULL"
401- elif isinstance (value , Sequence ):
402- return f"[{ ', ' .join (format_value (it ) for it in value )} ]"
403-
404- raise DataError (f"unsupported parameter type { type (value )} " )
405-
406-
407- def format_statement (statement : Statement , parameters : Sequence [ParameterType ]) -> str :
408- """
409- Substitute placeholders in a `sqlparse` statement with provided values.
410- """
411- idx = 0
412-
413- def process_token (token : Token ) -> Token :
414- nonlocal idx
415- if token .ttype == TokenType .Name .Placeholder :
416- # Replace placeholder with formatted parameter
417- if idx >= len (parameters ):
418- raise DataError (
419- "not enough parameters provided for substitution: given "
420- f"{ len (parameters )} , found one more"
421- )
422- formatted = format_value (parameters [idx ])
423- idx += 1
424- return Token (TokenType .Text , formatted )
425- if isinstance (token , TokenList ):
426- # Process all children tokens
427-
428- return TokenList ([process_token (t ) for t in token .tokens ])
429- return token
430-
431- formatted_sql = statement_to_sql (process_token (statement ))
432-
433- if idx < len (parameters ):
434- raise DataError (
435- f"too many parameters provided for substitution: given { len (parameters )} , "
436- f"used only { idx } "
437- )
438-
439- return formatted_sql
440-
441-
442358SetParameter = namedtuple ("SetParameter" , ["name" , "value" ])
443-
444-
445- def statement_to_set (statement : Statement ) -> Optional [SetParameter ]:
446- """
447- Try to parse `statement` as a `SET` command.
448- Return `None` if it's not a `SET` command.
449- """
450- # Filter out meaningless tokens like Punctuation and Whitespaces
451- skip_types = [Whitespace , Newline ]
452- tokens = [
453- token
454- for token in statement .tokens
455- if token .ttype not in skip_types and not isinstance (token , Comment )
456- ]
457- # Trim tail punctuation
458- right_idx = len (tokens ) - 1
459- while str (tokens [right_idx ]) == ";" :
460- right_idx -= 1
461-
462- tokens = tokens [: right_idx + 1 ]
463-
464- # Check if it's a SET statement by checking if it starts with set
465- if (
466- len (tokens ) > 0
467- and tokens [0 ].ttype == TokenType .Keyword
468- and tokens [0 ].value .lower () == "set"
469- ):
470- # Check if set statement has a valid format
471- if len (tokens ) == 2 and isinstance (tokens [1 ], Comparison ):
472- return SetParameter (
473- statement_to_sql (tokens [1 ].left ),
474- statement_to_sql (tokens [1 ].right ).strip ("'" ),
475- )
476- # Or if at least there is a comparison
477- cmp_idx = next (
478- (
479- i
480- for i , token in enumerate (tokens )
481- if token .ttype == ComparisonType or isinstance (token , Comparison )
482- ),
483- None ,
484- )
485- if cmp_idx :
486- left_tokens , right_tokens = tokens [1 :cmp_idx ], tokens [cmp_idx + 1 :]
487- if isinstance (tokens [cmp_idx ], Comparison ):
488- left_tokens = left_tokens + [tokens [cmp_idx ].left ]
489- right_tokens = [tokens [cmp_idx ].right ] + right_tokens
490-
491- if left_tokens and right_tokens :
492- return SetParameter (
493- "" .join (statement_to_sql (t ) for t in left_tokens ),
494- "" .join (statement_to_sql (t ) for t in right_tokens ).strip ("'" ),
495- )
496-
497- raise InterfaceError (
498- f"Invalid set statement format: { statement_to_sql (statement )} ,"
499- " expected SET <param> = <value>"
500- )
501- return None
502-
503-
504- def statement_to_sql (statement : Statement ) -> str :
505- return str (statement ).strip ().rstrip (";" )
506-
507-
508- def split_format_sql (
509- query : str , parameters : Sequence [Sequence [ParameterType ]]
510- ) -> List [Union [str , SetParameter ]]:
511- """
512- Multi-statement query formatting will result in `NotSupportedError`.
513- Instead, split a query into a separate statement and format with parameters.
514- """
515- statements = parse_sql (query )
516- if not statements :
517- return [query ]
518-
519- if parameters :
520- if len (statements ) > 1 :
521- raise NotSupportedError (
522- "Formatting multi-statement queries is not supported."
523- )
524- if statement_to_set (statements [0 ]):
525- raise NotSupportedError ("Formatting set statements is not supported." )
526- return [format_statement (statements [0 ], paramset ) for paramset in parameters ]
527-
528- # Try parsing each statement as a SET, otherwise return as a plain sql string
529- return [statement_to_set (st ) or statement_to_sql (st ) for st in statements ]
0 commit comments