Skip to content

Commit 4209c44

Browse files
authored
Merge pull request #93 from awslabs/aurora
Aurora MySQL Load check the number of files loaded after commit
2 parents bc9ebda + e0ddef1 commit 4209c44

File tree

3 files changed

+124
-16
lines changed

3 files changed

+124
-16
lines changed

awswrangler/aurora.py

Lines changed: 21 additions & 15 deletions
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,7 @@
11
from typing import Union, List, Dict, Tuple, Any
22
import logging
33
import json
4+
import warnings
45

56
import pg8000 # type: ignore
67
import pymysql # type: ignore
@@ -158,7 +159,6 @@ def load_table(dataframe: pd.DataFrame,
158159
table_name=table_name,
159160
preserve_index=preserve_index,
160161
engine=engine)
161-
162162
for path in load_paths:
163163
sql = Aurora._get_load_sql(path=path,
164164
schema_name=schema_name,
@@ -167,22 +167,21 @@ def load_table(dataframe: pd.DataFrame,
167167
region=region)
168168
logger.debug(sql)
169169
cursor.execute(sql)
170-
171-
if "mysql" in engine.lower():
172-
sql = ("-- AWS DATA WRANGLER\n"
173-
f"SELECT COUNT(*) as num_files_loaded FROM mysql.aurora_s3_load_history "
174-
f"WHERE load_prefix = '{path}'")
175-
logger.debug(sql)
176-
cursor.execute(sql)
177-
num_files_loaded = cursor.fetchall()[0][0]
178-
if num_files_loaded != (num_files + 1):
179-
connection.rollback()
180-
raise AuroraLoadError(
181-
f"Aurora load rolled back. {num_files_loaded} files counted. {num_files} expected.")
182-
183170
connection.commit()
184171
logger.debug("Load committed.")
185172

173+
if "mysql" in engine.lower():
174+
with connection.cursor() as cursor:
175+
sql = ("-- AWS DATA WRANGLER\n"
176+
f"SELECT COUNT(*) as num_files_loaded FROM mysql.aurora_s3_load_history "
177+
f"WHERE load_prefix = '{path}'")
178+
logger.debug(sql)
179+
cursor.execute(sql)
180+
num_files_loaded = cursor.fetchall()[0][0]
181+
if num_files_loaded != (num_files + 1):
182+
raise AuroraLoadError(
183+
f"Missing files to load. {num_files_loaded} files counted. {num_files + 1} expected.")
184+
186185
@staticmethod
187186
def _parse_path(path):
188187
path2 = path.replace("s3://", "")
@@ -233,7 +232,14 @@ def _create_table(cursor,
233232
sql: str = f"-- AWS DATA WRANGLER\n" \
234233
f"DROP TABLE IF EXISTS {schema_name}.{table_name}"
235234
logger.debug(f"Drop table query:\n{sql}")
236-
cursor.execute(sql)
235+
if "postgres" in engine.lower():
236+
cursor.execute(sql)
237+
elif "mysql" in engine.lower():
238+
with warnings.catch_warnings():
239+
warnings.filterwarnings(action="ignore", message=".*Unknown table.*")
240+
cursor.execute(sql)
241+
else:
242+
raise InvalidEngine(f"{engine} is not a valid engine. Please use 'mysql' or 'postgres'!")
237243
schema = Aurora._get_schema(dataframe=dataframe,
238244
dataframe_type=dataframe_type,
239245
preserve_index=preserve_index,

awswrangler/pandas.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1473,7 +1473,7 @@ def to_aurora(self,
14731473
:param engine: "mysql" or "postgres"
14741474
:param temp_s3_path: S3 path to write temporary files (E.g. s3://BUCKET_NAME/ANY_NAME/)
14751475
:param preserve_index: Should we preserve the Dataframe index?
1476-
:param mode: append, overwrite or upsert
1476+
:param mode: append or overwrite
14771477
:param procs_cpu_bound: Number of cores used for CPU bound tasks
14781478
:param procs_io_bound: Number of cores used for I/O bound tasks
14791479
:param inplace: True is cheapest (CPU and Memory) but False leaves your DataFrame intact

testing/test_awswrangler/test_pandas.py

Lines changed: 102 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1779,6 +1779,108 @@ def test_read_csv_list_iterator(bucket, sample, row_num):
17791779
assert total_count == row_num * n
17801780

17811781

1782+
def test_aurora_mysql_load_append(bucket, mysql_parameters):
1783+
n: int = 10_000
1784+
df = pd.DataFrame({"id": list((range(n))), "value": list(["foo" if i % 2 == 0 else "boo" for i in range(n)])})
1785+
conn = Aurora.generate_connection(database="mysql",
1786+
host=mysql_parameters["MysqlAddress"],
1787+
port=3306,
1788+
user="test",
1789+
password=mysql_parameters["Password"],
1790+
engine="mysql")
1791+
path = f"s3://{bucket}/test_aurora_mysql_load_append"
1792+
1793+
# LOAD
1794+
wr.pandas.to_aurora(dataframe=df,
1795+
connection=conn,
1796+
schema="test",
1797+
table="test_aurora_mysql_load_append",
1798+
mode="overwrite",
1799+
temp_s3_path=path)
1800+
with conn.cursor() as cursor:
1801+
cursor.execute("SELECT count(*) FROM test.test_aurora_mysql_load_append")
1802+
count = cursor.fetchall()[0][0]
1803+
assert count == len(df.index)
1804+
1805+
# APPEND
1806+
wr.pandas.to_aurora(dataframe=df,
1807+
connection=conn,
1808+
schema="test",
1809+
table="test_aurora_mysql_load_append",
1810+
mode="append",
1811+
temp_s3_path=path)
1812+
with conn.cursor() as cursor:
1813+
cursor.execute("SELECT count(*) FROM test.test_aurora_mysql_load_append")
1814+
count = cursor.fetchall()[0][0]
1815+
assert count == len(df.index) * 2
1816+
1817+
# RESET
1818+
wr.pandas.to_aurora(dataframe=df,
1819+
connection=conn,
1820+
schema="test",
1821+
table="test_aurora_mysql_load_append",
1822+
mode="overwrite",
1823+
temp_s3_path=path)
1824+
with conn.cursor() as cursor:
1825+
cursor.execute("SELECT count(*) FROM test.test_aurora_mysql_load_append")
1826+
count = cursor.fetchall()[0][0]
1827+
assert count == len(df.index)
1828+
1829+
conn.close()
1830+
1831+
1832+
def test_aurora_postgres_load_append(bucket, postgres_parameters):
1833+
df = pd.DataFrame({"id": [1, 2, 3], "value": ["foo", "boo", "bar"]})
1834+
conn = Aurora.generate_connection(database="postgres",
1835+
host=postgres_parameters["PostgresAddress"],
1836+
port=3306,
1837+
user="test",
1838+
password=postgres_parameters["Password"],
1839+
engine="postgres")
1840+
path = f"s3://{bucket}/test_aurora_postgres_load_append"
1841+
1842+
# LOAD
1843+
wr.pandas.to_aurora(dataframe=df,
1844+
connection=conn,
1845+
schema="public",
1846+
table="test_aurora_postgres_load_append",
1847+
mode="overwrite",
1848+
temp_s3_path=path,
1849+
engine="postgres")
1850+
with conn.cursor() as cursor:
1851+
cursor.execute("SELECT count(*) FROM public.test_aurora_postgres_load_append")
1852+
count = cursor.fetchall()[0][0]
1853+
assert count == len(df.index)
1854+
1855+
# APPEND
1856+
wr.pandas.to_aurora(dataframe=df,
1857+
connection=conn,
1858+
schema="public",
1859+
table="test_aurora_postgres_load_append",
1860+
mode="append",
1861+
temp_s3_path=path,
1862+
engine="postgres")
1863+
with conn.cursor() as cursor:
1864+
cursor.execute("SELECT count(*) FROM public.test_aurora_postgres_load_append")
1865+
count = cursor.fetchall()[0][0]
1866+
assert count == len(df.index) * 2
1867+
1868+
# RESET
1869+
wr.pandas.to_aurora(dataframe=df,
1870+
connection=conn,
1871+
schema="public",
1872+
table="test_aurora_postgres_load_append",
1873+
mode="overwrite",
1874+
temp_s3_path=path,
1875+
engine="postgres")
1876+
with conn.cursor() as cursor:
1877+
cursor.execute("SELECT count(*) FROM public.test_aurora_postgres_load_append")
1878+
count = cursor.fetchall()[0][0]
1879+
assert count == len(df.index)
1880+
1881+
conn.close()
1882+
1883+
17821884
def test_to_csv_metadata(
17831885
session,
17841886
bucket,

0 commit comments

Comments
 (0)