Skip to content

Commit 9280adf

Browse files
authored
Merge pull request #1 from igorborgest/feature/partition_columns
Updating branch, resolving conflicts, adding more tests
2 parents 52bc8db + 6df511d commit 9280adf

File tree

5 files changed

+159
-42
lines changed

5 files changed

+159
-42
lines changed

awswrangler/athena.py

Lines changed: 9 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -31,25 +31,30 @@ def _type_athena2pandas(dtype):
3131
return "bool"
3232
elif dtype in ["string", "char", "varchar", "array", "row", "map"]:
3333
return "object"
34-
elif dtype in ["timestamp", "date"]:
34+
elif dtype == "timestamp":
3535
return "datetime64"
36+
elif dtype == "date":
37+
return "date"
3638
else:
3739
raise UnsupportedType(f"Unsupported Athena type: {dtype}")
3840

3941
def get_query_dtype(self, query_execution_id):
4042
cols_metadata = self.get_query_columns_metadata(
4143
query_execution_id=query_execution_id)
4244
dtype = {}
45+
parse_timestamps = []
4346
parse_dates = []
4447
for col_name, col_type in cols_metadata.items():
4548
ptype = Athena._type_athena2pandas(dtype=col_type)
46-
if ptype == "datetime64":
47-
parse_dates.append(col_name)
49+
if ptype in ["datetime64", "date"]:
50+
parse_timestamps.append(col_name)
51+
if ptype == "date":
52+
parse_dates.append(col_name)
4853
else:
4954
dtype[col_name] = ptype
5055
logger.debug(f"dtype: {dtype}")
5156
logger.debug(f"parse_dates: {parse_dates}")
52-
return dtype, parse_dates
57+
return dtype, parse_timestamps, parse_dates
5358

5459
def create_athena_bucket(self):
5560
"""

awswrangler/exceptions.py

Lines changed: 0 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -64,7 +64,3 @@ class QueryCancelled(Exception):
6464

6565
class QueryFailed(Exception):
6666
pass
67-
68-
69-
class PartitionColumnTypeNotFound(Exception):
70-
pass

awswrangler/glue.py

Lines changed: 58 additions & 32 deletions
Original file line numberDiff line numberDiff line change
@@ -3,7 +3,9 @@
33
import logging
44
from datetime import datetime, date
55

6-
from awswrangler.exceptions import UnsupportedType, UnsupportedFileFormat, PartitionColumnTypeNotFound
6+
import pyarrow
7+
8+
from awswrangler.exceptions import UnsupportedType, UnsupportedFileFormat
79

810
logger = logging.getLogger(__name__)
911

@@ -43,6 +45,28 @@ def get_table_python_types(self, database, table):
4345
dtypes = self.get_table_athena_types(database=database, table=table)
4446
return {k: Glue.type_athena2python(v) for k, v in dtypes.items()}
4547

48+
@staticmethod
49+
def type_pyarrow2athena(dtype):
50+
dtype = str(dtype).lower()
51+
if dtype == "int32":
52+
return "int"
53+
elif dtype == "int64":
54+
return "bigint"
55+
elif dtype == "float":
56+
return "float"
57+
elif dtype == "double":
58+
return "double"
59+
elif dtype == "bool":
60+
return "boolean"
61+
elif dtype == "string":
62+
return "string"
63+
elif dtype.startswith("timestamp"):
64+
return "timestamp"
65+
elif dtype.startswith("date"):
66+
return "date"
67+
else:
68+
raise UnsupportedType(f"Unsupported Pyarrow type: {dtype}")
69+
4670
@staticmethod
4771
def type_pandas2athena(dtype):
4872
dtype = dtype.lower()
@@ -58,7 +82,7 @@ def type_pandas2athena(dtype):
5882
return "boolean"
5983
elif dtype == "object":
6084
return "string"
61-
elif dtype[:10] == "datetime64":
85+
elif dtype.startswith("datetime64"):
6286
return "timestamp"
6387
else:
6488
raise UnsupportedType(f"Unsupported Pandas type: {dtype}")
@@ -114,8 +138,7 @@ def metadata_to_glue(self,
114138
schema, partition_cols_schema = Glue._build_schema(
115139
dataframe=dataframe,
116140
partition_cols=partition_cols,
117-
preserve_index=preserve_index,
118-
cast_columns=cast_columns)
141+
preserve_index=preserve_index)
119142
table = table if table else Glue._parse_table_name(path)
120143
table = table.lower().replace(".", "_")
121144
if mode == "overwrite":
@@ -132,7 +155,6 @@ def metadata_to_glue(self,
132155
if partition_cols:
133156
partitions_tuples = Glue._parse_partitions_tuples(
134157
objects_paths=objects_paths, partition_cols=partition_cols)
135-
print(partitions_tuples)
136158
self.add_partitions(
137159
database=database,
138160
table=table,
@@ -190,9 +212,6 @@ def add_partitions(self, database, table, partition_paths, file_format):
190212
for _ in range(pages_num):
191213
page = partitions[:100]
192214
del partitions[:100]
193-
print(database)
194-
print(table)
195-
print(page)
196215
self._client_glue.batch_create_partition(DatabaseName=database,
197216
TableName=table,
198217
PartitionInputList=page)
@@ -202,36 +221,43 @@ def get_connection_details(self, name):
202221
Name=name, HidePassword=False)["Connection"]
203222

204223
@staticmethod
205-
def _build_schema(dataframe,
206-
partition_cols,
207-
preserve_index,
208-
cast_columns=None):
224+
def _extract_pyarrow_schema(dataframe, preserve_index):
225+
cols = []
226+
schema = []
227+
for name, dtype in dataframe.dtypes.to_dict().items():
228+
dtype = str(dtype)
229+
if str(dtype) == "Int64":
230+
schema.append((name, "int64"))
231+
else:
232+
cols.append(name)
233+
234+
# Convert pyarrow.Schema to list of tuples (e.g. [(name1, type1), (name2, type2)...])
235+
schema += [(str(x.name), str(x.type))
236+
for x in pyarrow.Schema.from_pandas(
237+
df=dataframe[cols], preserve_index=preserve_index)]
238+
logger.debug(f"schema: {schema}")
239+
return schema
240+
241+
@staticmethod
242+
def _build_schema(dataframe, partition_cols, preserve_index):
209243
logger.debug(f"dataframe.dtypes:\n{dataframe.dtypes}")
210244
if not partition_cols:
211245
partition_cols = []
246+
247+
pyarrow_schema = Glue._extract_pyarrow_schema(
248+
dataframe=dataframe, preserve_index=preserve_index)
249+
212250
schema_built = []
213-
partition_cols_schema_built = []
214-
if preserve_index:
215-
name = str(
216-
dataframe.index.name) if dataframe.index.name else "index"
217-
dataframe.index.name = "index"
218-
dtype = str(dataframe.index.dtype)
219-
athena_type = Glue.type_pandas2athena(dtype)
220-
if name not in partition_cols:
221-
schema_built.append((name, athena_type))
222-
else:
223-
partition_cols_schema_built.append((name, athena_type))
224-
for col in dataframe.columns:
225-
name = str(col)
226-
if cast_columns and name in cast_columns:
227-
dtype = cast_columns[name]
251+
partition_cols_types = {}
252+
for name, dtype in pyarrow_schema:
253+
athena_type = Glue.type_pyarrow2athena(dtype)
254+
if name in partition_cols:
255+
partition_cols_types[name] = athena_type
228256
else:
229-
dtype = str(dataframe[name].dtype)
230-
athena_type = Glue.type_pandas2athena(dtype)
231-
if name not in partition_cols:
232257
schema_built.append((name, athena_type))
233-
else:
234-
partition_cols_schema_built.append((name, athena_type))
258+
259+
partition_cols_schema_built = [(name, partition_cols_types[name]) for name in partition_cols]
260+
235261
logger.debug(f"schema_built:\n{schema_built}")
236262
logger.debug(
237263
f"partition_cols_schema_built:\n{partition_cols_schema_built}")

awswrangler/pandas.py

Lines changed: 4 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -419,14 +419,16 @@ def read_sql_athena(self,
419419
message_error = f"Query error: {reason}"
420420
raise AthenaQueryError(message_error)
421421
else:
422-
dtype, parse_dates = self._session.athena.get_query_dtype(
422+
dtype, parse_timestamps, parse_dates = self._session.athena.get_query_dtype(
423423
query_execution_id=query_execution_id)
424424
path = f"{s3_output}{query_execution_id}.csv"
425425
ret = self.read_csv(path=path,
426426
dtype=dtype,
427-
parse_dates=parse_dates,
427+
parse_dates=parse_timestamps,
428428
quoting=csv.QUOTE_ALL,
429429
max_result_size=max_result_size)
430+
for col in parse_dates:
431+
ret[col] = ret[col].dt.date
430432
return ret
431433

432434
def to_csv(

testing/test_awswrangler/test_pandas.py

Lines changed: 88 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -489,3 +489,91 @@ def test_to_csv_with_sep(
489489
assert len(list(dataframe.columns)) == len(list(dataframe2.columns))
490490
assert dataframe[dataframe["id"] == 0].iloc[0]["name"] == dataframe2[
491491
dataframe2["id"] == 0].iloc[0]["name"]
492+
493+
494+
@pytest.mark.parametrize("index, partition_cols", [
495+
(None, []),
496+
("default", []),
497+
("my_date", []),
498+
("my_timestamp", []),
499+
(None, ["my_int"]),
500+
("default", ["my_int"]),
501+
("my_date", ["my_int"]),
502+
("my_timestamp", ["my_int"]),
503+
(None, ["my_float"]),
504+
("default", ["my_float"]),
505+
("my_date", ["my_float"]),
506+
("my_timestamp", ["my_float"]),
507+
(None, ["my_bool"]),
508+
("default", ["my_bool"]),
509+
("my_date", ["my_bool"]),
510+
("my_timestamp", ["my_bool"]),
511+
(None, ["my_date"]),
512+
("default", ["my_date"]),
513+
("my_date", ["my_date"]),
514+
("my_timestamp", ["my_date"]),
515+
(None, ["my_timestamp"]),
516+
("default", ["my_timestamp"]),
517+
("my_date", ["my_timestamp"]),
518+
("my_timestamp", ["my_timestamp"]),
519+
(None, ["my_timestamp", "my_date"]),
520+
("default", ["my_date", "my_timestamp"]),
521+
("my_date", ["my_timestamp", "my_date"]),
522+
("my_timestamp", ["my_date", "my_timestamp"]),
523+
(None, ["my_bool", "my_timestamp", "my_date"]),
524+
("default", ["my_date", "my_timestamp", "my_int"]),
525+
("my_date", ["my_timestamp", "my_float", "my_date"]),
526+
("my_timestamp", ["my_int", "my_float", "my_bool", "my_date", "my_timestamp"]),
527+
])
528+
def test_to_parquet_types(session, bucket, database, index, partition_cols):
529+
dataframe = pandas.read_csv("data_samples/complex.csv",
530+
dtype={"my_int_with_null": "Int64"},
531+
parse_dates=["my_timestamp", "my_date"])
532+
dataframe["my_date"] = dataframe["my_date"].dt.date
533+
dataframe["my_bool"] = True
534+
535+
preserve_index = True
536+
if not index:
537+
preserve_index = False
538+
elif index != "default":
539+
dataframe["new_index"] = dataframe[index]
540+
dataframe = dataframe.set_index("new_index")
541+
542+
session.pandas.to_parquet(dataframe=dataframe,
543+
database=database,
544+
path=f"s3://{bucket}/test/",
545+
preserve_index=preserve_index,
546+
partition_cols=partition_cols,
547+
mode="overwrite",
548+
procs_cpu_bound=1)
549+
sleep(1)
550+
dataframe2 = session.pandas.read_sql_athena(sql="select * from test",
551+
database=database)
552+
for row in dataframe2.itertuples():
553+
if index:
554+
if index == "default":
555+
ex_index_col = 8 - len(partition_cols)
556+
assert isinstance(row[ex_index_col], numpy.int64)
557+
elif index == "my_date":
558+
assert isinstance(row.new_index, date)
559+
elif index == "my_timestamp":
560+
assert isinstance(row.new_index, datetime)
561+
assert isinstance(row.my_timestamp, datetime)
562+
assert type(row.my_date) == date
563+
assert isinstance(row.my_float, float)
564+
assert isinstance(row.my_int, numpy.int64)
565+
assert isinstance(row.my_string, str)
566+
assert isinstance(row.my_bool, bool)
567+
assert str(row.my_timestamp) == "2018-01-01 04:03:02.001000"
568+
assert str(row.my_date) == "2019-02-02"
569+
assert str(row.my_float) == "12345.6789"
570+
assert str(row.my_int) == "123456789"
571+
assert str(row.my_bool) == "True"
572+
assert str(
573+
row.my_string
574+
) == "foo\nboo\nbar\nFOO\nBOO\nBAR\nxxxxx\nÁÃÀÂÇ\n汉字汉字汉字汉字汉字汉字汉字æøåæøåæøåæøåæøåæøåæøåæøåæøåæøå汉字汉字汉字汉字汉字汉字汉字æøåæøåæøåæøåæøåæøåæøåæøåæøåæøå"
575+
assert len(dataframe.index) == len(dataframe2.index)
576+
if index:
577+
assert (len(list(dataframe.columns)) + 1) == len(list(dataframe2.columns))
578+
else:
579+
assert len(list(dataframe.columns)) == len(list(dataframe2.columns))

0 commit comments

Comments
 (0)