Skip to content

Commit a6b7e92

Browse files
committed
Fixing athena queries for workgroups without encryption #159
1 parent ee33dcc commit a6b7e92

File tree

2 files changed

+38
-31
lines changed

2 files changed

+38
-31
lines changed

awswrangler/athena.py

Lines changed: 4 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -539,12 +539,13 @@ def get_work_group(workgroup: str, boto3_session: Optional[boto3.Session] = None
539539
def _ensure_workgroup(
540540
session: boto3.Session, workgroup: Optional[str] = None
541541
) -> Tuple[Optional[str], Optional[str], Optional[str]]:
542-
if workgroup:
542+
if workgroup is not None:
543543
res: Dict[str, Any] = get_work_group(workgroup=workgroup, boto3_session=session)
544544
config: Dict[str, Any] = res["WorkGroup"]["Configuration"]["ResultConfiguration"]
545545
wg_s3_output: Optional[str] = config.get("OutputLocation")
546-
wg_encryption: Optional[str] = config["EncryptionConfiguration"].get("EncryptionOption")
547-
wg_kms_key: Optional[str] = config["EncryptionConfiguration"].get("KmsKey")
546+
encrypt_config: Optional[Dict[str, str]] = config.get("EncryptionConfiguration")
547+
wg_encryption: Optional[str] = None if encrypt_config is None else encrypt_config.get("EncryptionOption")
548+
wg_kms_key: Optional[str] = None if encrypt_config is None else encrypt_config.get("KmsKey")
548549
else:
549550
wg_s3_output, wg_encryption, wg_kms_key = None, None, None
550551
return wg_s3_output, wg_encryption, wg_kms_key

testing/test_awswrangler/test_data_lake.py

Lines changed: 34 additions & 28 deletions
Original file line numberDiff line numberDiff line change
@@ -25,43 +25,27 @@ def cloudformation_outputs():
2525

2626
@pytest.fixture(scope="module")
2727
def region(cloudformation_outputs):
28-
if "Region" in cloudformation_outputs:
29-
region = cloudformation_outputs["Region"]
30-
else:
31-
raise Exception("You must deploy/update the test infrastructure (CloudFormation)!")
32-
yield region
28+
yield cloudformation_outputs["Region"]
3329

3430

3531
@pytest.fixture(scope="module")
3632
def bucket(cloudformation_outputs):
37-
if "BucketName" in cloudformation_outputs:
38-
bucket = cloudformation_outputs["BucketName"]
39-
else:
40-
raise Exception("You must deploy/update the test infrastructure (CloudFormation)")
41-
yield bucket
33+
yield cloudformation_outputs["BucketName"]
4234

4335

4436
@pytest.fixture(scope="module")
4537
def database(cloudformation_outputs):
46-
if "GlueDatabaseName" in cloudformation_outputs:
47-
database = cloudformation_outputs["GlueDatabaseName"]
48-
else:
49-
raise Exception("You must deploy the test infrastructure using Cloudformation!")
50-
yield database
38+
yield cloudformation_outputs["GlueDatabaseName"]
5139

5240

5341
@pytest.fixture(scope="module")
5442
def kms_key(cloudformation_outputs):
55-
if "KmsKeyArn" in cloudformation_outputs:
56-
key = cloudformation_outputs["KmsKeyArn"]
57-
else:
58-
raise Exception("You must deploy the test infrastructure using Cloudformation!")
59-
yield key
43+
yield cloudformation_outputs["KmsKeyArn"]
6044

6145

6246
@pytest.fixture(scope="module")
63-
def workgroup_secondary(bucket):
64-
wkg_name = "awswrangler_test"
47+
def workgroup0(bucket):
48+
wkg_name = "awswrangler_test_0"
6549
client = boto3.client("athena")
6650
wkgs = client.list_work_groups()
6751
wkgs = [x["Name"] for x in wkgs["WorkGroups"]]
@@ -70,15 +54,36 @@ def workgroup_secondary(bucket):
7054
Name=wkg_name,
7155
Configuration={
7256
"ResultConfiguration": {
73-
"OutputLocation": f"s3://{bucket}/athena_workgroup_secondary/",
57+
"OutputLocation": f"s3://{bucket}/athena_workgroup0/",
7458
"EncryptionConfiguration": {"EncryptionOption": "SSE_S3"},
7559
},
7660
"EnforceWorkGroupConfiguration": True,
7761
"PublishCloudWatchMetricsEnabled": True,
7862
"BytesScannedCutoffPerQuery": 100_000_000,
7963
"RequesterPaysEnabled": False,
8064
},
81-
Description="AWS Data Wrangler Test WorkGroup",
65+
Description="AWS Data Wrangler Test WorkGroup Number 0",
66+
)
67+
yield wkg_name
68+
69+
70+
@pytest.fixture(scope="module")
71+
def workgroup1(bucket):
72+
wkg_name = "awswrangler_test_1"
73+
client = boto3.client("athena")
74+
wkgs = client.list_work_groups()
75+
wkgs = [x["Name"] for x in wkgs["WorkGroups"]]
76+
if wkg_name not in wkgs:
77+
client.create_work_group(
78+
Name=wkg_name,
79+
Configuration={
80+
"ResultConfiguration": {"OutputLocation": f"s3://{bucket}/athena_workgroup1/"},
81+
"EnforceWorkGroupConfiguration": True,
82+
"PublishCloudWatchMetricsEnabled": True,
83+
"BytesScannedCutoffPerQuery": 100_000_000,
84+
"RequesterPaysEnabled": False,
85+
},
86+
Description="AWS Data Wrangler Test WorkGroup Number 1",
8287
)
8388
yield wkg_name
8489

@@ -120,7 +125,7 @@ def test_athena_ctas(bucket, database, kms_key):
120125
wr.s3.delete_objects(path=f"s3://{bucket}/test_athena_ctas_result/")
121126

122127

123-
def test_athena(bucket, database, kms_key, workgroup_secondary):
128+
def test_athena(bucket, database, kms_key, workgroup0, workgroup1):
124129
wr.s3.delete_objects(path=f"s3://{bucket}/test_athena/")
125130
paths = wr.s3.to_parquet(
126131
df=get_df(),
@@ -141,21 +146,22 @@ def test_athena(bucket, database, kms_key, workgroup_secondary):
141146
chunksize=1,
142147
encryption="SSE_KMS",
143148
kms_key=kms_key,
144-
workgroup=workgroup_secondary,
149+
workgroup=workgroup0,
145150
)
146151
for df2 in dfs:
147152
print(df2)
148153
ensure_data_types(df=df2)
149154
df = wr.athena.read_sql_query(
150-
sql="SELECT * FROM __test_athena", database=database, ctas_approach=False, workgroup=workgroup_secondary
155+
sql="SELECT * FROM __test_athena", database=database, ctas_approach=False, workgroup=workgroup1
151156
)
152157
assert len(df.index) == 3
153158
ensure_data_types(df=df)
154159
wr.athena.repair_table(table="__test_athena", database=database)
155160
wr.catalog.delete_table_if_exists(database=database, table="__test_athena")
156161
wr.s3.delete_objects(path=paths)
157162
wr.s3.wait_objects_not_exist(paths=paths)
158-
wr.s3.delete_objects(path=f"s3://{bucket}/athena_workgroup_secondary/")
163+
wr.s3.delete_objects(path=f"s3://{bucket}/athena_workgroup0/")
164+
wr.s3.delete_objects(path=f"s3://{bucket}/athena_workgroup1/")
159165

160166

161167
def test_csv(bucket):

0 commit comments

Comments
 (0)