11"""RDS Data API Connector."""
22import datetime as dt
33import logging
4+ import re
45import time
56import uuid
67from decimal import Decimal
@@ -227,6 +228,19 @@ def _get_statement_result(self, request_id: str) -> pd.DataFrame:
227228 return dataframe
228229
229230
231+ def escape_identifier (identifier : str , sql_mode : str = "mysql" ) -> str :
232+ """Escape identifiers. Uses MySQL-compatible backticks by default."""
233+ if not isinstance (identifier , str ):
234+ raise TypeError ("SQL identifier must be a string" )
235+ if re .search (r"\W" , identifier ):
236+ raise TypeError (f"SQL identifier contains invalid characters: { identifier } " )
237+ if sql_mode == "mysql" :
238+ return f"`{ identifier } `"
239+ elif sql_mode == "ansi" :
240+ return f'"{ identifier } "'
241+ raise ValueError (f"Unknown SQL MODE: { sql_mode } " )
242+
243+
230244def connect (
231245 resource_arn : str , database : str , secret_arn : str = "" , boto3_session : Optional [boto3 .Session ] = None , ** kwargs : Any
232246) -> RdsDataApi :
@@ -271,8 +285,8 @@ def read_sql_query(sql: str, con: RdsDataApi, database: Optional[str] = None) ->
271285 return con .execute (sql , database = database )
272286
273287
274- def _drop_table (con : RdsDataApi , table : str , database : str , transaction_id : str ) -> None :
275- sql = f"DROP TABLE IF EXISTS ` { table } ` "
288+ def _drop_table (con : RdsDataApi , table : str , database : str , transaction_id : str , sql_mode : str ) -> None :
289+ sql = f"DROP TABLE IF EXISTS { escape_identifier ( table , sql_mode = sql_mode ) } "
276290 _logger .debug ("Drop table query:\n %s" , sql )
277291 con .execute (sql , database = database , transaction_id = transaction_id )
278292
@@ -292,9 +306,10 @@ def _create_table(
292306 index : bool ,
293307 dtype : Optional [Dict [str , str ]],
294308 varchar_lengths : Optional [Dict [str , int ]],
309+ sql_mode : str ,
295310) -> None :
296311 if mode == "overwrite" :
297- _drop_table (con = con , table = table , database = database , transaction_id = transaction_id )
312+ _drop_table (con = con , table = table , database = database , transaction_id = transaction_id , sql_mode = sql_mode )
298313 elif _does_table_exist (con = con , table = table , database = database , transaction_id = transaction_id ):
299314 return
300315
@@ -306,8 +321,8 @@ def _create_table(
306321 varchar_lengths = varchar_lengths ,
307322 converter_func = _data_types .pyarrow2mysql ,
308323 )
309- cols_str : str = "" .join ([f"` { k } ` { v } ,\n " for k , v in mysql_types .items ()])[:- 2 ]
310- sql = f"CREATE TABLE IF NOT EXISTS ` { table } ` (\n { cols_str } )"
324+ cols_str : str = "" .join ([f"{ escape_identifier ( k , sql_mode = sql_mode ) } { v } ,\n " for k , v in mysql_types .items ()])[:- 2 ]
325+ sql = f"CREATE TABLE IF NOT EXISTS { escape_identifier ( table , sql_mode = sql_mode ) } (\n { cols_str } )"
311326
312327 _logger .debug ("Create table query:\n %s" , sql )
313328 con .execute (sql , database = database , transaction_id = transaction_id )
@@ -388,6 +403,7 @@ def to_sql(
388403 varchar_lengths : Optional [Dict [str , int ]] = None ,
389404 use_column_names : bool = False ,
390405 chunksize : int = 200 ,
406+ sql_mode : str = "mysql" ,
391407) -> None :
392408 """
393409 Insert data using an SQL query on a Data API connection.
@@ -439,19 +455,22 @@ def to_sql(
439455 index = index ,
440456 dtype = dtype ,
441457 varchar_lengths = varchar_lengths ,
458+ sql_mode = sql_mode ,
442459 )
443460
444461 if index :
445462 df = df .reset_index (level = df .index .names )
446463
447464 if use_column_names :
448- insertion_columns = "(" + ", " .join ([f"`{ col } `" for col in df .columns ]) + ")"
465+ insertion_columns = (
466+ "(" + ", " .join ([f"{ escape_identifier (col , sql_mode = sql_mode )} " for col in df .columns ]) + ")"
467+ )
449468 else :
450469 insertion_columns = ""
451470
452471 placeholders = ", " .join ([f":{ col } " for col in df .columns ])
453472
454- sql = f""" INSERT INTO ` { table } ` { insertion_columns } VALUES ({ placeholders } )"" "
473+ sql = f"INSERT INTO { escape_identifier ( table , sql_mode = sql_mode ) } { insertion_columns } VALUES ({ placeholders } )"
455474 parameter_sets = _generate_parameter_sets (df )
456475
457476 for parameter_sets_chunk in _utils .chunkify (parameter_sets , max_length = chunksize ):
0 commit comments