Skip to content

Commit dad9922

Browse files
authored
fix: KeyError & add lock to athena cache manager (#2299)
* Fix KeyError & add lock to atena cache manager * [skip ci] Mypy fix
1 parent bceb9cd commit dad9922

File tree

1 file changed

+19
-12
lines changed

1 file changed

+19
-12
lines changed

awswrangler/athena/_cache.py

Lines changed: 19 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -2,6 +2,7 @@
22
import datetime
33
import logging
44
import re
5+
import threading
56
from heapq import heappop, heappush
67
from typing import TYPE_CHECKING, Any, Dict, List, Match, NamedTuple, Optional, Tuple, Union
78

@@ -24,6 +25,7 @@ class _CacheInfo(NamedTuple):
2425

2526
class _LocalMetadataCacheManager:
2627
def __init__(self) -> None:
28+
self._lock: threading.Lock = threading.Lock()
2729
self._cache: Dict[str, Any] = {}
2830
self._pqueue: List[Tuple[datetime.datetime, str]] = []
2931
self._max_cache_size = 100
@@ -42,20 +44,25 @@ def update_cache(self, items: List[Dict[str, Any]]) -> None:
4244
None
4345
None.
4446
"""
45-
if self._pqueue:
46-
oldest_item = self._cache[self._pqueue[0][1]]
47-
items = list(
48-
filter(lambda x: x["Status"]["SubmissionDateTime"] > oldest_item["Status"]["SubmissionDateTime"], items)
49-
)
47+
with self._lock:
48+
if self._pqueue:
49+
oldest_item = self._cache.get(self._pqueue[0][1])
50+
if oldest_item:
51+
items = list(
52+
filter(
53+
lambda x: x["Status"]["SubmissionDateTime"] > oldest_item["Status"]["SubmissionDateTime"], # type: ignore[arg-type]
54+
items,
55+
)
56+
)
5057

51-
cache_oversize = len(self._cache) + len(items) - self._max_cache_size
52-
for _ in range(cache_oversize):
53-
_, query_execution_id = heappop(self._pqueue)
54-
del self._cache[query_execution_id]
58+
cache_oversize = len(self._cache) + len(items) - self._max_cache_size
59+
for _ in range(cache_oversize):
60+
_, query_execution_id = heappop(self._pqueue)
61+
del self._cache[query_execution_id]
5562

56-
for item in items[: self._max_cache_size]:
57-
heappush(self._pqueue, (item["Status"]["SubmissionDateTime"], item["QueryExecutionId"]))
58-
self._cache[item["QueryExecutionId"]] = item
63+
for item in items[: self._max_cache_size]:
64+
heappush(self._pqueue, (item["Status"]["SubmissionDateTime"], item["QueryExecutionId"]))
65+
self._cache[item["QueryExecutionId"]] = item
5966

6067
def sorted_successful_generator(self) -> List[Dict[str, Any]]:
6168
"""

0 commit comments

Comments
 (0)