Skip to content

Commit 8cb0b79

Browse files
authored
Add local metadata caching (#504)
1 parent e1d7570 commit 8cb0b79

File tree

6 files changed

+268
-85
lines changed

6 files changed

+268
-85
lines changed

awswrangler/_config.py

Lines changed: 31 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -29,6 +29,8 @@ class _ConfigArg(NamedTuple):
2929
"database": _ConfigArg(dtype=str, nullable=True),
3030
"max_cache_query_inspections": _ConfigArg(dtype=int, nullable=False),
3131
"max_cache_seconds": _ConfigArg(dtype=int, nullable=False),
32+
"max_remote_cache_entries": _ConfigArg(dtype=int, nullable=False),
33+
"max_local_cache_entries": _ConfigArg(dtype=int, nullable=False),
3234
"s3_block_size": _ConfigArg(dtype=int, nullable=False, enforced=True),
3335
"workgroup": _ConfigArg(dtype=str, nullable=False, enforced=True),
3436
# Endpoints URLs
@@ -226,6 +228,35 @@ def max_cache_seconds(self) -> int:
226228
def max_cache_seconds(self, value: int) -> None:
227229
self._set_config_value(key="max_cache_seconds", value=value)
228230

231+
@property
232+
def max_local_cache_entries(self) -> int:
233+
"""Property max_local_cache_entries."""
234+
return cast(int, self["max_local_cache_entries"])
235+
236+
@max_local_cache_entries.setter
237+
def max_local_cache_entries(self, value: int) -> None:
238+
try:
239+
max_remote_cache_entries = cast(int, self["max_remote_cache_entries"])
240+
except AttributeError:
241+
max_remote_cache_entries = 50
242+
if value < max_remote_cache_entries:
243+
_logger.warning(
244+
"max_remote_cache_entries shouldn't be greater than max_local_cache_entries. "
245+
"Therefore max_remote_cache_entries will be set to %s as well.",
246+
value,
247+
)
248+
self._set_config_value(key="max_remote_cache_entries", value=value)
249+
self._set_config_value(key="max_local_cache_entries", value=value)
250+
251+
@property
252+
def max_remote_cache_entries(self) -> int:
253+
"""Property max_remote_cache_entries."""
254+
return cast(int, self["max_remote_cache_entries"])
255+
256+
@max_remote_cache_entries.setter
257+
def max_remote_cache_entries(self, value: int) -> None:
258+
self._set_config_value(key="max_remote_cache_entries", value=value)
259+
229260
@property
230261
def s3_block_size(self) -> int:
231262
"""Property s3_block_size."""

awswrangler/athena/_read.py

Lines changed: 90 additions & 53 deletions
Original file line numberDiff line numberDiff line change
@@ -20,6 +20,7 @@
2020
_get_query_metadata,
2121
_get_s3_output,
2222
_get_workgroup_config,
23+
_LocalMetadataCacheManager,
2324
_QueryMetadata,
2425
_start_query_execution,
2526
_WorkGroupConfig,
@@ -96,33 +97,37 @@ def _compare_query_string(sql: str, other: str) -> bool:
9697
return False
9798

9899

99-
def _get_last_query_executions(
100-
boto3_session: Optional[boto3.Session] = None, workgroup: Optional[str] = None
101-
) -> Iterator[List[Dict[str, Any]]]:
100+
def _get_last_query_infos(
101+
max_remote_cache_entries: int,
102+
boto3_session: Optional[boto3.Session] = None,
103+
workgroup: Optional[str] = None,
104+
) -> List[Dict[str, Any]]:
102105
"""Return an iterator of `query_execution_info`s run by the workgroup in Athena."""
103106
client_athena: boto3.client = _utils.client(service_name="athena", session=boto3_session)
104-
args: Dict[str, Union[str, Dict[str, int]]] = {"PaginationConfig": {"MaxItems": 50, "PageSize": 50}}
107+
page_size = 50
108+
args: Dict[str, Union[str, Dict[str, int]]] = {
109+
"PaginationConfig": {"MaxItems": max_remote_cache_entries, "PageSize": page_size}
110+
}
105111
if workgroup is not None:
106112
args["WorkGroup"] = workgroup
107113
paginator = client_athena.get_paginator("list_query_executions")
114+
uncached_ids = []
108115
for page in paginator.paginate(**args):
109116
_logger.debug("paginating Athena's queries history...")
110117
query_execution_id_list: List[str] = page["QueryExecutionIds"]
111-
execution_data = client_athena.batch_get_query_execution(QueryExecutionIds=query_execution_id_list)
112-
yield execution_data.get("QueryExecutions")
113-
114-
115-
def _sort_successful_executions_data(query_executions: List[Dict[str, Any]]) -> List[Dict[str, Any]]:
116-
"""
117-
Sorts `_get_last_query_executions`'s results based on query Completion DateTime.
118-
119-
This is useful to guarantee LRU caching rules.
120-
"""
121-
filtered: List[Dict[str, Any]] = []
122-
for query in query_executions:
123-
if (query["Status"].get("State") == "SUCCEEDED") and (query.get("StatementType") in ["DDL", "DML"]):
124-
filtered.append(query)
125-
return sorted(filtered, key=lambda e: str(e["Status"]["CompletionDateTime"]), reverse=True)
118+
for query_execution_id in query_execution_id_list:
119+
if query_execution_id not in _cache_manager:
120+
uncached_ids.append(query_execution_id)
121+
if uncached_ids:
122+
new_execution_data = []
123+
for i in range(0, len(uncached_ids), page_size):
124+
new_execution_data.extend(
125+
client_athena.batch_get_query_execution(QueryExecutionIds=uncached_ids[i : i + page_size]).get(
126+
"QueryExecutions"
127+
)
128+
)
129+
_cache_manager.update_cache(new_execution_data)
130+
return _cache_manager.sorted_successful_generator()
126131

127132

128133
def _parse_select_query_from_possible_ctas(possible_ctas: str) -> Optional[str]:
@@ -150,6 +155,7 @@ def _check_for_cached_results(
150155
workgroup: Optional[str],
151156
max_cache_seconds: int,
152157
max_cache_query_inspections: int,
158+
max_remote_cache_entries: int,
153159
) -> _CacheInfo:
154160
"""
155161
Check whether `sql` has been run before, within the `max_cache_seconds` window, by the `workgroup`.
@@ -162,45 +168,41 @@ def _check_for_cached_results(
162168
comparable_sql: str = _prepare_query_string_for_comparison(sql)
163169
current_timestamp: datetime.datetime = datetime.datetime.now(datetime.timezone.utc)
164170
_logger.debug("current_timestamp: %s", current_timestamp)
165-
for query_executions in _get_last_query_executions(boto3_session=boto3_session, workgroup=workgroup):
166-
_logger.debug("len(query_executions): %s", len(query_executions))
167-
cached_queries: List[Dict[str, Any]] = _sort_successful_executions_data(query_executions=query_executions)
168-
_logger.debug("len(cached_queries): %s", len(cached_queries))
169-
for query_info in cached_queries:
170-
query_execution_id: str = query_info["QueryExecutionId"]
171-
query_timestamp: datetime.datetime = query_info["Status"]["CompletionDateTime"]
172-
_logger.debug("query_timestamp: %s", query_timestamp)
173-
174-
if (current_timestamp - query_timestamp).total_seconds() > max_cache_seconds:
175-
return _CacheInfo(
176-
has_valid_cache=False, query_execution_id=query_execution_id, query_execution_payload=query_info
177-
)
178-
179-
statement_type: Optional[str] = query_info.get("StatementType")
180-
if statement_type == "DDL" and query_info["Query"].startswith("CREATE TABLE"):
181-
parsed_query: Optional[str] = _parse_select_query_from_possible_ctas(possible_ctas=query_info["Query"])
182-
if parsed_query is not None:
183-
if _compare_query_string(sql=comparable_sql, other=parsed_query):
184-
return _CacheInfo(
185-
has_valid_cache=True,
186-
file_format="parquet",
187-
query_execution_id=query_execution_id,
188-
query_execution_payload=query_info,
189-
)
190-
elif statement_type == "DML" and not query_info["Query"].startswith("INSERT"):
191-
if _compare_query_string(sql=comparable_sql, other=query_info["Query"]):
171+
for query_info in _get_last_query_infos(
172+
max_remote_cache_entries=max_remote_cache_entries,
173+
boto3_session=boto3_session,
174+
workgroup=workgroup,
175+
):
176+
query_execution_id: str = query_info["QueryExecutionId"]
177+
query_timestamp: datetime.datetime = query_info["Status"]["CompletionDateTime"]
178+
_logger.debug("query_timestamp: %s", query_timestamp)
179+
if (current_timestamp - query_timestamp).total_seconds() > max_cache_seconds:
180+
return _CacheInfo(
181+
has_valid_cache=False, query_execution_id=query_execution_id, query_execution_payload=query_info
182+
)
183+
statement_type: Optional[str] = query_info.get("StatementType")
184+
if statement_type == "DDL" and query_info["Query"].startswith("CREATE TABLE"):
185+
parsed_query: Optional[str] = _parse_select_query_from_possible_ctas(possible_ctas=query_info["Query"])
186+
if parsed_query is not None:
187+
if _compare_query_string(sql=comparable_sql, other=parsed_query):
192188
return _CacheInfo(
193189
has_valid_cache=True,
194-
file_format="csv",
190+
file_format="parquet",
195191
query_execution_id=query_execution_id,
196192
query_execution_payload=query_info,
197193
)
198-
199-
num_executions_inspected += 1
200-
_logger.debug("num_executions_inspected: %s", num_executions_inspected)
201-
if num_executions_inspected >= max_cache_query_inspections:
202-
return _CacheInfo(has_valid_cache=False)
203-
194+
elif statement_type == "DML" and not query_info["Query"].startswith("INSERT"):
195+
if _compare_query_string(sql=comparable_sql, other=query_info["Query"]):
196+
return _CacheInfo(
197+
has_valid_cache=True,
198+
file_format="csv",
199+
query_execution_id=query_execution_id,
200+
query_execution_payload=query_info,
201+
)
202+
num_executions_inspected += 1
203+
_logger.debug("num_executions_inspected: %s", num_executions_inspected)
204+
if num_executions_inspected >= max_cache_query_inspections:
205+
return _CacheInfo(has_valid_cache=False)
204206
return _CacheInfo(has_valid_cache=False)
205207

206208

@@ -302,6 +304,7 @@ def _resolve_query_with_cache(
302304
boto3_session=session,
303305
categories=categories,
304306
query_execution_payload=cache_info.query_execution_payload,
307+
metadata_cache_manager=_cache_manager,
305308
)
306309
if cache_info.file_format == "parquet":
307310
return _fetch_parquet_result(
@@ -380,6 +383,7 @@ def _resolve_query_without_cache_ctas(
380383
query_execution_id=query_id,
381384
boto3_session=boto3_session,
382385
categories=categories,
386+
metadata_cache_manager=_cache_manager,
383387
)
384388
except exceptions.QueryFailed as ex:
385389
msg: str = str(ex)
@@ -439,6 +443,7 @@ def _resolve_query_without_cache_regular(
439443
query_execution_id=query_id,
440444
boto3_session=boto3_session,
441445
categories=categories,
446+
metadata_cache_manager=_cache_manager,
442447
)
443448
return _fetch_csv_result(
444449
query_metadata=query_metadata,
@@ -532,6 +537,8 @@ def read_sql_query(
532537
boto3_session: Optional[boto3.Session] = None,
533538
max_cache_seconds: int = 0,
534539
max_cache_query_inspections: int = 50,
540+
max_remote_cache_entries: int = 50,
541+
max_local_cache_entries: int = 100,
535542
data_source: Optional[str] = None,
536543
params: Optional[Dict[str, Any]] = None,
537544
) -> Union[pd.DataFrame, Iterator[pd.DataFrame]]:
@@ -678,6 +685,15 @@ def read_sql_query(
678685
Max number of queries that will be inspected from the history to try to find some result to reuse.
679686
The bigger the number of inspection, the bigger will be the latency for not cached queries.
680687
Only takes effect if max_cache_seconds > 0.
688+
max_remote_cache_entries : int
689+
Max number of queries that will be retrieved from AWS for cache inspection.
690+
The bigger the number of inspection, the bigger will be the latency for not cached queries.
691+
Only takes effect if max_cache_seconds > 0 and default value is 50.
692+
max_local_cache_entries : int
693+
Max number of queries for which metadata will be cached locally. This will reduce the latency and also
694+
enables keeping more than `max_remote_cache_entries` available for the cache. This value should not be
695+
smaller than max_remote_cache_entries.
696+
Only takes effect if max_cache_seconds > 0 and default value is 100.
681697
data_source : str, optional
682698
Data Source / Catalog name. If None, 'AwsDataCatalog' will be used by default.
683699
params: Dict[str, any], optional
@@ -718,12 +734,17 @@ def read_sql_query(
718734
for key, value in params.items():
719735
sql = sql.replace(f":{key};", str(value))
720736

737+
if max_remote_cache_entries > max_local_cache_entries:
738+
max_remote_cache_entries = max_local_cache_entries
739+
740+
_cache_manager.max_cache_size = max_local_cache_entries
721741
cache_info: _CacheInfo = _check_for_cached_results(
722742
sql=sql,
723743
boto3_session=session,
724744
workgroup=workgroup,
725745
max_cache_seconds=max_cache_seconds,
726746
max_cache_query_inspections=max_cache_query_inspections,
747+
max_remote_cache_entries=max_remote_cache_entries,
727748
)
728749
_logger.debug("cache_info:\n%s", cache_info)
729750
if cache_info.has_valid_cache is True:
@@ -774,6 +795,8 @@ def read_sql_table(
774795
boto3_session: Optional[boto3.Session] = None,
775796
max_cache_seconds: int = 0,
776797
max_cache_query_inspections: int = 50,
798+
max_remote_cache_entries: int = 50,
799+
max_local_cache_entries: int = 100,
777800
data_source: Optional[str] = None,
778801
) -> Union[pd.DataFrame, Iterator[pd.DataFrame]]:
779802
"""Extract the full table AWS Athena and return the results as a Pandas DataFrame.
@@ -914,6 +937,15 @@ def read_sql_table(
914937
Max number of queries that will be inspected from the history to try to find some result to reuse.
915938
The bigger the number of inspection, the bigger will be the latency for not cached queries.
916939
Only takes effect if max_cache_seconds > 0.
940+
max_remote_cache_entries : int
941+
Max number of queries that will be retrieved from AWS for cache inspection.
942+
The bigger the number of inspection, the bigger will be the latency for not cached queries.
943+
Only takes effect if max_cache_seconds > 0 and default value is 50.
944+
max_local_cache_entries : int
945+
Max number of queries for which metadata will be cached locally. This will reduce the latency and also
946+
enables keeping more than `max_remote_cache_entries` available for the cache. This value should not be
947+
smaller than max_remote_cache_entries.
948+
Only takes effect if max_cache_seconds > 0 and default value is 100.
917949
data_source : str, optional
918950
Data Source / Catalog name. If None, 'AwsDataCatalog' will be used by default.
919951
@@ -947,4 +979,9 @@ def read_sql_table(
947979
boto3_session=boto3_session,
948980
max_cache_seconds=max_cache_seconds,
949981
max_cache_query_inspections=max_cache_query_inspections,
982+
max_remote_cache_entries=max_remote_cache_entries,
983+
max_local_cache_entries=max_local_cache_entries,
950984
)
985+
986+
987+
_cache_manager = _LocalMetadataCacheManager()

awswrangler/athena/_utils.py

Lines changed: 71 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,11 +1,13 @@
11
"""Utilities Module for Amazon Athena."""
22
import csv
3+
import datetime
34
import logging
45
import pprint
56
import time
67
import warnings
78
from decimal import Decimal
8-
from typing import Any, Dict, Generator, List, NamedTuple, Optional, Union, cast
9+
from heapq import heappop, heappush
10+
from typing import Any, Dict, Generator, List, NamedTuple, Optional, Tuple, Union, cast
911

1012
import boto3
1113
import botocore.exceptions
@@ -39,6 +41,71 @@ class _WorkGroupConfig(NamedTuple):
3941
kms_key: Optional[str]
4042

4143

44+
class _LocalMetadataCacheManager:
45+
def __init__(self) -> None:
46+
self._cache: Dict[str, Any] = dict()
47+
self._pqueue: List[Tuple[datetime.datetime, str]] = []
48+
self._max_cache_size = 100
49+
50+
def update_cache(self, items: List[Dict[str, Any]]) -> None:
51+
"""
52+
Update the local metadata cache with new query metadata.
53+
54+
Parameters
55+
----------
56+
items : List[Dict[str, Any]]
57+
List of query execution metadata which is returned by boto3 `batch_get_query_execution()`.
58+
59+
Returns
60+
-------
61+
None
62+
None.
63+
"""
64+
if self._pqueue:
65+
oldest_item = self._cache[self._pqueue[0][1]]
66+
items = list(
67+
filter(lambda x: x["Status"]["SubmissionDateTime"] > oldest_item["Status"]["SubmissionDateTime"], items)
68+
)
69+
70+
cache_oversize = len(self._cache) + len(items) - self._max_cache_size
71+
for _ in range(cache_oversize):
72+
_, query_execution_id = heappop(self._pqueue)
73+
del self._cache[query_execution_id]
74+
75+
for item in items[: self._max_cache_size]:
76+
heappush(self._pqueue, (item["Status"]["SubmissionDateTime"], item["QueryExecutionId"]))
77+
self._cache[item["QueryExecutionId"]] = item
78+
79+
def sorted_successful_generator(self) -> List[Dict[str, Any]]:
80+
"""
81+
Sorts the entries in the local cache based on query Completion DateTime.
82+
83+
This is useful to guarantee LRU caching rules.
84+
85+
Returns
86+
-------
87+
List[Dict[str, Any]]
88+
Returns successful DDL and DML queries sorted by query completion time.
89+
"""
90+
filtered: List[Dict[str, Any]] = []
91+
for query in self._cache.values():
92+
if (query["Status"].get("State") == "SUCCEEDED") and (query.get("StatementType") in ["DDL", "DML"]):
93+
filtered.append(query)
94+
return sorted(filtered, key=lambda e: str(e["Status"]["CompletionDateTime"]), reverse=True)
95+
96+
def __contains__(self, key: str) -> bool:
97+
return key in self._cache
98+
99+
@property
100+
def max_cache_size(self) -> int:
101+
"""Property max_cache_size."""
102+
return self._max_cache_size
103+
104+
@max_cache_size.setter
105+
def max_cache_size(self, value: int) -> None:
106+
self._max_cache_size = value
107+
108+
42109
def _get_s3_output(s3_output: Optional[str], wg_config: _WorkGroupConfig, boto3_session: boto3.Session) -> str:
43110
if wg_config.enforced and wg_config.s3_output is not None:
44111
return wg_config.s3_output
@@ -171,6 +238,7 @@ def _get_query_metadata( # pylint: disable=too-many-statements
171238
boto3_session: boto3.Session,
172239
categories: Optional[List[str]] = None,
173240
query_execution_payload: Optional[Dict[str, Any]] = None,
241+
metadata_cache_manager: Optional[_LocalMetadataCacheManager] = None,
174242
) -> _QueryMetadata:
175243
"""Get query metadata."""
176244
if (query_execution_payload is not None) and (query_execution_payload["Status"]["State"] in _QUERY_FINAL_STATES):
@@ -224,6 +292,8 @@ def _get_query_metadata( # pylint: disable=too-many-statements
224292
athena_statistics: Dict[str, Union[int, str]] = _query_execution_payload.get("Statistics", {})
225293
manifest_location: Optional[str] = str(athena_statistics.get("DataManifestLocation"))
226294

295+
if metadata_cache_manager is not None and query_execution_id not in metadata_cache_manager:
296+
metadata_cache_manager.update_cache(items=[_query_execution_payload])
227297
query_metadata: _QueryMetadata = _QueryMetadata(
228298
execution_id=query_execution_id,
229299
dtype=dtype,

0 commit comments

Comments
 (0)