Skip to content

Commit e0356bb

Browse files
committed
Add first load tests
1 parent 32b27e6 commit e0356bb

File tree

8 files changed

+515
-72
lines changed

8 files changed

+515
-72
lines changed

awswrangler/aurora.py

Lines changed: 166 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1,17 +1,21 @@
1-
from typing import Union
1+
from typing import Union, List, Dict, Tuple, Any
22
import logging
3+
import json
34

45
import pg8000 # type: ignore
56
import pymysql # type: ignore
7+
import pandas as pd # type: ignore
68

7-
from awswrangler.exceptions import InvalidEngine
9+
from awswrangler import data_types
10+
from awswrangler.exceptions import InvalidEngine, InvalidDataframeType, AuroraLoadError
811

912
logger = logging.getLogger(__name__)
1013

1114

1215
class Aurora:
1316
def __init__(self, session):
1417
self._session = session
18+
self._client_s3 = session.boto3_session.client(service_name="s3", use_ssl=True, config=session.botocore_config)
1519

1620
@staticmethod
1721
def _validate_connection(database: str,
@@ -101,3 +105,163 @@ def generate_connection(database: str,
101105
else:
102106
raise InvalidEngine(f"{engine} is not a valid engine. Please use 'mysql' or 'postgres'!")
103107
return conn
108+
109+
def write_load_manifest(self, manifest_path: str,
110+
objects_paths: List[str]) -> Dict[str, List[Dict[str, Union[str, bool]]]]:
111+
manifest: Dict[str, List[Dict[str, Union[str, bool]]]] = {"entries": []}
112+
path: str
113+
for path in objects_paths:
114+
entry: Dict[str, Union[str, bool]] = {"url": path, "mandatory": True}
115+
manifest["entries"].append(entry)
116+
payload: str = json.dumps(manifest)
117+
bucket: str
118+
bucket, path = manifest_path.replace("s3://", "").split("/", 1)
119+
logger.info(f"payload: {payload}")
120+
self._client_s3.put_object(Body=payload, Bucket=bucket, Key=path)
121+
return manifest
122+
123+
@staticmethod
124+
def load_table(dataframe: pd.DataFrame,
125+
dataframe_type: str,
126+
load_paths: List[str],
127+
schema_name: str,
128+
table_name: str,
129+
connection: Any,
130+
num_files,
131+
mode: str = "append",
132+
preserve_index: bool = False,
133+
engine: str = "mysql",
134+
region: str = "us-east-1"):
135+
"""
136+
Load text/CSV files into a Aurora table using a manifest file.
137+
Creates the table if necessary.
138+
139+
:param dataframe: Pandas or Spark Dataframe
140+
:param dataframe_type: "pandas" or "spark"
141+
:param load_paths: S3 paths to be loaded (E.g. S3://...)
142+
:param schema_name: Aurora schema
143+
:param table_name: Aurora table name
144+
:param connection: A PEP 249 compatible connection (Can be generated with Aurora.generate_connection())
145+
:param num_files: Number of files to be loaded
146+
:param mode: append or overwrite
147+
:param preserve_index: Should we preserve the Dataframe index? (ONLY for Pandas Dataframe)
148+
:param engine: "mysql" or "postgres"
149+
:param region: AWS S3 bucket region (Required only for postgres engine)
150+
:return: None
151+
"""
152+
with connection.cursor() as cursor:
153+
if mode == "overwrite":
154+
Aurora._create_table(cursor=cursor,
155+
dataframe=dataframe,
156+
dataframe_type=dataframe_type,
157+
schema_name=schema_name,
158+
table_name=table_name,
159+
preserve_index=preserve_index,
160+
engine=engine)
161+
162+
for path in load_paths:
163+
sql = Aurora._get_load_sql(path=path,
164+
schema_name=schema_name,
165+
table_name=table_name,
166+
engine=engine,
167+
region=region)
168+
logger.debug(sql)
169+
cursor.execute(sql)
170+
171+
if "mysql" in engine.lower():
172+
sql = ("-- AWS DATA WRANGLER\n"
173+
f"SELECT COUNT(*) as num_files_loaded FROM mysql.aurora_s3_load_history "
174+
f"WHERE load_prefix = '{path}'")
175+
logger.debug(sql)
176+
cursor.execute(sql)
177+
num_files_loaded = cursor.fetchall()[0][0]
178+
if num_files_loaded != (num_files + 1):
179+
connection.rollback()
180+
raise AuroraLoadError(
181+
f"Aurora load rolled back. {num_files_loaded} files counted. {num_files} expected.")
182+
183+
connection.commit()
184+
logger.debug("Load committed.")
185+
186+
@staticmethod
187+
def _parse_path(path):
188+
path2 = path.replace("s3://", "")
189+
parts = path2.partition("/")
190+
return parts[0], parts[2]
191+
192+
@staticmethod
193+
def _get_load_sql(path: str, schema_name: str, table_name: str, engine: str, region: str = "us-east-1") -> str:
194+
if "postgres" in engine.lower():
195+
bucket, key = Aurora._parse_path(path=path)
196+
sql: str = ("-- AWS DATA WRANGLER\n"
197+
"SELECT aws_s3.table_import_from_s3(\n"
198+
f"'{schema_name}.{table_name}',\n"
199+
"'',\n"
200+
"'(FORMAT CSV, DELIMITER '','', QUOTE ''\"'', ESCAPE ''\\'')',\n"
201+
f"'({bucket},{key},{region})')")
202+
elif "mysql" in engine.lower():
203+
sql = ("-- AWS DATA WRANGLER\n"
204+
"SELECT aws_s3.table_import_from_s3(\n"
205+
f"LOAD DATA FROM S3 MANIFEST '{path}'\n"
206+
"REPLACE\n"
207+
f"INTO TABLE {schema_name}.{table_name}\n"
208+
"FIELDS TERMINATED BY ',' OPTIONALLY ENCLOSED BY '\"' ESCAPED BY '\\\\'\n"
209+
"LINES TERMINATED BY '\\n'")
210+
else:
211+
raise InvalidEngine(f"{engine} is not a valid engine. Please use 'mysql' or 'postgres'!")
212+
return sql
213+
214+
@staticmethod
215+
def _create_table(cursor,
216+
dataframe,
217+
dataframe_type,
218+
schema_name,
219+
table_name,
220+
preserve_index=False,
221+
engine: str = "mysql"):
222+
"""
223+
Creates Aurora table.
224+
225+
:param cursor: A PEP 249 compatible cursor
226+
:param dataframe: Pandas or Spark Dataframe
227+
:param dataframe_type: "pandas" or "spark"
228+
:param schema_name: Redshift schema
229+
:param table_name: Redshift table name
230+
:param preserve_index: Should we preserve the Dataframe index? (ONLY for Pandas Dataframe)
231+
:param engine: "mysql" or "postgres"
232+
:return: None
233+
"""
234+
sql: str = f"-- AWS DATA WRANGLER\n" \
235+
f"DROP TABLE IF EXISTS {schema_name}.{table_name}"
236+
logger.debug(f"Drop table query:\n{sql}")
237+
cursor.execute(sql)
238+
schema = Aurora._get_schema(dataframe=dataframe,
239+
dataframe_type=dataframe_type,
240+
preserve_index=preserve_index,
241+
engine=engine)
242+
cols_str: str = "".join([f"{col[0]} {col[1]},\n" for col in schema])[:-2]
243+
sql = (f"-- AWS DATA WRANGLER\n" f"CREATE TABLE IF NOT EXISTS {schema_name}.{table_name} (\n" f"{cols_str})")
244+
logger.debug(f"Create table query:\n{sql}")
245+
cursor.execute(sql)
246+
247+
@staticmethod
248+
def _get_schema(dataframe,
249+
dataframe_type: str,
250+
preserve_index: bool,
251+
engine: str = "mysql") -> List[Tuple[str, str]]:
252+
schema_built: List[Tuple[str, str]] = []
253+
if "postgres" in engine.lower():
254+
convert_func = data_types.pyarrow2postgres
255+
elif "mysql" in engine.lower():
256+
convert_func = data_types.pyarrow2mysql
257+
else:
258+
raise InvalidEngine(f"{engine} is not a valid engine. Please use 'mysql' or 'postgres'!")
259+
if dataframe_type.lower() == "pandas":
260+
pyarrow_schema: List[Tuple[str, str]] = data_types.extract_pyarrow_schema_from_pandas(
261+
dataframe=dataframe, preserve_index=preserve_index, indexes_position="right")
262+
for name, dtype in pyarrow_schema:
263+
aurora_type: str = convert_func(dtype)
264+
schema_built.append((name, aurora_type))
265+
else:
266+
raise InvalidDataframeType(f"{dataframe_type} is not a valid DataFrame type. Please use 'pandas'!")
267+
return schema_built

awswrangler/data_types.py

Lines changed: 52 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -203,6 +203,58 @@ def pyarrow2redshift(dtype: pa.types) -> str:
203203
raise UnsupportedType(f"Unsupported Pyarrow type: {dtype}")
204204

205205

206+
def pyarrow2postgres(dtype: pa.types) -> str:
207+
dtype_str = str(dtype).lower()
208+
if dtype_str == "int16":
209+
return "SMALLINT"
210+
elif dtype_str == "int32":
211+
return "INT"
212+
elif dtype_str == "int64":
213+
return "BIGINT"
214+
elif dtype_str == "float":
215+
return "FLOAT4"
216+
elif dtype_str == "double":
217+
return "FLOAT8"
218+
elif dtype_str == "bool":
219+
return "BOOLEAN"
220+
elif dtype_str == "string":
221+
return "VARCHAR(256)"
222+
elif dtype_str.startswith("timestamp"):
223+
return "TIMESTAMP"
224+
elif dtype_str.startswith("date"):
225+
return "DATE"
226+
elif dtype_str.startswith("decimal"):
227+
return dtype_str.replace(" ", "").upper()
228+
else:
229+
raise UnsupportedType(f"Unsupported Pyarrow type: {dtype}")
230+
231+
232+
def pyarrow2mysql(dtype: pa.types) -> str:
233+
dtype_str = str(dtype).lower()
234+
if dtype_str == "int16":
235+
return "SMALLINT"
236+
elif dtype_str == "int32":
237+
return "INT"
238+
elif dtype_str == "int64":
239+
return "BIGINT"
240+
elif dtype_str == "float":
241+
return "FLOAT"
242+
elif dtype_str == "double":
243+
return "DOUBLE"
244+
elif dtype_str == "bool":
245+
return "BOOLEAN"
246+
elif dtype_str == "string":
247+
return "VARCHAR(256)"
248+
elif dtype_str.startswith("timestamp"):
249+
return "TIMESTAMP"
250+
elif dtype_str.startswith("date"):
251+
return "DATE"
252+
elif dtype_str.startswith("decimal"):
253+
return dtype_str.replace(" ", "").upper()
254+
else:
255+
raise UnsupportedType(f"Unsupported Pyarrow type: {dtype}")
256+
257+
206258
def python2athena(python_type: type) -> str:
207259
python_type_str: str = str(python_type)
208260
if python_type_str == "<class 'int'>":

awswrangler/exceptions.py

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -26,6 +26,10 @@ class RedshiftLoadError(Exception):
2626
pass
2727

2828

29+
class AuroraLoadError(Exception):
30+
pass
31+
32+
2933
class AthenaQueryError(Exception):
3034
pass
3135

awswrangler/pandas.py

Lines changed: 79 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -19,10 +19,11 @@
1919
from awswrangler import data_types
2020
from awswrangler.exceptions import (UnsupportedWriteMode, UnsupportedFileFormat, AthenaQueryError, EmptyS3Object,
2121
LineTerminatorNotFound, EmptyDataframe, InvalidSerDe, InvalidCompression,
22-
InvalidParameters)
22+
InvalidParameters, InvalidEngine)
2323
from awswrangler.utils import calculate_bounders
2424
from awswrangler import s3
2525
from awswrangler.athena import Athena
26+
from awswrangler.aurora import Aurora
2627

2728
logger = logging.getLogger(__name__)
2829

@@ -834,9 +835,9 @@ def data_to_s3(self,
834835
procs_io_bound=None,
835836
cast_columns=None,
836837
extra_args=None):
837-
if not procs_cpu_bound:
838+
if procs_cpu_bound is None:
838839
procs_cpu_bound = self._session.procs_cpu_bound
839-
if not procs_io_bound:
840+
if procs_io_bound is None:
840841
procs_io_bound = self._session.procs_io_bound
841842
logger.debug(f"procs_cpu_bound: {procs_cpu_bound}")
842843
logger.debug(f"procs_io_bound: {procs_io_bound}")
@@ -1473,3 +1474,78 @@ def read_sql_aurora(self,
14731474
else:
14741475
self._session.s3.delete_objects(path=temp_s3_path)
14751476
raise e
1477+
1478+
def to_aurora(self,
1479+
dataframe: pd.DataFrame,
1480+
connection: Any,
1481+
schema: str,
1482+
table: str,
1483+
engine: str = "mysql",
1484+
temp_s3_path: Optional[str] = None,
1485+
preserve_index: bool = False,
1486+
mode: str = "append",
1487+
procs_cpu_bound: Optional[int] = None,
1488+
procs_io_bound: Optional[int] = None,
1489+
inplace=True) -> None:
1490+
"""
1491+
Load Pandas Dataframe as a Table on Aurora
1492+
1493+
:param dataframe: Pandas Dataframe
1494+
:param connection: A PEP 249 compatible connection (Can be generated with Redshift.generate_connection())
1495+
:param schema: The Redshift Schema for the table
1496+
:param table: The name of the desired Redshift table
1497+
:param engine: "mysql" or "postgres"
1498+
:param temp_s3_path: S3 path to write temporary files (E.g. s3://BUCKET_NAME/ANY_NAME/)
1499+
:param preserve_index: Should we preserve the Dataframe index?
1500+
:param mode: append, overwrite or upsert
1501+
:param procs_cpu_bound: Number of cores used for CPU bound tasks
1502+
:param procs_io_bound: Number of cores used for I/O bound tasks
1503+
:param inplace: True is cheapest (CPU and Memory) but False leaves your DataFrame intact
1504+
:return: None
1505+
"""
1506+
if temp_s3_path is None:
1507+
if self._session.aurora_temp_s3_path is not None:
1508+
temp_s3_path = self._session.aurora_temp_s3_path
1509+
else:
1510+
guid: str = pa.compat.guid()
1511+
temp_directory = f"temp_aurora_{guid}"
1512+
temp_s3_path = self._session.athena.create_athena_bucket() + temp_directory + "/"
1513+
temp_s3_path = temp_s3_path if temp_s3_path[-1] == "/" else temp_s3_path + "/"
1514+
logger.debug(f"temp_s3_path: {temp_s3_path}")
1515+
1516+
paths: List[str] = self.to_csv(dataframe=dataframe,
1517+
path=temp_s3_path,
1518+
sep=",",
1519+
preserve_index=preserve_index,
1520+
mode="overwrite",
1521+
procs_cpu_bound=procs_cpu_bound,
1522+
procs_io_bound=procs_io_bound,
1523+
inplace=inplace)
1524+
1525+
load_paths: List[str]
1526+
region: str = "us-east-1"
1527+
if "postgres" in engine.lower():
1528+
load_paths = paths.copy()
1529+
bucket, _ = Pandas._parse_path(path=load_paths[0])
1530+
region = self._session.s3.get_bucket_region(bucket=bucket)
1531+
elif "mysql" in engine.lower():
1532+
manifest_path: str = f"{temp_s3_path}manifest_{pa.compat.guid()}.json"
1533+
self._session.aurora.write_load_manifest(manifest_path=manifest_path, objects_paths=paths)
1534+
load_paths = [manifest_path]
1535+
else:
1536+
raise InvalidEngine(f"{engine} is not a valid engine. Please use 'mysql' or 'postgres'!")
1537+
logger.debug(f"load_paths: {load_paths}")
1538+
1539+
Aurora.load_table(dataframe=dataframe,
1540+
dataframe_type="pandas",
1541+
load_paths=load_paths,
1542+
schema_name=schema,
1543+
table_name=table,
1544+
connection=connection,
1545+
num_files=len(paths),
1546+
mode=mode,
1547+
preserve_index=preserve_index,
1548+
engine=engine,
1549+
region=region)
1550+
1551+
self._session.s3.delete_objects(path=temp_s3_path, procs_io_bound=procs_io_bound)

0 commit comments

Comments
 (0)