Skip to content

Commit d8972d2

Browse files
committed
Adding support for lists to Pandas
1 parent 561f139 commit d8972d2

File tree

4 files changed

+69
-16
lines changed

4 files changed

+69
-16
lines changed

awswrangler/athena.py

Lines changed: 7 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,6 @@
11
from time import sleep
22
import logging
3+
import ast
34

45
from awswrangler.exceptions import UnsupportedType, QueryFailed, QueryCancelled
56

@@ -35,6 +36,8 @@ def _type_athena2pandas(dtype):
3536
return "datetime64"
3637
elif dtype == "date":
3738
return "date"
39+
elif dtype == "array":
40+
return "literal_eval"
3841
else:
3942
raise UnsupportedType(f"Unsupported Athena type: {dtype}")
4043

@@ -44,18 +47,21 @@ def get_query_dtype(self, query_execution_id):
4447
dtype = {}
4548
parse_timestamps = []
4649
parse_dates = []
50+
converters = {}
4751
for col_name, col_type in cols_metadata.items():
4852
ptype = Athena._type_athena2pandas(dtype=col_type)
4953
if ptype in ["datetime64", "date"]:
5054
parse_timestamps.append(col_name)
5155
if ptype == "date":
5256
parse_dates.append(col_name)
57+
elif ptype == "literal_eval":
58+
converters[col_name] = ast.literal_eval
5359
else:
5460
dtype[col_name] = ptype
5561
logger.debug(f"dtype: {dtype}")
5662
logger.debug(f"parse_timestamps: {parse_timestamps}")
5763
logger.debug(f"parse_dates: {parse_dates}")
58-
return dtype, parse_timestamps, parse_dates
64+
return dtype, parse_timestamps, parse_dates, converters
5965

6066
def create_athena_bucket(self):
6167
"""

awswrangler/glue.py

Lines changed: 12 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -47,23 +47,25 @@ def get_table_python_types(self, database, table):
4747

4848
@staticmethod
4949
def type_pyarrow2athena(dtype):
50-
dtype = str(dtype).lower()
51-
if dtype == "int32":
50+
dtype_str = str(dtype).lower()
51+
if dtype_str == "int32":
5252
return "int"
53-
elif dtype == "int64":
53+
elif dtype_str == "int64":
5454
return "bigint"
55-
elif dtype == "float":
55+
elif dtype_str == "float":
5656
return "float"
57-
elif dtype == "double":
57+
elif dtype_str == "double":
5858
return "double"
59-
elif dtype == "bool":
59+
elif dtype_str == "bool":
6060
return "boolean"
61-
elif dtype == "string":
61+
elif dtype_str == "string":
6262
return "string"
63-
elif dtype.startswith("timestamp"):
63+
elif dtype_str.startswith("timestamp"):
6464
return "timestamp"
65-
elif dtype.startswith("date"):
65+
elif dtype_str.startswith("date"):
6666
return "date"
67+
elif dtype_str.startswith("list"):
68+
return f"array<{Glue.type_pyarrow2athena(dtype.value_type)}>"
6769
else:
6870
raise UnsupportedType(f"Unsupported Pyarrow type: {dtype}")
6971

@@ -260,7 +262,7 @@ def _extract_pyarrow_schema(dataframe, preserve_index, cast_columns=None):
260262
for field in pyarrow.Schema.from_pandas(df=dataframe[cols],
261263
preserve_index=preserve_index):
262264
name = str(field.name)
263-
dtype = str(field.type)
265+
dtype = field.type
264266
cols_dtypes[name] = dtype
265267
if name not in dataframe.columns:
266268
schema.append((name, dtype))

awswrangler/pandas.py

Lines changed: 16 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -56,6 +56,7 @@ def read_csv(
5656
parse_dates=False,
5757
infer_datetime_format=False,
5858
encoding="utf-8",
59+
converters=None,
5960
):
6061
"""
6162
Read CSV file from AWS S3 using optimized strategies.
@@ -76,6 +77,7 @@ def read_csv(
7677
:param parse_dates: Same as pandas.read_csv()
7778
:param infer_datetime_format: Same as pandas.read_csv()
7879
:param encoding: Same as pandas.read_csv()
80+
:param converters: Same as pandas.read_csv()
7981
:return: Pandas Dataframe or Iterator of Pandas Dataframes if max_result_size != None
8082
"""
8183
bucket_name, key_path = self._parse_path(path)
@@ -99,7 +101,8 @@ def read_csv(
99101
escapechar=escapechar,
100102
parse_dates=parse_dates,
101103
infer_datetime_format=infer_datetime_format,
102-
encoding=encoding)
104+
encoding=encoding,
105+
converters=converters)
103106
else:
104107
ret = Pandas._read_csv_once(
105108
client_s3=client_s3,
@@ -115,7 +118,8 @@ def read_csv(
115118
escapechar=escapechar,
116119
parse_dates=parse_dates,
117120
infer_datetime_format=infer_datetime_format,
118-
encoding=encoding)
121+
encoding=encoding,
122+
converters=converters)
119123
return ret
120124

121125
@staticmethod
@@ -135,6 +139,7 @@ def _read_csv_iterator(
135139
parse_dates=False,
136140
infer_datetime_format=False,
137141
encoding="utf-8",
142+
converters=None,
138143
):
139144
"""
140145
Read CSV file from AWS S3 using optimized strategies.
@@ -156,6 +161,7 @@ def _read_csv_iterator(
156161
:param parse_dates: Same as pandas.read_csv()
157162
:param infer_datetime_format: Same as pandas.read_csv()
158163
:param encoding: Same as pandas.read_csv()
164+
:param converters: Same as pandas.read_csv()
159165
:return: Pandas Dataframe
160166
"""
161167
metadata = s3.S3.head_object_with_retry(client=client_s3,
@@ -181,7 +187,8 @@ def _read_csv_iterator(
181187
escapechar=escapechar,
182188
parse_dates=parse_dates,
183189
infer_datetime_format=infer_datetime_format,
184-
encoding=encoding)
190+
encoding=encoding,
191+
converters=converters)
185192
else:
186193
bounders = calculate_bounders(num_items=total_size,
187194
max_size=max_result_size)
@@ -234,7 +241,7 @@ def _read_csv_iterator(
234241
lineterminator=lineterminator,
235242
dtype=dtype,
236243
encoding=encoding,
237-
)
244+
converters=converters)
238245
yield df
239246
if count == 1: # first chunk
240247
names = df.columns
@@ -352,6 +359,7 @@ def _read_csv_once(
352359
parse_dates=False,
353360
infer_datetime_format=False,
354361
encoding=None,
362+
converters=None,
355363
):
356364
"""
357365
Read CSV file from AWS S3 using optimized strategies.
@@ -372,6 +380,7 @@ def _read_csv_once(
372380
:param parse_dates: Same as pandas.read_csv()
373381
:param infer_datetime_format: Same as pandas.read_csv()
374382
:param encoding: Same as pandas.read_csv()
383+
:param converters: Same as pandas.read_csv()
375384
:return: Pandas Dataframe
376385
"""
377386
buff = BytesIO()
@@ -392,6 +401,7 @@ def _read_csv_once(
392401
lineterminator=lineterminator,
393402
dtype=dtype,
394403
encoding=encoding,
404+
converters=converters,
395405
)
396406
buff.close()
397407
return dataframe
@@ -425,12 +435,13 @@ def read_sql_athena(self,
425435
message_error = f"Query error: {reason}"
426436
raise AthenaQueryError(message_error)
427437
else:
428-
dtype, parse_timestamps, parse_dates = self._session.athena.get_query_dtype(
438+
dtype, parse_timestamps, parse_dates, converters = self._session.athena.get_query_dtype(
429439
query_execution_id=query_execution_id)
430440
path = f"{s3_output}{query_execution_id}.csv"
431441
ret = self.read_csv(path=path,
432442
dtype=dtype,
433443
parse_dates=parse_timestamps,
444+
converters=converters,
434445
quoting=csv.QUOTE_ALL,
435446
max_result_size=max_result_size)
436447
if max_result_size is None:

testing/test_awswrangler/test_pandas.py

Lines changed: 34 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -630,3 +630,37 @@ def test_to_parquet_compressed(session, bucket, database, compression):
630630
assert len(list(dataframe.columns)) == len(list(dataframe2.columns))
631631
assert dataframe[dataframe["id"] == 1].iloc[0]["name"] == dataframe2[
632632
dataframe2["id"] == 1].iloc[0]["name"]
633+
634+
635+
def test_to_parquet_lists(session, bucket, database):
636+
dataframe = pandas.DataFrame({
637+
"id": [0, 1],
638+
"col_int": [[1, 2], [3, 4, 5]],
639+
"col_float": [[1.0, 2.0, 3.0], [4.0, 5.0]],
640+
"col_string": [["foo"], ["boo", "bar"]],
641+
"col_timestamp": [[datetime(2019, 1, 1),
642+
datetime(2019, 1, 2)], [datetime(2019, 1, 3)]],
643+
"col_date": [[date(2019, 1, 1), date(2019, 1, 2)], [date(2019, 1, 3)]],
644+
"col_list_int": [[[1]], [[2, 3], [4, 5, 6]]],
645+
"col_list_list_string": [[[["foo"]]], [[["boo", "bar"]]]],
646+
})
647+
paths = session.pandas.to_parquet(dataframe=dataframe,
648+
database=database,
649+
path=f"s3://{bucket}/test/",
650+
preserve_index=False,
651+
mode="overwrite",
652+
procs_cpu_bound=1)
653+
assert len(paths) == 1
654+
dataframe2 = None
655+
for counter in range(10):
656+
dataframe2 = session.pandas.read_sql_athena(
657+
sql="select id, col_int, col_float, col_list_int from test",
658+
database=database)
659+
if len(dataframe.index) == len(dataframe2.index):
660+
break
661+
sleep(2)
662+
assert len(dataframe.index) == len(dataframe2.index)
663+
assert 4 == len(list(dataframe2.columns))
664+
val = dataframe[dataframe["id"] == 0].iloc[0]["col_list_int"]
665+
val2 = dataframe2[dataframe2["id"] == 0].iloc[0]["col_list_int"]
666+
assert val == val2

0 commit comments

Comments
 (0)