Skip to content

Commit 43c720c

Browse files
committed
Handling null values for Pandas.to_aurora()
1 parent 0d3402c commit 43c720c

File tree

3 files changed

+133
-9
lines changed

3 files changed

+133
-9
lines changed

awswrangler/pandas.py

Lines changed: 30 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -644,9 +644,11 @@ def _apply_dates_to_generator(generator, parse_dates):
644644
def to_csv(self,
645645
dataframe: pd.DataFrame,
646646
path: str,
647-
sep: str = ",",
647+
sep: Optional[str] = None,
648+
na_rep: Optional[str] = None,
649+
quoting: Optional[int] = None,
648650
escapechar: Optional[str] = None,
649-
serde: str = "OpenCSVSerDe",
651+
serde: Optional[str] = "OpenCSVSerDe",
650652
database: Optional[str] = None,
651653
table: Optional[str] = None,
652654
partition_cols: Optional[List[str]] = None,
@@ -665,8 +667,10 @@ def to_csv(self,
665667
:param dataframe: Pandas Dataframe
666668
:param path: AWS S3 path (E.g. s3://bucket-name/folder_name/
667669
:param sep: Same as pandas.to_csv()
670+
:param na_rep: Same as pandas.to_csv()
671+
:param quoting: Same as pandas.to_csv()
668672
:param escapechar: Same as pandas.to_csv()
669-
:param serde: SerDe library name (e.g. OpenCSVSerDe, LazySimpleSerDe)
673+
:param serde: SerDe library name (e.g. OpenCSVSerDe, LazySimpleSerDe) (For Athena/Glue Catalog only)
670674
:param database: AWS Glue Database name
671675
:param table: AWS Glue table name
672676
:param partition_cols: List of columns names that will be partitions on S3
@@ -680,9 +684,17 @@ def to_csv(self,
680684
:param columns_comments: Columns names and the related comments (Optional[Dict[str, str]])
681685
:return: List of objects written on S3
682686
"""
683-
if serde not in Pandas.VALID_CSV_SERDES:
687+
if (serde is not None) and (serde not in Pandas.VALID_CSV_SERDES):
684688
raise InvalidSerDe(f"{serde} in not in the valid SerDe list ({Pandas.VALID_CSV_SERDES})")
685-
extra_args: Dict[str, Optional[str]] = {"sep": sep, "serde": serde, "escapechar": escapechar}
689+
if (database is not None) and (serde is None):
690+
raise InvalidParameters(f"It is not possible write to a Glue Database without a SerDe.")
691+
extra_args: Dict[str, Optional[str]] = {
692+
"sep": sep,
693+
"na_rep": na_rep,
694+
"serde": serde,
695+
"escapechar": escapechar,
696+
"quoting": quoting
697+
}
686698
return self.to_s3(dataframe=dataframe,
687699
path=path,
688700
file_format="csv",
@@ -1053,17 +1065,24 @@ def write_csv_dataframe(dataframe, path, preserve_index, compression, fs, extra_
10531065

10541066
serde = extra_args.get("serde")
10551067
if serde is None:
1056-
escapechar = extra_args.get("escapechar")
1068+
escapechar: Optional[str] = extra_args.get("escapechar")
10571069
if escapechar is not None:
10581070
csv_extra_args["escapechar"] = escapechar
1071+
quoting: Optional[str] = extra_args.get("quoting")
1072+
if escapechar is not None:
1073+
csv_extra_args["quoting"] = quoting
1074+
na_rep: Optional[str] = extra_args.get("na_rep")
1075+
if na_rep is not None:
1076+
csv_extra_args["na_rep"] = na_rep
10591077
else:
10601078
if serde == "OpenCSVSerDe":
10611079
csv_extra_args["quoting"] = csv.QUOTE_ALL
10621080
csv_extra_args["escapechar"] = "\\"
10631081
elif serde == "LazySimpleSerDe":
10641082
csv_extra_args["quoting"] = csv.QUOTE_NONE
10651083
csv_extra_args["escapechar"] = "\\"
1066-
csv_buffer = bytes(
1084+
logger.debug(f"csv_extra_args: {csv_extra_args}")
1085+
csv_buffer: bytes = bytes(
10671086
dataframe.to_csv(None, header=False, index=preserve_index, compression=compression, **csv_extra_args),
10681087
"utf-8")
10691088
Pandas._write_csv_to_s3_retrying(fs=fs, path=path, buffer=csv_buffer)
@@ -1554,9 +1573,13 @@ def to_aurora(self,
15541573
temp_s3_path = self._session.athena.create_athena_bucket() + temp_directory + "/"
15551574
temp_s3_path = temp_s3_path if temp_s3_path[-1] == "/" else temp_s3_path + "/"
15561575
logger.debug(f"temp_s3_path: {temp_s3_path}")
1576+
na_rep: str = "NULL" if "mysql" in engine.lower() else ""
15571577
paths: List[str] = self.to_csv(dataframe=dataframe,
15581578
path=temp_s3_path,
1579+
serde=None,
15591580
sep=",",
1581+
na_rep=na_rep,
1582+
quoting=csv.QUOTE_MINIMAL,
15601583
escapechar="\"",
15611584
preserve_index=preserve_index,
15621585
mode="overwrite",

awswrangler/s3.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -308,8 +308,8 @@ def get_objects_sizes(self, objects_paths: List[str], procs_io_bound: Optional[i
308308
receive_pipes[i].close()
309309
return objects_sizes
310310

311-
def copy_listed_objects(self, objects_paths, source_path, target_path, mode="append", procs_io_bound=None):
312-
if not procs_io_bound:
311+
def copy_listed_objects(self, objects_paths: List[str], source_path: str, target_path: str, mode: str = "append", procs_io_bound: Optional[int] = None):
312+
if procs_io_bound is None:
313313
procs_io_bound = self._session.procs_io_bound
314314
logger.debug(f"procs_io_bound: {procs_io_bound}")
315315
logger.debug(f"len(objects_paths): {len(objects_paths)}")

testing/test_awswrangler/test_pandas.py

Lines changed: 101 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1417,7 +1417,9 @@ def test_read_parquet_dataset(session, bucket):
14171417
preserve_index=False,
14181418
procs_cpu_bound=4,
14191419
partition_cols=["partition"])
1420+
sleep(15)
14201421
df2 = session.pandas.read_parquet(path=path)
1422+
wr.s3.delete_objects(path=path)
14211423
assert len(list(df.columns)) == len(list(df2.columns))
14221424
assert len(df.index) == len(df2.index)
14231425

@@ -2068,3 +2070,102 @@ def test_read_sql_athena_empty(ctas_approach):
20682070
"""
20692071
df = wr.pandas.read_sql_athena(sql=sql, ctas_approach=ctas_approach)
20702072
print(df)
2073+
2074+
2075+
def test_aurora_postgres_load_special2(bucket, postgres_parameters):
2076+
dt = lambda x: datetime.strptime(x, "%Y-%m-%d %H:%M:%S.%f")
2077+
df = pd.DataFrame({
2078+
"integer1": [0, 1, np.NaN, 3],
2079+
"integer2": [8986, 9735, 9918, 9150],
2080+
"string1": ["O", "P", "P", "O"],
2081+
"string2": ["050100", "010101", "010101", "050100"],
2082+
"string3": ["A", "R", "A", "R"],
2083+
"string4": ["SGD", "SGD", "SGD", "SGD"],
2084+
"float1": [0.0, 1800000.0, np.NaN, 0.0],
2085+
"string5": ["0000296722", "0000199396", "0000298592", "0000196380"],
2086+
"string6": [None, "C", "C", None],
2087+
"timestamp1": [dt("2020-01-07 00:00:00.000"), None, dt("2020-01-07 00:00:00.000"),
2088+
dt("2020-01-07 00:00:00.000")],
2089+
"string7": ["XXX", "XXX", "XXX", "XXX"],
2090+
"timestamp2": [dt("2020-01-10 10:34:55.863"), dt("2020-01-10 10:34:55.864"), dt("2020-01-10 10:34:55.865"),
2091+
dt("2020-01-10 10:34:55.866")],
2092+
})
2093+
df = pd.concat([df for _ in range(10_000)])
2094+
path = f"s3://{bucket}/test_aurora_postgres_special"
2095+
wr.pandas.to_aurora(dataframe=df,
2096+
connection="aws-data-wrangler-postgres",
2097+
schema="public",
2098+
table="test_aurora_postgres_load_special2",
2099+
mode="overwrite",
2100+
temp_s3_path=path,
2101+
engine="postgres",
2102+
procs_cpu_bound=1)
2103+
conn = Aurora.generate_connection(database="postgres",
2104+
host=postgres_parameters["PostgresAddress"],
2105+
port=3306,
2106+
user="test",
2107+
password=postgres_parameters["Password"],
2108+
engine="postgres")
2109+
with conn.cursor() as cursor:
2110+
cursor.execute("SELECT count(*) FROM public.test_aurora_postgres_load_special2")
2111+
assert cursor.fetchall()[0][0] == len(df.index)
2112+
cursor.execute("SELECT timestamp2 FROM public.test_aurora_postgres_load_special2 limit 4")
2113+
rows = cursor.fetchall()
2114+
assert rows[0][0] == dt("2020-01-10 10:34:55.863")
2115+
assert rows[1][0] == dt("2020-01-10 10:34:55.864")
2116+
assert rows[2][0] == dt("2020-01-10 10:34:55.865")
2117+
assert rows[3][0] == dt("2020-01-10 10:34:55.866")
2118+
cursor.execute("SELECT integer1, float1, string6, timestamp1 FROM public.test_aurora_postgres_load_special2 limit 4")
2119+
rows = cursor.fetchall()
2120+
assert rows[2][0] is None
2121+
assert rows[2][1] is None
2122+
assert rows[0][2] is None
2123+
assert rows[1][3] is None
2124+
conn.close()
2125+
2126+
2127+
def test_aurora_mysql_load_special2(bucket, mysql_parameters):
2128+
dt = lambda x: datetime.strptime(x, "%Y-%m-%d %H:%M:%S.%f")
2129+
df = pd.DataFrame({
2130+
"integer1": [0, 1, np.NaN, 3],
2131+
"integer2": [8986, 9735, 9918, 9150],
2132+
"string1": ["O", "P", "P", "O"],
2133+
"string2": ["050100", "010101", "010101", "050100"],
2134+
"string3": ["A", "R", "A", "R"],
2135+
"string4": ["SGD", "SGD", "SGD", "SGD"],
2136+
"float1": [0.0, 1800000.0, np.NaN, 0.0],
2137+
"string5": ["0000296722", "0000199396", "0000298592", "0000196380"],
2138+
"string6": [None, "C", "C", None],
2139+
"timestamp1": [dt("2020-01-07 00:00:00.000"), None, dt("2020-01-07 00:00:00.000"),
2140+
dt("2020-01-07 00:00:00.000")],
2141+
"string7": ["XXX", "XXX", "XXX", "XXX"],
2142+
"timestamp2": [dt("2020-01-10 10:34:55.863"), dt("2020-01-10 10:34:55.864"), dt("2020-01-10 10:34:55.865"),
2143+
dt("2020-01-10 10:34:55.866")],
2144+
})
2145+
df = pd.concat([df for _ in range(10_000)])
2146+
path = f"s3://{bucket}/test_aurora_mysql_load_special2"
2147+
wr.pandas.to_aurora(dataframe=df,
2148+
connection="aws-data-wrangler-mysql",
2149+
schema="test",
2150+
table="test_aurora_mysql_load_special2",
2151+
mode="overwrite",
2152+
temp_s3_path=path,
2153+
engine="mysql",
2154+
procs_cpu_bound=1)
2155+
conn = Aurora.generate_connection(database="mysql",
2156+
host=mysql_parameters["MysqlAddress"],
2157+
port=3306,
2158+
user="test",
2159+
password=mysql_parameters["Password"],
2160+
engine="mysql")
2161+
with conn.cursor() as cursor:
2162+
cursor.execute("SELECT count(*) FROM test.test_aurora_mysql_load_special2")
2163+
assert cursor.fetchall()[0][0] == len(df.index)
2164+
cursor.execute(
2165+
"SELECT integer1, float1, string6, timestamp1 FROM test.test_aurora_mysql_load_special2 limit 4")
2166+
rows = cursor.fetchall()
2167+
assert rows[2][0] is None
2168+
assert rows[2][1] is None
2169+
assert rows[0][2] is None
2170+
assert rows[1][3] is None
2171+
conn.close()

0 commit comments

Comments
 (0)