Skip to content

Commit 97f8763

Browse files
authored
Merge pull request #220 from awslabs/write-nested-types
Add support to write nested types (array and struct).
2 parents b748a35 + 08cf244 commit 97f8763

File tree

4 files changed

+85
-11
lines changed

4 files changed

+85
-11
lines changed

.gitignore

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -138,6 +138,8 @@ testing/*parameters-*.properties
138138
testing/*requirements*.txt
139139
testing/coverage/*
140140
building/*requirements*.txt
141+
building/arrow
142+
building/lambda/arrow
141143
/docs/coverage/
142144
/docs/build/
143145
/docs/source/_build/

awswrangler/_data_types.py

Lines changed: 57 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -1,8 +1,9 @@
11
"""Internal (private) Data Types Module."""
22

33
import logging
4+
import re
45
from decimal import Decimal
5-
from typing import Dict, List, Optional, Tuple
6+
from typing import Any, Dict, List, Match, Optional, Sequence, Tuple
67

78
import pandas as pd # type: ignore
89
import pyarrow as pa # type: ignore
@@ -139,8 +140,10 @@ def pyarrow2athena(dtype: pa.DataType) -> str: # pylint: disable=too-many-branc
139140
return f"decimal({dtype.precision},{dtype.scale})"
140141
if pa.types.is_list(dtype):
141142
return f"array<{pyarrow2athena(dtype=dtype.value_type)}>"
142-
if pa.types.is_struct(dtype): # pragma: no cover
143-
return f"struct<{', '.join([f'{f.name}: {pyarrow2athena(dtype=f.type)}' for f in dtype])}>"
143+
if pa.types.is_struct(dtype):
144+
return f"struct<{', '.join([f'{f.name}:{pyarrow2athena(dtype=f.type)}' for f in dtype])}>"
145+
if pa.types.is_map(dtype): # pragma: no cover
146+
return f"map<{pyarrow2athena(dtype=dtype.key_type)},{pyarrow2athena(dtype=dtype.item_type)}>"
144147
if dtype == pa.null():
145148
raise exceptions.UndetectedType("We can not infer the data type from an entire null object column")
146149
raise exceptions.UnsupportedType(f"Unsupported Pyarrow type: {dtype}") # pragma: no cover
@@ -167,7 +170,7 @@ def pyarrow2pandas_extension( # pylint: disable=too-many-branches,too-many-retu
167170

168171
def pyarrow2sqlalchemy( # pylint: disable=too-many-branches,too-many-return-statements
169172
dtype: pa.DataType, db_type: str
170-
) -> VisitableType:
173+
) -> Optional[VisitableType]:
171174
"""Pyarrow to Athena data types conversion."""
172175
if pa.types.is_int8(dtype):
173176
return sqlalchemy.types.SmallInteger
@@ -214,7 +217,7 @@ def pyarrow2sqlalchemy( # pylint: disable=too-many-branches,too-many-return-sta
214217
if pa.types.is_dictionary(dtype):
215218
return pyarrow2sqlalchemy(dtype=dtype.value_type, db_type=db_type)
216219
if dtype == pa.null(): # pragma: no cover
217-
raise exceptions.UndetectedType("We can not infer the data type from an entire null object column")
220+
return None
218221
raise exceptions.UnsupportedType(f"Unsupported Pyarrow type: {dtype}") # pragma: no cover
219222

220223

@@ -243,12 +246,23 @@ def pyarrow_types_from_pandas(
243246
else:
244247
cols.append(name)
245248

246-
# Filling cols_dtypes and indexes
249+
# Filling cols_dtypes
250+
for col in cols:
251+
_logger.debug("Inferring PyArrow type from column: %s", col)
252+
try:
253+
schema: pa.Schema = pa.Schema.from_pandas(df=df[[col]], preserve_index=False)
254+
except pa.ArrowInvalid as ex: # pragma: no cover
255+
cols_dtypes[col] = process_not_inferred_dtype(ex)
256+
else:
257+
cols_dtypes[col] = schema.field(col).type
258+
259+
# Filling indexes
247260
indexes: List[str] = []
248-
for field in pa.Schema.from_pandas(df=df[cols], preserve_index=index):
249-
name = str(field.name)
250-
cols_dtypes[name] = field.type
251-
if (name not in df.columns) and (index is True):
261+
if index is True:
262+
for field in pa.Schema.from_pandas(df=df[[]], preserve_index=True):
263+
name = str(field.name)
264+
_logger.debug("Inferring PyArrow type from index: %s", name)
265+
cols_dtypes[name] = field.type
252266
indexes.append(name)
253267

254268
# Merging Index
@@ -261,6 +275,39 @@ def pyarrow_types_from_pandas(
261275
return columns_types
262276

263277

278+
def process_not_inferred_dtype(ex: pa.ArrowInvalid) -> pa.DataType:
279+
"""Infer data type from PyArrow inference exception."""
280+
ex_str = str(ex)
281+
_logger.debug("PyArrow was not able to infer data type:\n%s", ex_str)
282+
match: Optional[Match] = re.search(
283+
pattern="Could not convert (.*) with type (.*): did not recognize "
284+
"Python value type when inferring an Arrow data type",
285+
string=ex_str,
286+
)
287+
if match is None:
288+
raise ex # pragma: no cover
289+
groups: Optional[Sequence[str]] = match.groups()
290+
if groups is None:
291+
raise ex # pragma: no cover
292+
if len(groups) != 2:
293+
raise ex # pragma: no cover
294+
_logger.debug("groups: %s", groups)
295+
type_str: str = groups[1]
296+
if type_str == "UUID":
297+
return pa.string()
298+
raise ex # pragma: no cover
299+
300+
301+
def process_not_inferred_array(ex: pa.ArrowInvalid, values: Any) -> pa.Array:
302+
"""Infer `pyarrow.array` from PyArrow inference exception."""
303+
dtype = process_not_inferred_dtype(ex=ex)
304+
if dtype == pa.string():
305+
array: pa.Array = pa.array(obj=[str(x) for x in values], type=dtype, safe=True)
306+
else:
307+
raise ex # pragma: no cover
308+
return array
309+
310+
264311
def athena_types_from_pandas(
265312
df: pd.DataFrame, index: bool, dtype: Optional[Dict[str, str]] = None, index_left: bool = False
266313
) -> Dict[str, str]:

awswrangler/db.py

Lines changed: 4 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -185,7 +185,10 @@ def _records2df(
185185
arrays: List[pa.Array] = []
186186
for col_values, col_name in zip(tuple(zip(*records)), cols_names): # Transposing
187187
if (dtype is None) or (col_name not in dtype):
188-
array: pa.Array = pa.array(obj=col_values, safe=True) # Creating Arrow array
188+
try:
189+
array: pa.Array = pa.array(obj=col_values, safe=True) # Creating Arrow array
190+
except pa.ArrowInvalid as ex:
191+
array = _data_types.process_not_inferred_array(ex, values=col_values) # Creating Arrow array
189192
else:
190193
array = pa.array(obj=col_values, type=dtype[col_name], safe=True) # Creating Arrow array with dtype
191194
arrays.append(array)

testing/test_awswrangler/test_data_lake.py

Lines changed: 22 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1203,3 +1203,25 @@ def test_athena_encryption(
12031203
assert len(df2.columns) == 2
12041204
wr.catalog.delete_table_if_exists(database=database, table=table)
12051205
wr.s3.delete_objects(path=paths)
1206+
1207+
1208+
def test_athena_nested(bucket, database):
1209+
table = "test_athena_nested"
1210+
path = f"s3://{bucket}/{table}/"
1211+
df = pd.DataFrame(
1212+
{
1213+
"c0": [[1, 2, 3], [4, 5, 6]],
1214+
"c1": [[[1, 2], [3, 4]], [[5, 6], [7, 8]]],
1215+
"c2": [[["a", "b"], ["c", "d"]], [["e", "f"], ["g", "h"]]],
1216+
"c3": [[], [[[[[[[[1]]]]]]]]],
1217+
"c4": [{"a": 1}, {"a": 1}],
1218+
"c5": [{"a": {"b": {"c": [1, 2]}}}, {"a": {"b": {"c": [3, 4]}}}],
1219+
}
1220+
)
1221+
paths = wr.s3.to_parquet(
1222+
df=df, path=path, index=False, use_threads=True, dataset=True, mode="overwrite", database=database, table=table
1223+
)["paths"]
1224+
wr.s3.wait_objects_exist(paths=paths)
1225+
df2 = wr.athena.read_sql_query(sql=f"SELECT c0, c1, c2, c4 FROM {table}", database=database)
1226+
assert len(df2.index) == 2
1227+
assert len(df2.columns) == 4

0 commit comments

Comments
 (0)