Skip to content

Commit 32b27e6

Browse files
committed
Add generate_connection() to Aurora
1 parent 434a0c8 commit 32b27e6

17 files changed

+272
-46
lines changed

awswrangler/aurora.py

Lines changed: 95 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,8 +1,103 @@
1+
from typing import Union
12
import logging
23

4+
import pg8000 # type: ignore
5+
import pymysql # type: ignore
6+
7+
from awswrangler.exceptions import InvalidEngine
8+
39
logger = logging.getLogger(__name__)
410

511

612
class Aurora:
713
def __init__(self, session):
814
self._session = session
15+
16+
@staticmethod
17+
def _validate_connection(database: str,
18+
host: str,
19+
port: Union[str, int],
20+
user: str,
21+
password: str,
22+
engine: str = "mysql",
23+
tcp_keepalive: bool = True,
24+
application_name: str = "aws-data-wrangler-validation",
25+
validation_timeout: int = 5) -> None:
26+
if "postgres" in engine.lower():
27+
conn = pg8000.connect(database=database,
28+
host=host,
29+
port=int(port),
30+
user=user,
31+
password=password,
32+
ssl=True,
33+
application_name=application_name,
34+
tcp_keepalive=tcp_keepalive,
35+
timeout=validation_timeout)
36+
elif "mysql" in engine.lower():
37+
conn = pymysql.connect(database=database,
38+
host=host,
39+
port=int(port),
40+
user=user,
41+
password=password,
42+
program_name=application_name,
43+
connect_timeout=validation_timeout)
44+
else:
45+
raise InvalidEngine(f"{engine} is not a valid engine. Please use 'mysql' or 'postgres'!")
46+
conn.close()
47+
48+
@staticmethod
49+
def generate_connection(database: str,
50+
host: str,
51+
port: Union[str, int],
52+
user: str,
53+
password: str,
54+
engine: str = "mysql",
55+
tcp_keepalive: bool = True,
56+
application_name: str = "aws-data-wrangler",
57+
connection_timeout: int = 1_200_000,
58+
validation_timeout: int = 5):
59+
"""
60+
Generates a valid connection object
61+
62+
:param database: The name of the database instance to connect with.
63+
:param host: The hostname of the Aurora server to connect with.
64+
:param port: The TCP/IP port of the Aurora server instance.
65+
:param user: The username to connect to the Aurora database with.
66+
:param password: The user password to connect to the server with.
67+
:param engine: "mysql" or "postgres"
68+
:param tcp_keepalive: If True then use TCP keepalive
69+
:param application_name: Application name
70+
:param connection_timeout: Connection Timeout
71+
:param validation_timeout: Timeout to try to validate the connection
72+
:return: PEP 249 compatible connection
73+
"""
74+
Aurora._validate_connection(database=database,
75+
host=host,
76+
port=port,
77+
user=user,
78+
password=password,
79+
engine=engine,
80+
tcp_keepalive=tcp_keepalive,
81+
application_name=application_name,
82+
validation_timeout=validation_timeout)
83+
if "postgres" in engine.lower():
84+
conn = pg8000.connect(database=database,
85+
host=host,
86+
port=int(port),
87+
user=user,
88+
password=password,
89+
ssl=True,
90+
application_name=application_name,
91+
tcp_keepalive=tcp_keepalive,
92+
timeout=connection_timeout)
93+
elif "mysql" in engine.lower():
94+
conn = pymysql.connect(database=database,
95+
host=host,
96+
port=int(port),
97+
user=user,
98+
password=password,
99+
program_name=application_name,
100+
connect_timeout=connection_timeout)
101+
else:
102+
raise InvalidEngine(f"{engine} is not a valid engine. Please use 'mysql' or 'postgres'!")
103+
return conn

awswrangler/exceptions.py

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -96,3 +96,7 @@ class InvalidParameters(Exception):
9696

9797
class AWSCredentialsNotFound(Exception):
9898
pass
99+
100+
101+
class InvalidEngine(Exception):
102+
pass

awswrangler/pandas.py

Lines changed: 40 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1433,3 +1433,43 @@ def read_sql_redshift(self,
14331433
else:
14341434
self._session.s3.delete_objects(path=temp_s3_path)
14351435
raise e
1436+
1437+
def read_sql_aurora(self,
1438+
sql: str,
1439+
iam_role: str,
1440+
connection: Any,
1441+
temp_s3_path: Optional[str] = None) -> pd.DataFrame:
1442+
"""
1443+
Convert a query result in a Pandas Dataframe.
1444+
1445+
:param sql: SQL Query
1446+
:param iam_role: AWS IAM role with the related permissions
1447+
:param connection: A PEP 249 compatible connection (Can be generated with Aurora.generate_connection())
1448+
:param temp_s3_path: AWS S3 path to write temporary data (e.g. s3://...) (Default uses the Athena's results bucket)
1449+
"""
1450+
guid: str = pa.compat.guid()
1451+
name: str = f"temp_redshift_{guid}"
1452+
if temp_s3_path is None:
1453+
if self._session.athena_s3_output is not None:
1454+
temp_s3_path = self._session.redshift_temp_s3_path
1455+
else:
1456+
temp_s3_path = self._session.athena.create_athena_bucket()
1457+
temp_s3_path = temp_s3_path[:-1] if temp_s3_path[-1] == "/" else temp_s3_path
1458+
temp_s3_path = f"{temp_s3_path}/{name}"
1459+
logger.debug(f"temp_s3_path: {temp_s3_path}")
1460+
paths: Optional[List[str]] = None
1461+
try:
1462+
paths = self._session.redshift.to_parquet(sql=sql,
1463+
path=temp_s3_path,
1464+
iam_role=iam_role,
1465+
connection=connection)
1466+
logger.debug(f"paths: {paths}")
1467+
df: pd.DataFrame = self.read_parquet(path=paths) # type: ignore
1468+
self._session.s3.delete_listed_objects(objects_paths=paths)
1469+
return df
1470+
except Exception as e:
1471+
if paths is not None:
1472+
self._session.s3.delete_listed_objects(objects_paths=paths)
1473+
else:
1474+
self._session.s3.delete_objects(path=temp_s3_path)
1475+
raise e

awswrangler/redshift.py

Lines changed: 10 additions & 15 deletions
Original file line numberDiff line numberDiff line change
@@ -38,19 +38,16 @@ def _validate_connection(database,
3838
tcp_keepalive=True,
3939
application_name="aws-data-wrangler-validation",
4040
validation_timeout=5):
41-
try:
42-
conn = pg8000.connect(database=database,
43-
host=host,
44-
port=int(port),
45-
user=user,
46-
password=password,
47-
ssl=True,
48-
application_name=application_name,
49-
tcp_keepalive=tcp_keepalive,
50-
timeout=validation_timeout)
51-
conn.close()
52-
except pg8000.core.InterfaceError as e:
53-
raise e
41+
conn = pg8000.connect(database=database,
42+
host=host,
43+
port=int(port),
44+
user=user,
45+
password=password,
46+
ssl=True,
47+
application_name=application_name,
48+
tcp_keepalive=tcp_keepalive,
49+
timeout=validation_timeout)
50+
conn.close()
5451

5552
@staticmethod
5653
def generate_connection(database,
@@ -86,8 +83,6 @@ def generate_connection(database,
8683
tcp_keepalive=tcp_keepalive,
8784
application_name=application_name,
8885
validation_timeout=validation_timeout)
89-
if isinstance(type(port), str) or isinstance(type(port), float):
90-
port = int(port)
9186
conn = pg8000.connect(database=database,
9287
host=host,
9388
port=int(port),

awswrangler/sagemaker.py

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -12,7 +12,9 @@ class SageMaker:
1212
def __init__(self, session):
1313
self._session = session
1414
self._client_s3 = session.boto3_session.client(service_name="s3", use_ssl=True, config=session.botocore_config)
15-
self._client_sagemaker = session.boto3_session.client(service_name="sagemaker", use_ssl=True, config=session.botocore_config)
15+
self._client_sagemaker = session.boto3_session.client(service_name="sagemaker",
16+
use_ssl=True,
17+
config=session.botocore_config)
1618

1719
@staticmethod
1820
def _parse_path(path):

testing/parameters.properties

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,5 @@
11
VpcId=VPC_ID
22
SubnetId=SUBNET_ID
3+
SubnetId2=SUBNET_ID2
34
Password=REDSHIFT_PASSWORD
45
TestUser=AWS_USER_THAT_WILL_RUN_THE_TESTS_ON_CLI

testing/template.yaml

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -234,9 +234,9 @@ Outputs:
234234
RedshiftPort:
235235
Value: !GetAtt Redshift.Endpoint.Port
236236
Description: Redshift Endpoint Port.
237-
RedshiftPassword:
237+
Password:
238238
Value: !Ref Password
239-
Description: Redshift Password.
239+
Description: Password.
240240
RedshiftRole:
241241
Value: !GetAtt RedshiftRole.Arn
242242
Description: Redshift IAM role.

testing/test_awswrangler/test_athena.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -15,7 +15,7 @@
1515

1616
@pytest.fixture(scope="module")
1717
def cloudformation_outputs():
18-
response = boto3.client("cloudformation").describe_stacks(StackName="aws-data-wrangler-test-arena")
18+
response = boto3.client("cloudformation").describe_stacks(StackName="aws-data-wrangler-test")
1919
outputs = {}
2020
for output in response.get("Stacks")[0].get("Outputs"):
2121
outputs[output.get("OutputKey")] = output.get("OutputValue")
Lines changed: 89 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,89 @@
1+
import logging
2+
3+
import pytest
4+
import boto3
5+
6+
from awswrangler import Aurora
7+
from awswrangler.exceptions import InvalidEngine
8+
9+
logging.basicConfig(level=logging.INFO, format="[%(asctime)s][%(levelname)s][%(name)s][%(funcName)s] %(message)s")
10+
logging.getLogger("awswrangler").setLevel(logging.DEBUG)
11+
12+
13+
@pytest.fixture(scope="module")
14+
def cloudformation_outputs():
15+
response = boto3.client("cloudformation").describe_stacks(StackName="aws-data-wrangler-test")
16+
outputs = {}
17+
for output in response.get("Stacks")[0].get("Outputs"):
18+
outputs[output.get("OutputKey")] = output.get("OutputValue")
19+
yield outputs
20+
21+
22+
@pytest.fixture(scope="module")
23+
def postgres_parameters(cloudformation_outputs):
24+
postgres_parameters = {}
25+
if "PostgresAddress" in cloudformation_outputs:
26+
postgres_parameters["PostgresAddress"] = cloudformation_outputs.get("PostgresAddress")
27+
else:
28+
raise Exception("You must deploy the test infrastructure using SAM!")
29+
if "Password" in cloudformation_outputs:
30+
postgres_parameters["Password"] = cloudformation_outputs.get("Password")
31+
else:
32+
raise Exception("You must deploy the test infrastructure using SAM!")
33+
yield postgres_parameters
34+
35+
36+
@pytest.fixture(scope="module")
37+
def mysql_parameters(cloudformation_outputs):
38+
mysql_parameters = {}
39+
if "MysqlAddress" in cloudformation_outputs:
40+
mysql_parameters["MysqlAddress"] = cloudformation_outputs.get("MysqlAddress")
41+
else:
42+
raise Exception("You must deploy the test infrastructure using SAM!")
43+
if "Password" in cloudformation_outputs:
44+
mysql_parameters["Password"] = cloudformation_outputs.get("Password")
45+
else:
46+
raise Exception("You must deploy the test infrastructure using SAM!")
47+
yield mysql_parameters
48+
49+
50+
def test_postgres_connection(postgres_parameters):
51+
conn = Aurora.generate_connection(database="postgres",
52+
host=postgres_parameters["PostgresAddress"],
53+
port=3306,
54+
user="test",
55+
password=postgres_parameters["Password"],
56+
engine="postgres")
57+
cursor = conn.cursor()
58+
cursor.execute("SELECT 1 + 2, 3 + 4")
59+
first_row = cursor.fetchall()[0]
60+
assert first_row[0] == 3
61+
assert first_row[1] == 7
62+
cursor.close()
63+
conn.close()
64+
65+
66+
def test_mysql_connection(mysql_parameters):
67+
conn = Aurora.generate_connection(database="mysql",
68+
host=mysql_parameters["MysqlAddress"],
69+
port=3306,
70+
user="test",
71+
password=mysql_parameters["Password"],
72+
engine="mysql")
73+
cursor = conn.cursor()
74+
cursor.execute("SELECT 1 + 2, 3 + 4")
75+
first_row = cursor.fetchall()[0]
76+
assert first_row[0] == 3
77+
assert first_row[1] == 7
78+
cursor.close()
79+
conn.close()
80+
81+
82+
def test_invalid_engine(mysql_parameters):
83+
with pytest.raises(InvalidEngine):
84+
Aurora.generate_connection(database="mysql",
85+
host=mysql_parameters["MysqlAddress"],
86+
port=3306,
87+
user="test",
88+
password=mysql_parameters["Password"],
89+
engine="foo")

testing/test_awswrangler/test_cloudwatchlogs.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -13,7 +13,7 @@
1313

1414
@pytest.fixture(scope="module")
1515
def cloudformation_outputs():
16-
response = boto3.client("cloudformation").describe_stacks(StackName="aws-data-wrangler-test-arena")
16+
response = boto3.client("cloudformation").describe_stacks(StackName="aws-data-wrangler-test")
1717
outputs = {}
1818
for output in response.get("Stacks")[0].get("Outputs"):
1919
outputs[output.get("OutputKey")] = output.get("OutputValue")

0 commit comments

Comments
 (0)