22import logging
33import ast
44
5- from awswrangler .exceptions import UnsupportedType , QueryFailed , QueryCancelled
5+ from awswrangler import data_types
6+ from awswrangler .exceptions import QueryFailed , QueryCancelled
67
78logger = logging .getLogger (__name__ )
89
@@ -21,26 +22,6 @@ def get_query_columns_metadata(self, query_execution_id):
2122 col_info = response ["ResultSet" ]["ResultSetMetadata" ]["ColumnInfo" ]
2223 return {x ["Name" ]: x ["Type" ] for x in col_info }
2324
24- @staticmethod
25- def _type_athena2pandas (dtype ):
26- dtype = dtype .lower ()
27- if dtype in ["int" , "integer" , "bigint" , "smallint" , "tinyint" ]:
28- return "Int64"
29- elif dtype in ["float" , "double" , "real" ]:
30- return "float64"
31- elif dtype == "boolean" :
32- return "bool"
33- elif dtype in ["string" , "char" , "varchar" ]:
34- return "str"
35- elif dtype == "timestamp" :
36- return "datetime64"
37- elif dtype == "date" :
38- return "date"
39- elif dtype == "array" :
40- return "literal_eval"
41- else :
42- raise UnsupportedType (f"Unsupported Athena type: { dtype } " )
43-
4425 def get_query_dtype (self , query_execution_id ):
4526 cols_metadata = self .get_query_columns_metadata (
4627 query_execution_id = query_execution_id )
@@ -49,15 +30,15 @@ def get_query_dtype(self, query_execution_id):
4930 parse_dates = []
5031 converters = {}
5132 for col_name , col_type in cols_metadata .items ():
52- ptype = Athena . _type_athena2pandas (dtype = col_type )
53- if ptype in ["datetime64" , "date" ]:
33+ pandas_type = data_types . athena2pandas (dtype = col_type )
34+ if pandas_type in ["datetime64" , "date" ]:
5435 parse_timestamps .append (col_name )
55- if ptype == "date" :
36+ if pandas_type == "date" :
5637 parse_dates .append (col_name )
57- elif ptype == "literal_eval" :
38+ elif pandas_type == "literal_eval" :
5839 converters [col_name ] = ast .literal_eval
5940 else :
60- dtype [col_name ] = ptype
41+ dtype [col_name ] = pandas_type
6142 logger .debug (f"dtype: { dtype } " )
6243 logger .debug (f"parse_timestamps: { parse_timestamps } " )
6344 logger .debug (f"parse_dates: { parse_dates } " )
0 commit comments