22
33import csv
44import logging
5+ import pprint
56import time
67from decimal import Decimal
78from typing import Any , Dict , Iterator , List , Optional , Tuple , Union
@@ -120,19 +121,49 @@ def start_query_execution(
120121 >>> query_exec_id = wr.athena.start_query_execution(sql='...', database='...')
121122
122123 """
124+ session : boto3 .Session = _utils .ensure_session (session = boto3_session )
125+ wg_config : Dict [str , Union [bool , Optional [str ]]] = _get_workgroup_config (session = session , workgroup = workgroup )
126+ return _start_query_execution (
127+ sql = sql ,
128+ wg_config = wg_config ,
129+ database = database ,
130+ s3_output = s3_output ,
131+ workgroup = workgroup ,
132+ encryption = encryption ,
133+ kms_key = kms_key ,
134+ boto3_session = session ,
135+ )
136+
137+
138+ def _start_query_execution (
139+ sql : str ,
140+ wg_config : Dict [str , Union [Optional [bool ], Optional [str ]]],
141+ database : Optional [str ] = None ,
142+ s3_output : Optional [str ] = None ,
143+ workgroup : Optional [str ] = None ,
144+ encryption : Optional [str ] = None ,
145+ kms_key : Optional [str ] = None ,
146+ boto3_session : Optional [boto3 .Session ] = None ,
147+ ) -> str :
123148 args : Dict [str , Any ] = {"QueryString" : sql }
124149 session : boto3 .Session = _utils .ensure_session (session = boto3_session )
125150
126151 # s3_output
127- if s3_output is None : # pragma: no cover
128- s3_output = create_athena_bucket ( boto3_session = session )
129- args [ "ResultConfiguration" ] = { "OutputLocation" : s3_output }
152+ args [ "ResultConfiguration" ] = {
153+ "OutputLocation" : _get_s3_output ( s3_output = s3_output , wg_config = wg_config , boto3_session = session )
154+ }
130155
131156 # encryption
132- if encryption is not None :
133- args ["ResultConfiguration" ]["EncryptionConfiguration" ] = {"EncryptionOption" : encryption }
134- if kms_key is not None :
135- args ["ResultConfiguration" ]["EncryptionConfiguration" ]["KmsKey" ] = kms_key
157+ if wg_config ["enforced" ] is True :
158+ if wg_config ["encryption" ] is not None :
159+ args ["ResultConfiguration" ]["EncryptionConfiguration" ] = {"EncryptionOption" : wg_config ["encryption" ]}
160+ if wg_config ["kms_key" ] is not None :
161+ args ["ResultConfiguration" ]["EncryptionConfiguration" ]["KmsKey" ] = wg_config ["kms_key" ]
162+ else :
163+ if encryption is not None :
164+ args ["ResultConfiguration" ]["EncryptionConfiguration" ] = {"EncryptionOption" : encryption }
165+ if kms_key is not None :
166+ args ["ResultConfiguration" ]["EncryptionConfiguration" ]["KmsKey" ] = kms_key
136167
137168 # database
138169 if database is not None :
@@ -143,10 +174,25 @@ def start_query_execution(
143174 args ["WorkGroup" ] = workgroup
144175
145176 client_athena : boto3 .client = _utils .client (service_name = "athena" , session = session )
177+ _logger .debug ("args: \n %s" , pprint .pformat (args ))
146178 response = client_athena .start_query_execution (** args )
147179 return response ["QueryExecutionId" ]
148180
149181
182+ def _get_s3_output (
183+ s3_output : Optional [str ], wg_config : Dict [str , Union [bool , Optional [str ]]], boto3_session : boto3 .Session
184+ ) -> str :
185+ if s3_output is None :
186+ _s3_output : Optional [str ] = wg_config ["s3_output" ] # type: ignore
187+ if _s3_output is not None :
188+ s3_output = _s3_output
189+ else :
190+ s3_output = create_athena_bucket (boto3_session = boto3_session )
191+ elif wg_config ["enforced" ] is True :
192+ s3_output = wg_config ["s3_output" ] # type: ignore
193+ return s3_output
194+
195+
150196def wait_query (query_execution_id : str , boto3_session : Optional [boto3 .Session ] = None ) -> Dict [str , Any ]:
151197 """Wait for the query end.
152198
@@ -355,12 +401,14 @@ def read_sql_query( # pylint: disable=too-many-branches,too-many-locals
355401
356402 Note
357403 ----
358- If `ctas_approach` is True, `chunksize` will return non deterministic chunks sizes,
359- but it still useful to overcome memory limitation.
404+ Valid encryption modes: [None, 'SSE_S3', 'SSE_KMS'].
405+
406+ `P.S. 'CSE_KMS' is not supported.`
360407
361408 Note
362409 ----
363410 Create the default Athena bucket if it doesn't exist and s3_output is None.
411+
364412 (E.g. s3://aws-athena-query-results-ACCOUNT-REGION/)
365413
366414 Note
@@ -403,9 +451,9 @@ def read_sql_query( # pylint: disable=too-many-branches,too-many-locals
403451 workgroup : str, optional
404452 Athena workgroup.
405453 encryption : str, optional
406- None, 'SSE_S3', 'SSE_KMS', 'CSE_KMS'.
454+ Valid values: [ None, 'SSE_S3', 'SSE_KMS']. Notice: 'CSE_KMS' is not supported .
407455 kms_key : str, optional
408- For SSE-KMS and CSE-KMS , this is the KMS key ARN or ID.
456+ For SSE-KMS, this is the KMS key ARN or ID.
409457 use_threads : bool
410458 True to enable concurrent requests, False to disable multiple threads.
411459 If enabled os.cpu_count() will be used as the max number of threads.
@@ -424,31 +472,27 @@ def read_sql_query( # pylint: disable=too-many-branches,too-many-locals
424472
425473 """
426474 session : boto3 .Session = _utils .ensure_session (session = boto3_session )
427- wg_s3_output , _ , _ = _ensure_workgroup (session = session , workgroup = workgroup )
428- if s3_output is None :
429- if wg_s3_output is None :
430- _s3_output : str = create_athena_bucket (boto3_session = session )
431- else :
432- _s3_output = wg_s3_output
433- else :
434- _s3_output = s3_output
475+ wg_config : Dict [str , Union [bool , Optional [str ]]] = _get_workgroup_config (session = session , workgroup = workgroup )
476+ _s3_output : str = _get_s3_output (s3_output = s3_output , wg_config = wg_config , boto3_session = session )
435477 _s3_output = _s3_output [:- 1 ] if _s3_output [- 1 ] == "/" else _s3_output
436478 name : str = ""
437479 if ctas_approach is True :
438480 name = f"temp_table_{ pa .compat .guid ()} "
439481 path : str = f"{ _s3_output } /{ name } "
482+ ext_location : str = "\n " if wg_config ["enforced" ] is True else f",\n external_location = '{ path } '\n "
440483 sql = (
441484 f"CREATE TABLE { name } \n "
442485 f"WITH(\n "
443486 f" format = 'Parquet',\n "
444- f" parquet_compression = 'SNAPPY', \n "
445- f" external_location = ' { path } ' \n "
487+ f" parquet_compression = 'SNAPPY'"
488+ f"{ ext_location } "
446489 f") AS\n "
447490 f"{ sql } "
448491 )
449492 _logger .debug ("sql: %s" , sql )
450- query_id : str = start_query_execution (
493+ query_id : str = _start_query_execution (
451494 sql = sql ,
495+ wg_config = wg_config ,
452496 database = database ,
453497 s3_output = _s3_output ,
454498 workgroup = workgroup ,
@@ -466,6 +510,7 @@ def read_sql_query( # pylint: disable=too-many-branches,too-many-locals
466510 if ctas_approach is True :
467511 catalog .delete_table_if_exists (database = database , table = name , boto3_session = session )
468512 manifest_path : str = f"{ _s3_output } /tables/{ query_id } -manifest.csv"
513+ _logger .debug ("manifest_path: %s" , manifest_path )
469514 paths : List [str ] = _extract_ctas_manifest_paths (path = manifest_path , boto3_session = session )
470515 chunked : Union [bool , int ] = False if chunksize is None else chunksize
471516 _logger .debug ("chunked: %s" , chunked )
@@ -560,19 +605,27 @@ def get_work_group(workgroup: str, boto3_session: Optional[boto3.Session] = None
560605 return client_athena .get_work_group (WorkGroup = workgroup )
561606
562607
563- def _ensure_workgroup (
608+ def _get_workgroup_config (
564609 session : boto3 .Session , workgroup : Optional [str ] = None
565- ) -> Tuple [ Optional [ str ], Optional [ str ] , Optional [str ]]:
610+ ) -> Dict [ str , Union [ bool , Optional [str ] ]]:
566611 if workgroup is not None :
567612 res : Dict [str , Any ] = get_work_group (workgroup = workgroup , boto3_session = session )
613+ enforced : bool = res ["WorkGroup" ]["Configuration" ]["EnforceWorkGroupConfiguration" ]
568614 config : Dict [str , Any ] = res ["WorkGroup" ]["Configuration" ]["ResultConfiguration" ]
569615 wg_s3_output : Optional [str ] = config .get ("OutputLocation" )
570616 encrypt_config : Optional [Dict [str , str ]] = config .get ("EncryptionConfiguration" )
571617 wg_encryption : Optional [str ] = None if encrypt_config is None else encrypt_config .get ("EncryptionOption" )
572618 wg_kms_key : Optional [str ] = None if encrypt_config is None else encrypt_config .get ("KmsKey" )
573619 else :
574- wg_s3_output , wg_encryption , wg_kms_key = None , None , None
575- return wg_s3_output , wg_encryption , wg_kms_key
620+ enforced , wg_s3_output , wg_encryption , wg_kms_key = False , None , None , None
621+ wg_config : Dict [str , Union [bool , Optional [str ]]] = {
622+ "enforced" : enforced ,
623+ "s3_output" : wg_s3_output ,
624+ "encryption" : wg_encryption ,
625+ "kms_key" : wg_kms_key ,
626+ }
627+ _logger .debug ("wg_config: \n %s" , pprint .pformat (wg_config ))
628+ return wg_config
576629
577630
578631def read_sql_table (
@@ -606,12 +659,14 @@ def read_sql_table(
606659
607660 Note
608661 ----
609- If `ctas_approach` is True, `chunksize` will return non deterministic chunks sizes,
610- but it still useful to overcome memory limitation.
662+ Valid encryption modes: [None, 'SSE_S3', 'SSE_KMS'].
663+
664+ `P.S. 'CSE_KMS' is not supported.`
611665
612666 Note
613667 ----
614668 Create the default Athena bucket if it doesn't exist and s3_output is None.
669+
615670 (E.g. s3://aws-athena-query-results-ACCOUNT-REGION/)
616671
617672 Note
0 commit comments