Skip to content

Commit 55369c3

Browse files
Fix the Athena cache unit test errors (#1883)
* Isolate the Athena cache tests * Move wr fixture for conftest
1 parent 93696bb commit 55369c3

File tree

2 files changed

+24
-11
lines changed

2 files changed

+24
-11
lines changed

tests/conftest.py

Lines changed: 17 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1,8 +1,11 @@
11
from datetime import datetime
2+
from importlib import reload
3+
from types import ModuleType
4+
from typing import Iterator
25

3-
import boto3 # type: ignore
6+
import boto3
47
import botocore.exceptions
5-
import pytest # type: ignore
8+
import pytest
69

710
import awswrangler as wr
811

@@ -373,3 +376,15 @@ def glue_ruleset() -> str:
373376
@pytest.fixture(scope="session")
374377
def glue_data_quality_role(cloudformation_outputs):
375378
return cloudformation_outputs["GlueDataQualityRole"]
379+
380+
381+
@pytest.fixture(scope="function", name="wr")
382+
def awswrangler_import() -> Iterator[ModuleType]:
383+
import awswrangler
384+
385+
awswrangler.config.reset()
386+
387+
yield reload(awswrangler)
388+
389+
# Reset for future tests
390+
awswrangler.config.reset()

tests/test_athena_cache.py

Lines changed: 7 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -4,14 +4,12 @@
44
import pandas as pd
55
import pytest
66

7-
import awswrangler as wr
8-
97
from ._utils import ensure_athena_query_metadata
108

119
logging.getLogger("awswrangler").setLevel(logging.DEBUG)
1210

1311

14-
def test_athena_cache(path, glue_database, glue_table, workgroup1):
12+
def test_athena_cache(wr, path, glue_database, glue_table, workgroup1):
1513
df = pd.DataFrame({"c0": [0, None]}, dtype="Int64")
1614
wr.s3.to_parquet(df=df, path=path, dataset=True, mode="overwrite", database=glue_database, table=glue_table)
1715

@@ -34,7 +32,7 @@ def test_athena_cache(path, glue_database, glue_table, workgroup1):
3432

3533

3634
@pytest.mark.parametrize("data_source", [None, "AwsDataCatalog"])
37-
def test_cache_query_ctas_approach_true(path, glue_database, glue_table, data_source):
35+
def test_cache_query_ctas_approach_true(wr, path, glue_database, glue_table, data_source):
3836
df = pd.DataFrame({"c0": [0, None]}, dtype="Int64")
3937
wr.s3.to_parquet(
4038
df=df,
@@ -70,7 +68,7 @@ def test_cache_query_ctas_approach_true(path, glue_database, glue_table, data_so
7068

7169

7270
@pytest.mark.parametrize("data_source", [None, "AwsDataCatalog"])
73-
def test_cache_query_ctas_approach_false(path, glue_database, glue_table, data_source):
71+
def test_cache_query_ctas_approach_false(wr, path, glue_database, glue_table, data_source):
7472
df = pd.DataFrame({"c0": [0, None]}, dtype="Int64")
7573
wr.s3.to_parquet(
7674
df=df,
@@ -105,7 +103,7 @@ def test_cache_query_ctas_approach_false(path, glue_database, glue_table, data_s
105103
ensure_athena_query_metadata(df=df3, ctas_approach=False, encrypted=False)
106104

107105

108-
def test_cache_query_semicolon(path, glue_database, glue_table):
106+
def test_cache_query_semicolon(wr, path, glue_database, glue_table):
109107
df = pd.DataFrame({"c0": [0, None]}, dtype="Int64")
110108
wr.s3.to_parquet(df=df, path=path, dataset=True, mode="overwrite", database=glue_database, table=glue_table)
111109

@@ -129,7 +127,7 @@ def test_cache_query_semicolon(path, glue_database, glue_table):
129127
assert df.c0.sum() == df3.c0.sum()
130128

131129

132-
def test_local_cache(path, glue_database, glue_table):
130+
def test_local_cache(wr, path, glue_database, glue_table):
133131
wr.config.max_local_cache_entries = 1
134132

135133
df = pd.DataFrame({"c0": [0, None]}, dtype="Int64")
@@ -160,7 +158,7 @@ def test_local_cache(path, glue_database, glue_table):
160158
assert second_query_id in wr.athena._read._cache_manager
161159

162160

163-
def test_paginated_remote_cache(path, glue_database, glue_table, workgroup1):
161+
def test_paginated_remote_cache(wr, path, glue_database, glue_table, workgroup1):
164162
wr.config.max_remote_cache_entries = 100
165163
df = pd.DataFrame({"c0": [0, None]}, dtype="Int64")
166164
wr.s3.to_parquet(df=df, path=path, dataset=True, mode="overwrite", database=glue_database, table=glue_table)
@@ -173,7 +171,7 @@ def test_paginated_remote_cache(path, glue_database, glue_table, workgroup1):
173171

174172

175173
@pytest.mark.parametrize("data_source", [None, "AwsDataCatalog"])
176-
def test_cache_start_query(path, glue_database, glue_table, data_source):
174+
def test_cache_start_query(wr, path, glue_database, glue_table, data_source):
177175
df = pd.DataFrame({"c0": [0, None]}, dtype="Int64")
178176
wr.s3.to_parquet(
179177
df=df,

0 commit comments

Comments
 (0)