Skip to content

Commit cf0690b

Browse files
authored
Merge pull request #162 from awslabs/athena-workgroup-encryption
Fixing athena queries for workgroups without encryption
2 parents f6c523a + a6b7e92 commit cf0690b

File tree

2 files changed

+38
-31
lines changed

2 files changed

+38
-31
lines changed

awswrangler/athena.py

Lines changed: 4 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -547,12 +547,13 @@ def get_work_group(workgroup: str, boto3_session: Optional[boto3.Session] = None
547547
def _ensure_workgroup(
548548
session: boto3.Session, workgroup: Optional[str] = None
549549
) -> Tuple[Optional[str], Optional[str], Optional[str]]:
550-
if workgroup:
550+
if workgroup is not None:
551551
res: Dict[str, Any] = get_work_group(workgroup=workgroup, boto3_session=session)
552552
config: Dict[str, Any] = res["WorkGroup"]["Configuration"]["ResultConfiguration"]
553553
wg_s3_output: Optional[str] = config.get("OutputLocation")
554-
wg_encryption: Optional[str] = config["EncryptionConfiguration"].get("EncryptionOption")
555-
wg_kms_key: Optional[str] = config["EncryptionConfiguration"].get("KmsKey")
554+
encrypt_config: Optional[Dict[str, str]] = config.get("EncryptionConfiguration")
555+
wg_encryption: Optional[str] = None if encrypt_config is None else encrypt_config.get("EncryptionOption")
556+
wg_kms_key: Optional[str] = None if encrypt_config is None else encrypt_config.get("KmsKey")
556557
else:
557558
wg_s3_output, wg_encryption, wg_kms_key = None, None, None
558559
return wg_s3_output, wg_encryption, wg_kms_key

testing/test_awswrangler/test_data_lake.py

Lines changed: 34 additions & 28 deletions
Original file line numberDiff line numberDiff line change
@@ -26,43 +26,27 @@ def cloudformation_outputs():
2626

2727
@pytest.fixture(scope="module")
2828
def region(cloudformation_outputs):
29-
if "Region" in cloudformation_outputs:
30-
region = cloudformation_outputs["Region"]
31-
else:
32-
raise Exception("You must deploy/update the test infrastructure (CloudFormation)!")
33-
yield region
29+
yield cloudformation_outputs["Region"]
3430

3531

3632
@pytest.fixture(scope="module")
3733
def bucket(cloudformation_outputs):
38-
if "BucketName" in cloudformation_outputs:
39-
bucket = cloudformation_outputs["BucketName"]
40-
else:
41-
raise Exception("You must deploy/update the test infrastructure (CloudFormation)")
42-
yield bucket
34+
yield cloudformation_outputs["BucketName"]
4335

4436

4537
@pytest.fixture(scope="module")
4638
def database(cloudformation_outputs):
47-
if "GlueDatabaseName" in cloudformation_outputs:
48-
database = cloudformation_outputs["GlueDatabaseName"]
49-
else:
50-
raise Exception("You must deploy the test infrastructure using Cloudformation!")
51-
yield database
39+
yield cloudformation_outputs["GlueDatabaseName"]
5240

5341

5442
@pytest.fixture(scope="module")
5543
def kms_key(cloudformation_outputs):
56-
if "KmsKeyArn" in cloudformation_outputs:
57-
key = cloudformation_outputs["KmsKeyArn"]
58-
else:
59-
raise Exception("You must deploy the test infrastructure using Cloudformation!")
60-
yield key
44+
yield cloudformation_outputs["KmsKeyArn"]
6145

6246

6347
@pytest.fixture(scope="module")
64-
def workgroup_secondary(bucket):
65-
wkg_name = "awswrangler_test"
48+
def workgroup0(bucket):
49+
wkg_name = "awswrangler_test_0"
6650
client = boto3.client("athena")
6751
wkgs = client.list_work_groups()
6852
wkgs = [x["Name"] for x in wkgs["WorkGroups"]]
@@ -71,15 +55,36 @@ def workgroup_secondary(bucket):
7155
Name=wkg_name,
7256
Configuration={
7357
"ResultConfiguration": {
74-
"OutputLocation": f"s3://{bucket}/athena_workgroup_secondary/",
58+
"OutputLocation": f"s3://{bucket}/athena_workgroup0/",
7559
"EncryptionConfiguration": {"EncryptionOption": "SSE_S3"},
7660
},
7761
"EnforceWorkGroupConfiguration": True,
7862
"PublishCloudWatchMetricsEnabled": True,
7963
"BytesScannedCutoffPerQuery": 100_000_000,
8064
"RequesterPaysEnabled": False,
8165
},
82-
Description="AWS Data Wrangler Test WorkGroup",
66+
Description="AWS Data Wrangler Test WorkGroup Number 0",
67+
)
68+
yield wkg_name
69+
70+
71+
@pytest.fixture(scope="module")
72+
def workgroup1(bucket):
73+
wkg_name = "awswrangler_test_1"
74+
client = boto3.client("athena")
75+
wkgs = client.list_work_groups()
76+
wkgs = [x["Name"] for x in wkgs["WorkGroups"]]
77+
if wkg_name not in wkgs:
78+
client.create_work_group(
79+
Name=wkg_name,
80+
Configuration={
81+
"ResultConfiguration": {"OutputLocation": f"s3://{bucket}/athena_workgroup1/"},
82+
"EnforceWorkGroupConfiguration": True,
83+
"PublishCloudWatchMetricsEnabled": True,
84+
"BytesScannedCutoffPerQuery": 100_000_000,
85+
"RequesterPaysEnabled": False,
86+
},
87+
Description="AWS Data Wrangler Test WorkGroup Number 1",
8388
)
8489
yield wkg_name
8590

@@ -121,7 +126,7 @@ def test_athena_ctas(bucket, database, kms_key):
121126
wr.s3.delete_objects(path=f"s3://{bucket}/test_athena_ctas_result/")
122127

123128

124-
def test_athena(bucket, database, kms_key, workgroup_secondary):
129+
def test_athena(bucket, database, kms_key, workgroup0, workgroup1):
125130
wr.s3.delete_objects(path=f"s3://{bucket}/test_athena/")
126131
paths = wr.s3.to_parquet(
127132
df=get_df(),
@@ -142,21 +147,22 @@ def test_athena(bucket, database, kms_key, workgroup_secondary):
142147
chunksize=1,
143148
encryption="SSE_KMS",
144149
kms_key=kms_key,
145-
workgroup=workgroup_secondary,
150+
workgroup=workgroup0,
146151
)
147152
for df2 in dfs:
148153
print(df2)
149154
ensure_data_types(df=df2)
150155
df = wr.athena.read_sql_query(
151-
sql="SELECT * FROM __test_athena", database=database, ctas_approach=False, workgroup=workgroup_secondary
156+
sql="SELECT * FROM __test_athena", database=database, ctas_approach=False, workgroup=workgroup1
152157
)
153158
assert len(df.index) == 3
154159
ensure_data_types(df=df)
155160
wr.athena.repair_table(table="__test_athena", database=database)
156161
wr.catalog.delete_table_if_exists(database=database, table="__test_athena")
157162
wr.s3.delete_objects(path=paths)
158163
wr.s3.wait_objects_not_exist(paths=paths)
159-
wr.s3.delete_objects(path=f"s3://{bucket}/athena_workgroup_secondary/")
164+
wr.s3.delete_objects(path=f"s3://{bucket}/athena_workgroup0/")
165+
wr.s3.delete_objects(path=f"s3://{bucket}/athena_workgroup1/")
160166

161167

162168
def test_csv(bucket):

0 commit comments

Comments
 (0)