Skip to content

Commit 9715c36

Browse files
authored
Merge pull request #81 from awslabs/read_sql_redshift
Add Pandas.read_sql_redshift()
2 parents c99dca2 + f2709a0 commit 9715c36

File tree

6 files changed

+157
-17
lines changed

6 files changed

+157
-17
lines changed

awswrangler/pandas.py

Lines changed: 44 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -523,12 +523,12 @@ def read_sql_athena(self,
523523
:param workgroup: The name of the workgroup in which the query is being started. (By default uses de Session() workgroup)
524524
:param encryption: None|'SSE_S3'|'SSE_KMS'|'CSE_KMS'
525525
:param kms_key: For SSE-KMS and CSE-KMS , this is the KMS key ARN or ID.
526-
:param ctas_approach: Wraps the query with a CTAS (Session's deafult is False)
526+
:param ctas_approach: Wraps the query with a CTAS (Session's default is False)
527527
:param procs_cpu_bound: Number of cores used for CPU bound tasks
528528
:param max_result_size: Max number of bytes on each request to S3 (VALID ONLY FOR ctas_approach=False)
529529
:return: Pandas Dataframe or Iterator of Pandas Dataframes if max_result_size was passed
530530
"""
531-
ctas_approach = ctas_approach if ctas_approach is not None else self._session.ctas_approach if self._session.ctas_approach is not None else False
531+
ctas_approach = ctas_approach if ctas_approach is not None else self._session.athena_ctas_approach if self._session.athena_ctas_approach is not None else False
532532
if ctas_approach is True and max_result_size is not None:
533533
raise InvalidParameters("ctas_approach can't use max_result_size!")
534534
if s3_output is None:
@@ -1376,3 +1376,45 @@ def read_table(self,
13761376
"""
13771377
path: str = self._session.glue.get_table_location(database=database, table=table)
13781378
return self.read_parquet(path=path, columns=columns, filters=filters, procs_cpu_bound=procs_cpu_bound)
1379+
1380+
def read_sql_redshift(self,
1381+
sql: str,
1382+
iam_role: str,
1383+
connection: Any,
1384+
temp_s3_path: Optional[str] = None,
1385+
procs_cpu_bound: Optional[int] = None) -> pd.DataFrame:
1386+
"""
1387+
Convert a query result in a Pandas Dataframe.
1388+
1389+
:param sql: SQL Query
1390+
:param iam_role: AWS IAM role with the related permissions
1391+
:param connection: A PEP 249 compatible connection (Can be generated with Redshift.generate_connection())
1392+
:param temp_s3_path: AWS S3 path to write temporary data (e.g. s3://...) (Default uses the Athena's results bucket)
1393+
:param procs_cpu_bound: Number of cores used for CPU bound tasks
1394+
"""
1395+
guid: str = pa.compat.guid()
1396+
name: str = f"temp_redshift_{guid}"
1397+
if temp_s3_path is None:
1398+
if self._session.athena_s3_output is not None:
1399+
temp_s3_path = self._session.redshift_temp_s3_path
1400+
else:
1401+
temp_s3_path = self._session.athena.create_athena_bucket()
1402+
temp_s3_path = temp_s3_path[:-1] if temp_s3_path[-1] == "/" else temp_s3_path
1403+
temp_s3_path = f"{temp_s3_path}/{name}"
1404+
logger.debug(f"temp_s3_path: {temp_s3_path}")
1405+
paths: Optional[List[str]] = None
1406+
try:
1407+
paths = self._session.redshift.to_parquet(sql=sql,
1408+
path=temp_s3_path,
1409+
iam_role=iam_role,
1410+
connection=connection)
1411+
logger.debug(f"paths: {paths}")
1412+
df: pd.DataFrame = self.read_parquet(path=paths, procs_cpu_bound=procs_cpu_bound) # type: ignore
1413+
self._session.s3.delete_listed_objects(objects_paths=paths)
1414+
return df
1415+
except Exception as e:
1416+
if paths is not None:
1417+
self._session.s3.delete_listed_objects(objects_paths=paths)
1418+
else:
1419+
self._session.s3.delete_objects(path=temp_s3_path)
1420+
raise e

awswrangler/redshift.py

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -351,20 +351,20 @@ def _get_redshift_schema(dataframe, dataframe_type, preserve_index=False, cast_c
351351
def to_parquet(sql: str,
352352
path: str,
353353
iam_role: str,
354-
redshift_conn: Any,
354+
connection: Any,
355355
partition_cols: Optional[List] = None) -> List[str]:
356356
"""
357357
Write a query result as parquet files on S3
358358
359359
:param sql: SQL Query
360360
:param path: AWS S3 path to write the data (e.g. s3://...)
361361
:param iam_role: AWS IAM role with the related permissions
362-
:param redshift_conn: A PEP 249 compatible connection (Can be generated with Redshift.generate_connection())
362+
:param connection: A PEP 249 compatible connection (Can be generated with Redshift.generate_connection())
363363
:param partition_cols: Specifies the partition keys for the unload operation.
364364
"""
365365
sql = sql.replace("'", "\'").replace(";", "") # escaping single quote
366366
path = path if path[-1] == "/" else path + "/"
367-
cursor: Any = redshift_conn.cursor()
367+
cursor: Any = connection.cursor()
368368
partition_str: str = ""
369369
if partition_cols is not None:
370370
partition_str = f"PARTITION BY ({','.join([x for x in partition_cols])})\n"
@@ -389,6 +389,6 @@ def to_parquet(sql: str,
389389
cursor.execute(query)
390390
paths: List[str] = [row[0].replace(" ", "") for row in cursor.fetchall()]
391391
logger.debug(f"paths: {paths}")
392-
redshift_conn.commit()
392+
connection.commit()
393393
cursor.close()
394394
return paths

awswrangler/session.py

Lines changed: 20 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -49,7 +49,8 @@ def __init__(self,
4949
athena_encryption: Optional[str] = "SSE_S3",
5050
athena_kms_key: Optional[str] = None,
5151
athena_database: str = "default",
52-
athena_ctas_approach: bool = False):
52+
athena_ctas_approach: bool = False,
53+
redshift_temp_s3_path: Optional[str] = None):
5354
"""
5455
Most parameters inherit from Boto3 or Pyspark.
5556
https://boto3.amazonaws.com/v1/documentation/api/latest/guide/configuration.html
@@ -73,6 +74,7 @@ def __init__(self,
7374
:param athena_s3_output: AWS S3 path
7475
:param athena_encryption: None|'SSE_S3'|'SSE_KMS'|'CSE_KMS'
7576
:param athena_kms_key: For SSE-KMS and CSE-KMS , this is the KMS key ARN or ID.
77+
:param redshift_temp_s3_path: redshift_temp_s3_path: AWS S3 path to write temporary data (e.g. s3://...)
7678
"""
7779
self._profile_name: Optional[str] = (boto3_session.profile_name if boto3_session else profile_name)
7880
self._aws_access_key_id: Optional[str] = (boto3_session.get_credentials().access_key
@@ -95,6 +97,7 @@ def __init__(self,
9597
self._athena_kms_key: Optional[str] = athena_kms_key
9698
self._athena_database: str = athena_database
9799
self._athena_ctas_approach: bool = athena_ctas_approach
100+
self._redshift_temp_s3_path: Optional[str] = redshift_temp_s3_path
98101
self._primitives = None
99102
self._load_new_primitives()
100103
if boto3_session:
@@ -149,7 +152,8 @@ def _load_new_primitives(self):
149152
athena_encryption=self._athena_encryption,
150153
athena_kms_key=self._athena_kms_key,
151154
athena_database=self._athena_database,
152-
athena_ctas_approach=self._athena_ctas_approach)
155+
athena_ctas_approach=self._athena_ctas_approach,
156+
redshift_temp_s3_path=self._redshift_temp_s3_path)
153157

154158
@property
155159
def profile_name(self):
@@ -223,6 +227,10 @@ def athena_database(self) -> str:
223227
def athena_ctas_approach(self) -> bool:
224228
return self._athena_ctas_approach
225229

230+
@property
231+
def redshift_temp_s3_path(self) -> Optional[str]:
232+
return self._redshift_temp_s3_path
233+
226234
@property
227235
def boto3_session(self):
228236
return self._boto3_session
@@ -304,7 +312,8 @@ def __init__(self,
304312
athena_encryption: Optional[str] = None,
305313
athena_kms_key: Optional[str] = None,
306314
athena_database: Optional[str] = None,
307-
athena_ctas_approach: bool = False):
315+
athena_ctas_approach: bool = False,
316+
redshift_temp_s3_path: Optional[str] = None):
308317
"""
309318
Most parameters inherit from Boto3.
310319
https://boto3.amazonaws.com/v1/documentation/api/latest/guide/configuration.html
@@ -325,6 +334,7 @@ def __init__(self,
325334
:param athena_s3_output: AWS S3 path
326335
:param athena_encryption: None|'SSE_S3'|'SSE_KMS'|'CSE_KMS'
327336
:param athena_kms_key: For SSE-KMS and CSE-KMS , this is the KMS key ARN or ID.
337+
:param redshift_temp_s3_path: AWS S3 path to write temporary data (e.g. s3://...)
328338
"""
329339
self._profile_name: Optional[str] = profile_name
330340
self._aws_access_key_id: Optional[str] = aws_access_key_id
@@ -342,6 +352,7 @@ def __init__(self,
342352
self._athena_kms_key: Optional[str] = athena_kms_key
343353
self._athena_database: Optional[str] = athena_database
344354
self._athena_ctas_approach: bool = athena_ctas_approach
355+
self._redshift_temp_s3_path: Optional[str] = redshift_temp_s3_path
345356

346357
@property
347358
def profile_name(self):
@@ -407,6 +418,10 @@ def athena_database(self) -> Optional[str]:
407418
def athena_ctas_approach(self) -> bool:
408419
return self._athena_ctas_approach
409420

421+
@property
422+
def redshift_temp_s3_path(self) -> Optional[str]:
423+
return self._redshift_temp_s3_path
424+
410425
@property
411426
def session(self):
412427
"""
@@ -427,4 +442,5 @@ def session(self):
427442
athena_encryption=self._athena_encryption,
428443
athena_kms_key=self._athena_kms_key,
429444
athena_database=self._athena_database,
430-
athena_ctas_approach=self._athena_ctas_approach)
445+
athena_ctas_approach=self._athena_ctas_approach,
446+
redshift_temp_s3_path=self._redshift_temp_s3_path)

requirements.txt

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1,8 +1,8 @@
11
numpy~=1.17.4
22
pandas~=0.25.3
33
pyarrow~=0.15.1
4-
botocore~=1.13.36
5-
boto3~=1.10.36
4+
botocore~=1.13.37
5+
boto3~=1.10.37
66
s3fs~=0.4.0
77
tenacity~=6.0.0
88
pg8000~=1.13.2

testing/test_awswrangler/test_pandas.py

Lines changed: 4 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1442,6 +1442,7 @@ def test_read_table(session, bucket, database):
14421442
preserve_index=False,
14431443
procs_cpu_bound=1)
14441444
df2 = session.pandas.read_table(database=database, table="test")
1445+
session.s3.delete_objects(path=path)
14451446
assert len(list(df.columns)) == len(list(df2.columns))
14461447
assert len(df.index) == len(df2.index)
14471448

@@ -1465,7 +1466,7 @@ def test_read_table2(session, bucket, database):
14651466
3)]],
14661467
"partition": [0, 0, 1]
14671468
})
1468-
path = f"s3://{bucket}/test_read_table/"
1469+
path = f"s3://{bucket}/test_read_table2/"
14691470
session.pandas.to_parquet(dataframe=df,
14701471
database=database,
14711472
table="test",
@@ -1474,8 +1475,9 @@ def test_read_table2(session, bucket, database):
14741475
preserve_index=False,
14751476
procs_cpu_bound=4,
14761477
partition_cols=["partition"])
1477-
sleep(5)
1478+
sleep(15)
14781479
df2 = session.pandas.read_table(database=database, table="test")
1480+
session.s3.delete_objects(path=path)
14791481
assert len(list(df.columns)) == len(list(df2.columns))
14801482
assert len(df.index) == len(df2.index)
14811483

testing/test_awswrangler/test_redshift.py

Lines changed: 83 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -510,7 +510,9 @@ def test_to_redshift_spark_decimal(session, bucket, redshift_parameters):
510510
assert row[2] == Decimal((0, (1, 9, 0, 0, 0, 0), -5))
511511

512512

513-
def test_to_parquet(bucket, redshift_parameters):
513+
def test_to_parquet(session, bucket, redshift_parameters):
514+
n: int = 1_000_000
515+
df = pd.DataFrame({"id": list((range(n))), "name": list(["foo" if i % 2 == 0 else "boo" for i in range(n)])})
514516
con = Redshift.generate_connection(
515517
database="test",
516518
host=redshift_parameters.get("RedshiftAddress"),
@@ -519,9 +521,87 @@ def test_to_parquet(bucket, redshift_parameters):
519521
password=redshift_parameters.get("RedshiftPassword"),
520522
)
521523
path = f"s3://{bucket}/test_to_parquet/"
524+
session.pandas.to_redshift(
525+
dataframe=df,
526+
path=path,
527+
schema="public",
528+
table="test",
529+
connection=con,
530+
iam_role=redshift_parameters.get("RedshiftRole"),
531+
mode="overwrite",
532+
preserve_index=True,
533+
)
534+
path = f"s3://{bucket}/test_to_parquet2/"
522535
paths = Redshift.to_parquet(sql="SELECT * FROM public.test",
523536
path=path,
524537
iam_role=redshift_parameters.get("RedshiftRole"),
525-
redshift_conn=con,
538+
connection=con,
526539
partition_cols=["name"])
527-
assert len(paths) == 20
540+
assert len(paths) == 4
541+
542+
543+
@pytest.mark.parametrize("sample_name", ["micro", "small", "nano"])
544+
def test_read_sql_redshift_pandas(session, bucket, redshift_parameters, sample_name):
545+
if sample_name == "micro":
546+
dates = ["date"]
547+
elif sample_name == "small":
548+
dates = ["date"]
549+
else:
550+
dates = ["date", "time"]
551+
df = pd.read_csv(f"data_samples/{sample_name}.csv", parse_dates=dates, infer_datetime_format=True)
552+
df["date"] = df["date"].dt.date
553+
con = Redshift.generate_connection(
554+
database="test",
555+
host=redshift_parameters.get("RedshiftAddress"),
556+
port=redshift_parameters.get("RedshiftPort"),
557+
user="test",
558+
password=redshift_parameters.get("RedshiftPassword"),
559+
)
560+
path = f"s3://{bucket}/test_read_sql_redshift_pandas/"
561+
session.pandas.to_redshift(
562+
dataframe=df,
563+
path=path,
564+
schema="public",
565+
table="test",
566+
connection=con,
567+
iam_role=redshift_parameters.get("RedshiftRole"),
568+
mode="overwrite",
569+
preserve_index=True,
570+
)
571+
path2 = f"s3://{bucket}/test_read_sql_redshift_pandas2/"
572+
df2 = session.pandas.read_sql_redshift(sql="select * from public.test",
573+
iam_role=redshift_parameters.get("RedshiftRole"),
574+
connection=con,
575+
temp_s3_path=path2)
576+
assert len(df.index) == len(df2.index)
577+
assert len(df.columns) + 1 == len(df2.columns)
578+
579+
580+
def test_read_sql_redshift_pandas2(session, bucket, redshift_parameters):
581+
n: int = 1_000_000
582+
df = pd.DataFrame({"id": list((range(n))), "val": list(["foo" if i % 2 == 0 else "boo" for i in range(n)])})
583+
con = Redshift.generate_connection(
584+
database="test",
585+
host=redshift_parameters.get("RedshiftAddress"),
586+
port=redshift_parameters.get("RedshiftPort"),
587+
user="test",
588+
password=redshift_parameters.get("RedshiftPassword"),
589+
)
590+
path = f"s3://{bucket}/test_read_sql_redshift_pandas2/"
591+
session.pandas.to_redshift(
592+
dataframe=df,
593+
path=path,
594+
schema="public",
595+
table="test",
596+
connection=con,
597+
iam_role=redshift_parameters.get("RedshiftRole"),
598+
mode="overwrite",
599+
preserve_index=True,
600+
)
601+
path2 = f"s3://{bucket}/test_read_sql_redshift_pandas22/"
602+
df2 = session.pandas.read_sql_redshift(sql="select * from public.test",
603+
iam_role=redshift_parameters.get("RedshiftRole"),
604+
connection=con,
605+
temp_s3_path=path2)
606+
assert len(df.index) == len(df2.index)
607+
assert len(df.columns) + 1 == len(df2.columns)

0 commit comments

Comments
 (0)