Skip to content

Commit 924b0bb

Browse files
committed
Add Redshift tests
1 parent d4b27c6 commit 924b0bb

File tree

1 file changed

+74
-0
lines changed

1 file changed

+74
-0
lines changed

testing/test_awswrangler/test_db.py

Lines changed: 74 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -76,6 +76,11 @@ def external_schema(cloudformation_outputs, parameters, glue_database):
7676
yield "aws_data_wrangler_external"
7777

7878

79+
@pytest.fixture(scope="module")
80+
def kms_key_id(cloudformation_outputs):
81+
yield cloudformation_outputs["KmsKeyArn"].split("/", 1)[1]
82+
83+
7984
@pytest.mark.parametrize("db_type", ["mysql", "redshift", "postgresql"])
8085
def test_sql(parameters, db_type):
8186
df = get_df()
@@ -386,3 +391,72 @@ def test_redshift_category(bucket, parameters):
386391
for df2 in dfs:
387392
ensure_data_types_category(df2)
388393
wr.s3.delete_objects(path=path)
394+
395+
396+
def test_redshift_unload_extras(bucket, parameters, kms_key_id):
397+
table = "test_redshift_unload_extras"
398+
schema = parameters["redshift"]["schema"]
399+
path = f"s3://{bucket}/{table}/"
400+
wr.s3.delete_objects(path=path)
401+
engine = wr.catalog.get_engine(connection=f"aws-data-wrangler-redshift")
402+
df = pd.DataFrame({"id": [1, 2], "name": ["foo", "boo"]})
403+
wr.db.to_sql(df=df, con=engine, name=table, schema=schema, if_exists="replace", index=False)
404+
paths = wr.db.unload_redshift_to_files(
405+
sql=f"SELECT * FROM {schema}.{table}",
406+
path=path,
407+
con=engine,
408+
iam_role=parameters["redshift"]["role"],
409+
region=wr.s3.get_bucket_region(bucket),
410+
max_file_size=5.0,
411+
kms_key_id=kms_key_id,
412+
partition_cols=["name"],
413+
)
414+
wr.s3.wait_objects_exist(paths=paths)
415+
df = wr.s3.read_parquet(path=path, dataset=True)
416+
assert len(df.index) == 2
417+
assert len(df.columns) == 2
418+
wr.s3.delete_objects(path=path)
419+
df = wr.db.unload_redshift(
420+
sql=f"SELECT * FROM {schema}.{table}",
421+
con=engine,
422+
iam_role=parameters["redshift"]["role"],
423+
path=path,
424+
keep_files=False,
425+
region=wr.s3.get_bucket_region(bucket),
426+
max_file_size=5.0,
427+
kms_key_id=kms_key_id,
428+
)
429+
assert len(df.index) == 2
430+
assert len(df.columns) == 2
431+
wr.s3.delete_objects(path=path)
432+
433+
434+
@pytest.mark.parametrize("db_type", ["mysql", "redshift", "postgresql"])
435+
def test_to_sql_cast(parameters, db_type):
436+
table = "test_to_sql_cast"
437+
schema = parameters[db_type]["schema"]
438+
df = pd.DataFrame(
439+
{
440+
"col": [
441+
"".join([str(i)[-1] for i in range(1_024)]),
442+
"".join([str(i)[-1] for i in range(1_024)]),
443+
"".join([str(i)[-1] for i in range(1_024)]),
444+
]
445+
},
446+
dtype="string",
447+
)
448+
engine = wr.catalog.get_engine(connection=f"aws-data-wrangler-{db_type}")
449+
wr.db.to_sql(
450+
df=df,
451+
con=engine,
452+
name=table,
453+
schema=schema,
454+
if_exists="replace",
455+
index=False,
456+
index_label=None,
457+
chunksize=None,
458+
method=None,
459+
dtype={"col": sqlalchemy.types.VARCHAR(length=1_024)},
460+
)
461+
df2 = wr.db.read_sql_query(sql=f"SELECT * FROM {schema}.{table}", con=engine)
462+
assert df.equals(df2)

0 commit comments

Comments
 (0)