Skip to content

Commit 31ca0a6

Browse files
committed
General tests refactoring.
1 parent 8f505cb commit 31ca0a6

File tree

8 files changed

+766
-817
lines changed

8 files changed

+766
-817
lines changed

tests/conftest.py

Lines changed: 225 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,225 @@
1+
from datetime import datetime
2+
3+
import boto3
4+
import pytest
5+
6+
import awswrangler as wr
7+
8+
from ._utils import extract_cloudformation_outputs, get_time_str_with_random_suffix, path_generator
9+
10+
11+
@pytest.fixture(scope="session")
12+
def cloudformation_outputs():
13+
return extract_cloudformation_outputs()
14+
15+
16+
@pytest.fixture(scope="session")
17+
def region(cloudformation_outputs):
18+
return cloudformation_outputs["Region"]
19+
20+
21+
@pytest.fixture(scope="session")
22+
def bucket(cloudformation_outputs):
23+
return cloudformation_outputs["BucketName"]
24+
25+
26+
@pytest.fixture(scope="session")
27+
def glue_database(cloudformation_outputs):
28+
return cloudformation_outputs["GlueDatabaseName"]
29+
30+
31+
@pytest.fixture(scope="session")
32+
def kms_key(cloudformation_outputs):
33+
return cloudformation_outputs["KmsKeyArn"]
34+
35+
36+
@pytest.fixture(scope="session")
37+
def kms_key_id(kms_key):
38+
return kms_key.split("/", 1)[1]
39+
40+
41+
@pytest.fixture(scope="session")
42+
def loggroup(cloudformation_outputs):
43+
loggroup_name = cloudformation_outputs["LogGroupName"]
44+
logstream_name = cloudformation_outputs["LogStream"]
45+
client = boto3.client("logs")
46+
response = client.describe_log_streams(logGroupName=loggroup_name, logStreamNamePrefix=logstream_name)
47+
token = response["logStreams"][0].get("uploadSequenceToken")
48+
events = []
49+
for i in range(5):
50+
events.append({"timestamp": int(1000 * datetime.now().timestamp()), "message": str(i)})
51+
args = {"logGroupName": loggroup_name, "logStreamName": logstream_name, "logEvents": events}
52+
if token:
53+
args["sequenceToken"] = token
54+
try:
55+
client.put_log_events(**args)
56+
except client.exceptions.DataAlreadyAcceptedException:
57+
pass # Concurrency
58+
while True:
59+
results = wr.cloudwatch.run_query(log_group_names=[loggroup_name], query="fields @timestamp | limit 5")
60+
if len(results) >= 5:
61+
break
62+
yield loggroup_name
63+
64+
65+
@pytest.fixture(scope="session")
66+
def workgroup0(bucket):
67+
wkg_name = "aws_data_wrangler_0"
68+
client = boto3.client("athena")
69+
wkgs = client.list_work_groups()
70+
wkgs = [x["Name"] for x in wkgs["WorkGroups"]]
71+
if wkg_name not in wkgs:
72+
client.create_work_group(
73+
Name=wkg_name,
74+
Configuration={
75+
"ResultConfiguration": {"OutputLocation": f"s3://{bucket}/athena_workgroup0/"},
76+
"EnforceWorkGroupConfiguration": True,
77+
"PublishCloudWatchMetricsEnabled": True,
78+
"BytesScannedCutoffPerQuery": 100_000_000,
79+
"RequesterPaysEnabled": False,
80+
},
81+
Description="AWS Data Wrangler Test WorkGroup Number 0",
82+
)
83+
return wkg_name
84+
85+
86+
@pytest.fixture(scope="session")
87+
def workgroup1(bucket):
88+
wkg_name = "aws_data_wrangler_1"
89+
client = boto3.client("athena")
90+
wkgs = client.list_work_groups()
91+
wkgs = [x["Name"] for x in wkgs["WorkGroups"]]
92+
if wkg_name not in wkgs:
93+
client.create_work_group(
94+
Name=wkg_name,
95+
Configuration={
96+
"ResultConfiguration": {
97+
"OutputLocation": f"s3://{bucket}/athena_workgroup1/",
98+
"EncryptionConfiguration": {"EncryptionOption": "SSE_S3"},
99+
},
100+
"EnforceWorkGroupConfiguration": True,
101+
"PublishCloudWatchMetricsEnabled": True,
102+
"BytesScannedCutoffPerQuery": 100_000_000,
103+
"RequesterPaysEnabled": False,
104+
},
105+
Description="AWS Data Wrangler Test WorkGroup Number 1",
106+
)
107+
return wkg_name
108+
109+
110+
@pytest.fixture(scope="session")
111+
def workgroup2(bucket, kms_key):
112+
wkg_name = "aws_data_wrangler_2"
113+
client = boto3.client("athena")
114+
wkgs = client.list_work_groups()
115+
wkgs = [x["Name"] for x in wkgs["WorkGroups"]]
116+
if wkg_name not in wkgs:
117+
client.create_work_group(
118+
Name=wkg_name,
119+
Configuration={
120+
"ResultConfiguration": {
121+
"OutputLocation": f"s3://{bucket}/athena_workgroup2/",
122+
"EncryptionConfiguration": {"EncryptionOption": "SSE_KMS", "KmsKey": kms_key},
123+
},
124+
"EnforceWorkGroupConfiguration": False,
125+
"PublishCloudWatchMetricsEnabled": True,
126+
"BytesScannedCutoffPerQuery": 100_000_000,
127+
"RequesterPaysEnabled": False,
128+
},
129+
Description="AWS Data Wrangler Test WorkGroup Number 2",
130+
)
131+
return wkg_name
132+
133+
134+
@pytest.fixture(scope="session")
135+
def workgroup3(bucket, kms_key):
136+
wkg_name = "aws_data_wrangler_3"
137+
client = boto3.client("athena")
138+
wkgs = client.list_work_groups()
139+
wkgs = [x["Name"] for x in wkgs["WorkGroups"]]
140+
if wkg_name not in wkgs:
141+
client.create_work_group(
142+
Name=wkg_name,
143+
Configuration={
144+
"ResultConfiguration": {
145+
"OutputLocation": f"s3://{bucket}/athena_workgroup3/",
146+
"EncryptionConfiguration": {"EncryptionOption": "SSE_KMS", "KmsKey": kms_key},
147+
},
148+
"EnforceWorkGroupConfiguration": True,
149+
"PublishCloudWatchMetricsEnabled": True,
150+
"BytesScannedCutoffPerQuery": 100_000_000,
151+
"RequesterPaysEnabled": False,
152+
},
153+
Description="AWS Data Wrangler Test WorkGroup Number 3",
154+
)
155+
return wkg_name
156+
157+
158+
@pytest.fixture(scope="session")
159+
def databases_parameters(cloudformation_outputs):
160+
parameters = dict(postgresql={}, mysql={}, redshift={})
161+
parameters["postgresql"]["host"] = cloudformation_outputs["PostgresqlAddress"]
162+
parameters["postgresql"]["port"] = 3306
163+
parameters["postgresql"]["schema"] = "public"
164+
parameters["postgresql"]["database"] = "postgres"
165+
parameters["mysql"]["host"] = cloudformation_outputs["MysqlAddress"]
166+
parameters["mysql"]["port"] = 3306
167+
parameters["mysql"]["schema"] = "test"
168+
parameters["mysql"]["database"] = "test"
169+
parameters["redshift"]["host"] = cloudformation_outputs["RedshiftAddress"]
170+
parameters["redshift"]["port"] = cloudformation_outputs["RedshiftPort"]
171+
parameters["redshift"]["identifier"] = cloudformation_outputs["RedshiftIdentifier"]
172+
parameters["redshift"]["schema"] = "public"
173+
parameters["redshift"]["database"] = "test"
174+
parameters["redshift"]["role"] = cloudformation_outputs["RedshiftRole"]
175+
parameters["password"] = cloudformation_outputs["DatabasesPassword"]
176+
parameters["user"] = "test"
177+
return parameters
178+
179+
180+
@pytest.fixture(scope="session")
181+
def redshift_external_schema(cloudformation_outputs, databases_parameters, glue_database):
182+
region = cloudformation_outputs.get("Region")
183+
sql = f"""
184+
CREATE EXTERNAL SCHEMA IF NOT EXISTS aws_data_wrangler_external FROM data catalog
185+
DATABASE '{glue_database}'
186+
IAM_ROLE '{databases_parameters["redshift"]["role"]}'
187+
REGION '{region}';
188+
"""
189+
engine = wr.catalog.get_engine(connection="aws-data-wrangler-redshift")
190+
with engine.connect() as con:
191+
con.execute(sql)
192+
return "aws_data_wrangler_external"
193+
194+
195+
@pytest.fixture(scope="function")
196+
def glue_table(glue_database):
197+
name = f"tbl_{get_time_str_with_random_suffix()}"
198+
print(f"Table name: {name}")
199+
wr.catalog.delete_table_if_exists(database=glue_database, table=name)
200+
yield name
201+
wr.catalog.delete_table_if_exists(database=glue_database, table=name)
202+
203+
204+
@pytest.fixture(scope="function")
205+
def glue_table2(glue_database):
206+
name = f"tbl_{get_time_str_with_random_suffix()}"
207+
print(f"Table name: {name}")
208+
wr.catalog.delete_table_if_exists(database=glue_database, table=name)
209+
yield name
210+
wr.catalog.delete_table_if_exists(database=glue_database, table=name)
211+
212+
213+
@pytest.fixture(scope="function")
214+
def path(bucket):
215+
yield from path_generator(bucket)
216+
217+
218+
@pytest.fixture(scope="function")
219+
def path2(bucket):
220+
yield from path_generator(bucket)
221+
222+
223+
@pytest.fixture(scope="function")
224+
def path3(bucket):
225+
yield from path_generator(bucket)

0 commit comments

Comments
 (0)