Skip to content

Commit f100a78

Browse files
committed
Add more Athena defaults to Session()
1 parent 7279ff1 commit f100a78

File tree

4 files changed

+179
-62
lines changed

4 files changed

+179
-62
lines changed

awswrangler/athena.py

Lines changed: 48 additions & 15 deletions
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,4 @@
1-
from typing import Dict, List, Tuple, Optional, Any, Iterator
1+
from typing import Dict, List, Tuple, Optional, Any, Iterator, Union
22
from time import sleep
33
import logging
44
import re
@@ -41,25 +41,53 @@ def create_athena_bucket(self):
4141
s3_resource.Bucket(s3_output)
4242
return s3_output
4343

44-
def run_query(self, query, database, s3_output=None, workgroup=None):
44+
def run_query(self, query: str, database: Optional[str] = None, s3_output: Optional[str] = None, workgroup: Optional[str] = None, encryption: Optional[str] = None, kms_key: Optional[str] = None) -> str:
4545
"""
4646
Run a SQL Query against AWS Athena
47+
P.S All default values will be inherited from the Session()
4748
4849
:param query: SQL query
4950
:param database: AWS Glue/Athena database name
5051
:param s3_output: AWS S3 path
5152
:param workgroup: Athena workgroup (By default uses de Session() workgroup)
53+
:param encryption: None|'SSE_S3'|'SSE_KMS'|'CSE_KMS'
54+
:param kms_key: For SSE-KMS and CSE-KMS , this is the KMS key ARN or ID.
5255
:return: Query execution ID
5356
"""
57+
args: Dict[str, Union[str, Dict[str, Union[str, Dict[str, str]]]]] = {"QueryString": query}
58+
59+
# s3_output
5460
if s3_output is None:
55-
s3_output = self.create_athena_bucket()
56-
if workgroup is None:
57-
workgroup = self._session.athena_workgroup
58-
logger.debug(f"Workgroup: {workgroup}")
59-
response = self._client_athena.start_query_execution(QueryString=query,
60-
QueryExecutionContext={"Database": database},
61-
ResultConfiguration={"OutputLocation": s3_output},
62-
WorkGroup=workgroup)
61+
if self._session.athena_s3_output is not None:
62+
s3_output = self._session.athena_s3_output
63+
else:
64+
s3_output = self.create_athena_bucket()
65+
args["ResultConfiguration"] = {"OutputLocation": s3_output}
66+
67+
# encryption
68+
if encryption is not None:
69+
args["ResultConfiguration"]["EncryptionConfiguration"] = {"EncryptionOption": encryption}
70+
if kms_key is not None:
71+
args["ResultConfiguration"]["EncryptionConfiguration"]["KmsKey"] = kms_key
72+
elif self._session.athena_encryption is not None:
73+
args["ResultConfiguration"]["EncryptionConfiguration"] = {"EncryptionOption": self._session.athena_encryption}
74+
if self._session.athena_kms_key is not None:
75+
args["ResultConfiguration"]["EncryptionConfiguration"]["KmsKey"] = self._session.athena_kms_key
76+
77+
# database
78+
if database is not None:
79+
args["QueryExecutionContext"] = {"Database": database}
80+
elif self._session.athena_database is not None:
81+
args["QueryExecutionContext"] = {"Database": self._session.athena_database}
82+
83+
# workgroup
84+
if workgroup is not None:
85+
args["WorkGroup"] = workgroup
86+
elif self._session.athena_workgroup is not None:
87+
args["WorkGroup"] = self._session.athena_workgroup
88+
89+
logger.debug(f"args: {args}")
90+
response = self._client_athena.start_query_execution(**args)
6391
return response["QueryExecutionId"]
6492

6593
def wait_query(self, query_execution_id):
@@ -84,7 +112,7 @@ def wait_query(self, query_execution_id):
84112
raise QueryCancelled(response["QueryExecution"]["Status"].get("StateChangeReason"))
85113
return response
86114

87-
def repair_table(self, database, table, s3_output=None, workgroup=None):
115+
def repair_table(self, table: str, database: Optional[str] = None, s3_output: Optional[str] = None, workgroup: Optional[str] = None, encryption: Optional[str] = None, kms_key: Optional[str] = None):
88116
"""
89117
Hive's metastore consistency check
90118
"MSCK REPAIR TABLE table;"
@@ -93,15 +121,18 @@ def repair_table(self, database, table, s3_output=None, workgroup=None):
93121
It is possible it will take some time to add all partitions.
94122
If this operation times out, it will be in an incomplete state
95123
where only a few partitions are added to the catalog.
124+
P.S All default values will be inherited from the Session()
96125
97126
:param database: Glue database name
98127
:param table: Glue table name
99128
:param s3_output: AWS S3 path
100129
:param workgroup: Athena workgroup (By default uses de Session() workgroup)
130+
:param encryption: None|'SSE_S3'|'SSE_KMS'|'CSE_KMS'
131+
:param kms_key: For SSE-KMS and CSE-KMS , this is the KMS key ARN or ID.
101132
:return: Query execution ID
102133
"""
103134
query = f"MSCK REPAIR TABLE {table};"
104-
query_id = self.run_query(query=query, database=database, s3_output=s3_output, workgroup=workgroup)
135+
query_id = self.run_query(query=query, database=database, s3_output=s3_output, workgroup=workgroup, encryption=encryption, kms_key=kms_key)
105136
self.wait_query(query_execution_id=query_id)
106137
return query_id
107138

@@ -142,18 +173,20 @@ def get_results(self, query_execution_id: str) -> Iterator[Dict[str, Any]]:
142173
yield row
143174
next_token = res.get("NextToken")
144175

145-
def query(self, query: str, database: str, s3_output: str = None,
146-
workgroup: str = None) -> Iterator[Dict[str, Any]]:
176+
def query(self, query: str, database: Optional[str] = None, s3_output: Optional[str] = None, workgroup: Optional[str] = None, encryption: Optional[str] = None, kms_key: Optional[str] = None) -> Iterator[Dict[str, Any]]:
147177
"""
148178
Run a SQL Query against AWS Athena and return the result as a Iterator of lists
179+
P.S All default values will be inherited from the Session()
149180
150181
:param query: SQL query
151182
:param database: Glue database name
152183
:param s3_output: AWS S3 path
153184
:param workgroup: Athena workgroup (By default uses de Session() workgroup)
185+
:param encryption: None|'SSE_S3'|'SSE_KMS'|'CSE_KMS'
186+
:param kms_key: For SSE-KMS and CSE-KMS , this is the KMS key ARN or ID.
154187
:return: Query execution ID
155188
"""
156-
query_id: str = self.run_query(query=query, database=database, s3_output=s3_output, workgroup=workgroup)
189+
query_id: str = self.run_query(query=query, database=database, s3_output=s3_output, workgroup=workgroup, encryption=encryption, kms_key=kms_key)
157190
self.wait_query(query_execution_id=query_id)
158191
return self.get_results(query_execution_id=query_id)
159192

awswrangler/pandas.py

Lines changed: 14 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,4 @@
1-
from typing import Dict, List, Tuple, Optional, Any
1+
from typing import Dict, List, Tuple, Optional, Any, Union
22
from io import BytesIO, StringIO
33
import multiprocessing as mp
44
import logging
@@ -32,7 +32,6 @@ def _get_bounders(dataframe, num_partitions):
3232

3333

3434
class Pandas:
35-
3635
VALID_CSV_SERDES = ["OpenCSVSerDe", "LazySimpleSerDe"]
3736
VALID_CSV_COMPRESSIONS = [None]
3837
VALID_PARQUET_COMPRESSIONS = [None, "snappy", "gzip"]
@@ -61,7 +60,7 @@ def read_csv(
6160
quotechar='"',
6261
quoting=csv.QUOTE_MINIMAL,
6362
escapechar=None,
64-
parse_dates=False,
63+
parse_dates: Union[bool, Dict, List] = False,
6564
infer_datetime_format=False,
6665
encoding="utf-8",
6766
converters=None,
@@ -153,7 +152,7 @@ def _read_csv_iterator(
153152
quotechar='"',
154153
quoting=csv.QUOTE_MINIMAL,
155154
escapechar=None,
156-
parse_dates=False,
155+
parse_dates: Union[bool, Dict, List] = False,
157156
infer_datetime_format=False,
158157
encoding="utf-8",
159158
converters=None,
@@ -365,7 +364,7 @@ def _read_csv_once(
365364
quotechar='"',
366365
quoting=0,
367366
escapechar=None,
368-
parse_dates=False,
367+
parse_dates: Union[bool, Dict, List] = False,
369368
infer_datetime_format=False,
370369
encoding=None,
371370
converters=None,
@@ -446,20 +445,25 @@ def _get_query_dtype(self, query_execution_id: str) -> Tuple[Dict[str, str], Lis
446445
logger.debug(f"parse_dates: {parse_dates}")
447446
return dtype, parse_timestamps, parse_dates, converters
448447

449-
def read_sql_athena(self, sql, database, s3_output=None, max_result_size=None):
448+
def read_sql_athena(self, sql, database=None, s3_output=None, max_result_size=None, workgroup=None,
449+
encryption=None, kms_key=None):
450450
"""
451451
Executes any SQL query on AWS Athena and return a Dataframe of the result.
452452
P.S. If max_result_size is passed, then a iterator of Dataframes is returned.
453+
P.S.S. All default values will be inherited from the Session()
453454
454455
:param sql: SQL Query
455456
:param database: Glue/Athena Database
456457
:param s3_output: AWS S3 path
457458
:param max_result_size: Max number of bytes on each request to S3
459+
:param workgroup: The name of the workgroup in which the query is being started. (By default uses de Session() workgroup)
460+
:param encryption: None|'SSE_S3'|'SSE_KMS'|'CSE_KMS'
461+
:param kms_key: For SSE-KMS and CSE-KMS , this is the KMS key ARN or ID.
458462
:return: Pandas Dataframe or Iterator of Pandas Dataframes if max_result_size != None
459463
"""
460464
if not s3_output:
461465
s3_output = self._session.athena.create_athena_bucket()
462-
query_execution_id = self._session.athena.run_query(query=sql, database=database, s3_output=s3_output)
466+
query_execution_id = self._session.athena.run_query(query=sql, database=database, s3_output=s3_output, workgroup=workgroup, encryption=encryption, kms_key=kms_key)
463467
query_response = self._session.athena.wait_query(query_execution_id=query_execution_id)
464468
if query_response["QueryExecution"]["Status"]["State"] in ["FAILED", "CANCELLED"]:
465469
reason = query_response["QueryExecution"]["Status"]["StateChangeReason"]
@@ -497,7 +501,7 @@ def to_csv(
497501
path,
498502
sep=",",
499503
serde="OpenCSVSerDe",
500-
database=None,
504+
database: Optional[str] = None,
501505
table=None,
502506
partition_cols=None,
503507
preserve_index=True,
@@ -544,7 +548,7 @@ def to_csv(
544548
def to_parquet(self,
545549
dataframe,
546550
path,
547-
database=None,
551+
database: Optional[str] = None,
548552
table=None,
549553
partition_cols=None,
550554
preserve_index=True,
@@ -766,7 +770,7 @@ def _data_to_s3_dataset_writer(dataframe,
766770
for keys, subgroup in dataframe.groupby(partition_cols):
767771
subgroup = subgroup.drop(partition_cols, axis="columns")
768772
if not isinstance(keys, tuple):
769-
keys = (keys, )
773+
keys = (keys,)
770774
subdir = "/".join([f"{name}={val}" for name, val in zip(partition_cols, keys)])
771775
prefix = "/".join([path, subdir])
772776
object_path = Pandas._data_to_s3_object_writer(dataframe=subgroup,

0 commit comments

Comments
 (0)