Skip to content

Commit 5667205

Browse files
fix: Add validation for table and schema params for Redshift (#2551)
1 parent 703e896 commit 5667205

File tree

2 files changed

+38
-24
lines changed

2 files changed

+38
-24
lines changed

awswrangler/redshift/_read.py

Lines changed: 9 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -7,7 +7,7 @@
77

88
import awswrangler.pandas as pd
99
from awswrangler import _databases as _db_utils
10-
from awswrangler import _utils, exceptions, s3
10+
from awswrangler import _sql_utils, _utils, exceptions, s3
1111
from awswrangler._config import apply_configs
1212
from awswrangler._distributed import EngineEnum, engine
1313

@@ -19,6 +19,10 @@
1919
_logger: logging.Logger = logging.getLogger(__name__)
2020

2121

22+
def _identifier(sql: str) -> str:
23+
return _sql_utils.identifier(sql, sql_mode="ansi")
24+
25+
2226
def _read_parquet_iterator(
2327
path: str,
2428
keep_files: bool,
@@ -199,7 +203,10 @@ def read_sql_table(
199203
>>> con.close()
200204
201205
"""
202-
sql: str = f'SELECT * FROM "{table}"' if schema is None else f'SELECT * FROM "{schema}"."{table}"'
206+
if schema is None:
207+
sql = f"SELECT * FROM {_identifier(table)}"
208+
else:
209+
sql = f"SELECT * FROM {_identifier(schema)}.{_identifier(table)}"
203210
return read_sql_query(
204211
sql=sql,
205212
con=con,

awswrangler/redshift/_utils.py

Lines changed: 29 additions & 22 deletions
Original file line numberDiff line numberDiff line change
@@ -9,7 +9,7 @@
99
import botocore
1010
import pandas as pd
1111

12-
from awswrangler import _data_types, _utils, exceptions, s3
12+
from awswrangler import _data_types, _sql_utils, _utils, exceptions, s3
1313

1414
redshift_connector = _utils.import_optional_dependency("redshift_connector")
1515

@@ -20,6 +20,10 @@
2020
_RS_SORTSTYLES: List[str] = ["COMPOUND", "INTERLEAVED"]
2121

2222

23+
def _identifier(sql: str) -> str:
24+
return _sql_utils.identifier(sql, sql_mode="ansi")
25+
26+
2327
def _make_s3_auth_string(
2428
aws_access_key_id: Optional[str] = None,
2529
aws_secret_access_key: Optional[str] = None,
@@ -66,15 +70,19 @@ def _drop_table(cursor: "redshift_connector.Cursor", schema: Optional[str], tabl
6670

6771

6872
def _truncate_table(cursor: "redshift_connector.Cursor", schema: Optional[str], table: str) -> None:
69-
schema_str = f'"{schema}".' if schema else ""
70-
sql = f'TRUNCATE TABLE {schema_str}"{table}"'
73+
if schema:
74+
sql = f"TRUNCATE TABLE {_identifier(schema)}.{_identifier(table)}"
75+
else:
76+
sql = f"TRUNCATE TABLE {_identifier(table)}"
7177
_logger.debug("Executing truncate table query:\n%s", sql)
7278
cursor.execute(sql)
7379

7480

7581
def _delete_all(cursor: "redshift_connector.Cursor", schema: Optional[str], table: str) -> None:
76-
schema_str = f'"{schema}".' if schema else ""
77-
sql = f'DELETE FROM {schema_str}"{table}"'
82+
if schema:
83+
sql = f"DELETE FROM {_identifier(schema)}.{_identifier(table)}"
84+
else:
85+
sql = f"DELETE FROM {_identifier(table)}"
7886
_logger.debug("Executing delete query:\n%s", sql)
7987
cursor.execute(sql)
8088

@@ -116,8 +124,9 @@ def _lock(
116124
table_names: List[str],
117125
schema: Optional[str] = None,
118126
) -> None:
119-
fmt = '"{schema}"."{table}"' if schema else '"{table}"'
120-
tables = ", ".join([fmt.format(schema=schema, table=table) for table in table_names])
127+
tables = ", ".join(
128+
[(f"{_identifier(schema)}.{_identifier(table)}" if schema else _identifier(table)) for table in table_names]
129+
)
121130
sql: str = f"LOCK {tables};\n"
122131
_logger.debug("Executing lock query:\n%s", sql)
123132
cursor.execute(sql)
@@ -137,32 +146,30 @@ def _upsert(
137146
_logger.debug("primary_keys: %s", primary_keys)
138147
if not primary_keys:
139148
raise exceptions.InvalidRedshiftPrimaryKeys()
140-
equals_clause: str = f'"{table}".%s = "{temp_table}".%s'
149+
equals_clause: str = f"{_identifier(table)}.%s = {_identifier(temp_table)}.%s"
141150
join_clause: str = " AND ".join([equals_clause % (pk, pk) for pk in primary_keys])
142151
if precombine_key:
143-
delete_from_target_filter: str = f"AND {table}.{precombine_key} <= {temp_table}.{precombine_key}"
144-
delete_from_temp_filter: str = f"AND {table}.{precombine_key} > {temp_table}.{precombine_key}"
145-
target_del_sql: str = (
146-
f'DELETE FROM "{schema}"."{table}" USING "{temp_table}" WHERE {join_clause} {delete_from_target_filter}'
152+
delete_from_target_filter: str = (
153+
f"AND {_identifier(table)}.{precombine_key} <= {_identifier(temp_table)}.{precombine_key}"
154+
)
155+
delete_from_temp_filter: str = (
156+
f"AND {_identifier(table)}.{precombine_key} > {_identifier(temp_table)}.{precombine_key}"
147157
)
158+
target_del_sql: str = f"DELETE FROM {_identifier(schema)}.{_identifier(table)} USING {_identifier(temp_table)} WHERE {join_clause} {delete_from_target_filter}"
148159
_logger.debug("Executing delete query:\n%s", target_del_sql)
149160
cursor.execute(target_del_sql)
150-
source_del_sql: str = (
151-
f'DELETE FROM "{temp_table}" USING "{schema}"."{table}" WHERE {join_clause} {delete_from_temp_filter}'
152-
)
161+
source_del_sql: str = f"DELETE FROM {_identifier(temp_table)} USING {_identifier(schema)}.{_identifier(table)} WHERE {join_clause} {delete_from_temp_filter}"
153162
_logger.debug("Executing delete query:\n%s", source_del_sql)
154163
cursor.execute(source_del_sql)
155164
else:
156-
sql: str = f'DELETE FROM "{schema}"."{table}" USING "{temp_table}" WHERE {join_clause}'
165+
sql: str = f"DELETE FROM {_identifier(schema)}.{_identifier(table)} USING {_identifier(temp_table)} WHERE {join_clause}"
157166
_logger.debug("Executing delete query:\n%s", sql)
158167
cursor.execute(sql)
159168
if column_names:
160169
column_names_str = ",".join(column_names)
161-
insert_sql = (
162-
f'INSERT INTO "{schema}"."{table}"({column_names_str}) SELECT {column_names_str} FROM "{temp_table}"'
163-
)
170+
insert_sql = f"INSERT INTO {_identifier(schema)}.{_identifier(table)}({column_names_str}) SELECT {column_names_str} FROM {_identifier(temp_table)}"
164171
else:
165-
insert_sql = f'INSERT INTO "{schema}"."{table}" SELECT * FROM "{temp_table}"'
172+
insert_sql = f"INSERT INTO {_identifier(schema)}.{_identifier(table)} SELECT * FROM {_identifier(temp_table)}"
166173
_logger.debug("Executing insert query:\n%s", insert_sql)
167174
cursor.execute(insert_sql)
168175
_drop_table(cursor=cursor, schema=schema, table=temp_table)
@@ -299,7 +306,7 @@ def _create_table( # pylint: disable=too-many-locals,too-many-arguments,too-man
299306
if mode == "upsert":
300307
guid: str = uuid.uuid4().hex
301308
temp_table: str = f"temp_redshift_{guid}"
302-
sql: str = f'CREATE TEMPORARY TABLE {temp_table} (LIKE "{schema}"."{table}")'
309+
sql: str = f"CREATE TEMPORARY TABLE {temp_table} (LIKE {_identifier(schema)}.{_identifier(table)})"
303310
_logger.debug("Executing create temporary table query:\n%s", sql)
304311
cursor.execute(sql)
305312
return temp_table, None
@@ -355,7 +362,7 @@ def _create_table( # pylint: disable=too-many-locals,too-many-arguments,too-man
355362
distkey_str: str = f"\nDISTKEY({distkey})" if distkey and diststyle == "KEY" else ""
356363
sortkey_str: str = f"\n{sortstyle} SORTKEY({','.join(sortkey)})" if sortkey else ""
357364
sql = (
358-
f'CREATE TABLE IF NOT EXISTS "{schema}"."{table}" (\n'
365+
f"CREATE TABLE IF NOT EXISTS {_identifier(schema)}.{_identifier(table)} (\n"
359366
f"{cols_str}"
360367
f"{primary_keys_str}"
361368
f")\nDISTSTYLE {diststyle}"

0 commit comments

Comments
 (0)