Skip to content

Commit eac8362

Browse files
committed
Add Pandas.read_sql_redshift()
1 parent c99dca2 commit eac8362

File tree

5 files changed

+137
-12
lines changed

5 files changed

+137
-12
lines changed

awswrangler/pandas.py

Lines changed: 43 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -523,7 +523,7 @@ def read_sql_athena(self,
523523
:param workgroup: The name of the workgroup in which the query is being started. (By default uses de Session() workgroup)
524524
:param encryption: None|'SSE_S3'|'SSE_KMS'|'CSE_KMS'
525525
:param kms_key: For SSE-KMS and CSE-KMS , this is the KMS key ARN or ID.
526-
:param ctas_approach: Wraps the query with a CTAS (Session's deafult is False)
526+
:param ctas_approach: Wraps the query with a CTAS (Session's default is False)
527527
:param procs_cpu_bound: Number of cores used for CPU bound tasks
528528
:param max_result_size: Max number of bytes on each request to S3 (VALID ONLY FOR ctas_approach=False)
529529
:return: Pandas Dataframe or Iterator of Pandas Dataframes if max_result_size was passed
@@ -1376,3 +1376,45 @@ def read_table(self,
13761376
"""
13771377
path: str = self._session.glue.get_table_location(database=database, table=table)
13781378
return self.read_parquet(path=path, columns=columns, filters=filters, procs_cpu_bound=procs_cpu_bound)
1379+
1380+
def read_sql_redshift(self,
1381+
sql: str,
1382+
iam_role: str,
1383+
connection: Any,
1384+
temp_s3_path: Optional[str] = None,
1385+
procs_cpu_bound: Optional[int] = None) -> pd.DataFrame:
1386+
"""
1387+
Convert a query result in a Pandas Dataframe.
1388+
1389+
:param sql: SQL Query
1390+
:param iam_role: AWS IAM role with the related permissions
1391+
:param connection: A PEP 249 compatible connection (Can be generated with Redshift.generate_connection())
1392+
:param temp_s3_path: AWS S3 path to write temporary data (e.g. s3://...) (Default uses the Athena's results bucket)
1393+
:param procs_cpu_bound: Number of cores used for CPU bound tasks
1394+
"""
1395+
guid: str = pa.compat.guid()
1396+
name: str = f"temp_redshift_{guid}"
1397+
if temp_s3_path is None:
1398+
if self._session.athena_s3_output is not None:
1399+
temp_s3_path = self._session.redshift_temp_s3_path
1400+
else:
1401+
temp_s3_path = self._session.athena.create_athena_bucket()
1402+
temp_s3_path = temp_s3_path[:-1] if temp_s3_path[-1] == "/" else temp_s3_path
1403+
temp_s3_path = f"{temp_s3_path}/{name}"
1404+
logger.debug(f"temp_s3_path: {temp_s3_path}")
1405+
paths: Optional[List[str]] = None
1406+
try:
1407+
paths = self._session.redshift.to_parquet(sql=sql,
1408+
path=temp_s3_path,
1409+
iam_role=iam_role,
1410+
connection=connection)
1411+
logger.debug(f"paths: {paths}")
1412+
df: pd.DataFrame = self.read_parquet(path=paths, procs_cpu_bound=procs_cpu_bound) # type: ignore
1413+
self._session.s3.delete_listed_objects(objects_paths=paths)
1414+
return df
1415+
except Exception as e:
1416+
if paths is not None:
1417+
self._session.s3.delete_listed_objects(objects_paths=paths)
1418+
else:
1419+
self._session.s3.delete_objects(path=temp_s3_path)
1420+
raise e

awswrangler/redshift.py

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -351,20 +351,20 @@ def _get_redshift_schema(dataframe, dataframe_type, preserve_index=False, cast_c
351351
def to_parquet(sql: str,
352352
path: str,
353353
iam_role: str,
354-
redshift_conn: Any,
354+
connection: Any,
355355
partition_cols: Optional[List] = None) -> List[str]:
356356
"""
357357
Write a query result as parquet files on S3
358358
359359
:param sql: SQL Query
360360
:param path: AWS S3 path to write the data (e.g. s3://...)
361361
:param iam_role: AWS IAM role with the related permissions
362-
:param redshift_conn: A PEP 249 compatible connection (Can be generated with Redshift.generate_connection())
362+
:param connection: A PEP 249 compatible connection (Can be generated with Redshift.generate_connection())
363363
:param partition_cols: Specifies the partition keys for the unload operation.
364364
"""
365365
sql = sql.replace("'", "\'").replace(";", "") # escaping single quote
366366
path = path if path[-1] == "/" else path + "/"
367-
cursor: Any = redshift_conn.cursor()
367+
cursor: Any = connection.cursor()
368368
partition_str: str = ""
369369
if partition_cols is not None:
370370
partition_str = f"PARTITION BY ({','.join([x for x in partition_cols])})\n"
@@ -389,6 +389,6 @@ def to_parquet(sql: str,
389389
cursor.execute(query)
390390
paths: List[str] = [row[0].replace(" ", "") for row in cursor.fetchall()]
391391
logger.debug(f"paths: {paths}")
392-
redshift_conn.commit()
392+
connection.commit()
393393
cursor.close()
394394
return paths

awswrangler/session.py

Lines changed: 20 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -49,7 +49,8 @@ def __init__(self,
4949
athena_encryption: Optional[str] = "SSE_S3",
5050
athena_kms_key: Optional[str] = None,
5151
athena_database: str = "default",
52-
athena_ctas_approach: bool = False):
52+
athena_ctas_approach: bool = False,
53+
redshift_temp_s3_path: Optional[str] = None):
5354
"""
5455
Most parameters inherit from Boto3 or Pyspark.
5556
https://boto3.amazonaws.com/v1/documentation/api/latest/guide/configuration.html
@@ -73,6 +74,7 @@ def __init__(self,
7374
:param athena_s3_output: AWS S3 path
7475
:param athena_encryption: None|'SSE_S3'|'SSE_KMS'|'CSE_KMS'
7576
:param athena_kms_key: For SSE-KMS and CSE-KMS , this is the KMS key ARN or ID.
77+
:param redshift_temp_s3_path: redshift_temp_s3_path: AWS S3 path to write temporary data (e.g. s3://...)
7678
"""
7779
self._profile_name: Optional[str] = (boto3_session.profile_name if boto3_session else profile_name)
7880
self._aws_access_key_id: Optional[str] = (boto3_session.get_credentials().access_key
@@ -95,6 +97,7 @@ def __init__(self,
9597
self._athena_kms_key: Optional[str] = athena_kms_key
9698
self._athena_database: str = athena_database
9799
self._athena_ctas_approach: bool = athena_ctas_approach
100+
self._redshift_temp_s3_path: Optional[str] = redshift_temp_s3_path
98101
self._primitives = None
99102
self._load_new_primitives()
100103
if boto3_session:
@@ -149,7 +152,8 @@ def _load_new_primitives(self):
149152
athena_encryption=self._athena_encryption,
150153
athena_kms_key=self._athena_kms_key,
151154
athena_database=self._athena_database,
152-
athena_ctas_approach=self._athena_ctas_approach)
155+
athena_ctas_approach=self._athena_ctas_approach,
156+
redshift_temp_s3_path=self._redshift_temp_s3_path)
153157

154158
@property
155159
def profile_name(self):
@@ -223,6 +227,10 @@ def athena_database(self) -> str:
223227
def athena_ctas_approach(self) -> bool:
224228
return self._athena_ctas_approach
225229

230+
@property
231+
def redshift_temp_s3_path(self) -> Optional[str]:
232+
return self._redshift_temp_s3_path
233+
226234
@property
227235
def boto3_session(self):
228236
return self._boto3_session
@@ -304,7 +312,8 @@ def __init__(self,
304312
athena_encryption: Optional[str] = None,
305313
athena_kms_key: Optional[str] = None,
306314
athena_database: Optional[str] = None,
307-
athena_ctas_approach: bool = False):
315+
athena_ctas_approach: bool = False,
316+
redshift_temp_s3_path: Optional[str] = None):
308317
"""
309318
Most parameters inherit from Boto3.
310319
https://boto3.amazonaws.com/v1/documentation/api/latest/guide/configuration.html
@@ -325,6 +334,7 @@ def __init__(self,
325334
:param athena_s3_output: AWS S3 path
326335
:param athena_encryption: None|'SSE_S3'|'SSE_KMS'|'CSE_KMS'
327336
:param athena_kms_key: For SSE-KMS and CSE-KMS , this is the KMS key ARN or ID.
337+
:param redshift_temp_s3_path: AWS S3 path to write temporary data (e.g. s3://...)
328338
"""
329339
self._profile_name: Optional[str] = profile_name
330340
self._aws_access_key_id: Optional[str] = aws_access_key_id
@@ -342,6 +352,7 @@ def __init__(self,
342352
self._athena_kms_key: Optional[str] = athena_kms_key
343353
self._athena_database: Optional[str] = athena_database
344354
self._athena_ctas_approach: bool = athena_ctas_approach
355+
self._redshift_temp_s3_path: Optional[str] = redshift_temp_s3_path
345356

346357
@property
347358
def profile_name(self):
@@ -407,6 +418,10 @@ def athena_database(self) -> Optional[str]:
407418
def athena_ctas_approach(self) -> bool:
408419
return self._athena_ctas_approach
409420

421+
@property
422+
def redshift_temp_s3_path(self) -> Optional[str]:
423+
return self._redshift_temp_s3_path
424+
410425
@property
411426
def session(self):
412427
"""
@@ -427,4 +442,5 @@ def session(self):
427442
athena_encryption=self._athena_encryption,
428443
athena_kms_key=self._athena_kms_key,
429444
athena_database=self._athena_database,
430-
athena_ctas_approach=self._athena_ctas_approach)
445+
athena_ctas_approach=self._athena_ctas_approach,
446+
redshift_temp_s3_path=self._redshift_temp_s3_path)

requirements.txt

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1,8 +1,8 @@
11
numpy~=1.17.4
22
pandas~=0.25.3
33
pyarrow~=0.15.1
4-
botocore~=1.13.36
5-
boto3~=1.10.36
4+
botocore~=1.13.37
5+
boto3~=1.10.37
66
s3fs~=0.4.0
77
tenacity~=6.0.0
88
pg8000~=1.13.2

testing/test_awswrangler/test_redshift.py

Lines changed: 68 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -522,6 +522,73 @@ def test_to_parquet(bucket, redshift_parameters):
522522
paths = Redshift.to_parquet(sql="SELECT * FROM public.test",
523523
path=path,
524524
iam_role=redshift_parameters.get("RedshiftRole"),
525-
redshift_conn=con,
525+
connection=con,
526526
partition_cols=["name"])
527527
assert len(paths) == 20
528+
529+
530+
@pytest.mark.parametrize("sample_name", ["micro", "small", "nano"])
531+
def test_read_sql_redshift_pandas(session, bucket, redshift_parameters, sample_name):
532+
if sample_name == "micro":
533+
dates = ["date"]
534+
elif sample_name == "small":
535+
dates = ["date"]
536+
else:
537+
dates = ["date", "time"]
538+
df = pd.read_csv(f"data_samples/{sample_name}.csv", parse_dates=dates, infer_datetime_format=True)
539+
df["date"] = df["date"].dt.date
540+
con = Redshift.generate_connection(
541+
database="test",
542+
host=redshift_parameters.get("RedshiftAddress"),
543+
port=redshift_parameters.get("RedshiftPort"),
544+
user="test",
545+
password=redshift_parameters.get("RedshiftPassword"),
546+
)
547+
path = f"s3://{bucket}/test_read_sql_redshift_pandas/"
548+
session.pandas.to_redshift(
549+
dataframe=df,
550+
path=path,
551+
schema="public",
552+
table="test",
553+
connection=con,
554+
iam_role=redshift_parameters.get("RedshiftRole"),
555+
mode="overwrite",
556+
preserve_index=True,
557+
)
558+
path2 = f"s3://{bucket}/test_read_sql_redshift_pandas2/"
559+
df2 = session.pandas.read_sql_redshift(sql="select * from public.test",
560+
iam_role=redshift_parameters.get("RedshiftRole"),
561+
connection=con,
562+
temp_s3_path=path2)
563+
assert len(df.index) == len(df2.index)
564+
assert len(df.columns) + 1 == len(df2.columns)
565+
566+
567+
def test_read_sql_redshift_pandas2(session, bucket, redshift_parameters):
568+
n: int = 1_000_000
569+
df = pd.DataFrame({"id": list((range(n))), "val": list(["foo" if i % 2 == 0 else "boo" for i in range(n)])})
570+
con = Redshift.generate_connection(
571+
database="test",
572+
host=redshift_parameters.get("RedshiftAddress"),
573+
port=redshift_parameters.get("RedshiftPort"),
574+
user="test",
575+
password=redshift_parameters.get("RedshiftPassword"),
576+
)
577+
path = f"s3://{bucket}/test_read_sql_redshift_pandas2/"
578+
session.pandas.to_redshift(
579+
dataframe=df,
580+
path=path,
581+
schema="public",
582+
table="test",
583+
connection=con,
584+
iam_role=redshift_parameters.get("RedshiftRole"),
585+
mode="overwrite",
586+
preserve_index=True,
587+
)
588+
path2 = f"s3://{bucket}/test_read_sql_redshift_pandas22/"
589+
df2 = session.pandas.read_sql_redshift(sql="select * from public.test",
590+
iam_role=redshift_parameters.get("RedshiftRole"),
591+
connection=con,
592+
temp_s3_path=path2)
593+
assert len(df.index) == len(df2.index)
594+
assert len(df.columns) + 1 == len(df2.columns)

0 commit comments

Comments
 (0)