Skip to content

Commit e93fbfa

Browse files
authored
Merge pull request #214 from awslabs/redshift-unload
Add kms_key_id, max_file_size and region to Redshift Unload
2 parents a54a578 + ca133a0 commit e93fbfa

File tree

6 files changed

+169
-31
lines changed

6 files changed

+169
-31
lines changed

awswrangler/_data_types.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -372,7 +372,7 @@ def sqlalchemy_types_from_pandas(
372372
df: pd.DataFrame, db_type: str, dtype: Optional[Dict[str, VisitableType]] = None
373373
) -> Dict[str, VisitableType]:
374374
"""Extract the related SQLAlchemy data types from any Pandas DataFrame."""
375-
casts: Dict[str, VisitableType] = dtype if dtype else {}
375+
casts: Dict[str, VisitableType] = dtype if dtype is not None else {}
376376
pa_columns_types: Dict[str, Optional[pa.DataType]] = pyarrow_types_from_pandas(
377377
df=df, index=False, ignore_cols=list(casts.keys())
378378
)

awswrangler/db.py

Lines changed: 62 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -2,6 +2,7 @@
22

33
import json
44
import logging
5+
import time
56
from typing import Any, Dict, Iterator, List, Optional, Tuple, Union
67
from urllib.parse import quote_plus
78

@@ -91,7 +92,16 @@ def to_sql(df: pd.DataFrame, con: sqlalchemy.engine.Engine, **pandas_kwargs) ->
9192
)
9293
pandas_kwargs["dtype"] = dtypes
9394
pandas_kwargs["con"] = con
94-
df.to_sql(**pandas_kwargs)
95+
max_attempts: int = 3
96+
for attempt in range(max_attempts):
97+
try:
98+
df.to_sql(**pandas_kwargs)
99+
except sqlalchemy.exc.InternalError as ex: # pragma: no cover
100+
if attempt == (max_attempts - 1):
101+
raise ex
102+
time.sleep(1)
103+
else:
104+
break
95105

96106

97107
def read_sql_query(
@@ -887,6 +897,9 @@ def unload_redshift(
887897
path: str,
888898
con: sqlalchemy.engine.Engine,
889899
iam_role: str,
900+
region: Optional[str] = None,
901+
max_file_size: Optional[float] = None,
902+
kms_key_id: Optional[str] = None,
890903
categories: List[str] = None,
891904
chunked: Union[bool, int] = False,
892905
keep_files: bool = False,
@@ -937,6 +950,19 @@ def unload_redshift(
937950
wr.db.get_engine(), wr.db.get_redshift_temp_engine() or wr.catalog.get_engine()
938951
iam_role : str
939952
AWS IAM role with the related permissions.
953+
region : str, optional
954+
Specifies the AWS Region where the target Amazon S3 bucket is located.
955+
REGION is required for UNLOAD to an Amazon S3 bucket that isn't in the
956+
same AWS Region as the Amazon Redshift cluster. By default, UNLOAD
957+
assumes that the target Amazon S3 bucket is located in the same AWS
958+
Region as the Amazon Redshift cluster.
959+
max_file_size : float, optional
960+
Specifies the maximum size (MB) of files that UNLOAD creates in Amazon S3.
961+
Specify a decimal value between 5.0 MB and 6200.0 MB. If None, the default
962+
maximum file size is 6200.0 MB.
963+
kms_key_id : str, optional
964+
Specifies the key ID for an AWS Key Management Service (AWS KMS) key to be
965+
used to encrypt data files on Amazon S3.
940966
categories: List[str], optional
941967
List of columns names that should be returned as pandas.Categorical.
942968
Recommended for memory restricted environments.
@@ -973,7 +999,15 @@ def unload_redshift(
973999
"""
9741000
session: boto3.Session = _utils.ensure_session(session=boto3_session)
9751001
paths: List[str] = unload_redshift_to_files(
976-
sql=sql, path=path, con=con, iam_role=iam_role, use_threads=use_threads, boto3_session=session
1002+
sql=sql,
1003+
path=path,
1004+
con=con,
1005+
iam_role=iam_role,
1006+
region=region,
1007+
max_file_size=max_file_size,
1008+
kms_key_id=kms_key_id,
1009+
use_threads=use_threads,
1010+
boto3_session=session,
9771011
)
9781012
s3.wait_objects_exist(paths=paths, use_threads=False, boto3_session=session)
9791013
if chunked is False:
@@ -1032,6 +1066,9 @@ def unload_redshift_to_files(
10321066
path: str,
10331067
con: sqlalchemy.engine.Engine,
10341068
iam_role: str,
1069+
region: Optional[str] = None,
1070+
max_file_size: Optional[float] = None,
1071+
kms_key_id: Optional[str] = None,
10351072
use_threads: bool = True,
10361073
manifest: bool = False,
10371074
partition_cols: Optional[List] = None,
@@ -1056,6 +1093,19 @@ def unload_redshift_to_files(
10561093
wr.db.get_engine(), wr.db.get_redshift_temp_engine() or wr.catalog.get_engine()
10571094
iam_role : str
10581095
AWS IAM role with the related permissions.
1096+
region : str, optional
1097+
Specifies the AWS Region where the target Amazon S3 bucket is located.
1098+
REGION is required for UNLOAD to an Amazon S3 bucket that isn't in the
1099+
same AWS Region as the Amazon Redshift cluster. By default, UNLOAD
1100+
assumes that the target Amazon S3 bucket is located in the same AWS
1101+
Region as the Amazon Redshift cluster.
1102+
max_file_size : float, optional
1103+
Specifies the maximum size (MB) of files that UNLOAD creates in Amazon S3.
1104+
Specify a decimal value between 5.0 MB and 6200.0 MB. If None, the default
1105+
maximum file size is 6200.0 MB.
1106+
kms_key_id : str, optional
1107+
Specifies the key ID for an AWS Key Management Service (AWS KMS) key to be
1108+
used to encrypt data files on Amazon S3.
10591109
use_threads : bool
10601110
True to enable concurrent requests, False to disable multiple threads.
10611111
If enabled os.cpu_count() will be used as the max number of threads.
@@ -1086,19 +1136,26 @@ def unload_redshift_to_files(
10861136
session: boto3.Session = _utils.ensure_session(session=boto3_session)
10871137
s3.delete_objects(path=path, use_threads=use_threads, boto3_session=session)
10881138
with con.connect() as _con:
1089-
partition_str: str = f"PARTITION BY ({','.join(partition_cols)})\n" if partition_cols else ""
1139+
partition_str: str = f"\nPARTITION BY ({','.join(partition_cols)})" if partition_cols else ""
10901140
manifest_str: str = "\nmanifest" if manifest is True else ""
1141+
region_str: str = f"\nREGION AS '{region}'" if region is not None else ""
1142+
max_file_size_str: str = f"\nMAXFILESIZE AS {max_file_size} MB" if max_file_size is not None else ""
1143+
kms_key_id_str: str = f"\nKMS_KEY_ID '{kms_key_id}'" if kms_key_id is not None else ""
10911144
sql = (
10921145
f"UNLOAD ('{sql}')\n"
10931146
f"TO '{path}'\n"
10941147
f"IAM_ROLE '{iam_role}'\n"
10951148
"ALLOWOVERWRITE\n"
10961149
"PARALLEL ON\n"
1097-
"ENCRYPTED\n"
1150+
"FORMAT PARQUET\n"
1151+
"ENCRYPTED"
1152+
f"{kms_key_id_str}"
10981153
f"{partition_str}"
1099-
"FORMAT PARQUET"
1154+
f"{region_str}"
1155+
f"{max_file_size_str}"
11001156
f"{manifest_str};"
11011157
)
1158+
_logger.debug("sql: \n%s", sql)
11021159
_con.execute(sql)
11031160
sql = "SELECT pg_last_query_id() AS query_id"
11041161
query_id: int = _con.execute(sql).fetchall()[0][0]

awswrangler/s3.py

Lines changed: 21 additions & 23 deletions
Original file line numberDiff line numberDiff line change
@@ -1132,7 +1132,7 @@ def _to_parquet_dataset(
11321132
schema: pa.Schema = _data_types.pyarrow_schema_from_pandas(
11331133
df=df, index=index, ignore_cols=partition_cols, dtype=dtype
11341134
)
1135-
_logger.debug("schema: %s", schema)
1135+
_logger.debug("schema: \n%s", schema)
11361136
if not partition_cols:
11371137
file_path: str = f"{path}{uuid.uuid4().hex}{compression_ext}.parquet"
11381138
_to_parquet_file(
@@ -1688,12 +1688,7 @@ def read_parquet(
16881688
data=data, columns=columns, categories=categories, use_threads=use_threads, validate_schema=validate_schema
16891689
)
16901690
return _read_parquet_chunked(
1691-
data=data,
1692-
columns=columns,
1693-
categories=categories,
1694-
chunked=chunked,
1695-
use_threads=use_threads,
1696-
validate_schema=validate_schema,
1691+
data=data, columns=columns, categories=categories, chunked=chunked, use_threads=use_threads
16971692
)
16981693

16991694

@@ -1728,29 +1723,32 @@ def _read_parquet_chunked(
17281723
data: pyarrow.parquet.ParquetDataset,
17291724
columns: Optional[List[str]] = None,
17301725
categories: List[str] = None,
1731-
validate_schema: bool = True,
17321726
chunked: Union[bool, int] = True,
17331727
use_threads: bool = True,
17341728
) -> Iterator[pd.DataFrame]:
1735-
promote: bool = not validate_schema
1736-
next_slice: Optional[pa.Table] = None
1729+
next_slice: Optional[pd.DataFrame] = None
17371730
for piece in data.pieces:
1738-
table: pa.Table = piece.read(
1739-
columns=columns, use_threads=use_threads, partitions=data.partitions, use_pandas_metadata=False
1731+
df: pd.DataFrame = _table2df(
1732+
table=piece.read(
1733+
columns=columns, use_threads=use_threads, partitions=data.partitions, use_pandas_metadata=False
1734+
),
1735+
categories=categories,
1736+
use_threads=use_threads,
17401737
)
17411738
if chunked is True:
1742-
yield _table2df(table=table, categories=categories, use_threads=use_threads)
1739+
yield df
17431740
else:
1744-
if next_slice:
1745-
table = pa.lib.concat_tables([next_slice, table], promote=promote)
1746-
while len(table) >= chunked:
1747-
yield _table2df(
1748-
table=table.slice(offset=0, length=chunked), categories=categories, use_threads=use_threads
1749-
)
1750-
table = table.slice(offset=chunked, length=None)
1751-
next_slice = table
1752-
if next_slice:
1753-
yield _table2df(table=next_slice, categories=categories, use_threads=use_threads)
1741+
if next_slice is not None:
1742+
df = pd.concat(objs=[next_slice, df], ignore_index=True, sort=False)
1743+
while len(df.index) >= chunked:
1744+
yield df.iloc[:chunked]
1745+
df = df.iloc[chunked:]
1746+
if df.empty:
1747+
next_slice = None
1748+
else:
1749+
next_slice = df
1750+
if next_slice is not None:
1751+
yield next_slice
17541752

17551753

17561754
def _table2df(table: pa.Table, categories: List[str] = None, use_threads: bool = True) -> pd.DataFrame:

testing/cloudformation.yaml

Lines changed: 9 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -96,6 +96,15 @@ Resources:
9696
PolicyDocument:
9797
Version: 2012-10-17
9898
Statement:
99+
- Effect: Allow
100+
Action:
101+
- kms:Encrypt
102+
- kms:Decrypt
103+
- kms:GenerateDataKey
104+
Resource:
105+
- Fn::GetAtt:
106+
- KmsKey
107+
- Arn
99108
- Effect: Allow
100109
Action:
101110
- s3:Get*

testing/test_awswrangler/test_db.py

Lines changed: 74 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -76,6 +76,11 @@ def external_schema(cloudformation_outputs, parameters, glue_database):
7676
yield "aws_data_wrangler_external"
7777

7878

79+
@pytest.fixture(scope="module")
80+
def kms_key_id(cloudformation_outputs):
81+
yield cloudformation_outputs["KmsKeyArn"].split("/", 1)[1]
82+
83+
7984
@pytest.mark.parametrize("db_type", ["mysql", "redshift", "postgresql"])
8085
def test_sql(parameters, db_type):
8186
df = get_df()
@@ -386,3 +391,72 @@ def test_redshift_category(bucket, parameters):
386391
for df2 in dfs:
387392
ensure_data_types_category(df2)
388393
wr.s3.delete_objects(path=path)
394+
395+
396+
def test_redshift_unload_extras(bucket, parameters, kms_key_id):
397+
table = "test_redshift_unload_extras"
398+
schema = parameters["redshift"]["schema"]
399+
path = f"s3://{bucket}/{table}/"
400+
wr.s3.delete_objects(path=path)
401+
engine = wr.catalog.get_engine(connection=f"aws-data-wrangler-redshift")
402+
df = pd.DataFrame({"id": [1, 2], "name": ["foo", "boo"]})
403+
wr.db.to_sql(df=df, con=engine, name=table, schema=schema, if_exists="replace", index=False)
404+
paths = wr.db.unload_redshift_to_files(
405+
sql=f"SELECT * FROM {schema}.{table}",
406+
path=path,
407+
con=engine,
408+
iam_role=parameters["redshift"]["role"],
409+
region=wr.s3.get_bucket_region(bucket),
410+
max_file_size=5.0,
411+
kms_key_id=kms_key_id,
412+
partition_cols=["name"],
413+
)
414+
wr.s3.wait_objects_exist(paths=paths)
415+
df = wr.s3.read_parquet(path=path, dataset=True)
416+
assert len(df.index) == 2
417+
assert len(df.columns) == 2
418+
wr.s3.delete_objects(path=path)
419+
df = wr.db.unload_redshift(
420+
sql=f"SELECT * FROM {schema}.{table}",
421+
con=engine,
422+
iam_role=parameters["redshift"]["role"],
423+
path=path,
424+
keep_files=False,
425+
region=wr.s3.get_bucket_region(bucket),
426+
max_file_size=5.0,
427+
kms_key_id=kms_key_id,
428+
)
429+
assert len(df.index) == 2
430+
assert len(df.columns) == 2
431+
wr.s3.delete_objects(path=path)
432+
433+
434+
@pytest.mark.parametrize("db_type", ["mysql", "redshift", "postgresql"])
435+
def test_to_sql_cast(parameters, db_type):
436+
table = "test_to_sql_cast"
437+
schema = parameters[db_type]["schema"]
438+
df = pd.DataFrame(
439+
{
440+
"col": [
441+
"".join([str(i)[-1] for i in range(1_024)]),
442+
"".join([str(i)[-1] for i in range(1_024)]),
443+
"".join([str(i)[-1] for i in range(1_024)]),
444+
]
445+
},
446+
dtype="string",
447+
)
448+
engine = wr.catalog.get_engine(connection=f"aws-data-wrangler-{db_type}")
449+
wr.db.to_sql(
450+
df=df,
451+
con=engine,
452+
name=table,
453+
schema=schema,
454+
if_exists="replace",
455+
index=False,
456+
index_label=None,
457+
chunksize=None,
458+
method=None,
459+
dtype={"col": sqlalchemy.types.VARCHAR(length=1_024)},
460+
)
461+
df2 = wr.db.read_sql_query(sql=f"SELECT * FROM {schema}.{table}", con=engine)
462+
assert df.equals(df2)

tox.ini

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -8,11 +8,11 @@ deps =
88
moto
99
-rrequirements-torch.txt
1010
commands =
11-
pytest -n 16 testing/test_awswrangler
11+
pytest -n 8 testing/test_awswrangler
1212

1313
[testenv:py36]
1414
deps =
1515
{[testenv]deps}
1616
pytest-cov
1717
commands =
18-
pytest --cov=awswrangler -n 16 testing/test_awswrangler
18+
pytest --cov=awswrangler -n 8 testing/test_awswrangler

0 commit comments

Comments
 (0)