Skip to content

Commit 326ea8d

Browse files
authored
fix: Refactor SQL identifiers to mitigate injection risks (#2543)
1 parent e748d7e commit 326ea8d

File tree

6 files changed

+116
-56
lines changed

6 files changed

+116
-56
lines changed

awswrangler/_sql_utils.py

Lines changed: 40 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,40 @@
1+
"""SQL utilities."""
2+
import re
3+
4+
from awswrangler import exceptions
5+
6+
7+
def identifier(sql: str, sql_mode: str = "mysql") -> str:
8+
"""
9+
Turn the input into an escaped SQL identifier, such as the name of a table or column.
10+
11+
sql: str
12+
Identifier to use in SQL.
13+
sql_mode: str
14+
"mysql" for default MySQL identifiers (backticks), "ansi" for ANSI-compatible identifiers (double quotes), or
15+
"mssql" for MSSQL identifiers (square brackets).
16+
17+
Returns
18+
-------
19+
str
20+
Escaped SQL identifier.
21+
"""
22+
if not isinstance(sql, str):
23+
raise exceptions.InvalidArgumentValue("identifier must be a str")
24+
25+
if len(sql) == 0:
26+
raise exceptions.InvalidArgumentValue("identifier must be > 0 characters in length")
27+
28+
if re.search(r"[^a-zA-Z0-9-_ ]", sql):
29+
raise exceptions.InvalidArgumentValue(
30+
"identifier must contain only alphanumeric characters, spaces, underscores, or hyphens"
31+
)
32+
33+
if sql_mode == "mysql":
34+
return f"`{sql}`"
35+
elif sql_mode == "ansi":
36+
return f'"{sql}"'
37+
elif sql_mode == "mssql":
38+
return f"[{sql}]"
39+
40+
raise ValueError(f"Unknown SQL MODE: {sql_mode}")

awswrangler/data_api/rds.py

Lines changed: 8 additions & 21 deletions
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,6 @@
11
"""RDS Data API Connector."""
22
import datetime as dt
33
import logging
4-
import re
54
import time
65
import uuid
76
from decimal import Decimal
@@ -12,6 +11,7 @@
1211

1312
import awswrangler.pandas as pd
1413
from awswrangler import _data_types, _databases, _utils, exceptions
14+
from awswrangler._sql_utils import identifier
1515
from awswrangler.data_api import _connector
1616

1717
if TYPE_CHECKING:
@@ -228,19 +228,6 @@ def _get_statement_result(self, request_id: str) -> pd.DataFrame:
228228
return dataframe
229229

230230

231-
def escape_identifier(identifier: str, sql_mode: str = "mysql") -> str:
232-
"""Escape identifiers. Uses MySQL-compatible backticks by default."""
233-
if not isinstance(identifier, str):
234-
raise TypeError("SQL identifier must be a string")
235-
if re.search(r"\W", identifier):
236-
raise TypeError(f"SQL identifier contains invalid characters: {identifier}")
237-
if sql_mode == "mysql":
238-
return f"`{identifier}`"
239-
elif sql_mode == "ansi":
240-
return f'"{identifier}"'
241-
raise ValueError(f"Unknown SQL MODE: {sql_mode}")
242-
243-
244231
def connect(
245232
resource_arn: str, database: str, secret_arn: str = "", boto3_session: Optional[boto3.Session] = None, **kwargs: Any
246233
) -> RdsDataApi:
@@ -286,7 +273,7 @@ def read_sql_query(sql: str, con: RdsDataApi, database: Optional[str] = None) ->
286273

287274

288275
def _drop_table(con: RdsDataApi, table: str, database: str, transaction_id: str, sql_mode: str) -> None:
289-
sql = f"DROP TABLE IF EXISTS {escape_identifier(table, sql_mode=sql_mode)}"
276+
sql = f"DROP TABLE IF EXISTS {identifier(table, sql_mode=sql_mode)}"
290277
_logger.debug("Drop table query:\n%s", sql)
291278
con.execute(sql, database=database, transaction_id=transaction_id)
292279

@@ -329,8 +316,8 @@ def _create_table(
329316
varchar_lengths=varchar_lengths,
330317
converter_func=_data_types.pyarrow2mysql,
331318
)
332-
cols_str: str = "".join([f"{escape_identifier(k, sql_mode=sql_mode)} {v},\n" for k, v in mysql_types.items()])[:-2]
333-
sql = f"CREATE TABLE IF NOT EXISTS {escape_identifier(table, sql_mode=sql_mode)} (\n{cols_str})"
319+
cols_str: str = "".join([f"{identifier(k, sql_mode=sql_mode)} {v},\n" for k, v in mysql_types.items()])[:-2]
320+
sql = f"CREATE TABLE IF NOT EXISTS {identifier(table, sql_mode=sql_mode)} (\n{cols_str})"
334321

335322
_logger.debug("Create table query:\n%s", sql)
336323
con.execute(sql, database=database, transaction_id=transaction_id)
@@ -443,6 +430,8 @@ def to_sql(
443430
inserted into the database columns `col1` and `col3`.
444431
chunksize: int
445432
Number of rows which are inserted with each SQL query. Defaults to inserting 200 rows per query.
433+
sql_mode: str
434+
"mysql" for default MySQL identifiers (backticks) or "ansi" for ANSI-compatible identifiers (double quotes).
446435
"""
447436
if df.empty is True:
448437
raise exceptions.EmptyDataFrame("DataFrame cannot be empty.")
@@ -470,15 +459,13 @@ def to_sql(
470459
df = df.reset_index(level=df.index.names)
471460

472461
if use_column_names:
473-
insertion_columns = (
474-
"(" + ", ".join([f"{escape_identifier(col, sql_mode=sql_mode)}" for col in df.columns]) + ")"
475-
)
462+
insertion_columns = "(" + ", ".join([f"{identifier(col, sql_mode=sql_mode)}" for col in df.columns]) + ")"
476463
else:
477464
insertion_columns = ""
478465

479466
placeholders = ", ".join([f":{col}" for col in df.columns])
480467

481-
sql = f"INSERT INTO {escape_identifier(table, sql_mode=sql_mode)} {insertion_columns} VALUES ({placeholders})"
468+
sql = f"INSERT INTO {identifier(table, sql_mode=sql_mode)} {insertion_columns} VALUES ({placeholders})"
482469
parameter_sets = _generate_parameter_sets(df)
483470

484471
for parameter_sets_chunk in _utils.chunkify(parameter_sets, max_length=chunksize):

awswrangler/mysql.py

Lines changed: 30 additions & 15 deletions
Original file line numberDiff line numberDiff line change
@@ -12,6 +12,7 @@
1212
from awswrangler import _data_types, _utils, exceptions
1313
from awswrangler import _databases as _db_utils
1414
from awswrangler._config import apply_configs
15+
from awswrangler._sql_utils import identifier
1516

1617
if TYPE_CHECKING:
1718
try:
@@ -37,15 +38,19 @@ def _validate_connection(con: "Connection[Any]") -> None:
3738

3839

3940
def _drop_table(cursor: "Cursor", schema: Optional[str], table: str) -> None:
40-
schema_str = f"`{schema}`." if schema else ""
41-
sql = f"DROP TABLE IF EXISTS {schema_str}`{table}`"
41+
schema_str = f"{identifier(schema)}." if schema else ""
42+
sql = f"DROP TABLE IF EXISTS {schema_str}{identifier(table)}"
4243
_logger.debug("Drop table query:\n%s", sql)
4344
cursor.execute(sql)
4445

4546

4647
def _does_table_exist(cursor: "Cursor", schema: Optional[str], table: str) -> bool:
47-
schema_str = f"TABLE_SCHEMA = '{schema}' AND" if schema else ""
48-
cursor.execute(f"SELECT * FROM INFORMATION_SCHEMA.TABLES WHERE {schema_str} TABLE_NAME = %s", args=[table])
48+
if schema:
49+
cursor.execute(
50+
"SELECT * FROM INFORMATION_SCHEMA.TABLES WHERE TABLE_SCHEMA = %s AND TABLE_NAME = %s", args=[schema, table]
51+
)
52+
else:
53+
cursor.execute("SELECT * FROM INFORMATION_SCHEMA.TABLES WHERE TABLE_NAME = %s", args=[table])
4954
return len(cursor.fetchall()) > 0
5055

5156

@@ -71,8 +76,8 @@ def _create_table(
7176
varchar_lengths=varchar_lengths,
7277
converter_func=_data_types.pyarrow2mysql,
7378
)
74-
cols_str: str = "".join([f"`{k}` {v},\n" for k, v in mysql_types.items()])[:-2]
75-
sql = f"CREATE TABLE IF NOT EXISTS `{schema}`.`{table}` (\n{cols_str})"
79+
cols_str: str = "".join([f"{identifier(k)} {v},\n" for k, v in mysql_types.items()])[:-2]
80+
sql = f"CREATE TABLE IF NOT EXISTS {identifier(schema)}.{identifier(table)} (\n{cols_str})"
7681
_logger.debug("Create table query:\n%s", sql)
7782
cursor.execute(sql)
7883

@@ -419,7 +424,11 @@ def read_sql_table(
419424
>>> con.close()
420425
421426
"""
422-
sql: str = f"SELECT * FROM `{table}`" if schema is None else f"SELECT * FROM `{schema}`.`{table}`"
427+
sql: str = (
428+
f"SELECT * FROM {identifier(table)}"
429+
if schema is None
430+
else f"SELECT * FROM {identifier(schema)}.{identifier(table)}"
431+
)
423432
return read_sql_query(
424433
sql=sql,
425434
con=con,
@@ -551,29 +560,35 @@ def to_sql(
551560
upsert_str = ""
552561
ignore_str = " IGNORE" if mode == "ignore" else ""
553562
if use_column_names:
554-
insertion_columns = f"(`{'`, `'.join(df.columns)}`)"
563+
insertion_columns = f"({', '.join([identifier(col) for col in df.columns])})"
555564
if mode == "upsert_duplicate_key":
556-
upsert_columns = ", ".join(df.columns.map(lambda column: f"`{column}`=VALUES(`{column}`)"))
565+
upsert_columns = ", ".join(df.columns.map(lambda col: f"{identifier(col)}=VALUES({identifier(col)})"))
557566
upsert_str = f" ON DUPLICATE KEY UPDATE {upsert_columns}"
558567
placeholder_parameter_pair_generator = _db_utils.generate_placeholder_parameter_pairs(
559568
df=df, column_placeholders=column_placeholders, chunksize=chunksize
560569
)
561570
sql: str
562571
for placeholders, parameters in placeholder_parameter_pair_generator:
563572
if mode == "upsert_replace_into":
564-
sql = f"REPLACE INTO `{schema}`.`{table}` {insertion_columns} VALUES {placeholders}"
573+
sql = f"REPLACE INTO {identifier(schema)}.{identifier(table)} {insertion_columns} VALUES {placeholders}"
565574
else:
566-
sql = f"""INSERT{ignore_str} INTO `{schema}`.`{table}` {insertion_columns}
575+
sql = f"""INSERT{ignore_str} INTO {identifier(schema)}.{identifier(table)} {insertion_columns}
567576
VALUES {placeholders}{upsert_str}"""
568577
_logger.debug("sql: %s", sql)
569578
cursor.executemany(sql, (parameters,))
570579
con.commit()
571580
if mode == "upsert_distinct":
572581
temp_table = f"{table}_{uuid.uuid4().hex}"
573-
cursor.execute(f"CREATE TABLE `{schema}`.`{temp_table}` LIKE `{schema}`.`{table}`")
574-
cursor.execute(f"INSERT INTO `{schema}`.`{temp_table}` SELECT DISTINCT * FROM `{schema}`.`{table}`")
575-
cursor.execute(f"DROP TABLE IF EXISTS `{schema}`.`{table}`")
576-
cursor.execute(f"ALTER TABLE `{schema}`.`{temp_table}` RENAME TO `{table}`")
582+
cursor.execute(
583+
f"CREATE TABLE {identifier(schema)}.{identifier(temp_table)} LIKE {identifier(schema)}.{identifier(table)}"
584+
)
585+
cursor.execute(
586+
f"INSERT INTO {identifier(schema)}.{identifier(temp_table)} SELECT DISTINCT * FROM {identifier(schema)}.{identifier(table)}"
587+
)
588+
cursor.execute(f"DROP TABLE IF EXISTS {identifier(schema)}.{identifier(table)}")
589+
cursor.execute(
590+
f"ALTER TABLE {identifier(schema)}.{identifier(temp_table)} RENAME TO {identifier(table)}"
591+
)
577592
con.commit()
578593

579594
except Exception as ex:

awswrangler/oracle.py

Lines changed: 23 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -24,6 +24,7 @@
2424
from awswrangler import _data_types, _utils, exceptions
2525
from awswrangler import _databases as _db_utils
2626
from awswrangler._config import apply_configs
27+
from awswrangler._sql_utils import identifier
2728

2829
__all__ = ["connect", "read_sql_query", "read_sql_table", "to_sql"]
2930

@@ -43,8 +44,8 @@ def _validate_connection(con: "oracledb.Connection") -> None:
4344

4445

4546
def _get_table_identifier(schema: Optional[str], table: str) -> str:
46-
schema_str = f'"{schema}".' if schema else ""
47-
table_identifier = f'{schema_str}"{table}"'
47+
schema_str = f'{identifier(schema, sql_mode="ansi")}.' if schema else ""
48+
table_identifier = f'{schema_str}{identifier(table, sql_mode="ansi")}'
4849
return table_identifier
4950

5051

@@ -65,8 +66,14 @@ def _drop_table(cursor: "oracledb.Cursor", schema: Optional[str], table: str) ->
6566

6667

6768
def _does_table_exist(cursor: "oracledb.Cursor", schema: Optional[str], table: str) -> bool:
68-
schema_str = f"OWNER = '{schema}' AND" if schema else ""
69-
cursor.execute(f"SELECT * FROM ALL_TABLES WHERE {schema_str} TABLE_NAME = '{table}'")
69+
if schema:
70+
cursor.execute(
71+
"SELECT * FROM ALL_TABLES WHERE OWNER = :db_schema AND TABLE_NAME = :db_table",
72+
db_schema=schema,
73+
db_table=table,
74+
)
75+
else:
76+
cursor.execute("SELECT * FROM ALL_TABLES WHERE TABLE_NAME = :tbl", tbl=table)
7077
return len(cursor.fetchall()) > 0
7178

7279

@@ -93,10 +100,10 @@ def _create_table(
93100
varchar_lengths=varchar_lengths,
94101
converter_func=_data_types.pyarrow2oracle,
95102
)
96-
cols_str: str = "".join([f'"{k}" {v},\n' for k, v in oracle_types.items()])[:-2]
103+
cols_str: str = "".join([f'{identifier(k, sql_mode="ansi")} {v},\n' for k, v in oracle_types.items()])[:-2]
97104

98105
if primary_keys:
99-
primary_keys_str = ", ".join([f'"{k}"' for k in primary_keys])
106+
primary_keys_str = ", ".join([f'{identifier(k, sql_mode="ansi")}' for k in primary_keys])
100107
else:
101108
primary_keys_str = None
102109

@@ -450,7 +457,7 @@ def _generate_insert_statement(
450457
column_placeholders: str = f"({', '.join([':' + str(i + 1) for i in range(len(df.columns))])})"
451458

452459
if use_column_names:
453-
insertion_columns = "(" + ", ".join('"' + column + '"' for column in df.columns) + ")"
460+
insertion_columns = "(" + ", ".join(identifier(column, sql_mode="ansi") for column in df.columns) + ")"
454461
else:
455462
insertion_columns = ""
456463

@@ -470,14 +477,19 @@ def _generate_upsert_statement(
470477

471478
non_primary_key_columns = [key for key in df.columns if key not in set(primary_keys)]
472479

473-
primary_keys_str = ", ".join([f'"{key}"' for key in primary_keys])
474-
columns_str = ", ".join([f'"{key}"' for key in non_primary_key_columns])
480+
primary_keys_str = ", ".join([f'{identifier(key, sql_mode="ansi")}' for key in primary_keys])
481+
columns_str = ", ".join([f'{identifier(key, sql_mode="ansi")}' for key in non_primary_key_columns])
475482

476483
column_placeholders: str = f"({', '.join([':' + str(i + 1) for i in range(len(df.columns))])})"
477484

478-
primary_key_condition_str = " AND ".join([f'"{key}" = :{i+1}' for i, key in enumerate(primary_keys)])
485+
primary_key_condition_str = " AND ".join(
486+
[f'{identifier(key, sql_mode="ansi")} = :{i+1}' for i, key in enumerate(primary_keys)]
487+
)
479488
assignment_str = ", ".join(
480-
[f'"{col}" = :{i + len(primary_keys) + 1}' for i, col in enumerate(non_primary_key_columns)]
489+
[
490+
f'{identifier(col, sql_mode="ansi")} = :{i + len(primary_keys) + 1}'
491+
for i, col in enumerate(non_primary_key_columns)
492+
]
481493
)
482494

483495
return f"""

awswrangler/postgresql.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -69,7 +69,7 @@ def _create_table(
6969
varchar_lengths=varchar_lengths,
7070
converter_func=_data_types.pyarrow2postgresql,
7171
)
72-
cols_str: str = "".join([f'"{k}" {v},\n' for k, v in postgresql_types.items()])[:-2]
72+
cols_str: str = "".join([f"{pg8000_native.identifier(k)} {v},\n" for k, v in postgresql_types.items()])[:-2]
7373
sql = f"CREATE TABLE IF NOT EXISTS {pg8000_native.identifier(schema)}.{pg8000_native.identifier(table)} (\n{cols_str})"
7474
_logger.debug("Create table query:\n%s", sql)
7575
cursor.execute(sql)
@@ -584,7 +584,7 @@ def to_sql(
584584
if index:
585585
df.reset_index(level=df.index.names, inplace=True)
586586
column_placeholders: str = ", ".join(["%s"] * len(df.columns))
587-
column_names = [f'"{column}"' for column in df.columns]
587+
column_names = [pg8000_native.identifier(column) for column in df.columns]
588588
insertion_columns = ""
589589
upsert_str = ""
590590
if use_column_names:

awswrangler/sqlserver.py

Lines changed: 13 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -24,6 +24,7 @@
2424
from awswrangler import _data_types, _utils, exceptions
2525
from awswrangler import _databases as _db_utils
2626
from awswrangler._config import apply_configs
27+
from awswrangler._sql_utils import identifier
2728

2829
__all__ = ["connect", "read_sql_query", "read_sql_table", "to_sql"]
2930

@@ -50,9 +51,10 @@ def _validate_connection(con: "pyodbc.Connection") -> None:
5051

5152

5253
def _get_table_identifier(schema: Optional[str], table: str) -> str:
53-
schema_str = f'"{schema}".' if schema else ""
54-
table_identifier = f'{schema_str}"{table}"'
55-
return table_identifier
54+
if schema:
55+
return f"{identifier(schema, sql_mode='mssql')}.{identifier(table, sql_mode='mssql')}"
56+
else:
57+
return identifier(table, sql_mode="mssql")
5658

5759

5860
def _drop_table(cursor: "Cursor", schema: Optional[str], table: str) -> None:
@@ -63,8 +65,12 @@ def _drop_table(cursor: "Cursor", schema: Optional[str], table: str) -> None:
6365

6466

6567
def _does_table_exist(cursor: "Cursor", schema: Optional[str], table: str) -> bool:
66-
schema_str = f"TABLE_SCHEMA = '{schema}' AND" if schema else ""
67-
cursor.execute(f"SELECT * FROM INFORMATION_SCHEMA.TABLES WHERE {schema_str} TABLE_NAME = ?", table)
68+
if schema:
69+
cursor.execute(
70+
"SELECT * FROM INFORMATION_SCHEMA.TABLES WHERE TABLE_SCHEMA = ? AND TABLE_NAME = ?", (schema, table)
71+
)
72+
else:
73+
cursor.execute("SELECT * FROM INFORMATION_SCHEMA.TABLES WHERE TABLE_NAME = ?", table)
6874
return len(cursor.fetchall()) > 0
6975

7076

@@ -90,7 +96,7 @@ def _create_table(
9096
varchar_lengths=varchar_lengths,
9197
converter_func=_data_types.pyarrow2sqlserver,
9298
)
93-
cols_str: str = "".join([f'"{k}" {v},\n' for k, v in sqlserver_types.items()])[:-2]
99+
cols_str: str = "".join([f"{identifier(k, sql_mode='mssql')} {v},\n" for k, v in sqlserver_types.items()])[:-2]
94100
table_identifier = _get_table_identifier(schema, table)
95101
sql = (
96102
f"IF OBJECT_ID(N'{table_identifier}', N'U') IS NULL BEGIN CREATE TABLE {table_identifier} (\n{cols_str}); END;"
@@ -529,7 +535,7 @@ def to_sql(
529535
table_identifier = _get_table_identifier(schema, table)
530536
insertion_columns = ""
531537
if use_column_names:
532-
quoted_columns = ", ".join(f'"{col}"' for col in df.columns)
538+
quoted_columns = ", ".join(f"{identifier(col, sql_mode='mssql')}" for col in df.columns)
533539
insertion_columns = f"({quoted_columns})"
534540
placeholder_parameter_pair_generator = _db_utils.generate_placeholder_parameter_pairs(
535541
df=df, column_placeholders=column_placeholders, chunksize=chunksize

0 commit comments

Comments
 (0)