Skip to content

Commit 5157184

Browse files
fix: Add parameterized queries where possible to address the risk of SQL injection (#2540)
* fix 3 instances of potential SQL injection * formatting fix * fix data API rds formatting
1 parent 67efc06 commit 5157184

File tree

3 files changed

+37
-12
lines changed

3 files changed

+37
-12
lines changed

awswrangler/data_api/rds.py

Lines changed: 9 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -292,7 +292,15 @@ def _drop_table(con: RdsDataApi, table: str, database: str, transaction_id: str,
292292

293293

294294
def _does_table_exist(con: RdsDataApi, table: str, database: str, transaction_id: str) -> bool:
295-
res = con.execute(f"SELECT * FROM INFORMATION_SCHEMA.TABLES WHERE TABLE_NAME = '{table}'")
295+
res = con.execute(
296+
"SELECT * FROM INFORMATION_SCHEMA.TABLES WHERE TABLE_NAME = :table",
297+
parameters=[
298+
{
299+
"name": "table",
300+
"value": {"stringValue": table},
301+
},
302+
],
303+
)
296304
return not res.empty
297305

298306

awswrangler/mysql.py

Lines changed: 16 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -3,7 +3,7 @@
33

44
import logging
55
import uuid
6-
from typing import Any, Dict, Iterator, List, Literal, Optional, Tuple, Type, Union, cast, overload
6+
from typing import TYPE_CHECKING, Any, Dict, Iterator, List, Literal, Optional, Tuple, Type, Union, cast, overload
77

88
import boto3
99
import pyarrow as pa
@@ -13,12 +13,21 @@
1313
from awswrangler import _databases as _db_utils
1414
from awswrangler._config import apply_configs
1515

16-
pymysql = _utils.import_optional_dependency("pymysql")
16+
if TYPE_CHECKING:
17+
try:
18+
import pymysql
19+
from pymysql.connections import Connection
20+
from pymysql.cursors import Cursor
21+
except ImportError:
22+
pass
23+
else:
24+
pymysql = _utils.import_optional_dependency("pymysql")
25+
1726

1827
_logger: logging.Logger = logging.getLogger(__name__)
1928

2029

21-
def _validate_connection(con: "pymysql.connections.Connection[Any]") -> None:
30+
def _validate_connection(con: "Connection[Any]") -> None:
2231
if not isinstance(con, pymysql.connections.Connection):
2332
raise exceptions.InvalidConnection(
2433
"Invalid 'conn' argument, please pass a "
@@ -27,16 +36,16 @@ def _validate_connection(con: "pymysql.connections.Connection[Any]") -> None:
2736
)
2837

2938

30-
def _drop_table(cursor: "pymysql.cursors.Cursor", schema: Optional[str], table: str) -> None:
39+
def _drop_table(cursor: "Cursor", schema: Optional[str], table: str) -> None:
3140
schema_str = f"`{schema}`." if schema else ""
3241
sql = f"DROP TABLE IF EXISTS {schema_str}`{table}`"
3342
_logger.debug("Drop table query:\n%s", sql)
3443
cursor.execute(sql)
3544

3645

37-
def _does_table_exist(cursor: "pymysql.cursors.Cursor", schema: Optional[str], table: str) -> bool:
46+
def _does_table_exist(cursor: "Cursor", schema: Optional[str], table: str) -> bool:
3847
schema_str = f"TABLE_SCHEMA = '{schema}' AND" if schema else ""
39-
cursor.execute(f"SELECT * FROM INFORMATION_SCHEMA.TABLES WHERE " f"{schema_str} TABLE_NAME = '{table}'")
48+
cursor.execute(f"SELECT * FROM INFORMATION_SCHEMA.TABLES WHERE {schema_str} TABLE_NAME = %s", args=[table])
4049
return len(cursor.fetchall()) > 0
4150

4251

@@ -164,7 +173,7 @@ def connect(
164173
password=attrs.password,
165174
port=attrs.port,
166175
host=attrs.host,
167-
ssl=attrs.ssl_context,
176+
ssl=attrs.ssl_context, # type: ignore[arg-type]
168177
read_timeout=read_timeout,
169178
write_timeout=write_timeout,
170179
connect_timeout=connect_timeout,

awswrangler/sqlserver.py

Lines changed: 12 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -3,6 +3,7 @@
33

44
import logging
55
from typing import (
6+
TYPE_CHECKING,
67
Any,
78
Callable,
89
Dict,
@@ -26,7 +27,14 @@
2627

2728
__all__ = ["connect", "read_sql_query", "read_sql_table", "to_sql"]
2829

29-
pyodbc = _utils.import_optional_dependency("pyodbc")
30+
if TYPE_CHECKING:
31+
try:
32+
import pyodbc
33+
from pyodbc import Cursor
34+
except ImportError:
35+
pass
36+
else:
37+
pyodbc = _utils.import_optional_dependency("pyodbc")
3038

3139
_logger: logging.Logger = logging.getLogger(__name__)
3240
FuncT = TypeVar("FuncT", bound=Callable[..., Any])
@@ -47,16 +55,16 @@ def _get_table_identifier(schema: Optional[str], table: str) -> str:
4755
return table_identifier
4856

4957

50-
def _drop_table(cursor: "pyodbc.Cursor", schema: Optional[str], table: str) -> None:
58+
def _drop_table(cursor: "Cursor", schema: Optional[str], table: str) -> None:
5159
table_identifier = _get_table_identifier(schema, table)
5260
sql = f"IF OBJECT_ID(N'{table_identifier}', N'U') IS NOT NULL DROP TABLE {table_identifier}"
5361
_logger.debug("Drop table query:\n%s", sql)
5462
cursor.execute(sql)
5563

5664

57-
def _does_table_exist(cursor: "pyodbc.Cursor", schema: Optional[str], table: str) -> bool:
65+
def _does_table_exist(cursor: "Cursor", schema: Optional[str], table: str) -> bool:
5866
schema_str = f"TABLE_SCHEMA = '{schema}' AND" if schema else ""
59-
cursor.execute(f"SELECT * FROM INFORMATION_SCHEMA.TABLES WHERE " f"{schema_str} TABLE_NAME = '{table}'")
67+
cursor.execute(f"SELECT * FROM INFORMATION_SCHEMA.TABLES WHERE {schema_str} TABLE_NAME = ?", table)
6068
return len(cursor.fetchall()) > 0
6169

6270

0 commit comments

Comments
 (0)