Skip to content

Commit 6c0f65b

Browse files
authored
fix: RDS Data API - allow ANSI-compatible identifiers. (#2391)
1 parent f8590a1 commit 6c0f65b

File tree

2 files changed

+83
-7
lines changed

2 files changed

+83
-7
lines changed

awswrangler/data_api/rds.py

Lines changed: 26 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,7 @@
11
"""RDS Data API Connector."""
22
import datetime as dt
33
import logging
4+
import re
45
import time
56
import uuid
67
from decimal import Decimal
@@ -227,6 +228,19 @@ def _get_statement_result(self, request_id: str) -> pd.DataFrame:
227228
return dataframe
228229

229230

231+
def escape_identifier(identifier: str, sql_mode: str = "mysql") -> str:
232+
"""Escape identifiers. Uses MySQL-compatible backticks by default."""
233+
if not isinstance(identifier, str):
234+
raise TypeError("SQL identifier must be a string")
235+
if re.search(r"\W", identifier):
236+
raise TypeError(f"SQL identifier contains invalid characters: {identifier}")
237+
if sql_mode == "mysql":
238+
return f"`{identifier}`"
239+
elif sql_mode == "ansi":
240+
return f'"{identifier}"'
241+
raise ValueError(f"Unknown SQL MODE: {sql_mode}")
242+
243+
230244
def connect(
231245
resource_arn: str, database: str, secret_arn: str = "", boto3_session: Optional[boto3.Session] = None, **kwargs: Any
232246
) -> RdsDataApi:
@@ -271,8 +285,8 @@ def read_sql_query(sql: str, con: RdsDataApi, database: Optional[str] = None) ->
271285
return con.execute(sql, database=database)
272286

273287

274-
def _drop_table(con: RdsDataApi, table: str, database: str, transaction_id: str) -> None:
275-
sql = f"DROP TABLE IF EXISTS `{table}`"
288+
def _drop_table(con: RdsDataApi, table: str, database: str, transaction_id: str, sql_mode: str) -> None:
289+
sql = f"DROP TABLE IF EXISTS {escape_identifier(table, sql_mode=sql_mode)}"
276290
_logger.debug("Drop table query:\n%s", sql)
277291
con.execute(sql, database=database, transaction_id=transaction_id)
278292

@@ -292,9 +306,10 @@ def _create_table(
292306
index: bool,
293307
dtype: Optional[Dict[str, str]],
294308
varchar_lengths: Optional[Dict[str, int]],
309+
sql_mode: str,
295310
) -> None:
296311
if mode == "overwrite":
297-
_drop_table(con=con, table=table, database=database, transaction_id=transaction_id)
312+
_drop_table(con=con, table=table, database=database, transaction_id=transaction_id, sql_mode=sql_mode)
298313
elif _does_table_exist(con=con, table=table, database=database, transaction_id=transaction_id):
299314
return
300315

@@ -306,8 +321,8 @@ def _create_table(
306321
varchar_lengths=varchar_lengths,
307322
converter_func=_data_types.pyarrow2mysql,
308323
)
309-
cols_str: str = "".join([f"`{k}` {v},\n" for k, v in mysql_types.items()])[:-2]
310-
sql = f"CREATE TABLE IF NOT EXISTS `{table}` (\n{cols_str})"
324+
cols_str: str = "".join([f"{escape_identifier(k, sql_mode=sql_mode)} {v},\n" for k, v in mysql_types.items()])[:-2]
325+
sql = f"CREATE TABLE IF NOT EXISTS {escape_identifier(table, sql_mode=sql_mode)} (\n{cols_str})"
311326

312327
_logger.debug("Create table query:\n%s", sql)
313328
con.execute(sql, database=database, transaction_id=transaction_id)
@@ -388,6 +403,7 @@ def to_sql(
388403
varchar_lengths: Optional[Dict[str, int]] = None,
389404
use_column_names: bool = False,
390405
chunksize: int = 200,
406+
sql_mode: str = "mysql",
391407
) -> None:
392408
"""
393409
Insert data using an SQL query on a Data API connection.
@@ -439,19 +455,22 @@ def to_sql(
439455
index=index,
440456
dtype=dtype,
441457
varchar_lengths=varchar_lengths,
458+
sql_mode=sql_mode,
442459
)
443460

444461
if index:
445462
df = df.reset_index(level=df.index.names)
446463

447464
if use_column_names:
448-
insertion_columns = "(" + ", ".join([f"`{col}`" for col in df.columns]) + ")"
465+
insertion_columns = (
466+
"(" + ", ".join([f"{escape_identifier(col, sql_mode=sql_mode)}" for col in df.columns]) + ")"
467+
)
449468
else:
450469
insertion_columns = ""
451470

452471
placeholders = ", ".join([f":{col}" for col in df.columns])
453472

454-
sql = f"""INSERT INTO `{table}` {insertion_columns} VALUES ({placeholders})"""
473+
sql = f"INSERT INTO {escape_identifier(table, sql_mode=sql_mode)} {insertion_columns} VALUES ({placeholders})"
455474
parameter_sets = _generate_parameter_sets(df)
456475

457476
for parameter_sets_chunk in _utils.chunkify(parameter_sets, max_length=chunksize):

tests/unit/test_data_api.py

Lines changed: 57 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -39,6 +39,13 @@ def mysql_serverless_connector(databases_parameters: Dict[str, Any]) -> "RdsData
3939
yield con
4040

4141

42+
@pytest.fixture
43+
def postgresql_serverless_connector(databases_parameters: Dict[str, Any]) -> "RdsDataApi":
44+
con = create_rds_connector("postgresql_serverless", databases_parameters)
45+
with con:
46+
yield con
47+
48+
4249
def test_connect_redshift_serverless_iam_role(databases_parameters: Dict[str, Any]) -> None:
4350
workgroup_name = databases_parameters["redshift_serverless"]["workgroup"]
4451
database = databases_parameters["redshift_serverless"]["database"]
@@ -68,6 +75,16 @@ def mysql_serverless_table(mysql_serverless_connector: "RdsDataApi") -> Iterator
6875
mysql_serverless_connector.execute(f"DROP TABLE IF EXISTS test.{name}")
6976

7077

78+
@pytest.fixture(scope="function")
79+
def postgresql_serverless_table(postgresql_serverless_connector: "RdsDataApi") -> Iterator[str]:
80+
name = f"tbl_{get_time_str_with_random_suffix()}"
81+
print(f"Table name: {name}")
82+
try:
83+
yield name
84+
finally:
85+
postgresql_serverless_connector.execute(f"DROP TABLE IF EXISTS test.{name}")
86+
87+
7188
def test_data_api_redshift_columnless_query(redshift_connector: "RedshiftDataApi") -> None:
7289
dataframe = wr.data_api.redshift.read_sql_query("SELECT 1", con=redshift_connector)
7390
unknown_column_indicator = "?column?"
@@ -223,3 +240,43 @@ def test_data_api_mysql_to_sql_mode(
223240
def test_data_api_exception(mysql_serverless_connector: "RdsDataApi", mysql_serverless_table: str) -> None:
224241
with pytest.raises(boto3.client("rds-data").exceptions.BadRequestException):
225242
wr.data_api.rds.read_sql_query("CUPCAKE", con=mysql_serverless_connector)
243+
244+
245+
def test_data_api_mysql_ansi(mysql_serverless_connector: "RdsDataApi", mysql_serverless_table: str) -> None:
246+
database = "test"
247+
frame = pd.DataFrame([[42, "test"]], columns=["id", "name"])
248+
249+
mysql_serverless_connector.execute("SET SESSION sql_mode='ANSI_QUOTES';")
250+
251+
wr.data_api.rds.to_sql(
252+
df=frame,
253+
con=mysql_serverless_connector,
254+
table=mysql_serverless_table,
255+
database=database,
256+
sql_mode="ansi",
257+
)
258+
259+
out_frame = wr.data_api.rds.read_sql_query(
260+
f"SELECT name FROM {mysql_serverless_table} WHERE id = 42", con=mysql_serverless_connector
261+
)
262+
expected_dataframe = pd.DataFrame([["test"]], columns=["name"])
263+
assert_pandas_equals(out_frame, expected_dataframe)
264+
265+
266+
def test_data_api_postgresql(postgresql_serverless_connector: "RdsDataApi", postgresql_serverless_table: str) -> None:
267+
database = "test"
268+
frame = pd.DataFrame([[42, "test"]], columns=["id", "name"])
269+
270+
wr.data_api.rds.to_sql(
271+
df=frame,
272+
con=postgresql_serverless_connector,
273+
table=postgresql_serverless_table,
274+
database=database,
275+
sql_mode="ansi",
276+
)
277+
278+
out_frame = wr.data_api.rds.read_sql_query(
279+
f"SELECT name FROM {postgresql_serverless_table} WHERE id = 42", con=postgresql_serverless_connector
280+
)
281+
expected_dataframe = pd.DataFrame([["test"]], columns=["name"])
282+
assert_pandas_equals(out_frame, expected_dataframe)

0 commit comments

Comments
 (0)