Skip to content

Commit 95f2b36

Browse files
authored
Merge pull request #119 from awslabs/unload-aurora-null
Fixing bug to unload null values from Aurora
2 parents f9617f0 + 14a3562 commit 95f2b36

File tree

2 files changed

+129
-23
lines changed

2 files changed

+129
-23
lines changed

awswrangler/pandas.py

Lines changed: 97 additions & 23 deletions
Original file line numberDiff line numberDiff line change
@@ -74,6 +74,9 @@ def read_csv(
7474
escapechar=None,
7575
parse_dates: Union[bool, Dict, List] = False,
7676
infer_datetime_format=False,
77+
na_values: Optional[Union[str, List[str]]] = None,
78+
keep_default_na: bool = True,
79+
na_filter: bool = True,
7780
encoding="utf-8",
7881
converters=None,
7982
):
@@ -98,6 +101,9 @@ def read_csv(
98101
:param escapechar: Same as pandas.read_csv()
99102
:param parse_dates: Same as pandas.read_csv()
100103
:param infer_datetime_format: Same as pandas.read_csv()
104+
:param na_values: Same as pandas.read_csv()
105+
:param keep_default_na: Same as pandas.read_csv()
106+
:param na_filter: Same as pandas.read_csv()
101107
:param encoding: Same as pandas.read_csv()
102108
:param converters: Same as pandas.read_csv()
103109
:return: Pandas Dataframe or Iterator of Pandas Dataframes if max_result_size != None
@@ -120,6 +126,9 @@ def read_csv(
120126
escapechar=escapechar,
121127
parse_dates=parse_dates,
122128
infer_datetime_format=infer_datetime_format,
129+
na_values=na_values,
130+
keep_default_na=keep_default_na,
131+
na_filter=na_filter,
123132
encoding=encoding,
124133
converters=converters)
125134
else:
@@ -139,6 +148,9 @@ def read_csv(
139148
escapechar=escapechar,
140149
parse_dates=parse_dates,
141150
infer_datetime_format=infer_datetime_format,
151+
na_values=na_values,
152+
keep_default_na=keep_default_na,
153+
na_filter=na_filter,
142154
encoding=encoding,
143155
converters=converters)
144156
return ret
@@ -161,6 +173,9 @@ def _read_csv_iterator(
161173
escapechar=None,
162174
parse_dates: Union[bool, Dict, List] = False,
163175
infer_datetime_format=False,
176+
na_values: Optional[Union[str, List[str]]] = None,
177+
keep_default_na: bool = True,
178+
na_filter: bool = True,
164179
encoding="utf-8",
165180
converters=None,
166181
):
@@ -185,6 +200,9 @@ def _read_csv_iterator(
185200
:param escapechar: Same as pandas.read_csv()
186201
:param parse_dates: Same as pandas.read_csv()
187202
:param infer_datetime_format: Same as pandas.read_csv()
203+
:param na_values: Same as pandas.read_csv()
204+
:param keep_default_na: Same as pandas.read_csv()
205+
:param na_filter: Same as pandas.read_csv()
188206
:param encoding: Same as pandas.read_csv()
189207
:param converters: Same as pandas.read_csv()
190208
:return: Pandas Dataframe
@@ -211,6 +229,9 @@ def _read_csv_iterator(
211229
escapechar=escapechar,
212230
parse_dates=parse_dates,
213231
infer_datetime_format=infer_datetime_format,
232+
na_values=na_values,
233+
keep_default_na=keep_default_na,
234+
na_filter=na_filter,
214235
encoding=encoding,
215236
converters=converters)
216237
else:
@@ -251,6 +272,9 @@ def _read_csv_iterator(
251272
header=header,
252273
names=names,
253274
usecols=usecols,
275+
na_values=na_values,
276+
keep_default_na=keep_default_na,
277+
na_filter=na_filter,
254278
sep=sep,
255279
thousands=thousands,
256280
decimal=decimal,
@@ -371,6 +395,9 @@ def _read_csv_once(
371395
escapechar=None,
372396
parse_dates: Union[bool, Dict, List] = False,
373397
infer_datetime_format=False,
398+
na_values: Optional[Union[str, List[str]]] = None,
399+
keep_default_na: bool = True,
400+
na_filter: bool = True,
374401
encoding=None,
375402
converters=None,
376403
):
@@ -395,6 +422,9 @@ def _read_csv_once(
395422
:param escapechar: Same as pandas.read_csv()
396423
:param parse_dates: Same as pandas.read_csv()
397424
:param infer_datetime_format: Same as pandas.read_csv()
425+
:param na_values: Same as pandas.read_csv()
426+
:param keep_default_na: Same as pandas.read_csv()
427+
:param na_filter: Same as pandas.read_csv()
398428
:param encoding: Same as pandas.read_csv()
399429
:param converters: Same as pandas.read_csv()
400430
:return: Pandas Dataframe
@@ -409,6 +439,9 @@ def _read_csv_once(
409439
header=header,
410440
names=names,
411441
usecols=usecols,
442+
na_values=na_values,
443+
keep_default_na=keep_default_na,
444+
na_filter=na_filter,
412445
sep=sep,
413446
thousands=thousands,
414447
decimal=decimal,
@@ -443,6 +476,9 @@ def _read_csv_once_remote(send_pipe: mp.connection.Connection,
443476
escapechar=None,
444477
parse_dates: Union[bool, Dict, List] = False,
445478
infer_datetime_format=False,
479+
na_values: Optional[Union[str, List[str]]] = None,
480+
keep_default_na: bool = True,
481+
na_filter: bool = True,
446482
encoding=None,
447483
converters=None):
448484
df: pd.DataFrame = Pandas._read_csv_once(session_primitives=session_primitives,
@@ -461,6 +497,9 @@ def _read_csv_once_remote(send_pipe: mp.connection.Connection,
461497
escapechar=escapechar,
462498
parse_dates=parse_dates,
463499
infer_datetime_format=infer_datetime_format,
500+
na_values=na_values,
501+
keep_default_na=keep_default_na,
502+
na_filter=na_filter,
464503
encoding=encoding,
465504
converters=converters)
466505
send_pipe.send(df)
@@ -869,7 +908,7 @@ def to_s3(self,
869908
logger.debug(f"cast_columns: {cast_columns}")
870909
partition_cols = [Athena.normalize_column_name(x) for x in partition_cols]
871910
logger.debug(f"partition_cols: {partition_cols}")
872-
if extra_args is not None and "columns" in extra_args:
911+
if extra_args is not None and "columns" in extra_args and extra_args["columns"] is not None:
873912
extra_args["columns"] = [Athena.normalize_column_name(x) for x in extra_args["columns"]] # type: ignore
874913
dataframe = Pandas.drop_duplicated_columns(dataframe=dataframe, inplace=inplace)
875914
if compression is not None:
@@ -1754,7 +1793,12 @@ def read_sql_aurora(self,
17541793
paths: List[str] = self._session.aurora.extract_manifest_paths(path=manifest_path)
17551794
logger.debug(f"paths: {paths}")
17561795
ret: Union[pd.DataFrame, Iterator[pd.DataFrame]]
1757-
ret = self.read_csv_list(paths=paths, max_result_size=max_result_size, header=None, names=col_names)
1796+
ret = self.read_csv_list(paths=paths,
1797+
max_result_size=max_result_size,
1798+
header=None,
1799+
names=col_names,
1800+
na_values=["\\N"],
1801+
keep_default_na=False)
17581802
self._session.s3.delete_listed_objects(objects_paths=paths + [manifest_path])
17591803
except Exception as ex:
17601804
connection.rollback()
@@ -1782,6 +1826,9 @@ def read_csv_list(
17821826
escapechar=None,
17831827
parse_dates: Union[bool, Dict, List] = False,
17841828
infer_datetime_format=False,
1829+
na_values: Optional[Union[str, List[str]]] = None,
1830+
keep_default_na: bool = True,
1831+
na_filter: bool = True,
17851832
encoding="utf-8",
17861833
converters=None,
17871834
procs_cpu_bound: Optional[int] = None,
@@ -1807,6 +1854,9 @@ def read_csv_list(
18071854
:param escapechar: Same as pandas.read_csv()
18081855
:param parse_dates: Same as pandas.read_csv()
18091856
:param infer_datetime_format: Same as pandas.read_csv()
1857+
:param na_values: Same as pandas.read_csv()
1858+
:param keep_default_na: Same as pandas.read_csv()
1859+
:param na_filter: Same as pandas.read_csv()
18101860
:param encoding: Same as pandas.read_csv()
18111861
:param converters: Same as pandas.read_csv()
18121862
:param procs_cpu_bound: Number of cores used for CPU bound tasks
@@ -1828,35 +1878,40 @@ def read_csv_list(
18281878
escapechar=escapechar,
18291879
parse_dates=parse_dates,
18301880
infer_datetime_format=infer_datetime_format,
1881+
na_values=na_values,
1882+
keep_default_na=keep_default_na,
1883+
na_filter=na_filter,
18311884
encoding=encoding,
18321885
converters=converters)
18331886
else:
18341887
procs_cpu_bound = procs_cpu_bound if procs_cpu_bound is not None else self._session.procs_cpu_bound if self._session.procs_cpu_bound is not None else 1
18351888
logger.debug(f"procs_cpu_bound: {procs_cpu_bound}")
1836-
df: Optional[pd.DataFrame] = None
18371889
session_primitives = self._session.primitives
18381890
if len(paths) == 1:
18391891
path = paths[0]
18401892
bucket_name, key_path = Pandas._parse_path(path)
18411893
logger.debug(f"path: {path}")
1842-
df = self._read_csv_once(session_primitives=self._session.primitives,
1843-
bucket_name=bucket_name,
1844-
key_path=key_path,
1845-
header=header,
1846-
names=names,
1847-
usecols=usecols,
1848-
dtype=dtype,
1849-
sep=sep,
1850-
thousands=thousands,
1851-
decimal=decimal,
1852-
lineterminator=lineterminator,
1853-
quotechar=quotechar,
1854-
quoting=quoting,
1855-
escapechar=escapechar,
1856-
parse_dates=parse_dates,
1857-
infer_datetime_format=infer_datetime_format,
1858-
encoding=encoding,
1859-
converters=converters)
1894+
df: pd.DataFrame = self._read_csv_once(session_primitives=self._session.primitives,
1895+
bucket_name=bucket_name,
1896+
key_path=key_path,
1897+
header=header,
1898+
names=names,
1899+
usecols=usecols,
1900+
dtype=dtype,
1901+
sep=sep,
1902+
thousands=thousands,
1903+
decimal=decimal,
1904+
lineterminator=lineterminator,
1905+
quotechar=quotechar,
1906+
quoting=quoting,
1907+
escapechar=escapechar,
1908+
parse_dates=parse_dates,
1909+
infer_datetime_format=infer_datetime_format,
1910+
na_values=na_values,
1911+
keep_default_na=keep_default_na,
1912+
na_filter=na_filter,
1913+
encoding=encoding,
1914+
converters=converters)
18601915
else:
18611916
procs = []
18621917
receive_pipes = []
@@ -1869,7 +1924,7 @@ def read_csv_list(
18691924
target=self._read_csv_once_remote,
18701925
args=(send_pipe, session_primitives, bucket_name, key_path, header, names, usecols, dtype, sep,
18711926
thousands, decimal, lineterminator, quotechar, quoting, escapechar, parse_dates,
1872-
infer_datetime_format, encoding, converters),
1927+
infer_datetime_format, na_values, keep_default_na, na_filter, encoding, converters),
18731928
)
18741929
proc.daemon = False
18751930
proc.start()
@@ -1906,6 +1961,9 @@ def _read_csv_list_iterator(
19061961
escapechar=None,
19071962
parse_dates: Union[bool, Dict, List] = False,
19081963
infer_datetime_format=False,
1964+
na_values: Optional[Union[str, List[str]]] = None,
1965+
keep_default_na: bool = True,
1966+
na_filter: bool = True,
19091967
encoding="utf-8",
19101968
converters=None,
19111969
):
@@ -1930,6 +1988,9 @@ def _read_csv_list_iterator(
19301988
:param escapechar: Same as pandas.read_csv()
19311989
:param parse_dates: Same as pandas.read_csv()
19321990
:param infer_datetime_format: Same as pandas.read_csv()
1991+
:param na_values: Same as pandas.read_csv()
1992+
:param keep_default_na: Same as pandas.read_csv()
1993+
:param na_filter: Same as pandas.read_csv()
19331994
:param encoding: Same as pandas.read_csv()
19341995
:param converters: Same as pandas.read_csv()
19351996
:return: Iterator of iterators of Pandas Dataframes
@@ -1953,6 +2014,9 @@ def _read_csv_list_iterator(
19532014
escapechar=escapechar,
19542015
parse_dates=parse_dates,
19552016
infer_datetime_format=infer_datetime_format,
2017+
na_values=na_values,
2018+
keep_default_na=keep_default_na,
2019+
na_filter=na_filter,
19562020
encoding=encoding,
19572021
converters=converters)
19582022

@@ -1973,6 +2037,9 @@ def read_csv_prefix(
19732037
escapechar=None,
19742038
parse_dates: Union[bool, Dict, List] = False,
19752039
infer_datetime_format=False,
2040+
na_values: Optional[Union[str, List[str]]] = None,
2041+
keep_default_na: bool = True,
2042+
na_filter: bool = True,
19762043
encoding="utf-8",
19772044
converters=None,
19782045
procs_cpu_bound: Optional[int] = None,
@@ -1998,6 +2065,9 @@ def read_csv_prefix(
19982065
:param escapechar: Same as pandas.read_csv()
19992066
:param parse_dates: Same as pandas.read_csv()
20002067
:param infer_datetime_format: Same as pandas.read_csv()
2068+
:param na_values: Same as pandas.read_csv()
2069+
:param keep_default_na: Same as pandas.read_csv()
2070+
:param na_filter: Same as pandas.read_csv()
20012071
:param encoding: Same as pandas.read_csv()
20022072
:param converters: Same as pandas.read_csv()
20032073
:param procs_cpu_bound: Number of cores used for CPU bound tasks
@@ -2020,5 +2090,9 @@ def read_csv_prefix(
20202090
escapechar=escapechar,
20212091
parse_dates=parse_dates,
20222092
infer_datetime_format=infer_datetime_format,
2093+
na_values=na_values,
2094+
keep_default_na=keep_default_na,
2095+
na_filter=na_filter,
20232096
encoding=encoding,
2024-
converters=converters)
2097+
converters=converters,
2098+
procs_cpu_bound=procs_cpu_bound)

testing/test_awswrangler/test_pandas.py

Lines changed: 32 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -2309,3 +2309,35 @@ def test_aurora_mysql_load_columns(bucket, mysql_parameters):
23092309
assert rows[4][1] == "boo"
23102310
assert rows[5][1] == "bar"
23112311
conn.close()
2312+
2313+
2314+
def test_aurora_mysql_unload_null(bucket, mysql_parameters):
2315+
df = pd.DataFrame({
2316+
"id": [1, 2, 3, 4, 5],
2317+
"c_str": ["foo", "", None, "bar", None],
2318+
"c_float": [1.1, None, 3.3, None, 5.5],
2319+
"c_int": [1, 2, None, 3, 4],
2320+
})
2321+
df["c_int"] = df["c_int"].astype("Int64")
2322+
print(df)
2323+
conn = Aurora.generate_connection(database="mysql",
2324+
host=mysql_parameters["MysqlAddress"],
2325+
port=3306,
2326+
user="test",
2327+
password=mysql_parameters["Password"],
2328+
engine="mysql")
2329+
path = f"s3://{bucket}/test_aurora_mysql_unload_complex"
2330+
wr.pandas.to_aurora(dataframe=df,
2331+
connection=conn,
2332+
schema="test",
2333+
table="test_aurora_mysql_unload_complex",
2334+
mode="overwrite",
2335+
temp_s3_path=path)
2336+
path2 = f"s3://{bucket}/test_aurora_mysql_unload_complex2"
2337+
df2 = wr.pandas.read_sql_aurora(sql="SELECT * FROM test.test_aurora_mysql_unload_complex",
2338+
connection=conn,
2339+
col_names=["id", "c_str", "c_float", "c_int"],
2340+
temp_s3_path=path2)
2341+
df2["c_int"] = df2["c_int"].astype("Int64")
2342+
assert df.equals(df2)
2343+
conn.close()

0 commit comments

Comments
 (0)