1- from typing import Dict , List , Tuple , Optional , Any , Iterator
1+ from typing import Dict , List , Tuple , Optional , Any , Iterator , Union
22from time import sleep
33import logging
44import re
@@ -41,25 +41,53 @@ def create_athena_bucket(self):
4141 s3_resource .Bucket (s3_output )
4242 return s3_output
4343
44- def run_query (self , query , database , s3_output = None , workgroup = None ) :
44+ def run_query (self , query : str , database : Optional [ str ] = None , s3_output : Optional [ str ] = None , workgroup : Optional [ str ] = None , encryption : Optional [ str ] = None , kms_key : Optional [ str ] = None ) -> str :
4545 """
4646 Run a SQL Query against AWS Athena
47+ P.S All default values will be inherited from the Session()
4748
4849 :param query: SQL query
4950 :param database: AWS Glue/Athena database name
5051 :param s3_output: AWS S3 path
5152 :param workgroup: Athena workgroup (By default uses de Session() workgroup)
53+ :param encryption: None|'SSE_S3'|'SSE_KMS'|'CSE_KMS'
54+ :param kms_key: For SSE-KMS and CSE-KMS , this is the KMS key ARN or ID.
5255 :return: Query execution ID
5356 """
57+ args : Dict [str , Union [str , Dict [str , Union [str , Dict [str , str ]]]]] = {"QueryString" : query }
58+
59+ # s3_output
5460 if s3_output is None :
55- s3_output = self .create_athena_bucket ()
56- if workgroup is None :
57- workgroup = self ._session .athena_workgroup
58- logger .debug (f"Workgroup: { workgroup } " )
59- response = self ._client_athena .start_query_execution (QueryString = query ,
60- QueryExecutionContext = {"Database" : database },
61- ResultConfiguration = {"OutputLocation" : s3_output },
62- WorkGroup = workgroup )
61+ if self ._session .athena_s3_output is not None :
62+ s3_output = self ._session .athena_s3_output
63+ else :
64+ s3_output = self .create_athena_bucket ()
65+ args ["ResultConfiguration" ] = {"OutputLocation" : s3_output }
66+
67+ # encryption
68+ if encryption is not None :
69+ args ["ResultConfiguration" ]["EncryptionConfiguration" ] = {"EncryptionOption" : encryption }
70+ if kms_key is not None :
71+ args ["ResultConfiguration" ]["EncryptionConfiguration" ]["KmsKey" ] = kms_key
72+ elif self ._session .athena_encryption is not None :
73+ args ["ResultConfiguration" ]["EncryptionConfiguration" ] = {"EncryptionOption" : self ._session .athena_encryption }
74+ if self ._session .athena_kms_key is not None :
75+ args ["ResultConfiguration" ]["EncryptionConfiguration" ]["KmsKey" ] = self ._session .athena_kms_key
76+
77+ # database
78+ if database is not None :
79+ args ["QueryExecutionContext" ] = {"Database" : database }
80+ elif self ._session .athena_database is not None :
81+ args ["QueryExecutionContext" ] = {"Database" : self ._session .athena_database }
82+
83+ # workgroup
84+ if workgroup is not None :
85+ args ["WorkGroup" ] = workgroup
86+ elif self ._session .athena_workgroup is not None :
87+ args ["WorkGroup" ] = self ._session .athena_workgroup
88+
89+ logger .debug (f"args: { args } " )
90+ response = self ._client_athena .start_query_execution (** args )
6391 return response ["QueryExecutionId" ]
6492
6593 def wait_query (self , query_execution_id ):
@@ -84,7 +112,7 @@ def wait_query(self, query_execution_id):
84112 raise QueryCancelled (response ["QueryExecution" ]["Status" ].get ("StateChangeReason" ))
85113 return response
86114
87- def repair_table (self , database , table , s3_output = None , workgroup = None ):
115+ def repair_table (self , table : str , database : Optional [ str ] = None , s3_output : Optional [ str ] = None , workgroup : Optional [ str ] = None , encryption : Optional [ str ] = None , kms_key : Optional [ str ] = None ):
88116 """
89117 Hive's metastore consistency check
90118 "MSCK REPAIR TABLE table;"
@@ -93,15 +121,18 @@ def repair_table(self, database, table, s3_output=None, workgroup=None):
93121 It is possible it will take some time to add all partitions.
94122 If this operation times out, it will be in an incomplete state
95123 where only a few partitions are added to the catalog.
124+ P.S All default values will be inherited from the Session()
96125
97126 :param database: Glue database name
98127 :param table: Glue table name
99128 :param s3_output: AWS S3 path
100129 :param workgroup: Athena workgroup (By default uses de Session() workgroup)
130+ :param encryption: None|'SSE_S3'|'SSE_KMS'|'CSE_KMS'
131+ :param kms_key: For SSE-KMS and CSE-KMS , this is the KMS key ARN or ID.
101132 :return: Query execution ID
102133 """
103134 query = f"MSCK REPAIR TABLE { table } ;"
104- query_id = self .run_query (query = query , database = database , s3_output = s3_output , workgroup = workgroup )
135+ query_id = self .run_query (query = query , database = database , s3_output = s3_output , workgroup = workgroup , encryption = encryption , kms_key = kms_key )
105136 self .wait_query (query_execution_id = query_id )
106137 return query_id
107138
@@ -142,18 +173,20 @@ def get_results(self, query_execution_id: str) -> Iterator[Dict[str, Any]]:
142173 yield row
143174 next_token = res .get ("NextToken" )
144175
145- def query (self , query : str , database : str , s3_output : str = None ,
146- workgroup : str = None ) -> Iterator [Dict [str , Any ]]:
176+ def query (self , query : str , database : Optional [str ] = None , s3_output : Optional [str ] = None , workgroup : Optional [str ] = None , encryption : Optional [str ] = None , kms_key : Optional [str ] = None ) -> Iterator [Dict [str , Any ]]:
147177 """
148178 Run a SQL Query against AWS Athena and return the result as a Iterator of lists
179+ P.S All default values will be inherited from the Session()
149180
150181 :param query: SQL query
151182 :param database: Glue database name
152183 :param s3_output: AWS S3 path
153184 :param workgroup: Athena workgroup (By default uses de Session() workgroup)
185+ :param encryption: None|'SSE_S3'|'SSE_KMS'|'CSE_KMS'
186+ :param kms_key: For SSE-KMS and CSE-KMS , this is the KMS key ARN or ID.
154187 :return: Query execution ID
155188 """
156- query_id : str = self .run_query (query = query , database = database , s3_output = s3_output , workgroup = workgroup )
189+ query_id : str = self .run_query (query = query , database = database , s3_output = s3_output , workgroup = workgroup , encryption = encryption , kms_key = kms_key )
157190 self .wait_query (query_execution_id = query_id )
158191 return self .get_results (query_execution_id = query_id )
159192
0 commit comments