Skip to content

Commit 75649ae

Browse files
committed
Fixing casting with char and varchar lengths
1 parent 9c4a4d3 commit 75649ae

File tree

2 files changed

+57
-2
lines changed

2 files changed

+57
-2
lines changed

awswrangler/_data_types.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -35,7 +35,7 @@ def athena2pyarrow(dtype: str) -> pa.DataType: # pylint: disable=too-many-retur
3535
return pa.float64()
3636
if dtype == "boolean":
3737
return pa.bool_()
38-
if dtype in ("string", "char", "varchar", "array", "row", "map"):
38+
if (dtype == "string") or dtype.startswith("char") or dtype.startswith("varchar"):
3939
return pa.string()
4040
if dtype == "timestamp":
4141
return pa.timestamp(unit="ns")
@@ -66,7 +66,7 @@ def athena2pandas(dtype: str) -> str: # pylint: disable=too-many-branches,too-m
6666
return "float64"
6767
if dtype == "boolean":
6868
return "boolean"
69-
if dtype in ("string", "char", "varchar"):
69+
if (dtype == "string") or dtype.startswith("char") or dtype.startswith("varchar"):
7070
return "string"
7171
if dtype in ("timestamp", "timestamp with time zone"):
7272
return "datetime64"

testing/test_awswrangler/test_data_lake.py

Lines changed: 55 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -48,6 +48,21 @@ def kms_key(cloudformation_outputs):
4848
yield cloudformation_outputs["KmsKeyArn"]
4949

5050

51+
@pytest.fixture(scope="module")
52+
def external_schema(cloudformation_outputs, database):
53+
region = cloudformation_outputs.get("Region")
54+
sql = f"""
55+
CREATE EXTERNAL SCHEMA IF NOT EXISTS aws_data_wrangler_external FROM data catalog
56+
DATABASE '{database}'
57+
IAM_ROLE '{cloudformation_outputs["RedshiftRole"]}'
58+
REGION '{region}';
59+
"""
60+
engine = wr.catalog.get_engine(connection=f"aws-data-wrangler-redshift")
61+
with engine.connect() as con:
62+
con.execute(sql)
63+
yield "aws_data_wrangler_external"
64+
65+
5166
@pytest.fixture(scope="module")
5267
def workgroup0(bucket):
5368
wkg_name = "awswrangler_test_0"
@@ -957,3 +972,43 @@ def test_csv_compress(bucket, compression):
957972
for df3 in dfs:
958973
assert len(df3.columns) == 10
959974
wr.s3.delete_objects(path=path)
975+
976+
977+
def test_parquet_char_length(bucket, database, external_schema):
978+
path = f"s3://{bucket}/test_parquet_char_length/"
979+
table = "test_parquet_char_length"
980+
981+
df = pd.DataFrame({
982+
"id": [1, 2],
983+
"cchar": ["foo", "boo"],
984+
"date": [datetime.date(2020, 1, 1), datetime.date(2020, 1, 2)]
985+
})
986+
wr.s3.to_parquet(
987+
df=df,
988+
path=path,
989+
dataset=True,
990+
database=database,
991+
table=table,
992+
mode="overwrite",
993+
partition_cols=["date"],
994+
dtype={'cchar': 'char(3)'}
995+
)
996+
997+
df2 = wr.s3.read_parquet(path, dataset=True)
998+
assert len(df2.index) == 2
999+
assert len(df2.columns) == 3
1000+
assert df2.id.sum() == 3
1001+
1002+
df2 = wr.athena.read_sql_table(table=table, database=database)
1003+
assert len(df2.index) == 2
1004+
assert len(df2.columns) == 3
1005+
assert df2.id.sum() == 3
1006+
1007+
engine = wr.catalog.get_engine("aws-data-wrangler-redshift")
1008+
df2 = wr.db.read_sql_table(con=engine, table=table, schema=external_schema)
1009+
assert len(df2.index) == 2
1010+
assert len(df2.columns) == 3
1011+
assert df2.id.sum() == 3
1012+
1013+
wr.s3.delete_objects(path=path)
1014+
assert wr.catalog.delete_table_if_exists(database=database, table=table) is True

0 commit comments

Comments
 (0)