@@ -16,17 +16,18 @@ class Athena:
1616 def __init__ (self , session ):
1717 self ._session = session
1818 self ._client_athena = session .boto3_session .client (
19- service_name = "athena" , config = session .botocore_config )
19+ service_name = "athena" , config = session .botocore_config
20+ )
2021
2122 def get_query_columns_metadata (self , query_execution_id ):
2223 response = self ._client_athena .get_query_results (
23- QueryExecutionId = query_execution_id , MaxResults = 1 )
24+ QueryExecutionId = query_execution_id , MaxResults = 1
25+ )
2426 col_info = response ["ResultSet" ]["ResultSetMetadata" ]["ColumnInfo" ]
2527 return {x ["Name" ]: x ["Type" ] for x in col_info }
2628
2729 def get_query_dtype (self , query_execution_id ):
28- cols_metadata = self .get_query_columns_metadata (
29- query_execution_id = query_execution_id )
30+ cols_metadata = self .get_query_columns_metadata (query_execution_id = query_execution_id )
3031 logger .debug (f"cols_metadata: { cols_metadata } " )
3132 dtype = {}
3233 parse_timestamps = []
@@ -53,10 +54,11 @@ def create_athena_bucket(self):
5354
5455 :return: Bucket s3 path (E.g. s3://aws-athena-query-results-ACCOUNT-REGION/)
5556 """
56- account_id = (self ._session .boto3_session .client (
57- service_name = "sts" ,
58- config = self ._session .botocore_config ).get_caller_identity ().get (
59- "Account" ))
57+ account_id = (
58+ self ._session .boto3_session .client (
59+ service_name = "sts" , config = self ._session .botocore_config
60+ ).get_caller_identity ().get ("Account" )
61+ )
6062 session_region = self ._session .boto3_session .region_name
6163 s3_output = f"s3://aws-athena-query-results-{ account_id } -{ session_region } /"
6264 s3_resource = self ._session .boto3_session .resource ("s3" )
@@ -82,7 +84,8 @@ def run_query(self, query, database, s3_output=None, workgroup=None):
8284 QueryString = query ,
8385 QueryExecutionContext = {"Database" : database },
8486 ResultConfiguration = {"OutputLocation" : s3_output },
85- WorkGroup = workgroup )
87+ WorkGroup = workgroup
88+ )
8689 return response ["QueryExecutionId" ]
8790
8891 def wait_query (self , query_execution_id ):
@@ -93,24 +96,20 @@ def wait_query(self, query_execution_id):
9396 :return: Query response
9497 """
9598 final_states = ["FAILED" , "SUCCEEDED" , "CANCELLED" ]
96- response = self ._client_athena .get_query_execution (
97- QueryExecutionId = query_execution_id )
99+ response = self ._client_athena .get_query_execution (QueryExecutionId = query_execution_id )
98100 state = response ["QueryExecution" ]["Status" ]["State" ]
99101 while state not in final_states :
100102 sleep (QUERY_WAIT_POLLING_DELAY )
101- response = self ._client_athena .get_query_execution (
102- QueryExecutionId = query_execution_id )
103+ response = self ._client_athena .get_query_execution (QueryExecutionId = query_execution_id )
103104 state = response ["QueryExecution" ]["Status" ]["State" ]
104105 logger .debug (f"state: { state } " )
105106 logger .debug (
106107 f"StateChangeReason: { response ['QueryExecution' ]['Status' ].get ('StateChangeReason' )} "
107108 )
108109 if state == "FAILED" :
109- raise QueryFailed (
110- response ["QueryExecution" ]["Status" ].get ("StateChangeReason" ))
110+ raise QueryFailed (response ["QueryExecution" ]["Status" ].get ("StateChangeReason" ))
111111 elif state == "CANCELLED" :
112- raise QueryCancelled (
113- response ["QueryExecution" ]["Status" ].get ("StateChangeReason" ))
112+ raise QueryCancelled (response ["QueryExecution" ]["Status" ].get ("StateChangeReason" ))
114113 return response
115114
116115 def repair_table (self , database , table , s3_output = None , workgroup = None ):
@@ -130,17 +129,17 @@ def repair_table(self, database, table, s3_output=None, workgroup=None):
130129 :return: Query execution ID
131130 """
132131 query = f"MSCK REPAIR TABLE { table } ;"
133- query_id = self .run_query (query = query ,
134- database = database ,
135- s3_output = s3_output ,
136- workgroup = workgroup )
132+ query_id = self .run_query (
133+ query = query , database = database , s3_output = s3_output , workgroup = workgroup
134+ )
137135 self .wait_query (query_execution_id = query_id )
138136 return query_id
139137
140138 @staticmethod
141139 def _normalize_name (name ):
142- name = "" .join (c for c in unicodedata .normalize ("NFD" , name )
143- if unicodedata .category (c ) != "Mn" )
140+ name = "" .join (
141+ c for c in unicodedata .normalize ("NFD" , name ) if unicodedata .category (c ) != "Mn"
142+ )
144143 name = name .replace (" " , "_" )
145144 name = name .replace ("-" , "_" )
146145 name = name .replace ("." , "_" )
0 commit comments