Skip to content

Commit 19f06df

Browse files
committed
Add max_cache_query_inspections to Athena cache.
1 parent c57632c commit 19f06df

File tree

1 file changed

+50
-37
lines changed

1 file changed

+50
-37
lines changed

awswrangler/athena.py

Lines changed: 50 additions & 37 deletions
Original file line numberDiff line numberDiff line change
@@ -1,11 +1,11 @@
11
"""Amazon Athena Module."""
22

33
import csv
4+
import datetime
45
import logging
56
import pprint
67
import re
78
import time
8-
from datetime import datetime, timezone
99
from decimal import Decimal
1010
from typing import Any, Dict, Iterator, List, Optional, Tuple, Union
1111

@@ -18,7 +18,6 @@
1818
_logger: logging.Logger = logging.getLogger(__name__)
1919

2020
_QUERY_WAIT_POLLING_DELAY: float = 0.2 # SECONDS
21-
_CACHE_PREVIOUS_QUERY_COUNT: int = 50 # number of past queries to scan for cached results
2221

2322

2423
def get_query_columns_types(query_execution_id: str, boto3_session: Optional[boto3.Session] = None) -> Dict[str, str]:
@@ -391,6 +390,7 @@ def read_sql_query(
391390
use_threads: bool = True,
392391
boto3_session: Optional[boto3.Session] = None,
393392
max_cache_seconds: int = 0,
393+
max_cache_query_inspections: int = 50,
394394
) -> Union[pd.DataFrame, Iterator[pd.DataFrame]]:
395395
"""Execute any SQL query on AWS Athena and return the results as a Pandas DataFrame.
396396
@@ -474,13 +474,17 @@ def read_sql_query(
474474
If enabled os.cpu_count() will be used as the max number of threads.
475475
boto3_session : boto3.Session(), optional
476476
Boto3 Session. The default boto3 session will be used if boto3_session receive None.
477-
max_cache_seconds: int
477+
max_cache_seconds : int
478478
Wrangler can look up in Athena's history if this query has been run before.
479479
If so, and its completion time is less than `max_cache_seconds` before now, wrangler
480480
skips query execution and just returns the same results as last time.
481481
If cached results are valid, wrangler ignores the `ctas_approach`, `s3_output`, `encryption`, `kms_key`,
482482
`keep_files` and `ctas_temp_table_name` params.
483483
If reading cached data fails for any reason, execution falls back to the usual query run path.
484+
max_cache_query_inspections : int
485+
Max number of queries that will be inspected from the history to try to find some result to reuse.
486+
The bigger the number of inspection, the bigger will be the latency for not cached queries.
487+
Only takes effect if max_cache_seconds > 0.
484488
485489
Returns
486490
-------
@@ -497,7 +501,11 @@ def read_sql_query(
497501

498502
# check for cached results
499503
cache_info: Dict[str, Any] = _check_for_cached_results(
500-
sql=sql, session=session, workgroup=workgroup, max_cache_seconds=max_cache_seconds
504+
sql=sql,
505+
session=session,
506+
workgroup=workgroup,
507+
max_cache_seconds=max_cache_seconds,
508+
max_cache_query_inspections=max_cache_query_inspections,
501509
)
502510

503511
if cache_info["has_valid_cache"] is True:
@@ -961,17 +969,17 @@ def _prepare_query_string_for_comparison(query_string: str) -> str:
961969

962970
def _get_last_query_executions(
963971
boto3_session: Optional[boto3.Session] = None, workgroup: Optional[str] = None
964-
) -> List[Dict[str, Any]]:
965-
"""Return the last 50 `query_execution_info`s run by the workgroup in Athena."""
972+
) -> Iterator[List[Dict[str, Any]]]:
973+
"""Return an iterator of `query_execution_info`s run by the workgroup in Athena."""
966974
client_athena: boto3.client = _utils.client(service_name="athena", session=boto3_session)
967-
968-
args: Dict[str, Any] = {"MaxResults": _CACHE_PREVIOUS_QUERY_COUNT}
969-
if workgroup:
975+
args: Dict[str, str] = {}
976+
if workgroup is not None:
970977
args["WorkGroup"] = workgroup
971-
query_execution_list: Dict[str, Any] = client_athena.list_query_executions(**args)
972-
query_execution_id_list: List[str] = query_execution_list["QueryExecutionIds"]
973-
execution_data = client_athena.batch_get_query_execution(QueryExecutionIds=query_execution_id_list)
974-
return execution_data.get("QueryExecutions")
978+
paginator = client_athena.get_paginator("get_query_results")
979+
for page in paginator.paginate(**args):
980+
query_execution_id_list: List[str] = page["QueryExecutionIds"]
981+
execution_data = client_athena.batch_get_query_execution(QueryExecutionIds=query_execution_id_list)
982+
yield execution_data.get("QueryExecutions")
975983

976984

977985
def _sort_successful_executions_data(query_executions: List[Dict[str, Any]]):
@@ -1002,37 +1010,42 @@ def _parse_select_query_from_possible_ctas(possible_ctas: str) -> Optional[str]:
10021010

10031011

10041012
def _check_for_cached_results(
1005-
sql: str, session: boto3.Session, workgroup: Optional[str], max_cache_seconds: int
1013+
sql: str, session: boto3.Session, workgroup: Optional[str], max_cache_seconds: int, max_cache_query_inspections: int
10061014
) -> Dict[str, Any]:
10071015
"""
1008-
Check wether `sql` has been run before, within the `max_cache_seconds` window, by the `workgroup`.
1016+
Check whether `sql` has been run before, within the `max_cache_seconds` window, by the `workgroup`.
10091017
10101018
If so, returns a dict with Athena's `query_execution_info` and the data format.
10111019
"""
1012-
if max_cache_seconds > 0:
1013-
last_query_executions = _get_last_query_executions(boto3_session=session, workgroup=workgroup)
1014-
cached_queries = _sort_successful_executions_data(query_executions=last_query_executions)
1015-
current_timestamp = datetime.now(timezone.utc)
1016-
comparable_sql: str = _prepare_query_string_for_comparison(sql)
1017-
1018-
# this could be mapreduced, but it is only 50 items long, tops
1019-
for query_info in cached_queries:
1020-
if (current_timestamp - query_info["Status"]["CompletionDateTime"]).total_seconds() > max_cache_seconds:
1021-
break # pragma: no cover
1022-
1023-
comparison_query: Optional[str]
1024-
if query_info["StatementType"] == "DDL" and query_info["Query"].startswith("CREATE TABLE"):
1025-
parsed_query: Optional[str] = _parse_select_query_from_possible_ctas(query_info["Query"])
1026-
if parsed_query is not None:
1027-
comparison_query = _prepare_query_string_for_comparison(query_string=parsed_query)
1020+
num_executions_inspected: int = 0
1021+
if max_cache_seconds > 0: # pylint: disable=too-many-nested-blocks
1022+
for query_executions in _get_last_query_executions(boto3_session=session, workgroup=workgroup):
1023+
cached_queries: List[Dict[str, Any]] = _sort_successful_executions_data(query_executions=query_executions)
1024+
current_timestamp = datetime.datetime.utcnow()
1025+
comparable_sql: str = _prepare_query_string_for_comparison(sql)
1026+
1027+
# this could be mapreduced, but it is only 50 items long, tops
1028+
for query_info in cached_queries:
1029+
if (current_timestamp - query_info["Status"]["CompletionDateTime"]).total_seconds() > max_cache_seconds:
1030+
break # pragma: no cover
1031+
1032+
comparison_query: Optional[str]
1033+
if query_info["StatementType"] == "DDL" and query_info["Query"].startswith("CREATE TABLE"):
1034+
parsed_query: Optional[str] = _parse_select_query_from_possible_ctas(query_info["Query"])
1035+
if parsed_query is not None:
1036+
comparison_query = _prepare_query_string_for_comparison(query_string=parsed_query)
1037+
if comparison_query == comparable_sql:
1038+
data_type = "parquet"
1039+
return {"has_valid_cache": True, "data_type": data_type, "query_execution_info": query_info}
1040+
1041+
elif query_info["StatementType"] == "DML" and not query_info["Query"].startswith("INSERT"):
1042+
comparison_query = _prepare_query_string_for_comparison(query_string=query_info["Query"])
10281043
if comparison_query == comparable_sql:
1029-
data_type = "parquet"
1044+
data_type = "csv"
10301045
return {"has_valid_cache": True, "data_type": data_type, "query_execution_info": query_info}
10311046

1032-
elif query_info["StatementType"] == "DML" and not query_info["Query"].startswith("INSERT"):
1033-
comparison_query = _prepare_query_string_for_comparison(query_string=query_info["Query"])
1034-
if comparison_query == comparable_sql:
1035-
data_type = "csv"
1036-
return {"has_valid_cache": True, "data_type": data_type, "query_execution_info": query_info}
1047+
num_executions_inspected += 1
1048+
if num_executions_inspected >= max_cache_query_inspections:
1049+
break
10371050

10381051
return {"has_valid_cache": False}

0 commit comments

Comments
 (0)