Skip to content

Commit 32f7914

Browse files
pwithamskukushking
andauthored
Add basic support for Redshift and RDS Data APIs (#828)
* Add basic outline of redshift data api connector * Add outline of support for pagination * Restructure data_api module and add basic RDS support * Remove unused variable * Add secret ARNs to CDK outputs and databases_parameters fixture * Add RDS resource identifiers to CDK outputs and databases_parameters fixture * Pylint fixes * Add basic outline for full infrastructure data api tests * Fix CDK typo * Add Aurora serverless to test resources to allow RDS Data API testing * Update data api integration tests based on new serverless Aurora instance * Fix issues based on initial integration tests * Finish first working integration tests for RDS and Redshift data api * Refactor code and tests, and add logging * Remove unused mysql CDK exports * Fix documentation style * Remove use of postponed type evaluations Co-authored-by: kukushking <[email protected]>
1 parent ebe6883 commit 32f7914

File tree

9 files changed

+694
-2
lines changed

9 files changed

+694
-2
lines changed

awswrangler/__init__.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -12,6 +12,7 @@
1212
catalog,
1313
chime,
1414
cloudwatch,
15+
data_api,
1516
dynamodb,
1617
emr,
1718
exceptions,
@@ -34,6 +35,7 @@
3435
"chime",
3536
"cloudwatch",
3637
"emr",
38+
"data_api",
3739
"dynamodb",
3840
"exceptions",
3941
"quicksight",

awswrangler/data_api/__init__.py

Lines changed: 7 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,7 @@
1+
"""Data API Service Module for RDS and Redshift."""
2+
from awswrangler.data_api import rds, redshift
3+
4+
__all__ = [
5+
"redshift",
6+
"rds",
7+
]

awswrangler/data_api/connector.py

Lines changed: 71 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,71 @@
1+
"""Data API Connector base class."""
2+
import logging
3+
from typing import Any, Dict, Optional
4+
5+
import pandas as pd
6+
7+
8+
class DataApiConnector:
9+
"""Base class for Data API (RDS, Redshift, etc.) connectors."""
10+
11+
def __init__(self, client: Any, logger: logging.Logger):
12+
self.client = client
13+
self.logger: logging.Logger = logger
14+
15+
def execute(self, sql: str, database: Optional[str] = None) -> pd.DataFrame:
16+
"""Execute SQL statement against a Data API Service.
17+
18+
Parameters
19+
----------
20+
sql: str
21+
SQL statement to execute.
22+
23+
Returns
24+
-------
25+
A Pandas DataFrame containing the execution results.
26+
"""
27+
request_id: str = self._execute_statement(sql, database=database)
28+
return self._get_statement_result(request_id)
29+
30+
def _execute_statement(self, sql: str, database: Optional[str] = None) -> str:
31+
raise NotImplementedError()
32+
33+
def _get_statement_result(self, request_id: str) -> pd.DataFrame:
34+
raise NotImplementedError()
35+
36+
@staticmethod
37+
def _get_column_value(column_value: Dict[str, Any]) -> Any:
38+
"""Return the first non-null key value for a given dictionary.
39+
40+
The key names for a given record depend on the column type: stringValue, longValue, etc.
41+
42+
Therefore, a record in the response does not have consistent key names. The ColumnMetadata
43+
typeName information could be used to infer the key, but there is no direct mapping here
44+
that could be easily parsed with creating a static dictionary:
45+
varchar -> stringValue
46+
int2 -> longValue
47+
timestamp -> stringValue
48+
49+
What has been observed is that each record appears to have a single key, so this function
50+
iterates over the keys and returns the first non-null value. If none are found, None is
51+
returned.
52+
53+
Documentation:
54+
https://boto3.amazonaws.com/v1/documentation/api/latest/reference/services/redshift-data.html#RedshiftDataAPIService.Client.get_statement_result
55+
https://boto3.amazonaws.com/v1/documentation/api/latest/reference/services/rds-data.html#RDSDataService.Client.execute_statement
56+
"""
57+
for key in column_value:
58+
if column_value[key] is not None:
59+
if key == "arrayValue":
60+
raise ValueError(f"arrayValue not supported yet - could not extract {column_value[key]}")
61+
return column_value[key]
62+
return None
63+
64+
65+
class WaitConfig:
66+
"""Holds standard wait configuration values."""
67+
68+
def __init__(self, sleep: float, backoff: float, retries: int) -> None:
69+
self.sleep = sleep
70+
self.backoff = backoff
71+
self.retries = retries

awswrangler/data_api/rds.py

Lines changed: 149 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,149 @@
1+
"""RDS Data API Connector."""
2+
import logging
3+
import time
4+
import uuid
5+
from typing import Any, Dict, List, Optional
6+
7+
import boto3
8+
import pandas as pd
9+
10+
from awswrangler.data_api import connector
11+
12+
13+
class RdsDataApi(connector.DataApiConnector):
14+
"""Provides access to the RDS Data API.
15+
16+
Parameters
17+
----------
18+
resource_arn: str
19+
ARN for the RDS resource.
20+
database: str
21+
Target database name.
22+
secret_arn: str
23+
The ARN for the secret to be used for authentication.
24+
sleep: float
25+
Number of seconds to sleep between connection attempts to paused clusters - defaults to 0.5.
26+
backoff: float
27+
Factor by which to increase the sleep between connection attempts to paused clusters - defaults to 1.0.
28+
retries: int
29+
Maximum number of connection attempts to paused clusters - defaults to 10.
30+
"""
31+
32+
def __init__(
33+
self,
34+
resource_arn: str,
35+
database: str,
36+
secret_arn: str = "",
37+
sleep: float = 0.5,
38+
backoff: float = 1.0,
39+
retries: int = 30,
40+
) -> None:
41+
self.resource_arn = resource_arn
42+
self.database = database
43+
self.secret_arn = secret_arn
44+
self.wait_config = connector.WaitConfig(sleep, backoff, retries)
45+
self.client = boto3.client("rds-data")
46+
self.results: Dict[str, Dict[str, Any]] = {}
47+
logger: logging.Logger = logging.getLogger(__name__)
48+
super().__init__(self.client, logger)
49+
50+
def _execute_statement(self, sql: str, database: Optional[str] = None) -> str:
51+
if database is None:
52+
database = self.database
53+
54+
sleep: float = self.wait_config.sleep
55+
total_tries: int = 0
56+
total_sleep: float = 0
57+
response: Optional[Dict[str, Any]] = None
58+
last_exception: Optional[Exception] = None
59+
while total_tries < self.wait_config.retries:
60+
try:
61+
response = self.client.execute_statement(
62+
resourceArn=self.resource_arn,
63+
database=database,
64+
sql=sql,
65+
secretArn=self.secret_arn,
66+
includeResultMetadata=True,
67+
)
68+
self.logger.debug(
69+
"Response received after %s tries and sleeping for a total of %s seconds", total_tries, total_sleep
70+
)
71+
break
72+
except self.client.exceptions.BadRequestException as exception:
73+
last_exception = exception
74+
total_sleep += sleep
75+
self.logger.debug("BadRequestException occurred: %s", exception)
76+
self.logger.debug(
77+
"Cluster may be paused - sleeping for %s seconds for a total of %s before retrying",
78+
sleep,
79+
total_sleep,
80+
)
81+
time.sleep(sleep)
82+
total_tries += 1
83+
sleep *= self.wait_config.backoff
84+
85+
if response is None:
86+
self.logger.exception("Maximum BadRequestException retries reached for query %s", sql)
87+
raise self.client.exceptions.BadRequestException(
88+
f"Query failed - BadRequestException received after {total_tries} tries and sleeping {total_sleep}s"
89+
) from last_exception
90+
91+
request_id: str = uuid.uuid4().hex
92+
self.results[request_id] = response
93+
return request_id
94+
95+
def _get_statement_result(self, request_id: str) -> pd.DataFrame:
96+
try:
97+
result = self.results.pop(request_id)
98+
except KeyError as exception:
99+
raise KeyError(f"Request {request_id} not found in results {self.results}") from exception
100+
101+
if "records" not in result:
102+
return pd.DataFrame()
103+
104+
rows: List[List[Any]] = []
105+
for record in result["records"]:
106+
row: List[Any] = [connector.DataApiConnector._get_column_value(column) for column in record]
107+
rows.append(row)
108+
109+
column_names: List[str] = [column["name"] for column in result["columnMetadata"]]
110+
dataframe = pd.DataFrame(rows, columns=column_names)
111+
return dataframe
112+
113+
114+
def connect(resource_arn: str, database: str, secret_arn: str = "", **kwargs: Any) -> RdsDataApi:
115+
"""Create a RDS Data API connection.
116+
117+
Parameters
118+
----------
119+
resource_arn: str
120+
ARN for the RDS resource.
121+
database: str
122+
Target database name.
123+
secret_arn: str
124+
The ARN for the secret to be used for authentication.
125+
**kwargs
126+
Any additional kwargs are passed to the underlying RdsDataApi class.
127+
128+
Returns
129+
-------
130+
A RdsDataApi connection instance that can be used with `wr.rds.data_api.read_sql_query`.
131+
"""
132+
return RdsDataApi(resource_arn, database, secret_arn=secret_arn, **kwargs)
133+
134+
135+
def read_sql_query(sql: str, con: RdsDataApi, database: Optional[str] = None) -> pd.DataFrame:
136+
"""Run an SQL query on an RdsDataApi connection and return the result as a dataframe.
137+
138+
Parameters
139+
----------
140+
sql: str
141+
SQL query to run.
142+
database: str
143+
Database to run query on - defaults to the database specified by `con`.
144+
145+
Returns
146+
-------
147+
A Pandas dataframe containing the query results.
148+
"""
149+
return con.execute(sql, database=database)

0 commit comments

Comments
 (0)