Skip to content

Commit 335ce5a

Browse files
authored
Adding flag to skip Redshift transaction commit (#718)
* Adding flag to skip Redshift transaction commit * Adding tests
1 parent f155691 commit 335ce5a

File tree

2 files changed

+42
-3
lines changed

2 files changed

+42
-3
lines changed

awswrangler/redshift.py

Lines changed: 11 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -685,7 +685,7 @@ def read_sql_table(
685685

686686

687687
@apply_configs
688-
def to_sql(
688+
def to_sql( # pylint: disable=too-many-locals
689689
df: pd.DataFrame,
690690
con: redshift_connector.Connection,
691691
table: str,
@@ -704,6 +704,7 @@ def to_sql(
704704
use_column_names: bool = False,
705705
lock: bool = False,
706706
chunksize: int = 200,
707+
commit_transaction: bool = True,
707708
) -> None:
708709
"""Write records stored in a DataFrame into Redshift.
709710
@@ -764,6 +765,8 @@ def to_sql(
764765
True to execute LOCK command inside the transaction to force serializable isolation.
765766
chunksize: int
766767
Number of rows which are inserted with each SQL query. Defaults to inserting 200 rows per query.
768+
commit_transaction: bool
769+
Whether to commit the transaction. True by default.
767770
768771
Returns
769772
-------
@@ -829,7 +832,8 @@ def to_sql(
829832
if lock:
830833
_lock(cursor, [table], schema=schema)
831834
_upsert(cursor=cursor, schema=schema, table=table, temp_table=created_table, primary_keys=primary_keys)
832-
con.commit()
835+
if commit_transaction:
836+
con.commit()
833837
except Exception as ex:
834838
con.rollback()
835839
_logger.error(ex)
@@ -1139,6 +1143,7 @@ def copy_from_files( # pylint: disable=too-many-locals,too-many-arguments
11391143
path_ignore_suffix: Optional[str] = None,
11401144
use_threads: bool = True,
11411145
lock: bool = False,
1146+
commit_transaction: bool = True,
11421147
boto3_session: Optional[boto3.Session] = None,
11431148
s3_additional_kwargs: Optional[Dict[str, str]] = None,
11441149
) -> None:
@@ -1227,6 +1232,8 @@ def copy_from_files( # pylint: disable=too-many-locals,too-many-arguments
12271232
If enabled os.cpu_count() will be used as the max number of threads.
12281233
lock : bool
12291234
True to execute LOCK command inside the transaction to force serializable isolation.
1235+
commit_transaction: bool
1236+
Whether to commit the transaction. True by default.
12301237
boto3_session : boto3.Session(), optional
12311238
Boto3 Session. The default boto3 session will be used if boto3_session receive None.
12321239
s3_additional_kwargs:
@@ -1302,7 +1309,8 @@ def copy_from_files( # pylint: disable=too-many-locals,too-many-arguments
13021309
if lock:
13031310
_lock(cursor, [table], schema=schema)
13041311
_upsert(cursor=cursor, schema=schema, table=table, temp_table=created_table, primary_keys=primary_keys)
1305-
con.commit()
1312+
if commit_transaction:
1313+
con.commit()
13061314
except Exception as ex:
13071315
con.rollback()
13081316
_logger.error(ex)

tests/test_redshift.py

Lines changed: 31 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -914,3 +914,34 @@ def test_dfs_are_equal_for_different_chunksizes(redshift_table, redshift_con, ch
914914
df["c1"] = df["c1"].astype("string")
915915

916916
assert df.equals(df2)
917+
918+
919+
def test_to_sql_multi_transaction(redshift_table, redshift_con):
920+
df = pd.DataFrame({"id": list((range(10))), "val": list(["foo" if i % 2 == 0 else "boo" for i in range(10)])})
921+
df2 = pd.DataFrame({"id": list((range(10, 15))), "val": list(["foo" if i % 2 == 0 else "boo" for i in range(5)])})
922+
923+
wr.redshift.to_sql(
924+
df=df,
925+
con=redshift_con,
926+
schema="public",
927+
table=redshift_table,
928+
mode="overwrite",
929+
index=False,
930+
primary_keys=["id"],
931+
commit_transaction=False, # Not committing
932+
)
933+
934+
wr.redshift.to_sql(
935+
df=df2,
936+
con=redshift_con,
937+
schema="public",
938+
table=redshift_table,
939+
mode="upsert",
940+
index=False,
941+
primary_keys=["id"],
942+
commit_transaction=False, # Not committing
943+
)
944+
redshift_con.commit()
945+
df3 = wr.redshift.read_sql_query(sql=f"SELECT * FROM public.{redshift_table} ORDER BY id", con=redshift_con)
946+
assert len(df.index) + len(df2.index) == len(df3.index)
947+
assert len(df.columns) == len(df3.columns)

0 commit comments

Comments
 (0)