Skip to content

Commit dad5e07

Browse files
committed
Handling parquet tinyint on Redshift load. #400
1 parent e47cdae commit dad5e07

File tree

2 files changed

+9
-4
lines changed

2 files changed

+9
-4
lines changed

awswrangler/_data_types.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -93,6 +93,8 @@ def athena2redshift( # pylint: disable=too-many-branches,too-many-return-statem
9393
) -> str:
9494
"""Athena to Redshift data types conversion."""
9595
dtype = dtype.lower()
96+
if dtype == "tinyint":
97+
return "SMALLINT"
9698
if dtype == "smallint":
9799
return "SMALLINT"
98100
if dtype in ("int", "integer"):

tests/test_db.py

Lines changed: 7 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -660,10 +660,11 @@ def test_redshift_copy_unload_kms(
660660
def test_redshift_copy_extras(path, redshift_table, databases_parameters, use_threads, parquet_infer_sampling):
661661
df = pd.DataFrame(
662662
{
663-
"int16": [1, None, 2],
664-
"int32": [1, None, 2],
665-
"int64": [1, None, 2],
666-
"float": [0.0, None, 1.1],
663+
"int8": [-1, None, 2],
664+
"int16": [-1, None, 2],
665+
"int32": [-1, None, 2],
666+
"int64": [-1, None, 2],
667+
"float": [0.0, None, -1.1],
667668
"double": [0.0, None, 1.1],
668669
"decimal": [Decimal((0, (1, 9, 9), -2)), None, Decimal((0, (1, 9, 0), -2))],
669670
"string": ["foo", None, "boo"],
@@ -672,6 +673,7 @@ def test_redshift_copy_extras(path, redshift_table, databases_parameters, use_th
672673
"bool": [True, None, False],
673674
}
674675
)
676+
df["int8"] = df["int8"].astype("Int8")
675677
df["int16"] = df["int16"].astype("Int16")
676678
df["int32"] = df["int32"].astype("Int32")
677679
df["int64"] = df["int64"].astype("Int64")
@@ -698,6 +700,7 @@ def test_redshift_copy_extras(path, redshift_table, databases_parameters, use_th
698700
df2 = wr.db.read_sql_table(schema="public", table=redshift_table, con=engine)
699701
assert len(df.columns) == len(df2.columns)
700702
assert len(df.index) * num == len(df2.index)
703+
assert df.int8.sum() * num == df2.int8.sum()
701704
assert df.int16.sum() * num == df2.int16.sum()
702705
assert df.int32.sum() * num == df2.int32.sum()
703706
assert df.int64.sum() * num == df2.int64.sum()

0 commit comments

Comments
 (0)