Skip to content

Commit 027d3dd

Browse files
authored
Merge pull request #91 from awslabs/aurora
Aurora
2 parents 584bc15 + 80b4653 commit 027d3dd

29 files changed

+1423
-270
lines changed

README.md

Lines changed: 48 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -27,13 +27,14 @@
2727
* Pandas -> Glue Catalog Table
2828
* Pandas -> Athena (Parallel)
2929
* Pandas -> Redshift (Append/Overwrite/Upsert) (Parallel)
30-
* Parquet (S3) -> Pandas (Parallel) (NEW :star:)
30+
* Pandas -> Aurora (MySQL/PostgreSQL) (Append/Overwrite) (Via S3) (NEW :star:)
31+
* Parquet (S3) -> Pandas (Parallel)
3132
* CSV (S3) -> Pandas (One shot or Batching)
32-
* Glue Catalog Table -> Pandas (Parallel) (NEW :star:)
33-
* Athena -> Pandas (One shot, Batching or Parallel (NEW :star:))
34-
* Redshift -> Pandas (Parallel) (NEW :star:)
35-
* Redshift -> Parquet (S3) (NEW :star:)
33+
* Glue Catalog Table -> Pandas (Parallel)
34+
* Athena -> Pandas (One shot, Batching or Parallel)
35+
* Redshift -> Pandas (Parallel)
3636
* CloudWatch Logs Insights -> Pandas
37+
* Aurora -> Pandas (MySQL) (Via S3) (NEW :star:)
3738
* Encrypt Pandas Dataframes on S3 with KMS keys
3839

3940
### PySpark
@@ -60,6 +61,8 @@
6061
* Get EMR step state
6162
* Athena query to receive the result as python primitives (*Iterable[Dict[str, Any]*)
6263
* Load and Unzip SageMaker jobs outputs
64+
* Redshift -> Parquet (S3)
65+
* Aurora -> CSV (S3) (MySQL) (NEW :star:)
6366

6467
## Installation
6568

@@ -147,6 +150,22 @@ df = sess.pandas.read_sql_athena(
147150
)
148151
```
149152

153+
#### Reading from Glue Catalog (Parquet) to Pandas
154+
155+
```py3
156+
import awswrangler as wr
157+
158+
df = wr.pandas.read_table(database="DATABASE_NAME", table="TABLE_NAME")
159+
```
160+
161+
#### Reading from S3 (Parquet) to Pandas
162+
163+
```py3
164+
import awswrangler as wr
165+
166+
df = wr.pandas.read_parquet(path="s3://...", columns=["c1", "c3"], filters=[("c5", "=", 0)])
167+
```
168+
150169
#### Reading from S3 (CSV) to Pandas
151170

152171
```py3
@@ -227,6 +246,30 @@ df = wr.pandas.read_sql_redshift(
227246
temp_s3_path="s3://temp_path")
228247
```
229248

249+
#### Loading Pandas Dataframe to Aurora (MySQL/PostgreSQL)
250+
251+
```py3
252+
import awswrangler as wr
253+
254+
wr.pandas.to_aurora(
255+
dataframe=df,
256+
connection=con,
257+
schema="...",
258+
table="..."
259+
)
260+
```
261+
262+
#### Extract Aurora query to Pandas DataFrame (MySQL)
263+
264+
```py3
265+
import awswrangler as wr
266+
267+
df = wr.pandas.read_sql_aurora(
268+
sql="SELECT ...",
269+
connection=con
270+
)
271+
```
272+
230273
### PySpark
231274

232275
#### Loading PySpark Dataframe to Redshift

awswrangler/__init__.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -9,6 +9,7 @@
99
from awswrangler.cloudwatchlogs import CloudWatchLogs # noqa
1010
from awswrangler.glue import Glue # noqa
1111
from awswrangler.redshift import Redshift # noqa
12+
from awswrangler.aurora import Aurora # noqa
1213
from awswrangler.emr import EMR # noqa
1314
from awswrangler.sagemaker import SageMaker # noqa
1415
import awswrangler.utils # noqa
@@ -38,6 +39,7 @@ def __getattr__(self, name):
3839
pandas = DynamicInstantiate("pandas")
3940
athena = DynamicInstantiate("athena")
4041
redshift = DynamicInstantiate("redshift")
42+
aurora = DynamicInstantiate("aurora")
4143
sagemaker = DynamicInstantiate("sagemaker")
4244
cloudwatchlogs = DynamicInstantiate("cloudwatchlogs")
4345

awswrangler/aurora.py

Lines changed: 297 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,297 @@
1+
from typing import Union, List, Dict, Tuple, Any
2+
import logging
3+
import json
4+
5+
import pg8000 # type: ignore
6+
import pymysql # type: ignore
7+
import pandas as pd # type: ignore
8+
9+
from awswrangler import data_types
10+
from awswrangler.exceptions import InvalidEngine, InvalidDataframeType, AuroraLoadError
11+
12+
logger = logging.getLogger(__name__)
13+
14+
15+
class Aurora:
16+
def __init__(self, session):
17+
self._session = session
18+
self._client_s3 = session.boto3_session.client(service_name="s3", use_ssl=True, config=session.botocore_config)
19+
20+
@staticmethod
21+
def _validate_connection(database: str,
22+
host: str,
23+
port: Union[str, int],
24+
user: str,
25+
password: str,
26+
engine: str = "mysql",
27+
tcp_keepalive: bool = True,
28+
application_name: str = "aws-data-wrangler-validation",
29+
validation_timeout: int = 5) -> None:
30+
if "postgres" in engine.lower():
31+
conn = pg8000.connect(database=database,
32+
host=host,
33+
port=int(port),
34+
user=user,
35+
password=password,
36+
ssl=True,
37+
application_name=application_name,
38+
tcp_keepalive=tcp_keepalive,
39+
timeout=validation_timeout)
40+
elif "mysql" in engine.lower():
41+
conn = pymysql.connect(database=database,
42+
host=host,
43+
port=int(port),
44+
user=user,
45+
password=password,
46+
program_name=application_name,
47+
connect_timeout=validation_timeout)
48+
else:
49+
raise InvalidEngine(f"{engine} is not a valid engine. Please use 'mysql' or 'postgres'!")
50+
conn.close()
51+
52+
@staticmethod
53+
def generate_connection(database: str,
54+
host: str,
55+
port: Union[str, int],
56+
user: str,
57+
password: str,
58+
engine: str = "mysql",
59+
tcp_keepalive: bool = True,
60+
application_name: str = "aws-data-wrangler",
61+
connection_timeout: int = 1_200_000,
62+
validation_timeout: int = 5):
63+
"""
64+
Generates a valid connection object
65+
66+
:param database: The name of the database instance to connect with.
67+
:param host: The hostname of the Aurora server to connect with.
68+
:param port: The TCP/IP port of the Aurora server instance.
69+
:param user: The username to connect to the Aurora database with.
70+
:param password: The user password to connect to the server with.
71+
:param engine: "mysql" or "postgres"
72+
:param tcp_keepalive: If True then use TCP keepalive
73+
:param application_name: Application name
74+
:param connection_timeout: Connection Timeout
75+
:param validation_timeout: Timeout to try to validate the connection
76+
:return: PEP 249 compatible connection
77+
"""
78+
Aurora._validate_connection(database=database,
79+
host=host,
80+
port=port,
81+
user=user,
82+
password=password,
83+
engine=engine,
84+
tcp_keepalive=tcp_keepalive,
85+
application_name=application_name,
86+
validation_timeout=validation_timeout)
87+
if "postgres" in engine.lower():
88+
conn = pg8000.connect(database=database,
89+
host=host,
90+
port=int(port),
91+
user=user,
92+
password=password,
93+
ssl=True,
94+
application_name=application_name,
95+
tcp_keepalive=tcp_keepalive,
96+
timeout=connection_timeout)
97+
elif "mysql" in engine.lower():
98+
conn = pymysql.connect(database=database,
99+
host=host,
100+
port=int(port),
101+
user=user,
102+
password=password,
103+
program_name=application_name,
104+
connect_timeout=connection_timeout)
105+
else:
106+
raise InvalidEngine(f"{engine} is not a valid engine. Please use 'mysql' or 'postgres'!")
107+
return conn
108+
109+
def write_load_manifest(self, manifest_path: str,
110+
objects_paths: List[str]) -> Dict[str, List[Dict[str, Union[str, bool]]]]:
111+
manifest: Dict[str, List[Dict[str, Union[str, bool]]]] = {"entries": []}
112+
path: str
113+
for path in objects_paths:
114+
entry: Dict[str, Union[str, bool]] = {"url": path, "mandatory": True}
115+
manifest["entries"].append(entry)
116+
payload: str = json.dumps(manifest)
117+
bucket: str
118+
bucket, path = manifest_path.replace("s3://", "").split("/", 1)
119+
logger.info(f"payload: {payload}")
120+
self._client_s3.put_object(Body=payload, Bucket=bucket, Key=path)
121+
return manifest
122+
123+
@staticmethod
124+
def load_table(dataframe: pd.DataFrame,
125+
dataframe_type: str,
126+
load_paths: List[str],
127+
schema_name: str,
128+
table_name: str,
129+
connection: Any,
130+
num_files,
131+
mode: str = "append",
132+
preserve_index: bool = False,
133+
engine: str = "mysql",
134+
region: str = "us-east-1"):
135+
"""
136+
Load text/CSV files into a Aurora table using a manifest file.
137+
Creates the table if necessary.
138+
139+
:param dataframe: Pandas or Spark Dataframe
140+
:param dataframe_type: "pandas" or "spark"
141+
:param load_paths: S3 paths to be loaded (E.g. S3://...)
142+
:param schema_name: Aurora schema
143+
:param table_name: Aurora table name
144+
:param connection: A PEP 249 compatible connection (Can be generated with Aurora.generate_connection())
145+
:param num_files: Number of files to be loaded
146+
:param mode: append or overwrite
147+
:param preserve_index: Should we preserve the Dataframe index? (ONLY for Pandas Dataframe)
148+
:param engine: "mysql" or "postgres"
149+
:param region: AWS S3 bucket region (Required only for postgres engine)
150+
:return: None
151+
"""
152+
with connection.cursor() as cursor:
153+
if mode == "overwrite":
154+
Aurora._create_table(cursor=cursor,
155+
dataframe=dataframe,
156+
dataframe_type=dataframe_type,
157+
schema_name=schema_name,
158+
table_name=table_name,
159+
preserve_index=preserve_index,
160+
engine=engine)
161+
162+
for path in load_paths:
163+
sql = Aurora._get_load_sql(path=path,
164+
schema_name=schema_name,
165+
table_name=table_name,
166+
engine=engine,
167+
region=region)
168+
logger.debug(sql)
169+
cursor.execute(sql)
170+
171+
if "mysql" in engine.lower():
172+
sql = ("-- AWS DATA WRANGLER\n"
173+
f"SELECT COUNT(*) as num_files_loaded FROM mysql.aurora_s3_load_history "
174+
f"WHERE load_prefix = '{path}'")
175+
logger.debug(sql)
176+
cursor.execute(sql)
177+
num_files_loaded = cursor.fetchall()[0][0]
178+
if num_files_loaded != (num_files + 1):
179+
connection.rollback()
180+
raise AuroraLoadError(
181+
f"Aurora load rolled back. {num_files_loaded} files counted. {num_files} expected.")
182+
183+
connection.commit()
184+
logger.debug("Load committed.")
185+
186+
@staticmethod
187+
def _parse_path(path):
188+
path2 = path.replace("s3://", "")
189+
parts = path2.partition("/")
190+
return parts[0], parts[2]
191+
192+
@staticmethod
193+
def _get_load_sql(path: str, schema_name: str, table_name: str, engine: str, region: str = "us-east-1") -> str:
194+
if "postgres" in engine.lower():
195+
bucket, key = Aurora._parse_path(path=path)
196+
sql: str = ("-- AWS DATA WRANGLER\n"
197+
"SELECT aws_s3.table_import_from_s3(\n"
198+
f"'{schema_name}.{table_name}',\n"
199+
"'',\n"
200+
"'(FORMAT CSV, DELIMITER '','', QUOTE ''\"'', ESCAPE ''\\'')',\n"
201+
f"'({bucket},{key},{region})')")
202+
elif "mysql" in engine.lower():
203+
sql = ("-- AWS DATA WRANGLER\n"
204+
f"LOAD DATA FROM S3 MANIFEST '{path}'\n"
205+
"REPLACE\n"
206+
f"INTO TABLE {schema_name}.{table_name}\n"
207+
"FIELDS TERMINATED BY ',' OPTIONALLY ENCLOSED BY '\"' ESCAPED BY '\\\\'\n"
208+
"LINES TERMINATED BY '\\n'")
209+
else:
210+
raise InvalidEngine(f"{engine} is not a valid engine. Please use 'mysql' or 'postgres'!")
211+
return sql
212+
213+
@staticmethod
214+
def _create_table(cursor,
215+
dataframe,
216+
dataframe_type,
217+
schema_name,
218+
table_name,
219+
preserve_index=False,
220+
engine: str = "mysql"):
221+
"""
222+
Creates Aurora table.
223+
224+
:param cursor: A PEP 249 compatible cursor
225+
:param dataframe: Pandas or Spark Dataframe
226+
:param dataframe_type: "pandas" or "spark"
227+
:param schema_name: Redshift schema
228+
:param table_name: Redshift table name
229+
:param preserve_index: Should we preserve the Dataframe index? (ONLY for Pandas Dataframe)
230+
:param engine: "mysql" or "postgres"
231+
:return: None
232+
"""
233+
sql: str = f"-- AWS DATA WRANGLER\n" \
234+
f"DROP TABLE IF EXISTS {schema_name}.{table_name}"
235+
logger.debug(f"Drop table query:\n{sql}")
236+
cursor.execute(sql)
237+
schema = Aurora._get_schema(dataframe=dataframe,
238+
dataframe_type=dataframe_type,
239+
preserve_index=preserve_index,
240+
engine=engine)
241+
cols_str: str = "".join([f"{col[0]} {col[1]},\n" for col in schema])[:-2]
242+
sql = f"-- AWS DATA WRANGLER\n" f"CREATE TABLE IF NOT EXISTS {schema_name}.{table_name} (\n" f"{cols_str})"
243+
logger.debug(f"Create table query:\n{sql}")
244+
cursor.execute(sql)
245+
246+
@staticmethod
247+
def _get_schema(dataframe,
248+
dataframe_type: str,
249+
preserve_index: bool,
250+
engine: str = "mysql") -> List[Tuple[str, str]]:
251+
schema_built: List[Tuple[str, str]] = []
252+
if "postgres" in engine.lower():
253+
convert_func = data_types.pyarrow2postgres
254+
elif "mysql" in engine.lower():
255+
convert_func = data_types.pyarrow2mysql
256+
else:
257+
raise InvalidEngine(f"{engine} is not a valid engine. Please use 'mysql' or 'postgres'!")
258+
if dataframe_type.lower() == "pandas":
259+
pyarrow_schema: List[Tuple[str, str]] = data_types.extract_pyarrow_schema_from_pandas(
260+
dataframe=dataframe, preserve_index=preserve_index, indexes_position="right")
261+
for name, dtype in pyarrow_schema:
262+
aurora_type: str = convert_func(dtype)
263+
schema_built.append((name, aurora_type))
264+
else:
265+
raise InvalidDataframeType(f"{dataframe_type} is not a valid DataFrame type. Please use 'pandas'!")
266+
return schema_built
267+
268+
def to_s3(self, sql: str, path: str, connection: Any, engine: str = "mysql") -> str:
269+
"""
270+
Write a query result on S3
271+
272+
:param sql: SQL Query
273+
:param path: AWS S3 path to write the data (e.g. s3://...)
274+
:param connection: A PEP 249 compatible connection (Can be generated with Redshift.generate_connection())
275+
:param engine: Only "mysql" by now
276+
:return: Manifest S3 path
277+
"""
278+
if "mysql" not in engine.lower():
279+
raise InvalidEngine(f"{engine} is not a valid engine. Please use 'mysql'!")
280+
path = path[-1] if path[-1] == "/" else path
281+
self._session.s3.delete_objects(path=path)
282+
sql = f"{sql}\n" \
283+
f"INTO OUTFILE S3 '{path}'\n" \
284+
"FIELDS TERMINATED BY ',' OPTIONALLY ENCLOSED BY '\"' ESCAPED BY '\\\\'\n" \
285+
"LINES TERMINATED BY '\\n'\n" \
286+
"MANIFEST ON\n" \
287+
"OVERWRITE ON"
288+
with connection.cursor() as cursor:
289+
logger.debug(sql)
290+
cursor.execute(sql)
291+
connection.commit()
292+
return path + ".manifest"
293+
294+
def extract_manifest_paths(self, path: str) -> List[str]:
295+
bucket_name, key_path = Aurora._parse_path(path)
296+
body: bytes = self._client_s3.get_object(Bucket=bucket_name, Key=key_path)["Body"].read()
297+
return [x["url"] for x in json.loads(body.decode('utf-8'))["entries"]]

0 commit comments

Comments
 (0)