Skip to content

Commit 79b6c3a

Browse files
authored
Merge pull request #42 from awslabs/athena-workgroup
Add Athena Workgroup
2 parents 6558cbc + 5caab27 commit 79b6c3a

File tree

3 files changed

+86
-24
lines changed

3 files changed

+86
-24
lines changed

awswrangler/athena.py

Lines changed: 11 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -63,22 +63,26 @@ def create_athena_bucket(self):
6363
s3_resource.Bucket(s3_output)
6464
return s3_output
6565

66-
def run_query(self, query, database, s3_output=None):
66+
def run_query(self, query, database, s3_output=None, workgroup=None):
6767
"""
6868
Run a SQL Query against AWS Athena
6969
7070
:param query: SQL query
7171
:param database: AWS Glue/Athena database name
7272
:param s3_output: AWS S3 path
73+
:param workgroup: Athena workgroup (By default uses de Session() workgroup)
7374
:return: Query execution ID
7475
"""
75-
if not s3_output:
76+
if s3_output is None:
7677
s3_output = self.create_athena_bucket()
78+
if workgroup is None:
79+
workgroup = self._session.athena_workgroup
80+
logger.debug(f"Workgroup: {workgroup}")
7781
response = self._client_athena.start_query_execution(
7882
QueryString=query,
7983
QueryExecutionContext={"Database": database},
8084
ResultConfiguration={"OutputLocation": s3_output},
81-
)
85+
WorkGroup=workgroup)
8286
return response["QueryExecutionId"]
8387

8488
def wait_query(self, query_execution_id):
@@ -109,7 +113,7 @@ def wait_query(self, query_execution_id):
109113
response["QueryExecution"]["Status"].get("StateChangeReason"))
110114
return response
111115

112-
def repair_table(self, database, table, s3_output=None):
116+
def repair_table(self, database, table, s3_output=None, workgroup=None):
113117
"""
114118
Hive's metastore consistency check
115119
"MSCK REPAIR TABLE table;"
@@ -122,12 +126,14 @@ def repair_table(self, database, table, s3_output=None):
122126
:param database: Glue database name
123127
:param table: Glue table name
124128
:param s3_output: AWS S3 path
129+
:param workgroup: Athena workgroup (By default uses de Session() workgroup)
125130
:return: Query execution ID
126131
"""
127132
query = f"MSCK REPAIR TABLE {table};"
128133
query_id = self.run_query(query=query,
129134
database=database,
130-
s3_output=s3_output)
135+
s3_output=s3_output,
136+
workgroup=workgroup)
131137
self.wait_query(query_execution_id=query_id)
132138
return query_id
133139

awswrangler/session.py

Lines changed: 29 additions & 19 deletions
Original file line numberDiff line numberDiff line change
@@ -43,6 +43,7 @@ def __init__(
4343
spark_session=None,
4444
procs_cpu_bound=os.cpu_count(),
4545
procs_io_bound=os.cpu_count() * PROCS_IO_BOUND_FACTOR,
46+
athena_workgroup="primary",
4647
):
4748
"""
4849
Most parameters inherit from Boto3 or Pyspark.
@@ -59,10 +60,9 @@ def __init__(
5960
:param s3_additional_kwargs: Passed on to s3fs (https://s3fs.readthedocs.io/en/latest/#serverside-encryption)
6061
:param spark_context: Spark Context (pyspark.SparkContext)
6162
:param spark_session: Spark Session (pyspark.sql.SparkSession)
62-
:param procs_cpu_bound: number of processes that can be used in single
63-
node applications for CPU bound case (Default: os.cpu_count())
64-
:param procs_io_bound: number of processes that can be used in single
65-
node applications for I/O bound cases (Default: os.cpu_count() * PROCS_IO_BOUND_FACTOR)
63+
:param procs_cpu_bound: number of processes that can be used in single node applications for CPU bound case (Default: os.cpu_count())
64+
:param procs_io_bound: number of processes that can be used in single node applications for I/O bound cases (Default: os.cpu_count() * PROCS_IO_BOUND_FACTOR)
65+
:param athena_workgroup: Default AWS Athena Workgroup (str)
6666
"""
6767
self._profile_name = (boto3_session.profile_name
6868
if boto3_session else profile_name)
@@ -81,6 +81,7 @@ def __init__(
8181
self._spark_session = spark_session
8282
self._procs_cpu_bound = procs_cpu_bound
8383
self._procs_io_bound = procs_io_bound
84+
self._athena_workgroup = athena_workgroup
8485
self._primitives = None
8586
self._load_new_primitives()
8687
if boto3_session:
@@ -134,6 +135,7 @@ def _load_new_primitives(self):
134135
botocore_config=self._botocore_config,
135136
procs_cpu_bound=self._procs_cpu_bound,
136137
procs_io_bound=self._procs_io_bound,
138+
athena_workgroup=self._athena_workgroup,
137139
)
138140

139141
@property
@@ -184,6 +186,10 @@ def procs_cpu_bound(self):
184186
def procs_io_bound(self):
185187
return self._procs_io_bound
186188

189+
@property
190+
def athena_workgroup(self):
191+
return self._athena_workgroup
192+
187193
@property
188194
def boto3_session(self):
189195
return self._boto3_session
@@ -255,6 +261,7 @@ def __init__(
255261
botocore_config=None,
256262
procs_cpu_bound=None,
257263
procs_io_bound=None,
264+
athena_workgroup=None,
258265
):
259266
"""
260267
Most parameters inherit from Boto3.
@@ -268,10 +275,9 @@ def __init__(
268275
:param botocore_max_retries: Botocore max retries
269276
:param s3_additional_kwargs: Passed on to s3fs (https://s3fs.readthedocs.io/en/latest/#serverside-encryption)
270277
:param botocore_config: Botocore configurations
271-
:param procs_cpu_bound: number of processes that can be used in single
272-
node applications for CPU bound case (Default: os.cpu_count())
273-
:param procs_io_bound: number of processes that can be used in single
274-
node applications for I/O bound cases (Default: os.cpu_count() * PROCS_IO_BOUND_FACTOR)
278+
:param procs_cpu_bound: number of processes that can be used in single node applications for CPU bound case (Default: os.cpu_count())
279+
:param procs_io_bound: number of processes that can be used in single node applications for I/O bound cases (Default: os.cpu_count() * PROCS_IO_BOUND_FACTOR)
280+
:param athena_workgroup: Default AWS Athena Workgroup (str)
275281
"""
276282
self._profile_name = profile_name
277283
self._aws_access_key_id = aws_access_key_id
@@ -283,6 +289,7 @@ def __init__(
283289
self._botocore_config = botocore_config
284290
self._procs_cpu_bound = procs_cpu_bound
285291
self._procs_io_bound = procs_io_bound
292+
self._athena_workgroup = athena_workgroup
286293

287294
@property
288295
def profile_name(self):
@@ -324,20 +331,23 @@ def procs_cpu_bound(self):
324331
def procs_io_bound(self):
325332
return self._procs_io_bound
326333

334+
@property
335+
def athena_workgroup(self):
336+
return self._athena_workgroup
337+
327338
@property
328339
def session(self):
329340
"""
330341
Reconstruct the session from primitives
331342
:return: awswrangler.session.Session
332343
"""
333-
return Session(
334-
profile_name=self._profile_name,
335-
aws_access_key_id=self._aws_access_key_id,
336-
aws_secret_access_key=self._aws_secret_access_key,
337-
aws_session_token=self._aws_session_token,
338-
region_name=self._region_name,
339-
botocore_max_retries=self._botocore_max_retries,
340-
s3_additional_kwargs=self._s3_additional_kwargs,
341-
procs_cpu_bound=self._procs_cpu_bound,
342-
procs_io_bound=self._procs_io_bound,
343-
)
344+
return Session(profile_name=self._profile_name,
345+
aws_access_key_id=self._aws_access_key_id,
346+
aws_secret_access_key=self._aws_secret_access_key,
347+
aws_session_token=self._aws_session_token,
348+
region_name=self._region_name,
349+
botocore_max_retries=self._botocore_max_retries,
350+
s3_additional_kwargs=self._s3_additional_kwargs,
351+
procs_cpu_bound=self._procs_cpu_bound,
352+
procs_io_bound=self._procs_io_bound,
353+
athena_workgroup=self._athena_workgroup)

testing/test_awswrangler/test_athena.py

Lines changed: 46 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,5 @@
11
import logging
2+
from pprint import pprint
23

34
import pytest
45
import boto3
@@ -37,6 +38,51 @@ def database(cloudformation_outputs):
3738
yield database
3839

3940

41+
@pytest.fixture(scope="module")
42+
def bucket(session, cloudformation_outputs):
43+
if "BucketName" in cloudformation_outputs:
44+
bucket = cloudformation_outputs["BucketName"]
45+
session.s3.delete_objects(path=f"s3://{bucket}/")
46+
else:
47+
raise Exception(
48+
"You must deploy the test infrastructure using Cloudformation!")
49+
yield bucket
50+
session.s3.delete_objects(path=f"s3://{bucket}/")
51+
52+
53+
@pytest.fixture(scope="module")
54+
def workgroup_secondary(bucket):
55+
wkg_name = "awswrangler_test"
56+
client = boto3.client('athena')
57+
wkgs = client.list_work_groups()
58+
wkgs = [x["Name"] for x in wkgs["WorkGroups"]]
59+
if wkg_name not in wkgs:
60+
response = client.create_work_group(
61+
Name=wkg_name,
62+
Configuration={
63+
"ResultConfiguration": {
64+
"OutputLocation":
65+
f"s3://{bucket}/athena_workgroup_secondary/",
66+
"EncryptionConfiguration": {
67+
"EncryptionOption": "SSE_S3",
68+
}
69+
},
70+
"EnforceWorkGroupConfiguration": True,
71+
"PublishCloudWatchMetricsEnabled": True,
72+
"BytesScannedCutoffPerQuery": 100_000_000,
73+
"RequesterPaysEnabled": False
74+
},
75+
Description="AWS Data Wrangler Test WorkGroup")
76+
pprint(response)
77+
yield wkg_name
78+
79+
80+
def test_workgroup_secondary(session, database, workgroup_secondary):
81+
session.athena.run_query(query="SELECT 1",
82+
database=database,
83+
workgroup=workgroup_secondary)
84+
85+
4086
def test_query_cancelled(session, database):
4187
client_athena = boto3.client("athena")
4288
query_execution_id = session.athena.run_query(query="""

0 commit comments

Comments
 (0)