Skip to content

Commit 5e3a533

Browse files
committed
Add cast for nest types for wr.s3.to_parquet. #263
1 parent f6aaf6f commit 5e3a533

File tree

2 files changed

+59
-10
lines changed

2 files changed

+59
-10
lines changed

awswrangler/_data_types.py

Lines changed: 16 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -21,7 +21,7 @@
2121

2222
def athena2pyarrow(dtype: str) -> pa.DataType: # pylint: disable=too-many-return-statements
2323
"""Athena to PyArrow data types conversion."""
24-
dtype = dtype.lower()
24+
dtype = dtype.lower().replace(" ", "")
2525
if dtype == "tinyint":
2626
return pa.int8()
2727
if dtype == "smallint":
@@ -47,6 +47,12 @@ def athena2pyarrow(dtype: str) -> pa.DataType: # pylint: disable=too-many-retur
4747
if dtype.startswith("decimal"):
4848
precision, scale = dtype.replace("decimal(", "").replace(")", "").split(sep=",")
4949
return pa.decimal128(precision=int(precision), scale=int(scale))
50+
if dtype.startswith("array"):
51+
return pa.large_list(athena2pyarrow(dtype=dtype[6:-1]))
52+
if dtype.startswith("struct"):
53+
return pa.struct([(f.split(":", 1)[0], athena2pyarrow(f.split(":", 1)[1])) for f in dtype[7:-1].split(",")])
54+
if dtype.startswith("map"): # pragma: no cover
55+
return pa.map_(athena2pyarrow(dtype[4:-1].split(",", 1)[0]), athena2pyarrow(dtype[4:-1].split(",", 1)[1]))
5056
raise exceptions.UnsupportedType(f"Unsupported Athena type: {dtype}") # pragma: no cover
5157

5258

@@ -77,8 +83,6 @@ def athena2pandas(dtype: str) -> str: # pylint: disable=too-many-branches,too-m
7783
return "decimal"
7884
if dtype in ("binary", "varbinary"):
7985
return "bytes"
80-
if dtype == "array": # pragma: no cover
81-
return "list"
8286
raise exceptions.UnsupportedType(f"Unsupported Athena type: {dtype}") # pragma: no cover
8387

8488

@@ -143,9 +147,9 @@ def pyarrow2athena(dtype: pa.DataType) -> str: # pylint: disable=too-many-branc
143147
if pa.types.is_list(dtype):
144148
return f"array<{pyarrow2athena(dtype=dtype.value_type)}>"
145149
if pa.types.is_struct(dtype):
146-
return f"struct<{', '.join([f'{f.name}:{pyarrow2athena(dtype=f.type)}' for f in dtype])}>"
150+
return f"struct<{','.join([f'{f.name}:{pyarrow2athena(dtype=f.type)}' for f in dtype])}>"
147151
if pa.types.is_map(dtype): # pragma: no cover
148-
return f"map<{pyarrow2athena(dtype=dtype.key_type)},{pyarrow2athena(dtype=dtype.item_type)}>"
152+
return f"map<{pyarrow2athena(dtype=dtype.key_type)}, {pyarrow2athena(dtype=dtype.item_type)}>"
149153
if dtype == pa.null():
150154
raise exceptions.UndetectedType("We can not infer the data type from an entire null object column")
151155
raise exceptions.UnsupportedType(f"Unsupported Pyarrow type: {dtype}") # pragma: no cover
@@ -321,7 +325,7 @@ def athena_types_from_pandas(
321325
athena_columns_types: Dict[str, str] = {}
322326
for k, v in pa_columns_types.items():
323327
if v is None:
324-
athena_columns_types[k] = casts[k]
328+
athena_columns_types[k] = casts[k].replace(" ", "")
325329
else:
326330
athena_columns_types[k] = pyarrow2athena(dtype=v)
327331
_logger.debug("athena_columns_types: %s", athena_columns_types)
@@ -384,7 +388,12 @@ def athena_types_from_pyarrow_schema(
384388
def cast_pandas_with_athena_types(df: pd.DataFrame, dtype: Dict[str, str]) -> pd.DataFrame:
385389
"""Cast columns in a Pandas DataFrame."""
386390
for col, athena_type in dtype.items():
387-
if col in df.columns:
391+
if (
392+
(col in df.columns)
393+
and (not athena_type.startswith("array"))
394+
and (not athena_type.startswith("struct"))
395+
and (not athena_type.startswith("map"))
396+
):
388397
pandas_type: str = athena2pandas(dtype=athena_type)
389398
if pandas_type == "datetime64":
390399
df[col] = pd.to_datetime(df[col])

testing/test_awswrangler/test_data_lake.py

Lines changed: 43 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -1295,9 +1295,7 @@ def test_athena_encryption(
12951295
assert len(df2.columns) == 2
12961296

12971297

1298-
def test_athena_nested(bucket, database):
1299-
table = "test_athena_nested"
1300-
path = f"s3://{bucket}/{table}/"
1298+
def test_athena_nested(path, database, table):
13011299
df = pd.DataFrame(
13021300
{
13031301
"c0": [[1, 2, 3], [4, 5, 6]],
@@ -2142,3 +2140,45 @@ def test_to_parquet_reverse_partitions(database, table, path, partition_cols):
21422140
assert df.c0.sum() == df2.c0.sum()
21432141
assert df.c1.sum() == df2.c1.sum()
21442142
assert df.c2.sum() == df2.c2.sum()
2143+
2144+
2145+
def test_to_parquet_nested_append(database, table, path):
2146+
df = pd.DataFrame(
2147+
{
2148+
"c0": [[1, 2, 3], [4, 5, 6]],
2149+
"c1": [[[1, 2], [3, 4]], [[5, 6], [7, 8]]],
2150+
"c2": [[["a", "b"], ["c", "d"]], [["e", "f"], ["g", "h"]]],
2151+
"c3": [[], [[[[[[[[1]]]]]]]]],
2152+
"c4": [{"a": 1}, {"a": 1}],
2153+
"c5": [{"a": {"b": {"c": [1, 2]}}}, {"a": {"b": {"c": [3, 4]}}}],
2154+
}
2155+
)
2156+
paths = wr.s3.to_parquet(df=df, path=path, dataset=True, database=database, table=table)["paths"]
2157+
wr.s3.wait_objects_exist(paths=paths, use_threads=False)
2158+
df2 = wr.athena.read_sql_query(sql=f"SELECT c0, c1, c2, c4 FROM {table}", database=database)
2159+
assert len(df2.index) == 2
2160+
assert len(df2.columns) == 4
2161+
paths = wr.s3.to_parquet(df=df, path=path, dataset=True, database=database, table=table)["paths"]
2162+
wr.s3.wait_objects_exist(paths=paths, use_threads=False)
2163+
df2 = wr.athena.read_sql_query(sql=f"SELECT c0, c1, c2, c4 FROM {table}", database=database)
2164+
assert len(df2.index) == 4
2165+
assert len(df2.columns) == 4
2166+
2167+
2168+
def test_to_parquet_nested_cast(database, table, path):
2169+
df = pd.DataFrame({"c0": [[1, 2, 3], [4, 5, 6]], "c1": [[], []], "c2": [{"a": 1, "b": 2}, {"a": 3, "b": 4}]})
2170+
paths = wr.s3.to_parquet(
2171+
df=df,
2172+
path=path,
2173+
dataset=True,
2174+
database=database,
2175+
table=table,
2176+
dtype={"c0": "array<double>", "c1": "array<string>", "c2": "struct<a:bigint, b:double>"},
2177+
)["paths"]
2178+
wr.s3.wait_objects_exist(paths=paths, use_threads=False)
2179+
df = pd.DataFrame({"c0": [[1, 2, 3], [4, 5, 6]], "c1": [["a"], ["b"]], "c2": [{"a": 1, "b": 2}, {"a": 3, "b": 4}]})
2180+
paths = wr.s3.to_parquet(df=df, path=path, dataset=True, database=database, table=table)["paths"]
2181+
wr.s3.wait_objects_exist(paths=paths, use_threads=False)
2182+
df2 = wr.athena.read_sql_query(sql=f"SELECT c0, c2 FROM {table}", database=database)
2183+
assert len(df2.index) == 4
2184+
assert len(df2.columns) == 2

0 commit comments

Comments
 (0)