Skip to content

Commit 8c15804

Browse files
authored
Merge pull request #35 from awslabs/spark-to-redshift-date-type
Adding support for Date data type to load dataframes on Redshift
2 parents 7574474 + ce921e1 commit 8c15804

File tree

9 files changed

+499
-320
lines changed

9 files changed

+499
-320
lines changed

README.md

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -187,7 +187,7 @@ session.spark.create_glue_table(dataframe=dataframe,
187187

188188
### General
189189

190-
#### Deleting a bunch of S3 objects (parallel :rocket:)
190+
#### Deleting a bunch of S3 objects (parallel)
191191

192192
```py3
193193
session = awswrangler.Session()

awswrangler/__init__.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -10,6 +10,7 @@
1010
from awswrangler.glue import Glue # noqa
1111
from awswrangler.redshift import Redshift # noqa
1212
import awswrangler.utils # noqa
13+
import awswrangler.data_types # noqa
1314

1415
if importlib.util.find_spec("pyspark"):
1516
from awswrangler.spark import Spark # noqa

awswrangler/athena.py

Lines changed: 7 additions & 26 deletions
Original file line numberDiff line numberDiff line change
@@ -2,7 +2,8 @@
22
import logging
33
import ast
44

5-
from awswrangler.exceptions import UnsupportedType, QueryFailed, QueryCancelled
5+
from awswrangler import data_types
6+
from awswrangler.exceptions import QueryFailed, QueryCancelled
67

78
logger = logging.getLogger(__name__)
89

@@ -21,26 +22,6 @@ def get_query_columns_metadata(self, query_execution_id):
2122
col_info = response["ResultSet"]["ResultSetMetadata"]["ColumnInfo"]
2223
return {x["Name"]: x["Type"] for x in col_info}
2324

24-
@staticmethod
25-
def _type_athena2pandas(dtype):
26-
dtype = dtype.lower()
27-
if dtype in ["int", "integer", "bigint", "smallint", "tinyint"]:
28-
return "Int64"
29-
elif dtype in ["float", "double", "real"]:
30-
return "float64"
31-
elif dtype == "boolean":
32-
return "bool"
33-
elif dtype in ["string", "char", "varchar"]:
34-
return "str"
35-
elif dtype == "timestamp":
36-
return "datetime64"
37-
elif dtype == "date":
38-
return "date"
39-
elif dtype == "array":
40-
return "literal_eval"
41-
else:
42-
raise UnsupportedType(f"Unsupported Athena type: {dtype}")
43-
4425
def get_query_dtype(self, query_execution_id):
4526
cols_metadata = self.get_query_columns_metadata(
4627
query_execution_id=query_execution_id)
@@ -49,15 +30,15 @@ def get_query_dtype(self, query_execution_id):
4930
parse_dates = []
5031
converters = {}
5132
for col_name, col_type in cols_metadata.items():
52-
ptype = Athena._type_athena2pandas(dtype=col_type)
53-
if ptype in ["datetime64", "date"]:
33+
pandas_type = data_types.athena2pandas(dtype=col_type)
34+
if pandas_type in ["datetime64", "date"]:
5435
parse_timestamps.append(col_name)
55-
if ptype == "date":
36+
if pandas_type == "date":
5637
parse_dates.append(col_name)
57-
elif ptype == "literal_eval":
38+
elif pandas_type == "literal_eval":
5839
converters[col_name] = ast.literal_eval
5940
else:
60-
dtype[col_name] = ptype
41+
dtype[col_name] = pandas_type
6142
logger.debug(f"dtype: {dtype}")
6243
logger.debug(f"parse_timestamps: {parse_timestamps}")
6344
logger.debug(f"parse_dates: {parse_dates}")

0 commit comments

Comments
 (0)