1+ from typing import Dict , List , Tuple , Optional , Any , Iterator
12from time import sleep
23import logging
34import ast
45import re
56import unicodedata
67
7- from awswrangler import data_types
8+ from awswrangler . data_types import athena2python , athena2pandas
89from awswrangler .exceptions import QueryFailed , QueryCancelled
910
1011logger = logging .getLogger (__name__ )
@@ -30,7 +31,7 @@ def get_query_dtype(self, query_execution_id):
3031 parse_dates = []
3132 converters = {}
3233 for col_name , col_type in cols_metadata .items ():
33- pandas_type = data_types . athena2pandas (dtype = col_type )
34+ pandas_type = athena2pandas (dtype = col_type )
3435 if pandas_type in ["datetime64" , "date" ]:
3536 parse_timestamps .append (col_name )
3637 if pandas_type == "date" :
@@ -122,6 +123,58 @@ def repair_table(self, database, table, s3_output=None, workgroup=None):
122123 self .wait_query (query_execution_id = query_id )
123124 return query_id
124125
126+ @staticmethod
127+ def _rows2row (rows : List [Dict [str , List [Dict [str , str ]]]],
128+ python_types : List [Tuple [str , Optional [type ]]]) -> Iterator [Dict [str , Any ]]:
129+ for row in rows :
130+ vals_varchar : List [Optional [str ]] = [x ["VarCharValue" ] if x else None for x in row ["Data" ]]
131+ data : Dict [str , Any ] = {}
132+ for (name , ptype ), val in zip (python_types , vals_varchar ):
133+ if ptype is not None :
134+ data [name ] = ptype (val )
135+ else :
136+ data [name ] = None
137+ yield data
138+
139+ def get_results (self , query_execution_id : str ) -> Iterator [Dict [str , Any ]]:
140+ """
141+ Get a query results and return a list of rows
142+ :param query_execution_id: Query execution ID
143+ :return: Iterator os lists
144+ """
145+ res : Dict = self ._client_athena .get_query_results (QueryExecutionId = query_execution_id )
146+ cols_info : List [Dict ] = res ["ResultSet" ]["ResultSetMetadata" ]["ColumnInfo" ]
147+ athena_types : List [Tuple [str , str ]] = [(x ["Label" ], x ["Type" ]) for x in cols_info ]
148+ logger .info (f"athena_types: { athena_types } " )
149+ python_types : List [Tuple [str , Optional [type ]]] = [(n , athena2python (dtype = t )) for n , t in athena_types ]
150+ logger .info (f"python_types: { python_types } " )
151+ rows : List [Dict [str , List [Dict [str , str ]]]] = res ["ResultSet" ]["Rows" ][1 :]
152+ for row in Athena ._rows2row (rows = rows , python_types = python_types ):
153+ yield row
154+ next_token : Optional [str ] = res .get ("NextToken" )
155+ while next_token is not None :
156+ logger .info (f"next_token: { next_token } " )
157+ res = self ._client_athena .get_query_results (QueryExecutionId = query_execution_id , NextToken = next_token )
158+ rows = res ["ResultSet" ]["Rows" ]
159+ for row in Athena ._rows2row (rows = rows , python_types = python_types ):
160+ yield row
161+ next_token = res .get ("NextToken" )
162+
163+ def query (self , query : str , database : str , s3_output : str = None ,
164+ workgroup : str = None ) -> Iterator [Dict [str , Any ]]:
165+ """
166+ Run a SQL Query against AWS Athena and return the result as a Iterator of lists
167+
168+ :param query: SQL query
169+ :param database: Glue database name
170+ :param s3_output: AWS S3 path
171+ :param workgroup: Athena workgroup (By default uses de Session() workgroup)
172+ :return: Query execution ID
173+ """
174+ query_id : str = self .run_query (query = query , database = database , s3_output = s3_output , workgroup = workgroup )
175+ self .wait_query (query_execution_id = query_id )
176+ return self .get_results (query_execution_id = query_id )
177+
125178 @staticmethod
126179 def _normalize_name (name ):
127180 name = "" .join (c for c in unicodedata .normalize ("NFD" , name ) if unicodedata .category (c ) != "Mn" )
0 commit comments