Skip to content

Commit 53c0c48

Browse files
fix qmark cache issue in Athena
1 parent 78522fd commit 53c0c48

File tree

4 files changed

+91
-25
lines changed

4 files changed

+91
-25
lines changed

awswrangler/athena/_cache.py

Lines changed: 26 additions & 16 deletions
Original file line numberDiff line numberDiff line change
@@ -7,7 +7,7 @@
77
import re
88
import threading
99
from heapq import heappop, heappush
10-
from typing import TYPE_CHECKING, Any, Match, NamedTuple
10+
from typing import TYPE_CHECKING, Match, NamedTuple
1111

1212
import boto3
1313

@@ -23,23 +23,23 @@ class _CacheInfo(NamedTuple):
2323
has_valid_cache: bool
2424
file_format: str | None = None
2525
query_execution_id: str | None = None
26-
query_execution_payload: dict[str, Any] | None = None
26+
query_execution_payload: "QueryExecutionTypeDef" | None = None
2727

2828

2929
class _LocalMetadataCacheManager:
3030
def __init__(self) -> None:
3131
self._lock: threading.Lock = threading.Lock()
32-
self._cache: dict[str, Any] = {}
32+
self._cache: dict[str, "QueryExecutionTypeDef"] = {}
3333
self._pqueue: list[tuple[datetime.datetime, str]] = []
3434
self._max_cache_size = 100
3535

36-
def update_cache(self, items: list[dict[str, Any]]) -> None:
36+
def update_cache(self, items: list["QueryExecutionTypeDef"]) -> None:
3737
"""
3838
Update the local metadata cache with new query metadata.
3939
4040
Parameters
4141
----------
42-
items : List[Dict[str, Any]]
42+
items
4343
List of query execution metadata which is returned by boto3 `batch_get_query_execution()`.
4444
"""
4545
with self._lock:
@@ -62,18 +62,17 @@ def update_cache(self, items: list[dict[str, Any]]) -> None:
6262
heappush(self._pqueue, (item["Status"]["SubmissionDateTime"], item["QueryExecutionId"]))
6363
self._cache[item["QueryExecutionId"]] = item
6464

65-
def sorted_successful_generator(self) -> list[dict[str, Any]]:
65+
def sorted_successful_generator(self) -> list["QueryExecutionTypeDef"]:
6666
"""
6767
Sorts the entries in the local cache based on query Completion DateTime.
6868
6969
This is useful to guarantee LRU caching rules.
7070
7171
Returns
7272
-------
73-
List[Dict[str, Any]]
7473
Returns successful DDL and DML queries sorted by query completion time.
7574
"""
76-
filtered: list[dict[str, Any]] = []
75+
filtered: list["QueryExecutionTypeDef"] = []
7776
for query in self._cache.values():
7877
if (query["Status"].get("State") == "SUCCEEDED") and (query.get("StatementType") in ["DDL", "DML"]):
7978
filtered.append(query)
@@ -111,13 +110,13 @@ def _parse_select_query_from_possible_ctas(possible_ctas: str) -> str | None:
111110
return None
112111

113112

114-
def _compare_query_string(sql: str, other: str) -> bool:
113+
def _compare_query_string(
114+
sql: str, other: str, sql_params: list[str] | None = None, other_params: list[str] | None = None
115+
) -> bool:
115116
comparison_query = _prepare_query_string_for_comparison(query_string=other)
116117
_logger.debug("sql: %s", sql)
117118
_logger.debug("comparison_query: %s", comparison_query)
118-
if sql == comparison_query:
119-
return True
120-
return False
119+
return sql == comparison_query and sql_params == other_params
121120

122121

123122
def _prepare_query_string_for_comparison(query_string: str) -> str:
@@ -135,7 +134,7 @@ def _get_last_query_infos(
135134
max_remote_cache_entries: int,
136135
boto3_session: boto3.Session | None = None,
137136
workgroup: str | None = None,
138-
) -> list[dict[str, Any]]:
137+
) -> list["QueryExecutionTypeDef"]:
139138
"""Return an iterator of `query_execution_info`s run by the workgroup in Athena."""
140139
client_athena = _utils.client(service_name="athena", session=boto3_session)
141140
page_size = 50
@@ -160,14 +159,15 @@ def _get_last_query_infos(
160159
QueryExecutionIds=uncached_ids[i : i + page_size],
161160
).get("QueryExecutions")
162161
)
163-
_cache_manager.update_cache(new_execution_data) # type: ignore[arg-type]
162+
_cache_manager.update_cache(new_execution_data)
164163
return _cache_manager.sorted_successful_generator()
165164

166165

167166
def _check_for_cached_results(
168167
sql: str,
169168
boto3_session: boto3.Session | None,
170169
workgroup: str | None,
170+
params: list[str] | None = None,
171171
athena_cache_settings: typing.AthenaCacheSettings | None = None,
172172
) -> _CacheInfo:
173173
"""
@@ -207,15 +207,25 @@ def _check_for_cached_results(
207207
if statement_type == "DDL" and query_info["Query"].startswith("CREATE TABLE"):
208208
parsed_query: str | None = _parse_select_query_from_possible_ctas(possible_ctas=query_info["Query"])
209209
if parsed_query is not None:
210-
if _compare_query_string(sql=comparable_sql, other=parsed_query):
210+
if _compare_query_string(
211+
sql=comparable_sql,
212+
other=parsed_query,
213+
sql_params=params,
214+
other_params=query_info.get("ExecutionParameters"),
215+
):
211216
return _CacheInfo(
212217
has_valid_cache=True,
213218
file_format="parquet",
214219
query_execution_id=query_execution_id,
215220
query_execution_payload=query_info,
216221
)
217222
elif statement_type == "DML" and not query_info["Query"].startswith("INSERT"):
218-
if _compare_query_string(sql=comparable_sql, other=query_info["Query"]):
223+
if _compare_query_string(
224+
sql=comparable_sql,
225+
other=query_info["Query"],
226+
sql_params=params,
227+
other_params=query_info.get("ExecutionParameters"),
228+
):
219229
return _CacheInfo(
220230
has_valid_cache=True,
221231
file_format="csv",

awswrangler/athena/_read.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1048,6 +1048,7 @@ def read_sql_query(
10481048
if not client_request_token:
10491049
cache_info: _CacheInfo = _check_for_cached_results(
10501050
sql=sql,
1051+
params=params if paramstyle == "qmark" else None,
10511052
boto3_session=boto3_session,
10521053
workgroup=workgroup,
10531054
athena_cache_settings=athena_cache_settings,

awswrangler/athena/_utils.py

Lines changed: 13 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -35,6 +35,7 @@
3535
from ._cache import _cache_manager, _LocalMetadataCacheManager
3636

3737
if TYPE_CHECKING:
38+
from mypy_boto3_athena.type_defs import QueryExecutionTypeDef
3839
from mypy_boto3_glue.type_defs import ColumnOutputTypeDef
3940

4041
_QUERY_FINAL_STATES: list[str] = ["FAILED", "SUCCEEDED", "CANCELLED"]
@@ -53,7 +54,7 @@ class _QueryMetadata(NamedTuple):
5354
binaries: list[str]
5455
output_location: str | None
5556
manifest_location: str | None
56-
raw_payload: dict[str, Any]
57+
raw_payload: "QueryExecutionTypeDef"
5758

5859

5960
class _WorkGroupConfig(NamedTuple):
@@ -214,7 +215,7 @@ def _get_query_metadata(
214215
query_execution_id: str,
215216
boto3_session: boto3.Session | None = None,
216217
categories: list[str] | None = None,
217-
query_execution_payload: dict[str, Any] | None = None,
218+
query_execution_payload: "QueryExecutionTypeDef" | None = None,
218219
metadata_cache_manager: _LocalMetadataCacheManager | None = None,
219220
athena_query_wait_polling_delay: float = _QUERY_WAIT_POLLING_DELAY,
220221
execution_params: list[str] | None = None,
@@ -225,12 +226,15 @@ def _get_query_metadata(
225226
if query_execution_payload["Status"]["State"] != "SUCCEEDED":
226227
reason: str = query_execution_payload["Status"]["StateChangeReason"]
227228
raise exceptions.QueryFailed(f"Query error: {reason}")
228-
_query_execution_payload: dict[str, Any] = query_execution_payload
229+
_query_execution_payload = query_execution_payload
229230
else:
230-
_query_execution_payload = _executions.wait_query(
231-
query_execution_id=query_execution_id,
232-
boto3_session=boto3_session,
233-
athena_query_wait_polling_delay=athena_query_wait_polling_delay,
231+
_query_execution_payload = cast(
232+
"QueryExecutionTypeDef",
233+
_executions.wait_query(
234+
query_execution_id=query_execution_id,
235+
boto3_session=boto3_session,
236+
athena_query_wait_polling_delay=athena_query_wait_polling_delay,
237+
),
234238
)
235239
cols_types: dict[str, str] = get_query_columns_types(
236240
query_execution_id=query_execution_id, boto3_session=boto3_session
@@ -266,8 +270,8 @@ def _get_query_metadata(
266270
if "ResultConfiguration" in _query_execution_payload:
267271
output_location = _query_execution_payload["ResultConfiguration"].get("OutputLocation")
268272

269-
athena_statistics: dict[str, int | str] = _query_execution_payload.get("Statistics", {})
270-
manifest_location: str | None = str(athena_statistics.get("DataManifestLocation"))
273+
athena_statistics = _query_execution_payload.get("Statistics", {})
274+
manifest_location: str | None = athena_statistics.get("DataManifestLocation")
271275

272276
if metadata_cache_manager is not None and query_execution_id not in metadata_cache_manager:
273277
metadata_cache_manager.update_cache(items=[_query_execution_payload])

tests/unit/test_athena.py

Lines changed: 51 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -461,6 +461,57 @@ def test_athena_paramstyle_qmark_parameters(
461461
assert len(df_out) == 1
462462

463463

464+
def test_athena_paramstyle_qmark_with_caching(
465+
path: str,
466+
path2: str,
467+
glue_database: str,
468+
glue_table: str,
469+
workgroup0: str,
470+
ctas_approach: bool,
471+
unload_approach: bool,
472+
) -> None:
473+
wr.s3.to_parquet(
474+
df=get_df(),
475+
path=path,
476+
index=False,
477+
dataset=True,
478+
mode="overwrite",
479+
database=glue_database,
480+
table=glue_table,
481+
partition_cols=["par0", "par1"],
482+
)
483+
484+
df_out = wr.athena.read_sql_query(
485+
sql=f"SELECT * FROM {glue_table} WHERE string = ?",
486+
database=glue_database,
487+
ctas_approach=ctas_approach,
488+
unload_approach=unload_approach,
489+
workgroup=workgroup0,
490+
params=["Washington"],
491+
paramstyle="qmark",
492+
keep_files=False,
493+
s3_output=path2,
494+
athena_cache_settings={"max_cache_seconds": 300}
495+
)
496+
497+
assert len(df_out) == 1 and df_out.iloc[0]["string"] == "Washington"
498+
499+
df_out = wr.athena.read_sql_query(
500+
sql=f"SELECT * FROM {glue_table} WHERE string = ?",
501+
database=glue_database,
502+
ctas_approach=ctas_approach,
503+
unload_approach=unload_approach,
504+
workgroup=workgroup0,
505+
params=["Seattle"],
506+
paramstyle="qmark",
507+
keep_files=False,
508+
s3_output=path2,
509+
athena_cache_settings={"max_cache_seconds": 300}
510+
)
511+
512+
assert len(df_out) == 1 and df_out.iloc[0]["string"] == "Seattle"
513+
514+
464515
def test_read_sql_query_parameter_formatting_respects_prefixes(path, glue_database, glue_table, workgroup0):
465516
wr.s3.to_parquet(
466517
df=get_df(),

0 commit comments

Comments
 (0)