Skip to content

Commit 7574474

Browse files
authored
Merge pull request #33 from awslabs/pandas-list-support
Adding support for lists to Pandas
2 parents 561f139 + 31ab849 commit 7574474

File tree

5 files changed

+198
-40
lines changed

5 files changed

+198
-40
lines changed

awswrangler/athena.py

Lines changed: 7 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,6 @@
11
from time import sleep
22
import logging
3+
import ast
34

45
from awswrangler.exceptions import UnsupportedType, QueryFailed, QueryCancelled
56

@@ -35,6 +36,8 @@ def _type_athena2pandas(dtype):
3536
return "datetime64"
3637
elif dtype == "date":
3738
return "date"
39+
elif dtype == "array":
40+
return "literal_eval"
3841
else:
3942
raise UnsupportedType(f"Unsupported Athena type: {dtype}")
4043

@@ -44,18 +47,21 @@ def get_query_dtype(self, query_execution_id):
4447
dtype = {}
4548
parse_timestamps = []
4649
parse_dates = []
50+
converters = {}
4751
for col_name, col_type in cols_metadata.items():
4852
ptype = Athena._type_athena2pandas(dtype=col_type)
4953
if ptype in ["datetime64", "date"]:
5054
parse_timestamps.append(col_name)
5155
if ptype == "date":
5256
parse_dates.append(col_name)
57+
elif ptype == "literal_eval":
58+
converters[col_name] = ast.literal_eval
5359
else:
5460
dtype[col_name] = ptype
5561
logger.debug(f"dtype: {dtype}")
5662
logger.debug(f"parse_timestamps: {parse_timestamps}")
5763
logger.debug(f"parse_dates: {parse_dates}")
58-
return dtype, parse_timestamps, parse_dates
64+
return dtype, parse_timestamps, parse_dates, converters
5965

6066
def create_athena_bucket(self):
6167
"""

awswrangler/glue.py

Lines changed: 50 additions & 24 deletions
Original file line numberDiff line numberDiff line change
@@ -45,25 +45,53 @@ def get_table_python_types(self, database, table):
4545
dtypes = self.get_table_athena_types(database=database, table=table)
4646
return {k: Glue.type_athena2python(v) for k, v in dtypes.items()}
4747

48+
@staticmethod
49+
def type_athena2pyarrow(dtype):
50+
dtype = dtype.lower()
51+
if dtype == "tinyint":
52+
return "int8"
53+
if dtype == "smallint":
54+
return "int16"
55+
elif dtype in ["int", "integer"]:
56+
return "int32"
57+
elif dtype == "bigint":
58+
return "int64"
59+
elif dtype == "float":
60+
return "float32"
61+
elif dtype == "double":
62+
return "float64"
63+
elif dtype in ["boolean", "bool"]:
64+
return "bool"
65+
elif dtype in ["string", "char", "varchar", "array", "row", "map"]:
66+
return "string"
67+
elif dtype == "timestamp":
68+
return "timestamp[ns]"
69+
elif dtype == "date":
70+
return "date32"
71+
else:
72+
raise UnsupportedType(f"Unsupported Athena type: {dtype}")
73+
4874
@staticmethod
4975
def type_pyarrow2athena(dtype):
50-
dtype = str(dtype).lower()
51-
if dtype == "int32":
76+
dtype_str = str(dtype).lower()
77+
if dtype_str == "int32":
5278
return "int"
53-
elif dtype == "int64":
79+
elif dtype_str == "int64":
5480
return "bigint"
55-
elif dtype == "float":
81+
elif dtype_str == "float":
5682
return "float"
57-
elif dtype == "double":
83+
elif dtype_str == "double":
5884
return "double"
59-
elif dtype == "bool":
85+
elif dtype_str == "bool":
6086
return "boolean"
61-
elif dtype == "string":
87+
elif dtype_str == "string":
6288
return "string"
63-
elif dtype.startswith("timestamp"):
89+
elif dtype_str.startswith("timestamp"):
6490
return "timestamp"
65-
elif dtype.startswith("date"):
91+
elif dtype_str.startswith("date"):
6692
return "date"
93+
elif dtype_str.startswith("list"):
94+
return f"array<{Glue.type_pyarrow2athena(dtype.value_type)}>"
6795
else:
6896
raise UnsupportedType(f"Unsupported Pyarrow type: {dtype}")
6997

@@ -239,28 +267,22 @@ def get_connection_details(self, name):
239267
Name=name, HidePassword=False)["Connection"]
240268

241269
@staticmethod
242-
def _extract_pyarrow_schema(dataframe, preserve_index, cast_columns=None):
270+
def _extract_pyarrow_schema(dataframe, preserve_index):
243271
cols = []
244272
cols_dtypes = {}
245273
schema = []
246274

247-
casted = []
248-
if cast_columns is not None:
249-
casted = cast_columns.keys()
250-
251275
for name, dtype in dataframe.dtypes.to_dict().items():
252276
dtype = str(dtype)
253277
if dtype == "Int64":
254278
cols_dtypes[name] = "int64"
255-
elif name in casted:
256-
cols_dtypes[name] = cast_columns[name]
257279
else:
258280
cols.append(name)
259281

260282
for field in pyarrow.Schema.from_pandas(df=dataframe[cols],
261283
preserve_index=preserve_index):
262284
name = str(field.name)
263-
dtype = str(field.type)
285+
dtype = field.type
264286
cols_dtypes[name] = dtype
265287
if name not in dataframe.columns:
266288
schema.append((name, dtype))
@@ -279,18 +301,22 @@ def _build_schema(dataframe,
279301
partition_cols = []
280302

281303
pyarrow_schema = Glue._extract_pyarrow_schema(
282-
dataframe=dataframe,
283-
preserve_index=preserve_index,
284-
cast_columns=cast_columns)
304+
dataframe=dataframe, preserve_index=preserve_index)
285305

286306
schema_built = []
287307
partition_cols_types = {}
288308
for name, dtype in pyarrow_schema:
289-
athena_type = Glue.type_pyarrow2athena(dtype)
290-
if name in partition_cols:
291-
partition_cols_types[name] = athena_type
309+
if (cast_columns is not None) and (name in cast_columns.keys()):
310+
if name in partition_cols:
311+
partition_cols_types[name] = cast_columns[name]
312+
else:
313+
schema_built.append((name, cast_columns[name]))
292314
else:
293-
schema_built.append((name, athena_type))
315+
athena_type = Glue.type_pyarrow2athena(dtype)
316+
if name in partition_cols:
317+
partition_cols_types[name] = athena_type
318+
else:
319+
schema_built.append((name, athena_type))
294320

295321
partition_cols_schema_built = [(name, partition_cols_types[name])
296322
for name in partition_cols]

awswrangler/pandas.py

Lines changed: 24 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -14,7 +14,7 @@
1414
AthenaQueryError, EmptyS3Object, LineTerminatorNotFound, EmptyDataframe, \
1515
InvalidSerDe, InvalidCompression
1616
from awswrangler.utils import calculate_bounders
17-
from awswrangler import s3
17+
from awswrangler import s3, glue
1818

1919
logger = logging.getLogger(__name__)
2020

@@ -56,6 +56,7 @@ def read_csv(
5656
parse_dates=False,
5757
infer_datetime_format=False,
5858
encoding="utf-8",
59+
converters=None,
5960
):
6061
"""
6162
Read CSV file from AWS S3 using optimized strategies.
@@ -76,6 +77,7 @@ def read_csv(
7677
:param parse_dates: Same as pandas.read_csv()
7778
:param infer_datetime_format: Same as pandas.read_csv()
7879
:param encoding: Same as pandas.read_csv()
80+
:param converters: Same as pandas.read_csv()
7981
:return: Pandas Dataframe or Iterator of Pandas Dataframes if max_result_size != None
8082
"""
8183
bucket_name, key_path = self._parse_path(path)
@@ -99,7 +101,8 @@ def read_csv(
99101
escapechar=escapechar,
100102
parse_dates=parse_dates,
101103
infer_datetime_format=infer_datetime_format,
102-
encoding=encoding)
104+
encoding=encoding,
105+
converters=converters)
103106
else:
104107
ret = Pandas._read_csv_once(
105108
client_s3=client_s3,
@@ -115,7 +118,8 @@ def read_csv(
115118
escapechar=escapechar,
116119
parse_dates=parse_dates,
117120
infer_datetime_format=infer_datetime_format,
118-
encoding=encoding)
121+
encoding=encoding,
122+
converters=converters)
119123
return ret
120124

121125
@staticmethod
@@ -135,6 +139,7 @@ def _read_csv_iterator(
135139
parse_dates=False,
136140
infer_datetime_format=False,
137141
encoding="utf-8",
142+
converters=None,
138143
):
139144
"""
140145
Read CSV file from AWS S3 using optimized strategies.
@@ -156,6 +161,7 @@ def _read_csv_iterator(
156161
:param parse_dates: Same as pandas.read_csv()
157162
:param infer_datetime_format: Same as pandas.read_csv()
158163
:param encoding: Same as pandas.read_csv()
164+
:param converters: Same as pandas.read_csv()
159165
:return: Pandas Dataframe
160166
"""
161167
metadata = s3.S3.head_object_with_retry(client=client_s3,
@@ -181,7 +187,8 @@ def _read_csv_iterator(
181187
escapechar=escapechar,
182188
parse_dates=parse_dates,
183189
infer_datetime_format=infer_datetime_format,
184-
encoding=encoding)
190+
encoding=encoding,
191+
converters=converters)
185192
else:
186193
bounders = calculate_bounders(num_items=total_size,
187194
max_size=max_result_size)
@@ -234,7 +241,7 @@ def _read_csv_iterator(
234241
lineterminator=lineterminator,
235242
dtype=dtype,
236243
encoding=encoding,
237-
)
244+
converters=converters)
238245
yield df
239246
if count == 1: # first chunk
240247
names = df.columns
@@ -352,6 +359,7 @@ def _read_csv_once(
352359
parse_dates=False,
353360
infer_datetime_format=False,
354361
encoding=None,
362+
converters=None,
355363
):
356364
"""
357365
Read CSV file from AWS S3 using optimized strategies.
@@ -372,6 +380,7 @@ def _read_csv_once(
372380
:param parse_dates: Same as pandas.read_csv()
373381
:param infer_datetime_format: Same as pandas.read_csv()
374382
:param encoding: Same as pandas.read_csv()
383+
:param converters: Same as pandas.read_csv()
375384
:return: Pandas Dataframe
376385
"""
377386
buff = BytesIO()
@@ -392,6 +401,7 @@ def _read_csv_once(
392401
lineterminator=lineterminator,
393402
dtype=dtype,
394403
encoding=encoding,
404+
converters=converters,
395405
)
396406
buff.close()
397407
return dataframe
@@ -425,12 +435,13 @@ def read_sql_athena(self,
425435
message_error = f"Query error: {reason}"
426436
raise AthenaQueryError(message_error)
427437
else:
428-
dtype, parse_timestamps, parse_dates = self._session.athena.get_query_dtype(
438+
dtype, parse_timestamps, parse_dates, converters = self._session.athena.get_query_dtype(
429439
query_execution_id=query_execution_id)
430440
path = f"{s3_output}{query_execution_id}.csv"
431441
ret = self.read_csv(path=path,
432442
dtype=dtype,
433443
parse_dates=parse_timestamps,
444+
converters=converters,
434445
quoting=csv.QUOTE_ALL,
435446
max_result_size=max_result_size)
436447
if max_result_size is None:
@@ -848,18 +859,21 @@ def write_parquet_dataframe(dataframe,
848859
if str(dtype) == "Int64":
849860
dataframe[name] = dataframe[name].astype("float64")
850861
casted_in_pandas.append(name)
851-
cast_columns[name] = "int64"
862+
cast_columns[name] = "bigint"
852863
logger.debug(f"Casting column {name} Int64 to float64")
853864
table = pyarrow.Table.from_pandas(df=dataframe,
854865
preserve_index=preserve_index,
855866
safe=False)
856867
if cast_columns:
857868
for col_name, dtype in cast_columns.items():
858869
col_index = table.column_names.index(col_name)
859-
table = table.set_column(col_index,
860-
table.column(col_name).cast(dtype))
870+
pyarrow_dtype = glue.Glue.type_athena2pyarrow(dtype)
871+
table = table.set_column(
872+
col_index,
873+
table.column(col_name).cast(pyarrow_dtype))
861874
logger.debug(
862-
f"Casting column {col_name} ({col_index}) to {dtype}")
875+
f"Casting column {col_name} ({col_index}) to {dtype} ({pyarrow_dtype})"
876+
)
863877
with fs.open(path, "wb") as f:
864878
parquet.write_table(table,
865879
f,

testing/test_awswrangler/test_cloudwatchlogs.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -64,7 +64,7 @@ def logstream(cloudformation_outputs, loggroup):
6464
if token:
6565
args["sequenceToken"] = token
6666
client.put_log_events(**args)
67-
sleep(180)
67+
sleep(300)
6868
yield logstream
6969

7070

0 commit comments

Comments
 (0)