Skip to content

Commit a54a578

Browse files
authored
Merge pull request #212 from awslabs/athena-encryption
Revisiting Athena encryption and workgroup
2 parents be0d89e + b4f6a36 commit a54a578

File tree

5 files changed

+192
-51
lines changed

5 files changed

+192
-51
lines changed

awswrangler/athena.py

Lines changed: 83 additions & 28 deletions
Original file line numberDiff line numberDiff line change
@@ -2,6 +2,7 @@
22

33
import csv
44
import logging
5+
import pprint
56
import time
67
from decimal import Decimal
78
from typing import Any, Dict, Iterator, List, Optional, Tuple, Union
@@ -120,19 +121,49 @@ def start_query_execution(
120121
>>> query_exec_id = wr.athena.start_query_execution(sql='...', database='...')
121122
122123
"""
124+
session: boto3.Session = _utils.ensure_session(session=boto3_session)
125+
wg_config: Dict[str, Union[bool, Optional[str]]] = _get_workgroup_config(session=session, workgroup=workgroup)
126+
return _start_query_execution(
127+
sql=sql,
128+
wg_config=wg_config,
129+
database=database,
130+
s3_output=s3_output,
131+
workgroup=workgroup,
132+
encryption=encryption,
133+
kms_key=kms_key,
134+
boto3_session=session,
135+
)
136+
137+
138+
def _start_query_execution(
139+
sql: str,
140+
wg_config: Dict[str, Union[Optional[bool], Optional[str]]],
141+
database: Optional[str] = None,
142+
s3_output: Optional[str] = None,
143+
workgroup: Optional[str] = None,
144+
encryption: Optional[str] = None,
145+
kms_key: Optional[str] = None,
146+
boto3_session: Optional[boto3.Session] = None,
147+
) -> str:
123148
args: Dict[str, Any] = {"QueryString": sql}
124149
session: boto3.Session = _utils.ensure_session(session=boto3_session)
125150

126151
# s3_output
127-
if s3_output is None: # pragma: no cover
128-
s3_output = create_athena_bucket(boto3_session=session)
129-
args["ResultConfiguration"] = {"OutputLocation": s3_output}
152+
args["ResultConfiguration"] = {
153+
"OutputLocation": _get_s3_output(s3_output=s3_output, wg_config=wg_config, boto3_session=session)
154+
}
130155

131156
# encryption
132-
if encryption is not None:
133-
args["ResultConfiguration"]["EncryptionConfiguration"] = {"EncryptionOption": encryption}
134-
if kms_key is not None:
135-
args["ResultConfiguration"]["EncryptionConfiguration"]["KmsKey"] = kms_key
157+
if wg_config["enforced"] is True:
158+
if wg_config["encryption"] is not None:
159+
args["ResultConfiguration"]["EncryptionConfiguration"] = {"EncryptionOption": wg_config["encryption"]}
160+
if wg_config["kms_key"] is not None:
161+
args["ResultConfiguration"]["EncryptionConfiguration"]["KmsKey"] = wg_config["kms_key"]
162+
else:
163+
if encryption is not None:
164+
args["ResultConfiguration"]["EncryptionConfiguration"] = {"EncryptionOption": encryption}
165+
if kms_key is not None:
166+
args["ResultConfiguration"]["EncryptionConfiguration"]["KmsKey"] = kms_key
136167

137168
# database
138169
if database is not None:
@@ -143,10 +174,25 @@ def start_query_execution(
143174
args["WorkGroup"] = workgroup
144175

145176
client_athena: boto3.client = _utils.client(service_name="athena", session=session)
177+
_logger.debug("args: \n%s", pprint.pformat(args))
146178
response = client_athena.start_query_execution(**args)
147179
return response["QueryExecutionId"]
148180

149181

182+
def _get_s3_output(
183+
s3_output: Optional[str], wg_config: Dict[str, Union[bool, Optional[str]]], boto3_session: boto3.Session
184+
) -> str:
185+
if s3_output is None:
186+
_s3_output: Optional[str] = wg_config["s3_output"] # type: ignore
187+
if _s3_output is not None:
188+
s3_output = _s3_output
189+
else:
190+
s3_output = create_athena_bucket(boto3_session=boto3_session)
191+
elif wg_config["enforced"] is True:
192+
s3_output = wg_config["s3_output"] # type: ignore
193+
return s3_output
194+
195+
150196
def wait_query(query_execution_id: str, boto3_session: Optional[boto3.Session] = None) -> Dict[str, Any]:
151197
"""Wait for the query end.
152198
@@ -355,12 +401,14 @@ def read_sql_query( # pylint: disable=too-many-branches,too-many-locals
355401
356402
Note
357403
----
358-
If `ctas_approach` is True, `chunksize` will return non deterministic chunks sizes,
359-
but it still useful to overcome memory limitation.
404+
Valid encryption modes: [None, 'SSE_S3', 'SSE_KMS'].
405+
406+
`P.S. 'CSE_KMS' is not supported.`
360407
361408
Note
362409
----
363410
Create the default Athena bucket if it doesn't exist and s3_output is None.
411+
364412
(E.g. s3://aws-athena-query-results-ACCOUNT-REGION/)
365413
366414
Note
@@ -403,9 +451,9 @@ def read_sql_query( # pylint: disable=too-many-branches,too-many-locals
403451
workgroup : str, optional
404452
Athena workgroup.
405453
encryption : str, optional
406-
None, 'SSE_S3', 'SSE_KMS', 'CSE_KMS'.
454+
Valid values: [None, 'SSE_S3', 'SSE_KMS']. Notice: 'CSE_KMS' is not supported.
407455
kms_key : str, optional
408-
For SSE-KMS and CSE-KMS , this is the KMS key ARN or ID.
456+
For SSE-KMS, this is the KMS key ARN or ID.
409457
use_threads : bool
410458
True to enable concurrent requests, False to disable multiple threads.
411459
If enabled os.cpu_count() will be used as the max number of threads.
@@ -424,31 +472,27 @@ def read_sql_query( # pylint: disable=too-many-branches,too-many-locals
424472
425473
"""
426474
session: boto3.Session = _utils.ensure_session(session=boto3_session)
427-
wg_s3_output, _, _ = _ensure_workgroup(session=session, workgroup=workgroup)
428-
if s3_output is None:
429-
if wg_s3_output is None:
430-
_s3_output: str = create_athena_bucket(boto3_session=session)
431-
else:
432-
_s3_output = wg_s3_output
433-
else:
434-
_s3_output = s3_output
475+
wg_config: Dict[str, Union[bool, Optional[str]]] = _get_workgroup_config(session=session, workgroup=workgroup)
476+
_s3_output: str = _get_s3_output(s3_output=s3_output, wg_config=wg_config, boto3_session=session)
435477
_s3_output = _s3_output[:-1] if _s3_output[-1] == "/" else _s3_output
436478
name: str = ""
437479
if ctas_approach is True:
438480
name = f"temp_table_{pa.compat.guid()}"
439481
path: str = f"{_s3_output}/{name}"
482+
ext_location: str = "\n" if wg_config["enforced"] is True else f",\n external_location = '{path}'\n"
440483
sql = (
441484
f"CREATE TABLE {name}\n"
442485
f"WITH(\n"
443486
f" format = 'Parquet',\n"
444-
f" parquet_compression = 'SNAPPY',\n"
445-
f" external_location = '{path}'\n"
487+
f" parquet_compression = 'SNAPPY'"
488+
f"{ext_location}"
446489
f") AS\n"
447490
f"{sql}"
448491
)
449492
_logger.debug("sql: %s", sql)
450-
query_id: str = start_query_execution(
493+
query_id: str = _start_query_execution(
451494
sql=sql,
495+
wg_config=wg_config,
452496
database=database,
453497
s3_output=_s3_output,
454498
workgroup=workgroup,
@@ -466,6 +510,7 @@ def read_sql_query( # pylint: disable=too-many-branches,too-many-locals
466510
if ctas_approach is True:
467511
catalog.delete_table_if_exists(database=database, table=name, boto3_session=session)
468512
manifest_path: str = f"{_s3_output}/tables/{query_id}-manifest.csv"
513+
_logger.debug("manifest_path: %s", manifest_path)
469514
paths: List[str] = _extract_ctas_manifest_paths(path=manifest_path, boto3_session=session)
470515
chunked: Union[bool, int] = False if chunksize is None else chunksize
471516
_logger.debug("chunked: %s", chunked)
@@ -560,19 +605,27 @@ def get_work_group(workgroup: str, boto3_session: Optional[boto3.Session] = None
560605
return client_athena.get_work_group(WorkGroup=workgroup)
561606

562607

563-
def _ensure_workgroup(
608+
def _get_workgroup_config(
564609
session: boto3.Session, workgroup: Optional[str] = None
565-
) -> Tuple[Optional[str], Optional[str], Optional[str]]:
610+
) -> Dict[str, Union[bool, Optional[str]]]:
566611
if workgroup is not None:
567612
res: Dict[str, Any] = get_work_group(workgroup=workgroup, boto3_session=session)
613+
enforced: bool = res["WorkGroup"]["Configuration"]["EnforceWorkGroupConfiguration"]
568614
config: Dict[str, Any] = res["WorkGroup"]["Configuration"]["ResultConfiguration"]
569615
wg_s3_output: Optional[str] = config.get("OutputLocation")
570616
encrypt_config: Optional[Dict[str, str]] = config.get("EncryptionConfiguration")
571617
wg_encryption: Optional[str] = None if encrypt_config is None else encrypt_config.get("EncryptionOption")
572618
wg_kms_key: Optional[str] = None if encrypt_config is None else encrypt_config.get("KmsKey")
573619
else:
574-
wg_s3_output, wg_encryption, wg_kms_key = None, None, None
575-
return wg_s3_output, wg_encryption, wg_kms_key
620+
enforced, wg_s3_output, wg_encryption, wg_kms_key = False, None, None, None
621+
wg_config: Dict[str, Union[bool, Optional[str]]] = {
622+
"enforced": enforced,
623+
"s3_output": wg_s3_output,
624+
"encryption": wg_encryption,
625+
"kms_key": wg_kms_key,
626+
}
627+
_logger.debug("wg_config: \n%s", pprint.pformat(wg_config))
628+
return wg_config
576629

577630

578631
def read_sql_table(
@@ -606,12 +659,14 @@ def read_sql_table(
606659
607660
Note
608661
----
609-
If `ctas_approach` is True, `chunksize` will return non deterministic chunks sizes,
610-
but it still useful to overcome memory limitation.
662+
Valid encryption modes: [None, 'SSE_S3', 'SSE_KMS'].
663+
664+
`P.S. 'CSE_KMS' is not supported.`
611665
612666
Note
613667
----
614668
Create the default Athena bucket if it doesn't exist and s3_output is None.
669+
615670
(E.g. s3://aws-athena-query-results-ACCOUNT-REGION/)
616671
617672
Note

awswrangler/emr.py

Lines changed: 9 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -1,8 +1,8 @@
11
"""EMR (Elastic Map Reduce) module."""
22
# pylint: disable=line-too-long
33

4-
import json
54
import logging
5+
import pprint
66
from typing import Any, Dict, List, Optional, Union
77

88
import boto3 # type: ignore
@@ -364,7 +364,7 @@ def _build_cluster_args(**pars): # pylint: disable=too-many-branches,too-many-s
364364
if pars["tags"] is not None:
365365
args["Tags"] = [{"Key": k, "Value": v} for k, v in pars["tags"].items()]
366366

367-
_logger.info("args: \n%s", json.dumps(args, default=str, indent=4))
367+
_logger.debug("args: \n%s", pprint.pformat(args))
368368
return args
369369

370370

@@ -665,7 +665,7 @@ def create_cluster( # pylint: disable=too-many-arguments,too-many-locals,unused
665665
args: Dict[str, Any] = _build_cluster_args(**locals())
666666
client_emr: boto3.client = _utils.client(service_name="emr", session=boto3_session)
667667
response: Dict[str, Any] = client_emr.run_job_flow(**args)
668-
_logger.debug("response: \n%s", json.dumps(response, default=str, indent=4))
668+
_logger.debug("response: \n%s", pprint.pformat(response))
669669
return response["JobFlowId"]
670670

671671

@@ -696,7 +696,7 @@ def get_cluster_state(cluster_id: str, boto3_session: Optional[boto3.Session] =
696696
"""
697697
client_emr: boto3.client = _utils.client(service_name="emr", session=boto3_session)
698698
response: Dict[str, Any] = client_emr.describe_cluster(ClusterId=cluster_id)
699-
_logger.debug("response: \n%s", json.dumps(response, default=str, indent=4))
699+
_logger.debug("response: \n%s", pprint.pformat(response))
700700
return response["Cluster"]["Status"]["State"]
701701

702702

@@ -723,7 +723,7 @@ def terminate_cluster(cluster_id: str, boto3_session: Optional[boto3.Session] =
723723
"""
724724
client_emr: boto3.client = _utils.client(service_name="emr", session=boto3_session)
725725
response: Dict[str, Any] = client_emr.terminate_job_flows(JobFlowIds=[cluster_id])
726-
_logger.debug("response: \n%s", json.dumps(response, default=str, indent=4))
726+
_logger.debug("response: \n%s", pprint.pformat(response))
727727

728728

729729
def submit_steps(
@@ -755,7 +755,7 @@ def submit_steps(
755755
"""
756756
client_emr: boto3.client = _utils.client(service_name="emr", session=boto3_session)
757757
response: Dict[str, Any] = client_emr.add_job_flow_steps(JobFlowId=cluster_id, Steps=steps)
758-
_logger.debug("response: \n%s", json.dumps(response, default=str, indent=4))
758+
_logger.debug("response: \n%s", pprint.pformat(response))
759759
return response["StepIds"]
760760

761761

@@ -807,7 +807,7 @@ def submit_step(
807807
)
808808
client_emr: boto3.client = _utils.client(service_name="emr", session=session)
809809
response: Dict[str, Any] = client_emr.add_job_flow_steps(JobFlowId=cluster_id, Steps=[step])
810-
_logger.debug("response: \n%s", json.dumps(response, default=str, indent=4))
810+
_logger.debug("response: \n%s", pprint.pformat(response))
811811
return response["StepIds"][0]
812812

813813

@@ -898,7 +898,7 @@ def get_step_state(cluster_id: str, step_id: str, boto3_session: Optional[boto3.
898898
"""
899899
client_emr: boto3.client = _utils.client(service_name="emr", session=boto3_session)
900900
response: Dict[str, Any] = client_emr.describe_step(ClusterId=cluster_id, StepId=step_id)
901-
_logger.debug("response: \n%s", json.dumps(response, default=str, indent=4))
901+
_logger.debug("response: \n%s", pprint.pformat(response))
902902
return response["Step"]["Status"]["State"]
903903

904904

@@ -942,7 +942,7 @@ def submit_ecr_credentials_refresh(
942942
)
943943
client_emr: boto3.client = _utils.client(service_name="emr", session=session)
944944
response: Dict[str, Any] = client_emr.add_job_flow_steps(JobFlowId=cluster_id, Steps=[step])
945-
_logger.debug("response: \n%s", json.dumps(response, default=str, indent=4))
945+
_logger.debug("response: \n%s", pprint.pformat(response))
946946
return response["StepIds"][0]
947947

948948

0 commit comments

Comments
 (0)