Skip to content

Commit 2f1cab6

Browse files
committed
Handling eventual consistency for Aurora postgres load
1 parent 43c720c commit 2f1cab6

File tree

4 files changed

+177
-48
lines changed

4 files changed

+177
-48
lines changed

awswrangler/aurora.py

Lines changed: 137 additions & 23 deletions
Original file line numberDiff line numberDiff line change
@@ -1,12 +1,14 @@
11
from typing import TYPE_CHECKING, Union, List, Dict, Tuple, Any
2-
from logging import getLogger, Logger
2+
from logging import getLogger, Logger, INFO
33
import json
44
import warnings
55

66
import pg8000 # type: ignore
7+
from pg8000 import ProgrammingError # type: ignore
78
import pymysql # type: ignore
89
import pandas as pd # type: ignore
910
from boto3 import client # type: ignore
11+
import tenacity # type: ignore
1012

1113
from awswrangler import data_types
1214
from awswrangler.exceptions import InvalidEngine, InvalidDataframeType, AuroraLoadError
@@ -134,7 +136,7 @@ def load_table(dataframe: pd.DataFrame,
134136
schema_name: str,
135137
table_name: str,
136138
connection: Any,
137-
num_files,
139+
num_files: int,
138140
mode: str = "append",
139141
preserve_index: bool = False,
140142
engine: str = "mysql",
@@ -156,6 +158,54 @@ def load_table(dataframe: pd.DataFrame,
156158
:param region: AWS S3 bucket region (Required only for postgres engine)
157159
:return: None
158160
"""
161+
if "postgres" in engine.lower():
162+
Aurora.load_table_postgres(dataframe=dataframe,
163+
dataframe_type=dataframe_type,
164+
load_paths=load_paths,
165+
schema_name=schema_name,
166+
table_name=table_name,
167+
connection=connection,
168+
mode=mode,
169+
preserve_index=preserve_index,
170+
region=region)
171+
elif "mysql" in engine.lower():
172+
Aurora.load_table_mysql(dataframe=dataframe,
173+
dataframe_type=dataframe_type,
174+
manifest_path=load_paths[0],
175+
schema_name=schema_name,
176+
table_name=table_name,
177+
connection=connection,
178+
mode=mode,
179+
preserve_index=preserve_index,
180+
num_files=num_files)
181+
else:
182+
raise InvalidEngine(f"{engine} is not a valid engine. Please use 'mysql' or 'postgres'!")
183+
184+
@staticmethod
185+
def load_table_postgres(dataframe: pd.DataFrame,
186+
dataframe_type: str,
187+
load_paths: List[str],
188+
schema_name: str,
189+
table_name: str,
190+
connection: Any,
191+
mode: str = "append",
192+
preserve_index: bool = False,
193+
region: str = "us-east-1"):
194+
"""
195+
Load text/CSV files into a Aurora table using a manifest file.
196+
Creates the table if necessary.
197+
198+
:param dataframe: Pandas or Spark Dataframe
199+
:param dataframe_type: "pandas" or "spark"
200+
:param load_paths: S3 paths to be loaded (E.g. S3://...)
201+
:param schema_name: Aurora schema
202+
:param table_name: Aurora table name
203+
:param connection: A PEP 249 compatible connection (Can be generated with Aurora.generate_connection())
204+
:param mode: append or overwrite
205+
:param preserve_index: Should we preserve the Dataframe index? (ONLY for Pandas Dataframe)
206+
:param region: AWS S3 bucket region (Required only for postgres engine)
207+
:return: None
208+
"""
159209
with connection.cursor() as cursor:
160210
if mode == "overwrite":
161211
Aurora._create_table(cursor=cursor,
@@ -164,30 +214,94 @@ def load_table(dataframe: pd.DataFrame,
164214
schema_name=schema_name,
165215
table_name=table_name,
166216
preserve_index=preserve_index,
167-
engine=engine)
168-
for path in load_paths:
169-
sql = Aurora._get_load_sql(path=path,
170-
schema_name=schema_name,
171-
table_name=table_name,
172-
engine=engine,
173-
region=region)
174-
logger.debug(sql)
217+
engine="postgres")
218+
connection.commit()
219+
logger.debug("CREATE TABLE committed.")
220+
for path in load_paths:
221+
Aurora._load_object_postgres_with_retry(connection=connection,
222+
schema_name=schema_name,
223+
table_name=table_name,
224+
path=path,
225+
region=region)
226+
227+
@staticmethod
228+
@tenacity.retry(retry=tenacity.retry_if_exception_type(exception_types=ProgrammingError),
229+
wait=tenacity.wait_random_exponential(multiplier=0.5),
230+
stop=tenacity.stop_after_attempt(max_attempt_number=5),
231+
reraise=True,
232+
after=tenacity.after_log(logger, INFO))
233+
def _load_object_postgres_with_retry(connection: Any, schema_name: str, table_name: str, path: str,
234+
region: str) -> None:
235+
with connection.cursor() as cursor:
236+
sql = Aurora._get_load_sql(path=path,
237+
schema_name=schema_name,
238+
table_name=table_name,
239+
engine="postgres",
240+
region=region)
241+
logger.debug(sql)
242+
try:
175243
cursor.execute(sql)
244+
except ProgrammingError as ex:
245+
if "The file has been modified" in str(ex):
246+
connection.rollback()
247+
raise ex
248+
connection.commit()
249+
logger.debug(f"Load committed for: {path}.")
176250

177-
connection.commit()
178-
logger.debug("Load committed.")
251+
@staticmethod
252+
def load_table_mysql(dataframe: pd.DataFrame,
253+
dataframe_type: str,
254+
manifest_path: str,
255+
schema_name: str,
256+
table_name: str,
257+
connection: Any,
258+
num_files: int,
259+
mode: str = "append",
260+
preserve_index: bool = False):
261+
"""
262+
Load text/CSV files into a Aurora table using a manifest file.
263+
Creates the table if necessary.
179264
180-
if "mysql" in engine.lower():
181-
with connection.cursor() as cursor:
182-
sql = ("-- AWS DATA WRANGLER\n"
183-
f"SELECT COUNT(*) as num_files_loaded FROM mysql.aurora_s3_load_history "
184-
f"WHERE load_prefix = '{path}'")
185-
logger.debug(sql)
186-
cursor.execute(sql)
187-
num_files_loaded = cursor.fetchall()[0][0]
188-
if num_files_loaded != (num_files + 1):
189-
raise AuroraLoadError(
190-
f"Missing files to load. {num_files_loaded} files counted. {num_files + 1} expected.")
265+
:param dataframe: Pandas or Spark Dataframe
266+
:param dataframe_type: "pandas" or "spark"
267+
:param manifest_path: S3 manifest path to be loaded (E.g. S3://...)
268+
:param schema_name: Aurora schema
269+
:param table_name: Aurora table name
270+
:param connection: A PEP 249 compatible connection (Can be generated with Aurora.generate_connection())
271+
:param num_files: Number of files to be loaded
272+
:param mode: append or overwrite
273+
:param preserve_index: Should we preserve the Dataframe index? (ONLY for Pandas Dataframe)
274+
:return: None
275+
"""
276+
with connection.cursor() as cursor:
277+
if mode == "overwrite":
278+
Aurora._create_table(cursor=cursor,
279+
dataframe=dataframe,
280+
dataframe_type=dataframe_type,
281+
schema_name=schema_name,
282+
table_name=table_name,
283+
preserve_index=preserve_index,
284+
engine="mysql")
285+
sql = Aurora._get_load_sql(path=manifest_path,
286+
schema_name=schema_name,
287+
table_name=table_name,
288+
engine="mysql")
289+
logger.debug(sql)
290+
cursor.execute(sql)
291+
logger.debug(f"Load done for: {manifest_path}")
292+
connection.commit()
293+
logger.debug("Load committed.")
294+
295+
with connection.cursor() as cursor:
296+
sql = ("-- AWS DATA WRANGLER\n"
297+
f"SELECT COUNT(*) as num_files_loaded FROM mysql.aurora_s3_load_history "
298+
f"WHERE load_prefix = '{manifest_path}'")
299+
logger.debug(sql)
300+
cursor.execute(sql)
301+
num_files_loaded = cursor.fetchall()[0][0]
302+
if num_files_loaded != (num_files + 1):
303+
raise AuroraLoadError(
304+
f"Missing files to load. {num_files_loaded} files counted. {num_files + 1} expected.")
191305

192306
@staticmethod
193307
def _parse_path(path):

awswrangler/pandas.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -688,7 +688,7 @@ def to_csv(self,
688688
raise InvalidSerDe(f"{serde} in not in the valid SerDe list ({Pandas.VALID_CSV_SERDES})")
689689
if (database is not None) and (serde is None):
690690
raise InvalidParameters(f"It is not possible write to a Glue Database without a SerDe.")
691-
extra_args: Dict[str, Optional[str]] = {
691+
extra_args: Dict[str, Optional[Union[str, int]]] = {
692692
"sep": sep,
693693
"na_rep": na_rep,
694694
"serde": serde,
@@ -779,7 +779,7 @@ def to_s3(self,
779779
procs_cpu_bound=None,
780780
procs_io_bound=None,
781781
cast_columns=None,
782-
extra_args: Optional[Dict[str, Optional[str]]] = None,
782+
extra_args: Optional[Dict[str, Optional[Union[str, int]]]] = None,
783783
inplace: bool = True,
784784
description: Optional[str] = None,
785785
parameters: Optional[Dict[str, str]] = None,

awswrangler/s3.py

Lines changed: 6 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -308,7 +308,12 @@ def get_objects_sizes(self, objects_paths: List[str], procs_io_bound: Optional[i
308308
receive_pipes[i].close()
309309
return objects_sizes
310310

311-
def copy_listed_objects(self, objects_paths: List[str], source_path: str, target_path: str, mode: str = "append", procs_io_bound: Optional[int] = None):
311+
def copy_listed_objects(self,
312+
objects_paths: List[str],
313+
source_path: str,
314+
target_path: str,
315+
mode: str = "append",
316+
procs_io_bound: Optional[int] = None):
312317
if procs_io_bound is None:
313318
procs_io_bound = self._session.procs_io_bound
314319
logger.debug(f"procs_io_bound: {procs_io_bound}")

testing/test_awswrangler/test_pandas.py

Lines changed: 32 additions & 22 deletions
Original file line numberDiff line numberDiff line change
@@ -1937,7 +1937,7 @@ def test_aurora_postgres_load_special(bucket, postgres_parameters):
19371937
Decimal((0, (1, 9, 9), -2)),
19381938
Decimal((0, (1, 9, 9), -2)),
19391939
Decimal((0, (1, 9, 0), -2)),
1940-
Decimal((0, (3, 1, 2), -2))
1940+
None
19411941
]
19421942
})
19431943

@@ -1978,7 +1978,7 @@ def test_aurora_postgres_load_special(bucket, postgres_parameters):
19781978
assert rows[0][4] == Decimal((0, (1, 9, 9), -2))
19791979
assert rows[1][4] == Decimal((0, (1, 9, 9), -2))
19801980
assert rows[2][4] == Decimal((0, (1, 9, 0), -2))
1981-
assert rows[3][4] == Decimal((0, (3, 1, 2), -2))
1981+
assert rows[3][4] is None
19821982
conn.close()
19831983

19841984

@@ -1992,7 +1992,7 @@ def test_aurora_mysql_load_special(bucket, mysql_parameters):
19921992
Decimal((0, (1, 9, 9), -2)),
19931993
Decimal((0, (1, 9, 9), -2)),
19941994
Decimal((0, (1, 9, 0), -2)),
1995-
Decimal((0, (3, 1, 2), -2))
1995+
None
19961996
]
19971997
})
19981998

@@ -2004,7 +2004,7 @@ def test_aurora_mysql_load_special(bucket, mysql_parameters):
20042004
mode="overwrite",
20052005
temp_s3_path=path,
20062006
engine="mysql",
2007-
procs_cpu_bound=1)
2007+
procs_cpu_bound=4)
20082008
conn = Aurora.generate_connection(database="mysql",
20092009
host=mysql_parameters["MysqlAddress"],
20102010
port=3306,
@@ -2033,7 +2033,7 @@ def test_aurora_mysql_load_special(bucket, mysql_parameters):
20332033
assert rows[0][4] == Decimal((0, (1, 9, 9), -2))
20342034
assert rows[1][4] == Decimal((0, (1, 9, 9), -2))
20352035
assert rows[2][4] == Decimal((0, (1, 9, 0), -2))
2036-
assert rows[3][4] == Decimal((0, (3, 1, 2), -2))
2036+
assert rows[3][4] is None
20372037
conn.close()
20382038

20392039

@@ -2073,7 +2073,7 @@ def test_read_sql_athena_empty(ctas_approach):
20732073

20742074

20752075
def test_aurora_postgres_load_special2(bucket, postgres_parameters):
2076-
dt = lambda x: datetime.strptime(x, "%Y-%m-%d %H:%M:%S.%f")
2076+
dt = lambda x: datetime.strptime(x, "%Y-%m-%d %H:%M:%S.%f") # noqa
20772077
df = pd.DataFrame({
20782078
"integer1": [0, 1, np.NaN, 3],
20792079
"integer2": [8986, 9735, 9918, 9150],
@@ -2084,11 +2084,17 @@ def test_aurora_postgres_load_special2(bucket, postgres_parameters):
20842084
"float1": [0.0, 1800000.0, np.NaN, 0.0],
20852085
"string5": ["0000296722", "0000199396", "0000298592", "0000196380"],
20862086
"string6": [None, "C", "C", None],
2087-
"timestamp1": [dt("2020-01-07 00:00:00.000"), None, dt("2020-01-07 00:00:00.000"),
2088-
dt("2020-01-07 00:00:00.000")],
2087+
"timestamp1":
2088+
[dt("2020-01-07 00:00:00.000"), None,
2089+
dt("2020-01-07 00:00:00.000"),
2090+
dt("2020-01-07 00:00:00.000")],
20892091
"string7": ["XXX", "XXX", "XXX", "XXX"],
2090-
"timestamp2": [dt("2020-01-10 10:34:55.863"), dt("2020-01-10 10:34:55.864"), dt("2020-01-10 10:34:55.865"),
2091-
dt("2020-01-10 10:34:55.866")],
2092+
"timestamp2": [
2093+
dt("2020-01-10 10:34:55.863"),
2094+
dt("2020-01-10 10:34:55.864"),
2095+
dt("2020-01-10 10:34:55.865"),
2096+
dt("2020-01-10 10:34:55.866")
2097+
],
20922098
})
20932099
df = pd.concat([df for _ in range(10_000)])
20942100
path = f"s3://{bucket}/test_aurora_postgres_special"
@@ -2098,8 +2104,7 @@ def test_aurora_postgres_load_special2(bucket, postgres_parameters):
20982104
table="test_aurora_postgres_load_special2",
20992105
mode="overwrite",
21002106
temp_s3_path=path,
2101-
engine="postgres",
2102-
procs_cpu_bound=1)
2107+
engine="postgres")
21032108
conn = Aurora.generate_connection(database="postgres",
21042109
host=postgres_parameters["PostgresAddress"],
21052110
port=3306,
@@ -2115,7 +2120,8 @@ def test_aurora_postgres_load_special2(bucket, postgres_parameters):
21152120
assert rows[1][0] == dt("2020-01-10 10:34:55.864")
21162121
assert rows[2][0] == dt("2020-01-10 10:34:55.865")
21172122
assert rows[3][0] == dt("2020-01-10 10:34:55.866")
2118-
cursor.execute("SELECT integer1, float1, string6, timestamp1 FROM public.test_aurora_postgres_load_special2 limit 4")
2123+
cursor.execute(
2124+
"SELECT integer1, float1, string6, timestamp1 FROM public.test_aurora_postgres_load_special2 limit 4")
21192125
rows = cursor.fetchall()
21202126
assert rows[2][0] is None
21212127
assert rows[2][1] is None
@@ -2125,7 +2131,7 @@ def test_aurora_postgres_load_special2(bucket, postgres_parameters):
21252131

21262132

21272133
def test_aurora_mysql_load_special2(bucket, mysql_parameters):
2128-
dt = lambda x: datetime.strptime(x, "%Y-%m-%d %H:%M:%S.%f")
2134+
dt = lambda x: datetime.strptime(x, "%Y-%m-%d %H:%M:%S.%f") # noqa
21292135
df = pd.DataFrame({
21302136
"integer1": [0, 1, np.NaN, 3],
21312137
"integer2": [8986, 9735, 9918, 9150],
@@ -2136,11 +2142,17 @@ def test_aurora_mysql_load_special2(bucket, mysql_parameters):
21362142
"float1": [0.0, 1800000.0, np.NaN, 0.0],
21372143
"string5": ["0000296722", "0000199396", "0000298592", "0000196380"],
21382144
"string6": [None, "C", "C", None],
2139-
"timestamp1": [dt("2020-01-07 00:00:00.000"), None, dt("2020-01-07 00:00:00.000"),
2140-
dt("2020-01-07 00:00:00.000")],
2145+
"timestamp1":
2146+
[dt("2020-01-07 00:00:00.000"), None,
2147+
dt("2020-01-07 00:00:00.000"),
2148+
dt("2020-01-07 00:00:00.000")],
21412149
"string7": ["XXX", "XXX", "XXX", "XXX"],
2142-
"timestamp2": [dt("2020-01-10 10:34:55.863"), dt("2020-01-10 10:34:55.864"), dt("2020-01-10 10:34:55.865"),
2143-
dt("2020-01-10 10:34:55.866")],
2150+
"timestamp2": [
2151+
dt("2020-01-10 10:34:55.863"),
2152+
dt("2020-01-10 10:34:55.864"),
2153+
dt("2020-01-10 10:34:55.865"),
2154+
dt("2020-01-10 10:34:55.866")
2155+
],
21442156
})
21452157
df = pd.concat([df for _ in range(10_000)])
21462158
path = f"s3://{bucket}/test_aurora_mysql_load_special2"
@@ -2150,8 +2162,7 @@ def test_aurora_mysql_load_special2(bucket, mysql_parameters):
21502162
table="test_aurora_mysql_load_special2",
21512163
mode="overwrite",
21522164
temp_s3_path=path,
2153-
engine="mysql",
2154-
procs_cpu_bound=1)
2165+
engine="mysql")
21552166
conn = Aurora.generate_connection(database="mysql",
21562167
host=mysql_parameters["MysqlAddress"],
21572168
port=3306,
@@ -2161,8 +2172,7 @@ def test_aurora_mysql_load_special2(bucket, mysql_parameters):
21612172
with conn.cursor() as cursor:
21622173
cursor.execute("SELECT count(*) FROM test.test_aurora_mysql_load_special2")
21632174
assert cursor.fetchall()[0][0] == len(df.index)
2164-
cursor.execute(
2165-
"SELECT integer1, float1, string6, timestamp1 FROM test.test_aurora_mysql_load_special2 limit 4")
2175+
cursor.execute("SELECT integer1, float1, string6, timestamp1 FROM test.test_aurora_mysql_load_special2 limit 4")
21662176
rows = cursor.fetchall()
21672177
assert rows[2][0] is None
21682178
assert rows[2][1] is None

0 commit comments

Comments
 (0)