Skip to content

Commit 5f9cc93

Browse files
authored
Merge pull request #39 from awslabs/normalize-athena-names
Normalize Athena names (tables and columns)
2 parents e434664 + be53a6b commit 5f9cc93

File tree

4 files changed

+104
-6
lines changed

4 files changed

+104
-6
lines changed

awswrangler/athena.py

Lines changed: 32 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,8 @@
11
from time import sleep
22
import logging
33
import ast
4+
import re
5+
import unicodedata
46

57
from awswrangler import data_types
68
from awswrangler.exceptions import QueryFailed, QueryCancelled
@@ -128,3 +130,33 @@ def repair_table(self, database, table, s3_output=None):
128130
s3_output=s3_output)
129131
self.wait_query(query_execution_id=query_id)
130132
return query_id
133+
134+
@staticmethod
135+
def _normalize_name(name):
136+
name = "".join(c for c in unicodedata.normalize("NFD", name)
137+
if unicodedata.category(c) != "Mn")
138+
name = name.replace(" ", "_")
139+
name = name.replace("-", "_")
140+
name = name.replace(".", "_")
141+
name = re.sub("(.)([A-Z][a-z]+)", r"\1_\2", name)
142+
name = re.sub("([a-z0-9])([A-Z])", r"\1_\2", name)
143+
name = name.lower()
144+
return re.sub(r"(_)\1+", "\\1", name) # remove repeated underscores
145+
146+
@staticmethod
147+
def normalize_column_name(name):
148+
"""
149+
https://docs.aws.amazon.com/athena/latest/ug/tables-databases-columns-names.html
150+
:param name: column name (str)
151+
:return: normalized column name (str)
152+
"""
153+
return Athena._normalize_name(name=name)
154+
155+
@staticmethod
156+
def normalize_table_name(name):
157+
"""
158+
https://docs.aws.amazon.com/athena/latest/ug/tables-databases-columns-names.html
159+
:param name: table name (str)
160+
:return: normalized table name (str)
161+
"""
162+
return Athena._normalize_name(name=name)

awswrangler/glue.py

Lines changed: 15 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -3,6 +3,7 @@
33
import logging
44

55
from awswrangler import data_types
6+
from awswrangler.athena import Athena
67
from awswrangler.exceptions import UnsupportedFileFormat, InvalidSerDe, ApiError
78

89
logger = logging.getLogger(__name__)
@@ -64,7 +65,7 @@ def metadata_to_glue(self,
6465
indexes_position=indexes_position,
6566
cast_columns=cast_columns)
6667
table = table if table else Glue.parse_table_name(path)
67-
table = table.lower().replace(".", "_")
68+
table = Athena.normalize_table_name(name=table)
6869
if mode == "overwrite":
6970
self.delete_table_if_exists(database=database, table=table)
7071
exists = self.does_table_exists(database=database, table=table)
@@ -124,8 +125,13 @@ def create_table(self,
124125
self._client_glue.create_table(DatabaseName=database,
125126
TableInput=table_input)
126127

127-
def add_partitions(self, database, table, partition_paths, file_format,
128-
compression, extra_args=None):
128+
def add_partitions(self,
129+
database,
130+
table,
131+
partition_paths,
132+
file_format,
133+
compression,
134+
extra_args=None):
129135
if not partition_paths:
130136
return None
131137
partitions = list()
@@ -207,8 +213,12 @@ def parse_table_name(path):
207213
return path.rpartition("/")[2]
208214

209215
@staticmethod
210-
def csv_table_definition(table, partition_cols_schema, schema, path,
211-
compression, extra_args=None):
216+
def csv_table_definition(table,
217+
partition_cols_schema,
218+
schema,
219+
path,
220+
compression,
221+
extra_args=None):
212222
if extra_args is None:
213223
extra_args = {}
214224
if partition_cols_schema is None:

awswrangler/pandas.py

Lines changed: 11 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -17,7 +17,7 @@
1717
EmptyDataframe, InvalidSerDe,
1818
InvalidCompression)
1919
from awswrangler.utils import calculate_bounders
20-
from awswrangler import s3
20+
from awswrangler import s3, athena
2121

2222
logger = logging.getLogger(__name__)
2323

@@ -599,6 +599,7 @@ def to_s3(self,
599599
:param extra_args: Extra arguments specific for each file formats (E.g. "sep" for CSV)
600600
:return: List of objects written on S3
601601
"""
602+
Pandas.normalize_columns_names_athena(dataframe, inplace=True)
602603
if compression is not None:
603604
compression = compression.lower()
604605
file_format = file_format.lower()
@@ -1024,3 +1025,12 @@ def read_log_query(self,
10241025
new_row[col_name] = col["value"]
10251026
pre_df.append(new_row)
10261027
return pandas.DataFrame(pre_df)
1028+
1029+
@staticmethod
1030+
def normalize_columns_names_athena(dataframe, inplace=True):
1031+
if inplace is False:
1032+
dataframe = dataframe.copy(deep=True)
1033+
dataframe.columns = [
1034+
athena.Athena.normalize_column_name(x) for x in dataframe.columns
1035+
]
1036+
return dataframe

testing/test_awswrangler/test_pandas.py

Lines changed: 46 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -778,3 +778,49 @@ def test_read_sql_athena_with_time_zone(session, bucket, database):
778778
assert len(dataframe.columns) == 2
779779
assert dataframe["type"][0] == "timestamp with time zone"
780780
assert dataframe["value"][0].year == datetime.utcnow().year
781+
782+
783+
def test_normalize_columns_names_athena():
784+
dataframe = pandas.DataFrame({
785+
"CammelCase": [1, 2, 3],
786+
"With Spaces": [4, 5, 6],
787+
"With-Dash": [7, 8, 9],
788+
"Ãccént": [10, 11, 12],
789+
})
790+
Pandas.normalize_columns_names_athena(dataframe=dataframe, inplace=True)
791+
assert dataframe.columns[0] == "cammel_case"
792+
assert dataframe.columns[1] == "with_spaces"
793+
assert dataframe.columns[2] == "with_dash"
794+
assert dataframe.columns[3] == "accent"
795+
796+
797+
def test_to_parquet_with_normalize(
798+
session,
799+
bucket,
800+
database,
801+
):
802+
dataframe = pandas.DataFrame({
803+
"CammelCase": [1, 2, 3],
804+
"With Spaces": [4, 5, 6],
805+
"With-Dash": [7, 8, 9],
806+
"Ãccént": [10, 11, 12],
807+
"with.dot": [10, 11, 12],
808+
})
809+
session.pandas.to_parquet(dataframe=dataframe,
810+
database=database,
811+
path=f"s3://{bucket}/TestTable-with.dot/",
812+
mode="overwrite")
813+
dataframe2 = None
814+
for counter in range(10):
815+
dataframe2 = session.pandas.read_sql_athena(
816+
sql="select * from test_table_with_dot", database=database)
817+
if len(dataframe.index) == len(dataframe2.index):
818+
break
819+
sleep(2)
820+
assert len(dataframe.index) == len(dataframe2.index)
821+
assert (len(list(dataframe.columns)) + 1) == len(list(dataframe2.columns))
822+
assert dataframe2.columns[0] == "cammel_case"
823+
assert dataframe2.columns[1] == "with_spaces"
824+
assert dataframe2.columns[2] == "with_dash"
825+
assert dataframe2.columns[3] == "accent"
826+
assert dataframe2.columns[4] == "with_dot"

0 commit comments

Comments
 (0)