Skip to content

Commit b1d90f4

Browse files
authored
Merge pull request #108 from awslabs/aurora
Handling null values for Pandas.to_aurora() and eventual consistency for postgres load
2 parents 0d3402c + 2f1cab6 commit b1d90f4

File tree

4 files changed

+291
-38
lines changed

4 files changed

+291
-38
lines changed

awswrangler/aurora.py

Lines changed: 137 additions & 23 deletions
Original file line numberDiff line numberDiff line change
@@ -1,12 +1,14 @@
11
from typing import TYPE_CHECKING, Union, List, Dict, Tuple, Any
2-
from logging import getLogger, Logger
2+
from logging import getLogger, Logger, INFO
33
import json
44
import warnings
55

66
import pg8000 # type: ignore
7+
from pg8000 import ProgrammingError # type: ignore
78
import pymysql # type: ignore
89
import pandas as pd # type: ignore
910
from boto3 import client # type: ignore
11+
import tenacity # type: ignore
1012

1113
from awswrangler import data_types
1214
from awswrangler.exceptions import InvalidEngine, InvalidDataframeType, AuroraLoadError
@@ -134,7 +136,7 @@ def load_table(dataframe: pd.DataFrame,
134136
schema_name: str,
135137
table_name: str,
136138
connection: Any,
137-
num_files,
139+
num_files: int,
138140
mode: str = "append",
139141
preserve_index: bool = False,
140142
engine: str = "mysql",
@@ -156,6 +158,54 @@ def load_table(dataframe: pd.DataFrame,
156158
:param region: AWS S3 bucket region (Required only for postgres engine)
157159
:return: None
158160
"""
161+
if "postgres" in engine.lower():
162+
Aurora.load_table_postgres(dataframe=dataframe,
163+
dataframe_type=dataframe_type,
164+
load_paths=load_paths,
165+
schema_name=schema_name,
166+
table_name=table_name,
167+
connection=connection,
168+
mode=mode,
169+
preserve_index=preserve_index,
170+
region=region)
171+
elif "mysql" in engine.lower():
172+
Aurora.load_table_mysql(dataframe=dataframe,
173+
dataframe_type=dataframe_type,
174+
manifest_path=load_paths[0],
175+
schema_name=schema_name,
176+
table_name=table_name,
177+
connection=connection,
178+
mode=mode,
179+
preserve_index=preserve_index,
180+
num_files=num_files)
181+
else:
182+
raise InvalidEngine(f"{engine} is not a valid engine. Please use 'mysql' or 'postgres'!")
183+
184+
@staticmethod
185+
def load_table_postgres(dataframe: pd.DataFrame,
186+
dataframe_type: str,
187+
load_paths: List[str],
188+
schema_name: str,
189+
table_name: str,
190+
connection: Any,
191+
mode: str = "append",
192+
preserve_index: bool = False,
193+
region: str = "us-east-1"):
194+
"""
195+
Load text/CSV files into a Aurora table using a manifest file.
196+
Creates the table if necessary.
197+
198+
:param dataframe: Pandas or Spark Dataframe
199+
:param dataframe_type: "pandas" or "spark"
200+
:param load_paths: S3 paths to be loaded (E.g. S3://...)
201+
:param schema_name: Aurora schema
202+
:param table_name: Aurora table name
203+
:param connection: A PEP 249 compatible connection (Can be generated with Aurora.generate_connection())
204+
:param mode: append or overwrite
205+
:param preserve_index: Should we preserve the Dataframe index? (ONLY for Pandas Dataframe)
206+
:param region: AWS S3 bucket region (Required only for postgres engine)
207+
:return: None
208+
"""
159209
with connection.cursor() as cursor:
160210
if mode == "overwrite":
161211
Aurora._create_table(cursor=cursor,
@@ -164,30 +214,94 @@ def load_table(dataframe: pd.DataFrame,
164214
schema_name=schema_name,
165215
table_name=table_name,
166216
preserve_index=preserve_index,
167-
engine=engine)
168-
for path in load_paths:
169-
sql = Aurora._get_load_sql(path=path,
170-
schema_name=schema_name,
171-
table_name=table_name,
172-
engine=engine,
173-
region=region)
174-
logger.debug(sql)
217+
engine="postgres")
218+
connection.commit()
219+
logger.debug("CREATE TABLE committed.")
220+
for path in load_paths:
221+
Aurora._load_object_postgres_with_retry(connection=connection,
222+
schema_name=schema_name,
223+
table_name=table_name,
224+
path=path,
225+
region=region)
226+
227+
@staticmethod
228+
@tenacity.retry(retry=tenacity.retry_if_exception_type(exception_types=ProgrammingError),
229+
wait=tenacity.wait_random_exponential(multiplier=0.5),
230+
stop=tenacity.stop_after_attempt(max_attempt_number=5),
231+
reraise=True,
232+
after=tenacity.after_log(logger, INFO))
233+
def _load_object_postgres_with_retry(connection: Any, schema_name: str, table_name: str, path: str,
234+
region: str) -> None:
235+
with connection.cursor() as cursor:
236+
sql = Aurora._get_load_sql(path=path,
237+
schema_name=schema_name,
238+
table_name=table_name,
239+
engine="postgres",
240+
region=region)
241+
logger.debug(sql)
242+
try:
175243
cursor.execute(sql)
244+
except ProgrammingError as ex:
245+
if "The file has been modified" in str(ex):
246+
connection.rollback()
247+
raise ex
248+
connection.commit()
249+
logger.debug(f"Load committed for: {path}.")
176250

177-
connection.commit()
178-
logger.debug("Load committed.")
251+
@staticmethod
252+
def load_table_mysql(dataframe: pd.DataFrame,
253+
dataframe_type: str,
254+
manifest_path: str,
255+
schema_name: str,
256+
table_name: str,
257+
connection: Any,
258+
num_files: int,
259+
mode: str = "append",
260+
preserve_index: bool = False):
261+
"""
262+
Load text/CSV files into a Aurora table using a manifest file.
263+
Creates the table if necessary.
179264
180-
if "mysql" in engine.lower():
181-
with connection.cursor() as cursor:
182-
sql = ("-- AWS DATA WRANGLER\n"
183-
f"SELECT COUNT(*) as num_files_loaded FROM mysql.aurora_s3_load_history "
184-
f"WHERE load_prefix = '{path}'")
185-
logger.debug(sql)
186-
cursor.execute(sql)
187-
num_files_loaded = cursor.fetchall()[0][0]
188-
if num_files_loaded != (num_files + 1):
189-
raise AuroraLoadError(
190-
f"Missing files to load. {num_files_loaded} files counted. {num_files + 1} expected.")
265+
:param dataframe: Pandas or Spark Dataframe
266+
:param dataframe_type: "pandas" or "spark"
267+
:param manifest_path: S3 manifest path to be loaded (E.g. S3://...)
268+
:param schema_name: Aurora schema
269+
:param table_name: Aurora table name
270+
:param connection: A PEP 249 compatible connection (Can be generated with Aurora.generate_connection())
271+
:param num_files: Number of files to be loaded
272+
:param mode: append or overwrite
273+
:param preserve_index: Should we preserve the Dataframe index? (ONLY for Pandas Dataframe)
274+
:return: None
275+
"""
276+
with connection.cursor() as cursor:
277+
if mode == "overwrite":
278+
Aurora._create_table(cursor=cursor,
279+
dataframe=dataframe,
280+
dataframe_type=dataframe_type,
281+
schema_name=schema_name,
282+
table_name=table_name,
283+
preserve_index=preserve_index,
284+
engine="mysql")
285+
sql = Aurora._get_load_sql(path=manifest_path,
286+
schema_name=schema_name,
287+
table_name=table_name,
288+
engine="mysql")
289+
logger.debug(sql)
290+
cursor.execute(sql)
291+
logger.debug(f"Load done for: {manifest_path}")
292+
connection.commit()
293+
logger.debug("Load committed.")
294+
295+
with connection.cursor() as cursor:
296+
sql = ("-- AWS DATA WRANGLER\n"
297+
f"SELECT COUNT(*) as num_files_loaded FROM mysql.aurora_s3_load_history "
298+
f"WHERE load_prefix = '{manifest_path}'")
299+
logger.debug(sql)
300+
cursor.execute(sql)
301+
num_files_loaded = cursor.fetchall()[0][0]
302+
if num_files_loaded != (num_files + 1):
303+
raise AuroraLoadError(
304+
f"Missing files to load. {num_files_loaded} files counted. {num_files + 1} expected.")
191305

192306
@staticmethod
193307
def _parse_path(path):

awswrangler/pandas.py

Lines changed: 31 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -644,9 +644,11 @@ def _apply_dates_to_generator(generator, parse_dates):
644644
def to_csv(self,
645645
dataframe: pd.DataFrame,
646646
path: str,
647-
sep: str = ",",
647+
sep: Optional[str] = None,
648+
na_rep: Optional[str] = None,
649+
quoting: Optional[int] = None,
648650
escapechar: Optional[str] = None,
649-
serde: str = "OpenCSVSerDe",
651+
serde: Optional[str] = "OpenCSVSerDe",
650652
database: Optional[str] = None,
651653
table: Optional[str] = None,
652654
partition_cols: Optional[List[str]] = None,
@@ -665,8 +667,10 @@ def to_csv(self,
665667
:param dataframe: Pandas Dataframe
666668
:param path: AWS S3 path (E.g. s3://bucket-name/folder_name/
667669
:param sep: Same as pandas.to_csv()
670+
:param na_rep: Same as pandas.to_csv()
671+
:param quoting: Same as pandas.to_csv()
668672
:param escapechar: Same as pandas.to_csv()
669-
:param serde: SerDe library name (e.g. OpenCSVSerDe, LazySimpleSerDe)
673+
:param serde: SerDe library name (e.g. OpenCSVSerDe, LazySimpleSerDe) (For Athena/Glue Catalog only)
670674
:param database: AWS Glue Database name
671675
:param table: AWS Glue table name
672676
:param partition_cols: List of columns names that will be partitions on S3
@@ -680,9 +684,17 @@ def to_csv(self,
680684
:param columns_comments: Columns names and the related comments (Optional[Dict[str, str]])
681685
:return: List of objects written on S3
682686
"""
683-
if serde not in Pandas.VALID_CSV_SERDES:
687+
if (serde is not None) and (serde not in Pandas.VALID_CSV_SERDES):
684688
raise InvalidSerDe(f"{serde} in not in the valid SerDe list ({Pandas.VALID_CSV_SERDES})")
685-
extra_args: Dict[str, Optional[str]] = {"sep": sep, "serde": serde, "escapechar": escapechar}
689+
if (database is not None) and (serde is None):
690+
raise InvalidParameters(f"It is not possible write to a Glue Database without a SerDe.")
691+
extra_args: Dict[str, Optional[Union[str, int]]] = {
692+
"sep": sep,
693+
"na_rep": na_rep,
694+
"serde": serde,
695+
"escapechar": escapechar,
696+
"quoting": quoting
697+
}
686698
return self.to_s3(dataframe=dataframe,
687699
path=path,
688700
file_format="csv",
@@ -767,7 +779,7 @@ def to_s3(self,
767779
procs_cpu_bound=None,
768780
procs_io_bound=None,
769781
cast_columns=None,
770-
extra_args: Optional[Dict[str, Optional[str]]] = None,
782+
extra_args: Optional[Dict[str, Optional[Union[str, int]]]] = None,
771783
inplace: bool = True,
772784
description: Optional[str] = None,
773785
parameters: Optional[Dict[str, str]] = None,
@@ -1053,17 +1065,24 @@ def write_csv_dataframe(dataframe, path, preserve_index, compression, fs, extra_
10531065

10541066
serde = extra_args.get("serde")
10551067
if serde is None:
1056-
escapechar = extra_args.get("escapechar")
1068+
escapechar: Optional[str] = extra_args.get("escapechar")
10571069
if escapechar is not None:
10581070
csv_extra_args["escapechar"] = escapechar
1071+
quoting: Optional[str] = extra_args.get("quoting")
1072+
if escapechar is not None:
1073+
csv_extra_args["quoting"] = quoting
1074+
na_rep: Optional[str] = extra_args.get("na_rep")
1075+
if na_rep is not None:
1076+
csv_extra_args["na_rep"] = na_rep
10591077
else:
10601078
if serde == "OpenCSVSerDe":
10611079
csv_extra_args["quoting"] = csv.QUOTE_ALL
10621080
csv_extra_args["escapechar"] = "\\"
10631081
elif serde == "LazySimpleSerDe":
10641082
csv_extra_args["quoting"] = csv.QUOTE_NONE
10651083
csv_extra_args["escapechar"] = "\\"
1066-
csv_buffer = bytes(
1084+
logger.debug(f"csv_extra_args: {csv_extra_args}")
1085+
csv_buffer: bytes = bytes(
10671086
dataframe.to_csv(None, header=False, index=preserve_index, compression=compression, **csv_extra_args),
10681087
"utf-8")
10691088
Pandas._write_csv_to_s3_retrying(fs=fs, path=path, buffer=csv_buffer)
@@ -1554,9 +1573,13 @@ def to_aurora(self,
15541573
temp_s3_path = self._session.athena.create_athena_bucket() + temp_directory + "/"
15551574
temp_s3_path = temp_s3_path if temp_s3_path[-1] == "/" else temp_s3_path + "/"
15561575
logger.debug(f"temp_s3_path: {temp_s3_path}")
1576+
na_rep: str = "NULL" if "mysql" in engine.lower() else ""
15571577
paths: List[str] = self.to_csv(dataframe=dataframe,
15581578
path=temp_s3_path,
1579+
serde=None,
15591580
sep=",",
1581+
na_rep=na_rep,
1582+
quoting=csv.QUOTE_MINIMAL,
15601583
escapechar="\"",
15611584
preserve_index=preserve_index,
15621585
mode="overwrite",

awswrangler/s3.py

Lines changed: 7 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -308,8 +308,13 @@ def get_objects_sizes(self, objects_paths: List[str], procs_io_bound: Optional[i
308308
receive_pipes[i].close()
309309
return objects_sizes
310310

311-
def copy_listed_objects(self, objects_paths, source_path, target_path, mode="append", procs_io_bound=None):
312-
if not procs_io_bound:
311+
def copy_listed_objects(self,
312+
objects_paths: List[str],
313+
source_path: str,
314+
target_path: str,
315+
mode: str = "append",
316+
procs_io_bound: Optional[int] = None):
317+
if procs_io_bound is None:
313318
procs_io_bound = self._session.procs_io_bound
314319
logger.debug(f"procs_io_bound: {procs_io_bound}")
315320
logger.debug(f"len(objects_paths): {len(objects_paths)}")

0 commit comments

Comments
 (0)