Skip to content

Commit 2da3134

Browse files
authored
Add Redshift overwrite methods (#676)
1 parent dff4aa6 commit 2da3134

File tree

2 files changed

+83
-8
lines changed

2 files changed

+83
-8
lines changed

awswrangler/redshift.py

Lines changed: 80 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,5 @@
11
"""Amazon Redshift Module."""
2+
# pylint: disable=too-many-lines
23

34
import logging
45
import uuid
@@ -30,13 +31,34 @@ def _validate_connection(con: redshift_connector.Connection) -> None:
3031
)
3132

3233

33-
def _drop_table(cursor: redshift_connector.Cursor, schema: Optional[str], table: str) -> None:
34+
def _begin_transaction(cursor: redshift_connector.Cursor) -> None:
35+
sql = "BEGIN TRANSACTION"
36+
_logger.debug("Begin transaction query:\n%s", sql)
37+
cursor.execute(sql)
38+
39+
40+
def _drop_table(cursor: redshift_connector.Cursor, schema: Optional[str], table: str, cascade: bool = False) -> None:
3441
schema_str = f'"{schema}".' if schema else ""
35-
sql = f'DROP TABLE IF EXISTS {schema_str}"{table}"'
42+
cascade_str = " CASCADE" if cascade else ""
43+
sql = f'DROP TABLE IF EXISTS {schema_str}"{table}"' f"{cascade_str}"
3644
_logger.debug("Drop table query:\n%s", sql)
3745
cursor.execute(sql)
3846

3947

48+
def _truncate_table(cursor: redshift_connector.Cursor, schema: Optional[str], table: str) -> None:
49+
schema_str = f'"{schema}".' if schema else ""
50+
sql = f'TRUNCATE TABLE {schema_str}"{table}"'
51+
_logger.debug("Truncate table query:\n%s", sql)
52+
cursor.execute(sql)
53+
54+
55+
def _delete_all(cursor: redshift_connector.Cursor, schema: Optional[str], table: str) -> None:
56+
schema_str = f'"{schema}".' if schema else ""
57+
sql = f'DELETE FROM {schema_str}"{table}"'
58+
_logger.debug("Delete query:\n%s", sql)
59+
cursor.execute(sql)
60+
61+
4062
def _get_primary_keys(cursor: redshift_connector.Cursor, schema: str, table: str) -> List[str]:
4163
cursor.execute(f"SELECT indexdef FROM pg_indexes WHERE schemaname = '{schema}' AND tablename = '{table}'")
4264
result: str = cursor.fetchall()[0][0]
@@ -214,13 +236,15 @@ def _redshift_types_from_path(
214236
return redshift_types
215237

216238

217-
def _create_table(
239+
def _create_table( # pylint: disable=too-many-locals,too-many-arguments
218240
df: Optional[pd.DataFrame],
219241
path: Optional[Union[str, List[str]]],
242+
con: redshift_connector.Connection,
220243
cursor: redshift_connector.Cursor,
221244
table: str,
222245
schema: str,
223246
mode: str,
247+
overwrite_method: str,
224248
index: bool,
225249
dtype: Optional[Dict[str, str]],
226250
diststyle: str,
@@ -238,7 +262,25 @@ def _create_table(
238262
s3_additional_kwargs: Optional[Dict[str, str]] = None,
239263
) -> Tuple[str, Optional[str]]:
240264
if mode == "overwrite":
241-
_drop_table(cursor=cursor, schema=schema, table=table)
265+
if overwrite_method == "truncate":
266+
try:
267+
# Truncate commits current transaction, if successful.
268+
# Fast, but not atomic.
269+
_truncate_table(cursor=cursor, schema=schema, table=table)
270+
except redshift_connector.error.ProgrammingError as e:
271+
# Caught "relation does not exist".
272+
if e.args[0]["C"] != "42P01": # pylint: disable=invalid-sequence-index
273+
raise e
274+
_logger.debug(str(e))
275+
con.rollback()
276+
_begin_transaction(cursor=cursor)
277+
elif overwrite_method == "delete":
278+
if _does_table_exist(cursor=cursor, schema=schema, table=table):
279+
# Atomic, but slow.
280+
_delete_all(cursor=cursor, schema=schema, table=table)
281+
else:
282+
# Fast, atomic, but either fails if there are any dependent views or, in cascade mode, deletes them.
283+
_drop_table(cursor=cursor, schema=schema, table=table, cascade=bool(overwrite_method == "cascade"))
242284
elif _does_table_exist(cursor=cursor, schema=schema, table=table) is True:
243285
if mode == "upsert":
244286
guid: str = uuid.uuid4().hex
@@ -649,6 +691,7 @@ def to_sql(
649691
table: str,
650692
schema: str,
651693
mode: str = "append",
694+
overwrite_method: str = "drop",
652695
index: bool = False,
653696
dtype: Optional[Dict[str, str]] = None,
654697
diststyle: str = "AUTO",
@@ -682,6 +725,14 @@ def to_sql(
682725
Schema name
683726
mode : str
684727
Append, overwrite or upsert.
728+
overwrite_method : str
729+
Drop, cascade, truncate, or delete. Only applicable in overwrite mode.
730+
731+
"drop" - ``DROP ... RESTRICT`` - drops the table. Fails if there are any views that depend on it.
732+
"cascade" - ``DROP ... CASCADE`` - drops the table, and all views that depend on it.
733+
"truncate" - ``TRUNCATE ...`` - truncates the table, but immediatly commits current
734+
transaction & starts a new one, hence the overwrite happens in two transactions and is not atomic.
735+
"delete" - ``DELETE FROM ...`` - deletes all rows from the table. Slow relative to the other methods.
685736
index : bool
686737
True to store the DataFrame index as a column in the table,
687738
otherwise False to ignore it.
@@ -744,10 +795,12 @@ def to_sql(
744795
created_table, created_schema = _create_table(
745796
df=df,
746797
path=None,
798+
con=con,
747799
cursor=cursor,
748800
table=table,
749801
schema=schema,
750802
mode=mode,
803+
overwrite_method=overwrite_method,
751804
index=index,
752805
dtype=dtype,
753806
diststyle=diststyle,
@@ -1073,6 +1126,7 @@ def copy_from_files( # pylint: disable=too-many-locals,too-many-arguments
10731126
aws_session_token: Optional[str] = None,
10741127
parquet_infer_sampling: float = 1.0,
10751128
mode: str = "append",
1129+
overwrite_method: str = "drop",
10761130
diststyle: str = "AUTO",
10771131
distkey: Optional[str] = None,
10781132
sortstyle: str = "COMPOUND",
@@ -1130,6 +1184,14 @@ def copy_from_files( # pylint: disable=too-many-locals,too-many-arguments
11301184
The lower, the faster.
11311185
mode : str
11321186
Append, overwrite or upsert.
1187+
overwrite_method : str
1188+
Drop, cascade, truncate, or delete. Only applicable in overwrite mode.
1189+
1190+
"drop" - ``DROP ... RESTRICT`` - drops the table. Fails if there are any views that depend on it.
1191+
"cascade" - ``DROP ... CASCADE`` - drops the table, and all views that depend on it.
1192+
"truncate" - ``TRUNCATE ...`` - truncates the table, but immediatly commits current
1193+
transaction & starts a new one, hence the overwrite happens in two transactions and is not atomic.
1194+
"delete" - ``DELETE FROM ...`` - deletes all rows from the table. Slow relative to the other methods.
11331195
diststyle : str
11341196
Redshift distribution styles. Must be in ["AUTO", "EVEN", "ALL", "KEY"].
11351197
https://docs.aws.amazon.com/redshift/latest/dg/t_Distributing_data.html
@@ -1202,10 +1264,12 @@ def copy_from_files( # pylint: disable=too-many-locals,too-many-arguments
12021264
parquet_infer_sampling=parquet_infer_sampling,
12031265
path_suffix=path_suffix,
12041266
path_ignore_suffix=path_ignore_suffix,
1267+
con=con,
12051268
cursor=cursor,
12061269
table=table,
12071270
schema=schema,
12081271
mode=mode,
1272+
overwrite_method=overwrite_method,
12091273
diststyle=diststyle,
12101274
sortstyle=sortstyle,
12111275
distkey=distkey,
@@ -1260,6 +1324,7 @@ def copy( # pylint: disable=too-many-arguments
12601324
index: bool = False,
12611325
dtype: Optional[Dict[str, str]] = None,
12621326
mode: str = "append",
1327+
overwrite_method: str = "drop",
12631328
diststyle: str = "AUTO",
12641329
distkey: Optional[str] = None,
12651330
sortstyle: str = "COMPOUND",
@@ -1327,9 +1392,17 @@ def copy( # pylint: disable=too-many-arguments
13271392
Useful when you have columns with undetermined or mixed data types.
13281393
Only takes effect if dataset=True.
13291394
(e.g. {'col name': 'bigint', 'col2 name': 'int'})
1330-
mode : str
1395+
mode: str
13311396
Append, overwrite or upsert.
1332-
diststyle : str
1397+
overwrite_method : str
1398+
Drop, cascade, truncate, or delete. Only applicable in overwrite mode.
1399+
1400+
"drop" - ``DROP ... RESTRICT`` - drops the table. Fails if there are any views that depend on it.
1401+
"cascade" - ``DROP ... CASCADE`` - drops the table, and all views that depend on it.
1402+
"truncate" - ``TRUNCATE ...`` - truncates the table, but immediatly commits current
1403+
transaction & starts a new one, hence the overwrite happens in two transactions and is not atomic.
1404+
"delete" - ``DELETE FROM ...`` - deletes all rows from the table. Slow relative to the other methods.
1405+
diststyle: str
13331406
Redshift distribution styles. Must be in ["AUTO", "EVEN", "ALL", "KEY"].
13341407
https://docs.aws.amazon.com/redshift/latest/dg/t_Distributing_data.html
13351408
distkey : str, optional
@@ -1416,6 +1489,7 @@ def copy( # pylint: disable=too-many-arguments
14161489
aws_secret_access_key=aws_secret_access_key,
14171490
aws_session_token=aws_session_token,
14181491
mode=mode,
1492+
overwrite_method=overwrite_method,
14191493
diststyle=diststyle,
14201494
distkey=distkey,
14211495
sortstyle=sortstyle,

tests/test_redshift.py

Lines changed: 3 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -42,9 +42,10 @@ def test_read_sql_query_simple(databases_parameters):
4242
assert df.shape == (1, 1)
4343

4444

45-
def test_to_sql_simple(redshift_table, redshift_con):
45+
@pytest.mark.parametrize("overwrite_method", [None, "drop", "cascade", "truncate", "delete"])
46+
def test_to_sql_simple(redshift_table, redshift_con, overwrite_method):
4647
df = pd.DataFrame({"c0": [1, 2, 3], "c1": ["foo", "boo", "bar"]})
47-
wr.redshift.to_sql(df, redshift_con, redshift_table, "public", "overwrite", True)
48+
wr.redshift.to_sql(df, redshift_con, redshift_table, "public", "overwrite", overwrite_method, True)
4849

4950

5051
def test_sql_types(redshift_table, redshift_con):

0 commit comments

Comments
 (0)