Skip to content

Commit 9903c9e

Browse files
committed
Add Redshift.to_parquet()
1 parent ed87ba6 commit 9903c9e

File tree

3 files changed

+67
-1
lines changed

3 files changed

+67
-1
lines changed

awswrangler/redshift.py

Lines changed: 47 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,4 @@
1-
from typing import Dict, List, Union, Optional
1+
from typing import Dict, List, Union, Optional, Any
22
import json
33
import logging
44

@@ -346,3 +346,49 @@ def _get_redshift_schema(dataframe, dataframe_type, preserve_index=False, cast_c
346346
else:
347347
raise InvalidDataframeType(dataframe_type)
348348
return schema_built
349+
350+
@staticmethod
351+
def to_parquet(sql: str,
352+
path: str,
353+
iam_role: str,
354+
redshift_conn: Any,
355+
partition_cols: Optional[List] = None) -> List[str]:
356+
"""
357+
Write a query result as parquet files on S3
358+
359+
:param sql: SQL Query
360+
:param path: AWS S3 path to write the data (e.g. s3://...)
361+
:param iam_role: AWS IAM role with the related permissions
362+
:param redshift_conn: A PEP 249 compatible connection (Can be generated with Redshift.generate_connection())
363+
:param partition_cols: Specifies the partition keys for the unload operation.
364+
"""
365+
sql = sql.replace("'", "\'").replace(";", "") # escaping single quote
366+
path = path if path[-1] == "/" else path + "/"
367+
cursor: Any = redshift_conn.cursor()
368+
partition_str: str = ""
369+
if partition_cols is not None:
370+
partition_str = f"PARTITION BY ({','.join([x for x in partition_cols])})\n"
371+
query: str = f"-- AWS DATA WRANGLER\n" \
372+
f"UNLOAD ('{sql}')\n" \
373+
f"TO '{path}'\n" \
374+
f"IAM_ROLE '{iam_role}'\n" \
375+
f"ALLOWOVERWRITE\n" \
376+
f"PARALLEL ON\n" \
377+
f"ENCRYPTED \n" \
378+
f"{partition_str}" \
379+
f"FORMAT PARQUET;"
380+
logger.debug(f"query:\n{query}")
381+
cursor.execute(query)
382+
query = "-- AWS DATA WRANGLER\nSELECT pg_last_query_id() AS query_id"
383+
logger.debug(f"query:\n{query}")
384+
cursor.execute(query)
385+
query_id = cursor.fetchall()[0][0]
386+
query = f"-- AWS DATA WRANGLER\n" \
387+
f"SELECT path FROM STL_UNLOAD_LOG WHERE query={query_id};"
388+
logger.debug(f"query:\n{query}")
389+
cursor.execute(query)
390+
paths: List[str] = [row[0].replace(" ", "") for row in cursor.fetchall()]
391+
logger.debug(f"paths: {paths}")
392+
redshift_conn.commit()
393+
cursor.close()
394+
return paths

testing/test_awswrangler/test_pandas.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1474,6 +1474,7 @@ def test_read_table2(session, bucket, database):
14741474
preserve_index=False,
14751475
procs_cpu_bound=4,
14761476
partition_cols=["partition"])
1477+
sleep(5)
14771478
df2 = session.pandas.read_table(database=database, table="test")
14781479
assert len(list(df.columns)) == len(list(df2.columns))
14791480
assert len(df.index) == len(df2.index)

testing/test_awswrangler/test_redshift.py

Lines changed: 19 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -508,3 +508,22 @@ def test_to_redshift_spark_decimal(session, bucket, redshift_parameters):
508508
elif row[2] == 3:
509509
assert row[1] == Decimal((0, (1, 9, 0), -2))
510510
assert row[2] == Decimal((0, (1, 9, 0, 0, 0, 0), -5))
511+
512+
513+
def test_to_parquet(bucket, redshift_parameters):
514+
con = Redshift.generate_connection(
515+
database="test",
516+
host=redshift_parameters.get("RedshiftAddress"),
517+
port=redshift_parameters.get("RedshiftPort"),
518+
user="test",
519+
password=redshift_parameters.get("RedshiftPassword"),
520+
)
521+
path = f"s3://{bucket}/test_to_parquet/"
522+
paths = Redshift.to_parquet(
523+
sql="SELECT * FROM public.test",
524+
path=path,
525+
iam_role=redshift_parameters.get("RedshiftRole"),
526+
redshift_conn=con,
527+
partition_cols=["name"]
528+
)
529+
assert len(paths) == 20

0 commit comments

Comments
 (0)