Skip to content

Commit 42b46f6

Browse files
Add cache to Athena start query (#1039)
* Add cache to Athena start query * Fix tests Co-authored-by: jaidisido <[email protected]>
1 parent 728a2a9 commit 42b46f6

File tree

5 files changed

+337
-221
lines changed

5 files changed

+337
-221
lines changed

awswrangler/athena/_cache.py

Lines changed: 214 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,214 @@
1+
"""Cache Module for Amazon Athena."""
2+
import datetime
3+
import logging
4+
import re
5+
from heapq import heappop, heappush
6+
from typing import Any, Dict, List, Match, NamedTuple, Optional, Tuple, Union
7+
8+
import boto3
9+
10+
from awswrangler import _utils
11+
12+
_logger: logging.Logger = logging.getLogger(__name__)
13+
14+
15+
class _CacheInfo(NamedTuple):
16+
has_valid_cache: bool
17+
file_format: Optional[str] = None
18+
query_execution_id: Optional[str] = None
19+
query_execution_payload: Optional[Dict[str, Any]] = None
20+
21+
22+
class _LocalMetadataCacheManager:
23+
def __init__(self) -> None:
24+
self._cache: Dict[str, Any] = {}
25+
self._pqueue: List[Tuple[datetime.datetime, str]] = []
26+
self._max_cache_size = 100
27+
28+
def update_cache(self, items: List[Dict[str, Any]]) -> None:
29+
"""
30+
Update the local metadata cache with new query metadata.
31+
32+
Parameters
33+
----------
34+
items : List[Dict[str, Any]]
35+
List of query execution metadata which is returned by boto3 `batch_get_query_execution()`.
36+
37+
Returns
38+
-------
39+
None
40+
None.
41+
"""
42+
if self._pqueue:
43+
oldest_item = self._cache[self._pqueue[0][1]]
44+
items = list(
45+
filter(lambda x: x["Status"]["SubmissionDateTime"] > oldest_item["Status"]["SubmissionDateTime"], items)
46+
)
47+
48+
cache_oversize = len(self._cache) + len(items) - self._max_cache_size
49+
for _ in range(cache_oversize):
50+
_, query_execution_id = heappop(self._pqueue)
51+
del self._cache[query_execution_id]
52+
53+
for item in items[: self._max_cache_size]:
54+
heappush(self._pqueue, (item["Status"]["SubmissionDateTime"], item["QueryExecutionId"]))
55+
self._cache[item["QueryExecutionId"]] = item
56+
57+
def sorted_successful_generator(self) -> List[Dict[str, Any]]:
58+
"""
59+
Sorts the entries in the local cache based on query Completion DateTime.
60+
61+
This is useful to guarantee LRU caching rules.
62+
63+
Returns
64+
-------
65+
List[Dict[str, Any]]
66+
Returns successful DDL and DML queries sorted by query completion time.
67+
"""
68+
filtered: List[Dict[str, Any]] = []
69+
for query in self._cache.values():
70+
if (query["Status"].get("State") == "SUCCEEDED") and (query.get("StatementType") in ["DDL", "DML"]):
71+
filtered.append(query)
72+
return sorted(filtered, key=lambda e: str(e["Status"]["CompletionDateTime"]), reverse=True)
73+
74+
def __contains__(self, key: str) -> bool:
75+
return key in self._cache
76+
77+
@property
78+
def max_cache_size(self) -> int:
79+
"""Property max_cache_size."""
80+
return self._max_cache_size
81+
82+
@max_cache_size.setter
83+
def max_cache_size(self, value: int) -> None:
84+
self._max_cache_size = value
85+
86+
87+
def _parse_select_query_from_possible_ctas(possible_ctas: str) -> Optional[str]:
88+
"""Check if `possible_ctas` is a valid parquet-generating CTAS and returns the full SELECT statement."""
89+
possible_ctas = possible_ctas.lower()
90+
parquet_format_regex: str = r"format\s*=\s*\'parquet\'\s*,"
91+
is_parquet_format: Optional[Match[str]] = re.search(pattern=parquet_format_regex, string=possible_ctas)
92+
if is_parquet_format is not None:
93+
unstripped_select_statement_regex: str = r"\s+as\s+\(*(select|with).*"
94+
unstripped_select_statement_match: Optional[Match[str]] = re.search(
95+
unstripped_select_statement_regex, possible_ctas, re.DOTALL
96+
)
97+
if unstripped_select_statement_match is not None:
98+
stripped_select_statement_match: Optional[Match[str]] = re.search(
99+
r"(select|with).*", unstripped_select_statement_match.group(0), re.DOTALL
100+
)
101+
if stripped_select_statement_match is not None:
102+
return stripped_select_statement_match.group(0)
103+
return None
104+
105+
106+
def _compare_query_string(sql: str, other: str) -> bool:
107+
comparison_query = _prepare_query_string_for_comparison(query_string=other)
108+
_logger.debug("sql: %s", sql)
109+
_logger.debug("comparison_query: %s", comparison_query)
110+
if sql == comparison_query:
111+
return True
112+
return False
113+
114+
115+
def _prepare_query_string_for_comparison(query_string: str) -> str:
116+
"""To use cached data, we need to compare queries. Returns a query string in canonical form."""
117+
# for now this is a simple complete strip, but it could grow into much more sophisticated
118+
# query comparison data structures
119+
query_string = "".join(query_string.split()).strip("()").lower()
120+
query_string = query_string[:-1] if query_string.endswith(";") else query_string
121+
return query_string
122+
123+
124+
def _get_last_query_infos(
125+
max_remote_cache_entries: int,
126+
boto3_session: Optional[boto3.Session] = None,
127+
workgroup: Optional[str] = None,
128+
) -> List[Dict[str, Any]]:
129+
"""Return an iterator of `query_execution_info`s run by the workgroup in Athena."""
130+
client_athena: boto3.client = _utils.client(service_name="athena", session=boto3_session)
131+
page_size = 50
132+
args: Dict[str, Union[str, Dict[str, int]]] = {
133+
"PaginationConfig": {"MaxItems": max_remote_cache_entries, "PageSize": page_size}
134+
}
135+
if workgroup is not None:
136+
args["WorkGroup"] = workgroup
137+
paginator = client_athena.get_paginator("list_query_executions")
138+
uncached_ids = []
139+
for page in paginator.paginate(**args):
140+
_logger.debug("paginating Athena's queries history...")
141+
query_execution_id_list: List[str] = page["QueryExecutionIds"]
142+
for query_execution_id in query_execution_id_list:
143+
if query_execution_id not in _cache_manager:
144+
uncached_ids.append(query_execution_id)
145+
if uncached_ids:
146+
new_execution_data = []
147+
for i in range(0, len(uncached_ids), page_size):
148+
new_execution_data.extend(
149+
client_athena.batch_get_query_execution(QueryExecutionIds=uncached_ids[i : i + page_size]).get(
150+
"QueryExecutions"
151+
)
152+
)
153+
_cache_manager.update_cache(new_execution_data)
154+
return _cache_manager.sorted_successful_generator()
155+
156+
157+
def _check_for_cached_results(
158+
sql: str,
159+
boto3_session: boto3.Session,
160+
workgroup: Optional[str],
161+
max_cache_seconds: int,
162+
max_cache_query_inspections: int,
163+
max_remote_cache_entries: int,
164+
) -> _CacheInfo:
165+
"""
166+
Check whether `sql` has been run before, within the `max_cache_seconds` window, by the `workgroup`.
167+
168+
If so, returns a dict with Athena's `query_execution_info` and the data format.
169+
"""
170+
if max_cache_seconds <= 0:
171+
return _CacheInfo(has_valid_cache=False)
172+
num_executions_inspected: int = 0
173+
comparable_sql: str = _prepare_query_string_for_comparison(sql)
174+
current_timestamp: datetime.datetime = datetime.datetime.now(datetime.timezone.utc)
175+
_logger.debug("current_timestamp: %s", current_timestamp)
176+
for query_info in _get_last_query_infos(
177+
max_remote_cache_entries=max_remote_cache_entries,
178+
boto3_session=boto3_session,
179+
workgroup=workgroup,
180+
):
181+
query_execution_id: str = query_info["QueryExecutionId"]
182+
query_timestamp: datetime.datetime = query_info["Status"]["CompletionDateTime"]
183+
_logger.debug("query_timestamp: %s", query_timestamp)
184+
if (current_timestamp - query_timestamp).total_seconds() > max_cache_seconds:
185+
return _CacheInfo(
186+
has_valid_cache=False, query_execution_id=query_execution_id, query_execution_payload=query_info
187+
)
188+
statement_type: Optional[str] = query_info.get("StatementType")
189+
if statement_type == "DDL" and query_info["Query"].startswith("CREATE TABLE"):
190+
parsed_query: Optional[str] = _parse_select_query_from_possible_ctas(possible_ctas=query_info["Query"])
191+
if parsed_query is not None:
192+
if _compare_query_string(sql=comparable_sql, other=parsed_query):
193+
return _CacheInfo(
194+
has_valid_cache=True,
195+
file_format="parquet",
196+
query_execution_id=query_execution_id,
197+
query_execution_payload=query_info,
198+
)
199+
elif statement_type == "DML" and not query_info["Query"].startswith("INSERT"):
200+
if _compare_query_string(sql=comparable_sql, other=query_info["Query"]):
201+
return _CacheInfo(
202+
has_valid_cache=True,
203+
file_format="csv",
204+
query_execution_id=query_execution_id,
205+
query_execution_payload=query_info,
206+
)
207+
num_executions_inspected += 1
208+
_logger.debug("num_executions_inspected: %s", num_executions_inspected)
209+
if num_executions_inspected >= max_cache_query_inspections:
210+
return _CacheInfo(has_valid_cache=False)
211+
return _CacheInfo(has_valid_cache=False)
212+
213+
214+
_cache_manager = _LocalMetadataCacheManager()

awswrangler/athena/_read.py

Lines changed: 3 additions & 141 deletions
Original file line numberDiff line numberDiff line change
@@ -1,12 +1,10 @@
11
"""Amazon Athena Module gathering all read_sql_* function."""
22

33
import csv
4-
import datetime
54
import logging
6-
import re
75
import sys
86
import uuid
9-
from typing import Any, Dict, Iterator, List, Match, NamedTuple, Optional, Tuple, Union
7+
from typing import Any, Dict, Iterator, List, Optional, Tuple, Union
108

119
import boto3
1210
import botocore.exceptions
@@ -21,20 +19,14 @@
2119
_get_query_metadata,
2220
_get_s3_output,
2321
_get_workgroup_config,
24-
_LocalMetadataCacheManager,
2522
_QueryMetadata,
2623
_start_query_execution,
2724
_WorkGroupConfig,
2825
)
2926

30-
_logger: logging.Logger = logging.getLogger(__name__)
31-
27+
from ._cache import _cache_manager, _CacheInfo, _check_for_cached_results
3228

33-
class _CacheInfo(NamedTuple):
34-
has_valid_cache: bool
35-
file_format: Optional[str] = None
36-
query_execution_id: Optional[str] = None
37-
query_execution_payload: Optional[Dict[str, Any]] = None
29+
_logger: logging.Logger = logging.getLogger(__name__)
3830

3931

4032
def _extract_ctas_manifest_paths(path: str, boto3_session: Optional[boto3.Session] = None) -> List[str]:
@@ -86,133 +78,6 @@ def _delete_after_iterate(
8678
)
8779

8880

89-
def _prepare_query_string_for_comparison(query_string: str) -> str:
90-
"""To use cached data, we need to compare queries. Returns a query string in canonical form."""
91-
# for now this is a simple complete strip, but it could grow into much more sophisticated
92-
# query comparison data structures
93-
query_string = "".join(query_string.split()).strip("()").lower()
94-
query_string = query_string[:-1] if query_string.endswith(";") else query_string
95-
return query_string
96-
97-
98-
def _compare_query_string(sql: str, other: str) -> bool:
99-
comparison_query = _prepare_query_string_for_comparison(query_string=other)
100-
_logger.debug("sql: %s", sql)
101-
_logger.debug("comparison_query: %s", comparison_query)
102-
if sql == comparison_query:
103-
return True
104-
return False
105-
106-
107-
def _get_last_query_infos(
108-
max_remote_cache_entries: int,
109-
boto3_session: Optional[boto3.Session] = None,
110-
workgroup: Optional[str] = None,
111-
) -> List[Dict[str, Any]]:
112-
"""Return an iterator of `query_execution_info`s run by the workgroup in Athena."""
113-
client_athena: boto3.client = _utils.client(service_name="athena", session=boto3_session)
114-
page_size = 50
115-
args: Dict[str, Union[str, Dict[str, int]]] = {
116-
"PaginationConfig": {"MaxItems": max_remote_cache_entries, "PageSize": page_size}
117-
}
118-
if workgroup is not None:
119-
args["WorkGroup"] = workgroup
120-
paginator = client_athena.get_paginator("list_query_executions")
121-
uncached_ids = []
122-
for page in paginator.paginate(**args):
123-
_logger.debug("paginating Athena's queries history...")
124-
query_execution_id_list: List[str] = page["QueryExecutionIds"]
125-
for query_execution_id in query_execution_id_list:
126-
if query_execution_id not in _cache_manager:
127-
uncached_ids.append(query_execution_id)
128-
if uncached_ids:
129-
new_execution_data = []
130-
for i in range(0, len(uncached_ids), page_size):
131-
new_execution_data.extend(
132-
client_athena.batch_get_query_execution(QueryExecutionIds=uncached_ids[i : i + page_size]).get(
133-
"QueryExecutions"
134-
)
135-
)
136-
_cache_manager.update_cache(new_execution_data)
137-
return _cache_manager.sorted_successful_generator()
138-
139-
140-
def _parse_select_query_from_possible_ctas(possible_ctas: str) -> Optional[str]:
141-
"""Check if `possible_ctas` is a valid parquet-generating CTAS and returns the full SELECT statement."""
142-
possible_ctas = possible_ctas.lower()
143-
parquet_format_regex: str = r"format\s*=\s*\'parquet\'\s*,"
144-
is_parquet_format: Optional[Match[str]] = re.search(pattern=parquet_format_regex, string=possible_ctas)
145-
if is_parquet_format is not None:
146-
unstripped_select_statement_regex: str = r"\s+as\s+\(*(select|with).*"
147-
unstripped_select_statement_match: Optional[Match[str]] = re.search(
148-
unstripped_select_statement_regex, possible_ctas, re.DOTALL
149-
)
150-
if unstripped_select_statement_match is not None:
151-
stripped_select_statement_match: Optional[Match[str]] = re.search(
152-
r"(select|with).*", unstripped_select_statement_match.group(0), re.DOTALL
153-
)
154-
if stripped_select_statement_match is not None:
155-
return stripped_select_statement_match.group(0)
156-
return None
157-
158-
159-
def _check_for_cached_results(
160-
sql: str,
161-
boto3_session: boto3.Session,
162-
workgroup: Optional[str],
163-
max_cache_seconds: int,
164-
max_cache_query_inspections: int,
165-
max_remote_cache_entries: int,
166-
) -> _CacheInfo:
167-
"""
168-
Check whether `sql` has been run before, within the `max_cache_seconds` window, by the `workgroup`.
169-
170-
If so, returns a dict with Athena's `query_execution_info` and the data format.
171-
"""
172-
if max_cache_seconds <= 0:
173-
return _CacheInfo(has_valid_cache=False)
174-
num_executions_inspected: int = 0
175-
comparable_sql: str = _prepare_query_string_for_comparison(sql)
176-
current_timestamp: datetime.datetime = datetime.datetime.now(datetime.timezone.utc)
177-
_logger.debug("current_timestamp: %s", current_timestamp)
178-
for query_info in _get_last_query_infos(
179-
max_remote_cache_entries=max_remote_cache_entries,
180-
boto3_session=boto3_session,
181-
workgroup=workgroup,
182-
):
183-
query_execution_id: str = query_info["QueryExecutionId"]
184-
query_timestamp: datetime.datetime = query_info["Status"]["CompletionDateTime"]
185-
_logger.debug("query_timestamp: %s", query_timestamp)
186-
if (current_timestamp - query_timestamp).total_seconds() > max_cache_seconds:
187-
return _CacheInfo(
188-
has_valid_cache=False, query_execution_id=query_execution_id, query_execution_payload=query_info
189-
)
190-
statement_type: Optional[str] = query_info.get("StatementType")
191-
if statement_type == "DDL" and query_info["Query"].startswith("CREATE TABLE"):
192-
parsed_query: Optional[str] = _parse_select_query_from_possible_ctas(possible_ctas=query_info["Query"])
193-
if parsed_query is not None:
194-
if _compare_query_string(sql=comparable_sql, other=parsed_query):
195-
return _CacheInfo(
196-
has_valid_cache=True,
197-
file_format="parquet",
198-
query_execution_id=query_execution_id,
199-
query_execution_payload=query_info,
200-
)
201-
elif statement_type == "DML" and not query_info["Query"].startswith("INSERT"):
202-
if _compare_query_string(sql=comparable_sql, other=query_info["Query"]):
203-
return _CacheInfo(
204-
has_valid_cache=True,
205-
file_format="csv",
206-
query_execution_id=query_execution_id,
207-
query_execution_payload=query_info,
208-
)
209-
num_executions_inspected += 1
210-
_logger.debug("num_executions_inspected: %s", num_executions_inspected)
211-
if num_executions_inspected >= max_cache_query_inspections:
212-
return _CacheInfo(has_valid_cache=False)
213-
return _CacheInfo(has_valid_cache=False)
214-
215-
21681
def _fetch_parquet_result(
21782
query_metadata: _QueryMetadata,
21883
keep_files: bool,
@@ -1114,6 +979,3 @@ def read_sql_table(
1114979
s3_additional_kwargs=s3_additional_kwargs,
1115980
pyarrow_additional_kwargs=pyarrow_additional_kwargs,
1116981
)
1117-
1118-
1119-
_cache_manager = _LocalMetadataCacheManager()

0 commit comments

Comments
 (0)