Skip to content

Commit 31ab849

Browse files
committed
Adding support for cast data types from Pandas to parquet (Nested types not supported yet.)
1 parent d8972d2 commit 31ab849

File tree

4 files changed

+129
-24
lines changed

4 files changed

+129
-24
lines changed

awswrangler/glue.py

Lines changed: 38 additions & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -45,6 +45,32 @@ def get_table_python_types(self, database, table):
4545
dtypes = self.get_table_athena_types(database=database, table=table)
4646
return {k: Glue.type_athena2python(v) for k, v in dtypes.items()}
4747

48+
@staticmethod
49+
def type_athena2pyarrow(dtype):
50+
dtype = dtype.lower()
51+
if dtype == "tinyint":
52+
return "int8"
53+
if dtype == "smallint":
54+
return "int16"
55+
elif dtype in ["int", "integer"]:
56+
return "int32"
57+
elif dtype == "bigint":
58+
return "int64"
59+
elif dtype == "float":
60+
return "float32"
61+
elif dtype == "double":
62+
return "float64"
63+
elif dtype in ["boolean", "bool"]:
64+
return "bool"
65+
elif dtype in ["string", "char", "varchar", "array", "row", "map"]:
66+
return "string"
67+
elif dtype == "timestamp":
68+
return "timestamp[ns]"
69+
elif dtype == "date":
70+
return "date32"
71+
else:
72+
raise UnsupportedType(f"Unsupported Athena type: {dtype}")
73+
4874
@staticmethod
4975
def type_pyarrow2athena(dtype):
5076
dtype_str = str(dtype).lower()
@@ -241,21 +267,15 @@ def get_connection_details(self, name):
241267
Name=name, HidePassword=False)["Connection"]
242268

243269
@staticmethod
244-
def _extract_pyarrow_schema(dataframe, preserve_index, cast_columns=None):
270+
def _extract_pyarrow_schema(dataframe, preserve_index):
245271
cols = []
246272
cols_dtypes = {}
247273
schema = []
248274

249-
casted = []
250-
if cast_columns is not None:
251-
casted = cast_columns.keys()
252-
253275
for name, dtype in dataframe.dtypes.to_dict().items():
254276
dtype = str(dtype)
255277
if dtype == "Int64":
256278
cols_dtypes[name] = "int64"
257-
elif name in casted:
258-
cols_dtypes[name] = cast_columns[name]
259279
else:
260280
cols.append(name)
261281

@@ -281,18 +301,22 @@ def _build_schema(dataframe,
281301
partition_cols = []
282302

283303
pyarrow_schema = Glue._extract_pyarrow_schema(
284-
dataframe=dataframe,
285-
preserve_index=preserve_index,
286-
cast_columns=cast_columns)
304+
dataframe=dataframe, preserve_index=preserve_index)
287305

288306
schema_built = []
289307
partition_cols_types = {}
290308
for name, dtype in pyarrow_schema:
291-
athena_type = Glue.type_pyarrow2athena(dtype)
292-
if name in partition_cols:
293-
partition_cols_types[name] = athena_type
309+
if (cast_columns is not None) and (name in cast_columns.keys()):
310+
if name in partition_cols:
311+
partition_cols_types[name] = cast_columns[name]
312+
else:
313+
schema_built.append((name, cast_columns[name]))
294314
else:
295-
schema_built.append((name, athena_type))
315+
athena_type = Glue.type_pyarrow2athena(dtype)
316+
if name in partition_cols:
317+
partition_cols_types[name] = athena_type
318+
else:
319+
schema_built.append((name, athena_type))
296320

297321
partition_cols_schema_built = [(name, partition_cols_types[name])
298322
for name in partition_cols]

awswrangler/pandas.py

Lines changed: 8 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -14,7 +14,7 @@
1414
AthenaQueryError, EmptyS3Object, LineTerminatorNotFound, EmptyDataframe, \
1515
InvalidSerDe, InvalidCompression
1616
from awswrangler.utils import calculate_bounders
17-
from awswrangler import s3
17+
from awswrangler import s3, glue
1818

1919
logger = logging.getLogger(__name__)
2020

@@ -859,18 +859,21 @@ def write_parquet_dataframe(dataframe,
859859
if str(dtype) == "Int64":
860860
dataframe[name] = dataframe[name].astype("float64")
861861
casted_in_pandas.append(name)
862-
cast_columns[name] = "int64"
862+
cast_columns[name] = "bigint"
863863
logger.debug(f"Casting column {name} Int64 to float64")
864864
table = pyarrow.Table.from_pandas(df=dataframe,
865865
preserve_index=preserve_index,
866866
safe=False)
867867
if cast_columns:
868868
for col_name, dtype in cast_columns.items():
869869
col_index = table.column_names.index(col_name)
870-
table = table.set_column(col_index,
871-
table.column(col_name).cast(dtype))
870+
pyarrow_dtype = glue.Glue.type_athena2pyarrow(dtype)
871+
table = table.set_column(
872+
col_index,
873+
table.column(col_name).cast(pyarrow_dtype))
872874
logger.debug(
873-
f"Casting column {col_name} ({col_index}) to {dtype}")
875+
f"Casting column {col_name} ({col_index}) to {dtype} ({pyarrow_dtype})"
876+
)
874877
with fs.open(path, "wb") as f:
875878
parquet.write_table(table,
876879
f,

testing/test_awswrangler/test_cloudwatchlogs.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -64,7 +64,7 @@ def logstream(cloudformation_outputs, loggroup):
6464
if token:
6565
args["sequenceToken"] = token
6666
client.put_log_events(**args)
67-
sleep(180)
67+
sleep(300)
6868
yield logstream
6969

7070

testing/test_awswrangler/test_pandas.py

Lines changed: 82 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -99,7 +99,7 @@ def logstream(cloudformation_outputs, loggroup):
9999
if token:
100100
args["sequenceToken"] = token
101101
client.put_log_events(**args)
102-
sleep(180)
102+
sleep(300)
103103
yield logstream
104104

105105

@@ -243,22 +243,21 @@ def test_to_s3(
243243
assert len(list(dataframe.columns)) == len(list(dataframe2.columns))
244244

245245

246-
def test_to_parquet_with_cast(
246+
def test_to_parquet_with_cast_int(
247247
session,
248248
bucket,
249249
database,
250250
):
251251
dataframe = pandas.read_csv("data_samples/nano.csv",
252252
dtype={"id": "Int64"},
253253
parse_dates=["date", "time"])
254-
print(dataframe.dtypes)
255254
session.pandas.to_parquet(dataframe=dataframe,
256255
database=database,
257256
path=f"s3://{bucket}/test/",
258257
preserve_index=False,
259258
mode="overwrite",
260259
procs_cpu_bound=1,
261-
cast_columns={"value": "int64"})
260+
cast_columns={"value": "int"})
262261
dataframe2 = None
263262
for counter in range(10):
264263
dataframe2 = session.pandas.read_sql_athena(sql="select * from test",
@@ -664,3 +663,82 @@ def test_to_parquet_lists(session, bucket, database):
664663
val = dataframe[dataframe["id"] == 0].iloc[0]["col_list_int"]
665664
val2 = dataframe2[dataframe2["id"] == 0].iloc[0]["col_list_int"]
666665
assert val == val2
666+
667+
668+
def test_to_parquet_cast(session, bucket, database):
669+
dataframe = pandas.DataFrame({
670+
"id": [0, 1],
671+
"col_int": [[1, 2], [3, 4, 5]],
672+
"col_float": [[1.0, 2.0, 3.0], [4.0, 5.0]],
673+
"col_string": [["foo"], ["boo", "bar"]],
674+
"col_timestamp": [[datetime(2019, 1, 1),
675+
datetime(2019, 1, 2)], [datetime(2019, 1, 3)]],
676+
"col_date": [[date(2019, 1, 1), date(2019, 1, 2)], [date(2019, 1, 3)]],
677+
"col_list_int": [[[1]], [[2, 3], [4, 5, 6]]],
678+
"col_list_list_string": [[[["foo"]]], [[["boo", "bar"]]]],
679+
})
680+
paths = session.pandas.to_parquet(dataframe=dataframe,
681+
database=database,
682+
path=f"s3://{bucket}/test/",
683+
preserve_index=False,
684+
mode="overwrite",
685+
procs_cpu_bound=1)
686+
assert len(paths) == 1
687+
dataframe2 = None
688+
for counter in range(10):
689+
dataframe2 = session.pandas.read_sql_athena(
690+
sql="select id, col_int, col_float, col_list_int from test",
691+
database=database)
692+
if len(dataframe.index) == len(dataframe2.index):
693+
break
694+
sleep(2)
695+
assert len(dataframe.index) == len(dataframe2.index)
696+
assert 4 == len(list(dataframe2.columns))
697+
val = dataframe[dataframe["id"] == 0].iloc[0]["col_list_int"]
698+
val2 = dataframe2[dataframe2["id"] == 0].iloc[0]["col_list_int"]
699+
assert val == val2
700+
701+
702+
def test_to_parquet_with_cast_null(
703+
session,
704+
bucket,
705+
database,
706+
):
707+
dataframe = pandas.DataFrame({
708+
"id": [0, 1],
709+
"col_null_tinyint": [None, None],
710+
"col_null_smallint": [None, None],
711+
"col_null_int": [None, None],
712+
"col_null_bigint": [None, None],
713+
"col_null_float": [None, None],
714+
"col_null_double": [None, None],
715+
"col_null_string": [None, None],
716+
"col_null_date": [None, None],
717+
"col_null_timestamp": [None, None],
718+
})
719+
session.pandas.to_parquet(dataframe=dataframe,
720+
database=database,
721+
path=f"s3://{bucket}/test/",
722+
preserve_index=False,
723+
mode="overwrite",
724+
procs_cpu_bound=1,
725+
cast_columns={
726+
"col_null_tinyint": "tinyint",
727+
"col_null_smallint": "smallint",
728+
"col_null_int": "int",
729+
"col_null_bigint": "bigint",
730+
"col_null_float": "float",
731+
"col_null_double": "double",
732+
"col_null_string": "string",
733+
"col_null_date": "date",
734+
"col_null_timestamp": "timestamp",
735+
})
736+
dataframe2 = None
737+
for counter in range(10):
738+
dataframe2 = session.pandas.read_sql_athena(sql="select * from test",
739+
database=database)
740+
if len(dataframe.index) == len(dataframe2.index):
741+
break
742+
sleep(2)
743+
assert len(dataframe.index) == len(dataframe2.index)
744+
assert len(list(dataframe.columns)) == len(list(dataframe2.columns))

0 commit comments

Comments
 (0)