Skip to content

Commit c703f45

Browse files
committed
Fix keep_files behaviour for failed redshift COPY. #505
1 parent cc5618c commit c703f45

File tree

5 files changed

+61
-50
lines changed

5 files changed

+61
-50
lines changed

awswrangler/redshift.py

Lines changed: 39 additions & 37 deletions
Original file line numberDiff line numberDiff line change
@@ -952,7 +952,7 @@ def unload(
952952
List of columns names that should be returned as pandas.Categorical.
953953
Recommended for memory restricted environments.
954954
keep_files : bool
955-
Should keep the stage files?
955+
Should keep stage files?
956956
chunked : Union[int, bool]
957957
If passed will split the data in a Iterable of DataFrames (Memory friendly).
958958
If `True` wrangler will iterate on the data by files in the most efficient way without guarantee of chunksize.
@@ -1290,7 +1290,7 @@ def copy( # pylint: disable=too-many-arguments
12901290
varchar_lengths : Dict[str, int], optional
12911291
Dict of VARCHAR length by columns. (e.g. {"col1": 10, "col5": 200}).
12921292
keep_files : bool
1293-
Should keep the stage files?
1293+
Should keep stage files?
12941294
use_threads : bool
12951295
True to enable concurrent requests, False to disable multiple threads.
12961296
If enabled os.cpu_count() will be used as the max number of threads.
@@ -1334,38 +1334,40 @@ def copy( # pylint: disable=too-many-arguments
13341334
f"The received S3 path ({path}) is not empty. "
13351335
"Please, provide a different path or use wr.s3.delete_objects() to clean up the current one."
13361336
)
1337-
s3.to_parquet(
1338-
df=df,
1339-
path=path,
1340-
index=index,
1341-
dataset=True,
1342-
mode="append",
1343-
dtype=dtype,
1344-
use_threads=use_threads,
1345-
boto3_session=session,
1346-
s3_additional_kwargs=s3_additional_kwargs,
1347-
max_rows_by_file=max_rows_by_file,
1348-
)
1349-
copy_from_files(
1350-
path=path,
1351-
con=con,
1352-
table=table,
1353-
schema=schema,
1354-
iam_role=iam_role,
1355-
aws_access_key_id=aws_access_key_id,
1356-
aws_secret_access_key=aws_secret_access_key,
1357-
aws_session_token=aws_session_token,
1358-
mode=mode,
1359-
diststyle=diststyle,
1360-
distkey=distkey,
1361-
sortstyle=sortstyle,
1362-
sortkey=sortkey,
1363-
primary_keys=primary_keys,
1364-
varchar_lengths_default=varchar_lengths_default,
1365-
varchar_lengths=varchar_lengths,
1366-
use_threads=use_threads,
1367-
boto3_session=session,
1368-
s3_additional_kwargs=s3_additional_kwargs,
1369-
)
1370-
if keep_files is False:
1371-
s3.delete_objects(path=path, use_threads=use_threads, boto3_session=session)
1337+
try:
1338+
s3.to_parquet(
1339+
df=df,
1340+
path=path,
1341+
index=index,
1342+
dataset=True,
1343+
mode="append",
1344+
dtype=dtype,
1345+
use_threads=use_threads,
1346+
boto3_session=session,
1347+
s3_additional_kwargs=s3_additional_kwargs,
1348+
max_rows_by_file=max_rows_by_file,
1349+
)
1350+
copy_from_files(
1351+
path=path,
1352+
con=con,
1353+
table=table,
1354+
schema=schema,
1355+
iam_role=iam_role,
1356+
aws_access_key_id=aws_access_key_id,
1357+
aws_secret_access_key=aws_secret_access_key,
1358+
aws_session_token=aws_session_token,
1359+
mode=mode,
1360+
diststyle=diststyle,
1361+
distkey=distkey,
1362+
sortstyle=sortstyle,
1363+
sortkey=sortkey,
1364+
primary_keys=primary_keys,
1365+
varchar_lengths_default=varchar_lengths_default,
1366+
varchar_lengths=varchar_lengths,
1367+
use_threads=use_threads,
1368+
boto3_session=session,
1369+
s3_additional_kwargs=s3_additional_kwargs,
1370+
)
1371+
finally:
1372+
if keep_files is False:
1373+
s3.delete_objects(path=path, use_threads=use_threads, boto3_session=session)

tests/test_athena.py

Lines changed: 0 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -669,7 +669,6 @@ def test_glue_database():
669669

670670
# Round 1 - Create Database
671671
glue_database_name = f"database_{get_time_str_with_random_suffix()}"
672-
print(f"Database Name: {glue_database_name}")
673672
wr.catalog.create_database(name=glue_database_name, description="Database Description")
674673
databases = wr.catalog.get_databases()
675674
test_database_name = ""
@@ -684,7 +683,6 @@ def test_glue_database():
684683
assert test_database_description == "Database Description"
685684

686685
# Round 2 - Delete Database
687-
print(f"Glue Database Name: {glue_database_name}")
688686
wr.catalog.delete_database(name=glue_database_name)
689687
databases = wr.catalog.get_databases()
690688
test_database_name = ""
@@ -786,8 +784,6 @@ def test_describe_table(path, glue_database, glue_table):
786784
def test_athena_nan_inf(glue_database, ctas_approach, data_source):
787785
sql = "SELECT nan() AS nan, infinity() as inf, -infinity() as inf_n, 1.2 as regular"
788786
df = wr.athena.read_sql_query(sql, glue_database, ctas_approach, data_source=data_source)
789-
print(df)
790-
print(df.dtypes)
791787
assert df.shape == (1, 4)
792788
assert df.dtypes.to_list() == ["float64", "float64", "float64", "float64"]
793789
assert np.isnan(df.nan.iloc[0])

tests/test_athena_csv.py

Lines changed: 5 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -330,12 +330,11 @@ def test_athena_csv_types(path, glue_database, glue_table):
330330
wr.athena.repair_table(glue_table, glue_database)
331331
assert len(wr.catalog.get_csv_partitions(glue_database, glue_table)) == 3
332332
df2 = wr.athena.read_sql_table(glue_table, glue_database)
333-
print(df2)
334-
# assert len(df2.index) == 3
335-
# assert len(df2.columns) == 10
336-
# assert df2["id"].sum() == 6
337-
# ensure_data_types_csv(df2)
338-
# assert wr.catalog.delete_table_if_exists(database=glue_database, table=glue_table) is True
333+
assert len(df2.index) == 3
334+
assert len(df2.columns) == 10
335+
assert df2["id"].sum() == 6
336+
ensure_data_types_csv(df2)
337+
assert wr.catalog.delete_table_if_exists(database=glue_database, table=glue_table) is True
339338

340339

341340
@pytest.mark.parametrize("use_threads", [True, False])

tests/test_athena_projection.py

Lines changed: 0 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -74,7 +74,6 @@ def test_to_parquet_projection_date(glue_database, glue_table, path):
7474
projection_ranges={"c1": "2020-01-01,2020-01-03", "c2": "2020-01-01 01:01:00,2020-01-01 01:01:03"},
7575
)
7676
df2 = wr.athena.read_sql_table(glue_table, glue_database)
77-
print(df2)
7877
assert df.shape == df2.shape
7978
assert df.c0.sum() == df2.c0.sum()
8079

tests/test_redshift.py

Lines changed: 17 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -8,6 +8,7 @@
88
import pyarrow as pa
99
import pytest
1010
import redshift_connector
11+
from redshift_connector.error import ProgrammingError
1112

1213
import awswrangler as wr
1314
from awswrangler import _utils
@@ -888,6 +889,20 @@ def test_column_length(path, redshift_table, databases_parameters):
888889
)
889890
df2 = wr.redshift.read_sql_query(sql=f"SELECT * FROM public.{redshift_table}", con=con)
890891
con.close()
891-
print(df.dtypes)
892-
print(df2.dtypes)
893892
assert df2.equals(df)
893+
894+
895+
def test_failed_keep_files(path, redshift_table, databases_parameters):
896+
df = pd.DataFrame({"c0": [1], "c1": ["foo"]}, dtype="string")
897+
con = wr.redshift.connect("aws-data-wrangler-redshift")
898+
with pytest.raises(ProgrammingError):
899+
wr.redshift.copy(
900+
df=df,
901+
path=path,
902+
con=con,
903+
table=redshift_table,
904+
schema="public",
905+
iam_role=databases_parameters["redshift"]["role"],
906+
varchar_lengths={"c1": 2},
907+
)
908+
assert len(wr.s3.list_objects(path)) == 0

0 commit comments

Comments
 (0)