Skip to content

Commit 7d2106f

Browse files
committed
Fix DataFrame sanitize for single files in to_parquet(). #240
1 parent c6e5a19 commit 7d2106f

File tree

3 files changed

+25
-9
lines changed

3 files changed

+25
-9
lines changed

awswrangler/catalog.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -640,7 +640,7 @@ def table(
640640

641641
def _sanitize_name(name: str) -> str:
642642
name = "".join(c for c in unicodedata.normalize("NFD", name) if unicodedata.category(c) != "Mn") # strip accents
643-
name = re.sub("[^A-Za-z0-9_]+", "_", name) # Removing non alphanumeric characters
643+
name = re.sub("[^A-Za-z0-9_]+", "_", name) # Replacing non alphanumeric characters by underscore
644644
return re.sub("([a-z0-9])([A-Z])", r"\1_\2", name).lower() # Converting CamelCase to snake_case
645645

646646

awswrangler/s3.py

Lines changed: 10 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -1043,11 +1043,19 @@ def to_parquet( # pylint: disable=too-many-arguments
10431043
)
10441044
if df.empty is True:
10451045
raise exceptions.EmptyDataFrame()
1046-
session: boto3.Session = _utils.ensure_session(session=boto3_session)
1046+
1047+
# Sanitize table to respect Athena's standards
10471048
partition_cols = partition_cols if partition_cols else []
10481049
dtype = dtype if dtype else {}
10491050
columns_comments = columns_comments if columns_comments else {}
10501051
partitions_values: Dict[str, List[str]] = {}
1052+
df = catalog.sanitize_dataframe_columns_names(df=df)
1053+
partition_cols = [catalog.sanitize_column_name(p) for p in partition_cols]
1054+
dtype = {catalog.sanitize_column_name(k): v.lower() for k, v in dtype.items()}
1055+
columns_comments = {catalog.sanitize_column_name(k): v for k, v in columns_comments.items()}
1056+
df = catalog.drop_duplicated_columns(df=df)
1057+
1058+
session: boto3.Session = _utils.ensure_session(session=boto3_session)
10511059
cpus: int = _utils.ensure_cpu_count(use_threads=use_threads)
10521060
fs: s3fs.S3FileSystem = _utils.get_fs(session=session, s3_additional_kwargs=s3_additional_kwargs)
10531061
compression_ext: Optional[str] = _COMPRESSION_2_EXT.get(compression, None)
@@ -1075,16 +1083,11 @@ def to_parquet( # pylint: disable=too-many-arguments
10751083
]
10761084
else:
10771085
mode = "append" if mode is None else mode
1078-
if (database is not None) and (table is not None): # Normalize table to respect Athena's standards
1079-
df = catalog.sanitize_dataframe_columns_names(df=df)
1080-
partition_cols = [catalog.sanitize_column_name(p) for p in partition_cols]
1081-
dtype = {catalog.sanitize_column_name(k): v.lower() for k, v in dtype.items()}
1082-
columns_comments = {catalog.sanitize_column_name(k): v for k, v in columns_comments.items()}
1086+
if (database is not None) and (table is not None):
10831087
exist: bool = catalog.does_table_exist(database=database, table=table, boto3_session=session)
10841088
if (exist is True) and (mode in ("append", "overwrite_partitions")):
10851089
for k, v in catalog.get_table_types(database=database, table=table, boto3_session=session).items():
10861090
dtype[k] = v
1087-
df = catalog.drop_duplicated_columns(df=df)
10881091
paths, partitions_values = _to_parquet_dataset(
10891092
df=df,
10901093
path=path,

testing/test_awswrangler/test_data_lake.py

Lines changed: 14 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1681,4 +1681,17 @@ def test_athena_undefined_column(database):
16811681
with pytest.raises(wr.exceptions.InvalidArgumentValue):
16821682
wr.athena.read_sql_query("SELECT 1", database)
16831683
with pytest.raises(wr.exceptions.InvalidArgumentValue):
1684-
wr.athena.read_sql_query("SELECT NULL", database)
1684+
wr.athena.read_sql_query("SELECT NULL AS my_null", database)
1685+
1686+
1687+
def test_to_parquet_file_sanitize(path):
1688+
df = pd.DataFrame({"C0": [0, 1], "camelCase": [2, 3], "c**--2": [4, 5]})
1689+
path_file = f"{path}0.parquet"
1690+
wr.s3.to_parquet(df, path_file)
1691+
wr.s3.wait_objects_exist([path_file])
1692+
df2 = wr.s3.read_parquet(path_file)
1693+
assert df.shape == df2.shape
1694+
assert list(df2.columns) == ["c0", "camel_case", "c_2"]
1695+
assert df2.c0.sum() == 1
1696+
assert df2.camel_case.sum() == 5
1697+
assert df2.c_2.sum() == 9

0 commit comments

Comments
 (0)