Skip to content

Commit 5bff8c2

Browse files
authored
Merge pull request #100 from awslabs/aurora
Aurora improvements
2 parents 4c4170e + d5e7f8b commit 5bff8c2

File tree

3 files changed

+198
-86
lines changed

3 files changed

+198
-86
lines changed

awswrangler/aurora.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -202,14 +202,14 @@ def _get_load_sql(path: str, schema_name: str, table_name: str, engine: str, reg
202202
"SELECT aws_s3.table_import_from_s3(\n"
203203
f"'{schema_name}.{table_name}',\n"
204204
"'',\n"
205-
"'(FORMAT CSV, DELIMITER '','', QUOTE ''\"'', ESCAPE ''\\'')',\n"
205+
"'(FORMAT CSV, DELIMITER '','', QUOTE ''\"'', ESCAPE ''\"'')',\n"
206206
f"'({bucket},{key},{region})')")
207207
elif "mysql" in engine.lower():
208208
sql = ("-- AWS DATA WRANGLER\n"
209209
f"LOAD DATA FROM S3 MANIFEST '{path}'\n"
210210
"REPLACE\n"
211211
f"INTO TABLE {schema_name}.{table_name}\n"
212-
"FIELDS TERMINATED BY ',' OPTIONALLY ENCLOSED BY '\"' ESCAPED BY '\\\\'\n"
212+
"FIELDS TERMINATED BY ',' OPTIONALLY ENCLOSED BY '\"' ESCAPED BY '\"'\n"
213213
"LINES TERMINATED BY '\\n'")
214214
else:
215215
raise InvalidEngine(f"{engine} is not a valid engine. Please use 'mysql' or 'postgres'!")

awswrangler/pandas.py

Lines changed: 112 additions & 84 deletions
Original file line numberDiff line numberDiff line change
@@ -637,17 +637,18 @@ def _apply_dates_to_generator(generator, parse_dates):
637637
yield df
638638

639639
def to_csv(self,
640-
dataframe,
641-
path,
642-
sep=",",
643-
serde="OpenCSVSerDe",
640+
dataframe: pd.DataFrame,
641+
path: str,
642+
sep: str = ",",
643+
escapechar: Optional[str] = None,
644+
serde: str = "OpenCSVSerDe",
644645
database: Optional[str] = None,
645-
table=None,
646-
partition_cols=None,
647-
preserve_index=True,
648-
mode="append",
649-
procs_cpu_bound=None,
650-
procs_io_bound=None,
646+
table: Optional[str] = None,
647+
partition_cols: Optional[List[str]] = None,
648+
preserve_index: bool = True,
649+
mode: str = "append",
650+
procs_cpu_bound: Optional[int] = None,
651+
procs_io_bound: Optional[int] = None,
651652
inplace=True,
652653
description: Optional[str] = None,
653654
parameters: Optional[Dict[str, str]] = None,
@@ -659,6 +660,7 @@ def to_csv(self,
659660
:param dataframe: Pandas Dataframe
660661
:param path: AWS S3 path (E.g. s3://bucket-name/folder_name/
661662
:param sep: Same as pandas.to_csv()
663+
:param escapechar: Same as pandas.to_csv()
662664
:param serde: SerDe library name (e.g. OpenCSVSerDe, LazySimpleSerDe)
663665
:param database: AWS Glue Database name
664666
:param table: AWS Glue table name
@@ -675,7 +677,7 @@ def to_csv(self,
675677
"""
676678
if serde not in Pandas.VALID_CSV_SERDES:
677679
raise InvalidSerDe(f"{serde} in not in the valid SerDe list ({Pandas.VALID_CSV_SERDES})")
678-
extra_args = {"sep": sep, "serde": serde}
680+
extra_args = {"sep": sep, "serde": serde, "escapechar": escapechar}
679681
return self.to_s3(dataframe=dataframe,
680682
path=path,
681683
file_format="csv",
@@ -1041,8 +1043,13 @@ def write_csv_dataframe(dataframe, path, preserve_index, compression, fs, extra_
10411043
sep = extra_args.get("sep")
10421044
if sep is not None:
10431045
csv_extra_args["sep"] = sep
1046+
10441047
serde = extra_args.get("serde")
1045-
if serde is not None:
1048+
if serde is None:
1049+
escapechar = extra_args.get("escapechar")
1050+
if escapechar is not None:
1051+
csv_extra_args["escapechar"] = escapechar
1052+
else:
10461053
if serde == "OpenCSVSerDe":
10471054
csv_extra_args["quoting"] = csv.QUOTE_ALL
10481055
csv_extra_args["escapechar"] = "\\"
@@ -1511,7 +1518,7 @@ def to_aurora(self,
15111518
Load Pandas Dataframe as a Table on Aurora
15121519
15131520
:param dataframe: Pandas Dataframe
1514-
:param connection: A PEP 249 compatible connection (Can be generated with Redshift.generate_connection())
1521+
:param connection: Glue connection name (str) OR a PEP 249 compatible connection (Can be generated with Redshift.generate_connection())
15151522
:param schema: The Redshift Schema for the table
15161523
:param table: The name of the desired Redshift table
15171524
:param engine: "mysql" or "postgres"
@@ -1523,58 +1530,66 @@ def to_aurora(self,
15231530
:param inplace: True is cheapest (CPU and Memory) but False leaves your DataFrame intact
15241531
:return: None
15251532
"""
1526-
if temp_s3_path is None:
1527-
if self._session.aurora_temp_s3_path is not None:
1528-
temp_s3_path = self._session.aurora_temp_s3_path
1529-
else:
1530-
guid: str = pa.compat.guid()
1531-
temp_directory = f"temp_aurora_{guid}"
1532-
temp_s3_path = self._session.athena.create_athena_bucket() + temp_directory + "/"
1533-
temp_s3_path = temp_s3_path if temp_s3_path[-1] == "/" else temp_s3_path + "/"
1534-
logger.debug(f"temp_s3_path: {temp_s3_path}")
1535-
1536-
paths: List[str] = self.to_csv(dataframe=dataframe,
1537-
path=temp_s3_path,
1538-
sep=",",
1539-
preserve_index=preserve_index,
1540-
mode="overwrite",
1541-
procs_cpu_bound=procs_cpu_bound,
1542-
procs_io_bound=procs_io_bound,
1543-
inplace=inplace)
1544-
1545-
load_paths: List[str]
1546-
region: str = "us-east-1"
1547-
if "postgres" in engine.lower():
1548-
load_paths = paths.copy()
1549-
bucket, _ = Pandas._parse_path(path=load_paths[0])
1550-
region = self._session.s3.get_bucket_region(bucket=bucket)
1551-
elif "mysql" in engine.lower():
1552-
manifest_path: str = f"{temp_s3_path}manifest_{pa.compat.guid()}.json"
1553-
self._session.aurora.write_load_manifest(manifest_path=manifest_path, objects_paths=paths)
1554-
load_paths = [manifest_path]
1555-
else:
1556-
raise InvalidEngine(f"{engine} is not a valid engine. Please use 'mysql' or 'postgres'!")
1557-
logger.debug(f"load_paths: {load_paths}")
1558-
1559-
Aurora.load_table(dataframe=dataframe,
1560-
dataframe_type="pandas",
1561-
load_paths=load_paths,
1562-
schema_name=schema,
1563-
table_name=table,
1564-
connection=connection,
1565-
num_files=len(paths),
1566-
mode=mode,
1567-
preserve_index=preserve_index,
1568-
engine=engine,
1569-
region=region)
1570-
1571-
if "postgres" in engine.lower():
1572-
self._session.s3.delete_listed_objects(objects_paths=load_paths, procs_io_bound=procs_io_bound)
1573-
elif "mysql" in engine.lower():
1574-
self._session.s3.delete_listed_objects(objects_paths=load_paths + [manifest_path],
1575-
procs_io_bound=procs_io_bound)
1576-
else:
1533+
if ("postgres" not in engine.lower()) and ("mysql" not in engine.lower()):
15771534
raise InvalidEngine(f"{engine} is not a valid engine. Please use 'mysql' or 'postgres'!")
1535+
generated_conn: bool = False
1536+
if type(connection) == str:
1537+
logger.debug("Glue connection (str) provided.")
1538+
connection = self._session.glue.get_connection(name=connection)
1539+
generated_conn = True
1540+
try:
1541+
if temp_s3_path is None:
1542+
if self._session.aurora_temp_s3_path is not None:
1543+
temp_s3_path = self._session.aurora_temp_s3_path
1544+
else:
1545+
guid: str = pa.compat.guid()
1546+
temp_directory = f"temp_aurora_{guid}"
1547+
temp_s3_path = self._session.athena.create_athena_bucket() + temp_directory + "/"
1548+
temp_s3_path = temp_s3_path if temp_s3_path[-1] == "/" else temp_s3_path + "/"
1549+
logger.debug(f"temp_s3_path: {temp_s3_path}")
1550+
paths: List[str] = self.to_csv(dataframe=dataframe,
1551+
path=temp_s3_path,
1552+
sep=",",
1553+
escapechar="\"",
1554+
preserve_index=preserve_index,
1555+
mode="overwrite",
1556+
procs_cpu_bound=procs_cpu_bound,
1557+
procs_io_bound=procs_io_bound,
1558+
inplace=inplace)
1559+
load_paths: List[str]
1560+
region: str = "us-east-1"
1561+
if "postgres" in engine.lower():
1562+
load_paths = paths.copy()
1563+
bucket, _ = Pandas._parse_path(path=load_paths[0])
1564+
region = self._session.s3.get_bucket_region(bucket=bucket)
1565+
elif "mysql" in engine.lower():
1566+
manifest_path: str = f"{temp_s3_path}manifest_{pa.compat.guid()}.json"
1567+
self._session.aurora.write_load_manifest(manifest_path=manifest_path, objects_paths=paths)
1568+
load_paths = [manifest_path]
1569+
logger.debug(f"load_paths: {load_paths}")
1570+
Aurora.load_table(dataframe=dataframe,
1571+
dataframe_type="pandas",
1572+
load_paths=load_paths,
1573+
schema_name=schema,
1574+
table_name=table,
1575+
connection=connection,
1576+
num_files=len(paths),
1577+
mode=mode,
1578+
preserve_index=preserve_index,
1579+
engine=engine,
1580+
region=region)
1581+
if "postgres" in engine.lower():
1582+
self._session.s3.delete_listed_objects(objects_paths=load_paths, procs_io_bound=procs_io_bound)
1583+
elif "mysql" in engine.lower():
1584+
self._session.s3.delete_listed_objects(objects_paths=load_paths + [manifest_path],
1585+
procs_io_bound=procs_io_bound)
1586+
except Exception as ex:
1587+
connection.rollback()
1588+
if generated_conn is True:
1589+
connection.close()
1590+
raise ex
1591+
if generated_conn is True:
1592+
connection.close()
15781593

15791594
def read_sql_aurora(self,
15801595
sql: str,
@@ -1587,7 +1602,7 @@ def read_sql_aurora(self,
15871602
Convert a query result in a Pandas Dataframe.
15881603
15891604
:param sql: SQL Query
1590-
:param connection: A PEP 249 compatible connection (Can be generated with Aurora.generate_connection())
1605+
:param connection: Glue connection name (str) OR a PEP 249 compatible connection (Can be generated with Aurora.generate_connection())
15911606
:param col_names: List of column names. Default (None) is use columns IDs as column names.
15921607
:param temp_s3_path: AWS S3 path to write temporary data (e.g. s3://...) (Default uses the Athena's results bucket)
15931608
:param engine: Only "mysql" by now
@@ -1596,25 +1611,38 @@ def read_sql_aurora(self,
15961611
"""
15971612
if "mysql" not in engine.lower():
15981613
raise InvalidEngine(f"{engine} is not a valid engine. Please use 'mysql'!")
1599-
guid: str = pa.compat.guid()
1600-
name: str = f"temp_aurora_{guid}"
1601-
if temp_s3_path is None:
1602-
if self._session.aurora_temp_s3_path is not None:
1603-
temp_s3_path = self._session.aurora_temp_s3_path
1604-
else:
1605-
temp_s3_path = self._session.athena.create_athena_bucket()
1606-
temp_s3_path = temp_s3_path[:-1] if temp_s3_path[-1] == "/" else temp_s3_path
1607-
temp_s3_path = f"{temp_s3_path}/{name}"
1608-
logger.debug(f"temp_s3_path: {temp_s3_path}")
1609-
manifest_path: str = self._session.aurora.to_s3(sql=sql,
1610-
path=temp_s3_path,
1611-
connection=connection,
1612-
engine=engine)
1613-
paths: List[str] = self._session.aurora.extract_manifest_paths(path=manifest_path)
1614-
logger.debug(f"paths: {paths}")
1615-
ret: Union[pd.DataFrame, Iterator[pd.DataFrame]]
1616-
ret = self.read_csv_list(paths=paths, max_result_size=max_result_size, header=None, names=col_names)
1617-
self._session.s3.delete_listed_objects(objects_paths=paths + [manifest_path])
1614+
generated_conn: bool = False
1615+
if type(connection) == str:
1616+
logger.debug("Glue connection (str) provided.")
1617+
connection = self._session.glue.get_connection(name=connection)
1618+
generated_conn = True
1619+
try:
1620+
guid: str = pa.compat.guid()
1621+
name: str = f"temp_aurora_{guid}"
1622+
if temp_s3_path is None:
1623+
if self._session.aurora_temp_s3_path is not None:
1624+
temp_s3_path = self._session.aurora_temp_s3_path
1625+
else:
1626+
temp_s3_path = self._session.athena.create_athena_bucket()
1627+
temp_s3_path = temp_s3_path[:-1] if temp_s3_path[-1] == "/" else temp_s3_path
1628+
temp_s3_path = f"{temp_s3_path}/{name}"
1629+
logger.debug(f"temp_s3_path: {temp_s3_path}")
1630+
manifest_path: str = self._session.aurora.to_s3(sql=sql,
1631+
path=temp_s3_path,
1632+
connection=connection,
1633+
engine=engine)
1634+
paths: List[str] = self._session.aurora.extract_manifest_paths(path=manifest_path)
1635+
logger.debug(f"paths: {paths}")
1636+
ret: Union[pd.DataFrame, Iterator[pd.DataFrame]]
1637+
ret = self.read_csv_list(paths=paths, max_result_size=max_result_size, header=None, names=col_names)
1638+
self._session.s3.delete_listed_objects(objects_paths=paths + [manifest_path])
1639+
except Exception as ex:
1640+
connection.rollback()
1641+
if generated_conn is True:
1642+
connection.close()
1643+
raise ex
1644+
if generated_conn is True:
1645+
connection.close()
16181646
return ret
16191647

16201648
def read_csv_list(

testing/test_awswrangler/test_pandas.py

Lines changed: 84 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1917,3 +1917,87 @@ def test_to_csv_metadata(
19171917
assert len(list(dataframe.columns)) == len(list(dataframe2.columns))
19181918
assert len(session.glue.tables(database=database, search_text="boo bar").index) == 1
19191919
assert len(session.glue.tables(database=database, search_text="value").index) > 0
1920+
1921+
1922+
def test_aurora_postgres_load_special(bucket, postgres_parameters):
1923+
df = pd.DataFrame({
1924+
"id": [1, 2, 3, 4],
1925+
"value": ["foo", "boo", "bar", "abc"],
1926+
"special": ["\\", "\"", "\\\\\\\\", "\"\"\"\""]
1927+
})
1928+
1929+
path = f"s3://{bucket}/test_aurora_postgres_slash"
1930+
wr.pandas.to_aurora(
1931+
dataframe=df,
1932+
connection="aws-data-wrangler-postgres",
1933+
schema="public",
1934+
table="test_aurora_postgres_special",
1935+
mode="overwrite",
1936+
temp_s3_path=path,
1937+
engine="postgres",
1938+
procs_cpu_bound=4
1939+
)
1940+
conn = Aurora.generate_connection(database="postgres",
1941+
host=postgres_parameters["PostgresAddress"],
1942+
port=3306,
1943+
user="test",
1944+
password=postgres_parameters["Password"],
1945+
engine="postgres")
1946+
with conn.cursor() as cursor:
1947+
cursor.execute("SELECT * FROM public.test_aurora_postgres_special")
1948+
rows = cursor.fetchall()
1949+
assert len(rows) == len(df.index)
1950+
assert rows[0][0] == 1
1951+
assert rows[1][0] == 2
1952+
assert rows[2][0] == 3
1953+
assert rows[0][1] == "foo"
1954+
assert rows[1][1] == "boo"
1955+
assert rows[2][1] == "bar"
1956+
assert rows[3][1] == "abc"
1957+
assert rows[0][2] == "\\"
1958+
assert rows[1][2] == "\""
1959+
assert rows[2][2] == "\\\\\\\\"
1960+
assert rows[3][2] == "\"\"\"\""
1961+
conn.close()
1962+
1963+
1964+
def test_aurora_mysql_load_special(bucket, mysql_parameters):
1965+
df = pd.DataFrame({
1966+
"id": [1, 2, 3, 4],
1967+
"value": ["foo", "boo", "bar", "abc"],
1968+
"special": ["\\", "\"", "\\\\\\\\", "\"\"\"\""]
1969+
})
1970+
1971+
path = f"s3://{bucket}/test_aurora_mysql_special"
1972+
wr.pandas.to_aurora(
1973+
dataframe=df,
1974+
connection="aws-data-wrangler-mysql",
1975+
schema="test",
1976+
table="test_aurora_mysql_special",
1977+
mode="overwrite",
1978+
temp_s3_path=path,
1979+
engine="mysql",
1980+
procs_cpu_bound=1
1981+
)
1982+
conn = Aurora.generate_connection(database="mysql",
1983+
host=mysql_parameters["MysqlAddress"],
1984+
port=3306,
1985+
user="test",
1986+
password=mysql_parameters["Password"],
1987+
engine="mysql")
1988+
with conn.cursor() as cursor:
1989+
cursor.execute("SELECT * FROM test.test_aurora_mysql_special")
1990+
rows = cursor.fetchall()
1991+
assert len(rows) == len(df.index)
1992+
assert rows[0][0] == 1
1993+
assert rows[1][0] == 2
1994+
assert rows[2][0] == 3
1995+
assert rows[0][1] == "foo"
1996+
assert rows[1][1] == "boo"
1997+
assert rows[2][1] == "bar"
1998+
assert rows[3][1] == "abc"
1999+
assert rows[0][2] == "\\"
2000+
assert rows[1][2] == "\""
2001+
assert rows[2][2] == "\\\\\\\\"
2002+
assert rows[3][2] == "\"\"\"\""
2003+
conn.close()

0 commit comments

Comments
 (0)