Skip to content

Commit cf42d6b

Browse files
committed
Fix nested struct append.
1 parent 0563164 commit cf42d6b

File tree

2 files changed

+64
-3
lines changed

2 files changed

+64
-3
lines changed

awswrangler/_data_types.py

Lines changed: 31 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -4,7 +4,7 @@
44
import logging
55
import re
66
from decimal import Decimal
7-
from typing import Any, Callable, Dict, List, Match, Optional, Sequence, Tuple, Union
7+
from typing import Any, Callable, Dict, Iterator, List, Match, Optional, Sequence, Tuple, Union
88

99
import numpy as np
1010
import pandas as pd
@@ -189,9 +189,35 @@ def pyarrow2timestream(dtype: pa.DataType) -> str: # pylint: disable=too-many-b
189189
raise exceptions.UnsupportedType(f"Unsupported Amazon Timestream measure type: {dtype}")
190190

191191

192+
def _split_fields(s: str) -> Iterator[str]:
193+
counter: int = 0
194+
last: int = 0
195+
for i, x in enumerate(s):
196+
if x == "<":
197+
counter += 1
198+
elif x == ">":
199+
counter -= 1
200+
elif x == "," and counter == 0:
201+
yield s[last:i]
202+
last = i + 1
203+
yield s[last:]
204+
205+
206+
def _split_struct(s: str) -> List[str]:
207+
return list(_split_fields(s=s))
208+
209+
210+
def _split_map(s: str) -> List[str]:
211+
parts: List[str] = list(_split_fields(s=s))
212+
if len(parts) != 2:
213+
raise RuntimeError(f"Invalid map fields: {s}")
214+
return parts
215+
216+
192217
def athena2pyarrow(dtype: str) -> pa.DataType: # pylint: disable=too-many-return-statements
193218
"""Athena to PyArrow data types conversion."""
194219
dtype = dtype.lower().replace(" ", "")
220+
print(dtype)
195221
if dtype == "tinyint":
196222
return pa.int8()
197223
if dtype == "smallint":
@@ -220,9 +246,10 @@ def athena2pyarrow(dtype: str) -> pa.DataType: # pylint: disable=too-many-retur
220246
if dtype.startswith("array") is True:
221247
return pa.list_(value_type=athena2pyarrow(dtype=dtype[6:-1]), list_size=-1)
222248
if dtype.startswith("struct") is True:
223-
return pa.struct([(f.split(":", 1)[0], athena2pyarrow(f.split(":", 1)[1])) for f in dtype[7:-1].split(",")])
249+
return pa.struct([(f.split(":", 1)[0], athena2pyarrow(f.split(":", 1)[1])) for f in _split_struct(dtype[7:-1])])
224250
if dtype.startswith("map") is True:
225-
return pa.map_(athena2pyarrow(dtype[4:-1].split(",", 1)[0]), athena2pyarrow(dtype[4:-1].split(",", 1)[1]))
251+
parts: List[str] = _split_map(s=dtype[4:-1])
252+
return pa.map_(athena2pyarrow(parts[0]), athena2pyarrow(parts[1]))
226253
raise exceptions.UnsupportedType(f"Unsupported Athena type: {dtype}")
227254

228255

@@ -491,6 +518,7 @@ def pyarrow_schema_from_pandas(
491518
) -> pa.Schema:
492519
"""Extract the related Pyarrow Schema from any Pandas DataFrame."""
493520
casts: Dict[str, str] = {} if dtype is None else dtype
521+
_logger.debug("casts: %s", casts)
494522
ignore: List[str] = [] if ignore_cols is None else ignore_cols
495523
ignore_plus = ignore + list(casts.keys())
496524
columns_types: Dict[str, Optional[pa.DataType]] = pyarrow_types_from_pandas(

tests/test_athena_parquet.py

Lines changed: 33 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -10,6 +10,7 @@
1010
import pytest
1111

1212
import awswrangler as wr
13+
from awswrangler._data_types import _split_fields
1314

1415
from ._utils import ensure_data_types, get_df, get_df_cast, get_df_list
1516

@@ -674,3 +675,35 @@ def test_cast_decimal(path, glue_table, glue_database):
674675
assert df2["c1"].iloc[0] == Decimal((0, (1, 0, 0, 1), -1))
675676
assert df2["c2"].iloc[0] == Decimal((0, (1, 0, 0, 1), -1))
676677
assert df2["c3"].iloc[0] == "100.1"
678+
679+
680+
def test_splits():
681+
s = "a:struct<id:string,name:string>,b:struct<id:string,name:string>"
682+
assert list(_split_fields(s)) == ["a:struct<id:string,name:string>", "b:struct<id:string,name:string>"]
683+
s = "a:struct<a:struct<id:string,name:string>,b:struct<id:string,name:string>>,b:struct<a:struct<id:string,name:string>,b:struct<id:string,name:string>>" # noqa
684+
assert list(_split_fields(s)) == [
685+
"a:struct<a:struct<id:string,name:string>,b:struct<id:string,name:string>>",
686+
"b:struct<a:struct<id:string,name:string>,b:struct<id:string,name:string>>",
687+
]
688+
s = "a:struct<id:string,name:string>,b:struct<id:string,name:string>,c:struct<id:string,name:string>,d:struct<id:string,name:string>" # noqa
689+
assert list(_split_fields(s)) == [
690+
"a:struct<id:string,name:string>",
691+
"b:struct<id:string,name:string>",
692+
"c:struct<id:string,name:string>",
693+
"d:struct<id:string,name:string>",
694+
]
695+
696+
697+
def test_to_parquet_nested_structs(glue_database, glue_table, path):
698+
df = pd.DataFrame(
699+
{
700+
"c0": [1],
701+
"c1": [[{"a": {"id": "0", "name": "foo", "amount": 1}, "b": {"id": "1", "name": "boo", "amount": 2}}]],
702+
}
703+
)
704+
wr.s3.to_parquet(df=df, path=path, dataset=True, database=glue_database, table=glue_table)
705+
df2 = wr.athena.read_sql_query(sql=f"SELECT * FROM {glue_table}", database=glue_database)
706+
assert df2.shape == (1, 2)
707+
wr.s3.to_parquet(df=df, path=path, dataset=True, database=glue_database, table=glue_table)
708+
df3 = wr.athena.read_sql_query(sql=f"SELECT * FROM {glue_table}", database=glue_database)
709+
assert df3.shape == (2, 2)

0 commit comments

Comments
 (0)