|
1 | 1 | from __future__ import annotations |
2 | 2 |
|
3 | | -import contextlib |
4 | | -import dbm |
5 | 3 | import json |
6 | 4 | import logging |
| 5 | +import os |
| 6 | +import sqlite3 |
7 | 7 | import string |
8 | 8 | import threading |
9 | 9 | import time |
10 | 10 | import typing as T |
| 11 | +from functools import wraps |
11 | 12 | from pathlib import Path |
12 | 13 |
|
13 | | -# dbm modules are dynamically imported, so here we explicitly import dbm.sqlite3 to make sure pyinstaller include it |
14 | | -# Otherwise you will see: ImportError: no dbm clone found; tried ['dbm.sqlite3', 'dbm.gnu', 'dbm.ndbm', 'dbm.dumb'] |
15 | | -try: |
16 | | - import dbm.sqlite3 # type: ignore |
17 | | -except ImportError: |
18 | | - pass |
19 | | - |
20 | | - |
21 | | -from . import constants, types |
| 14 | +from . import constants, store, types |
22 | 15 | from .serializer.description import DescriptionJSONSerializer |
23 | 16 |
|
24 | 17 | JSONDict = T.Dict[str, T.Union[str, int, float, None]] |
@@ -85,103 +78,140 @@ def write_history( |
85 | 78 | fp.write(json.dumps(history)) |
86 | 79 |
|
87 | 80 |
|
| 81 | +def _retry_on_database_lock_error(fn): |
| 82 | + """ |
| 83 | + Decorator to retry a function if it raises a sqlite3.OperationalError with |
| 84 | + "database is locked" in the message. |
| 85 | + """ |
| 86 | + |
| 87 | + @wraps(fn) |
| 88 | + def wrapper(*args, **kwargs): |
| 89 | + while True: |
| 90 | + try: |
| 91 | + return fn(*args, **kwargs) |
| 92 | + except sqlite3.OperationalError as ex: |
| 93 | + if "database is locked" in str(ex).lower(): |
| 94 | + LOG.warning(f"{str(ex)}") |
| 95 | + LOG.info("Retrying in 1 second...") |
| 96 | + time.sleep(1) |
| 97 | + else: |
| 98 | + raise ex |
| 99 | + |
| 100 | + return wrapper |
| 101 | + |
| 102 | + |
88 | 103 | class PersistentCache: |
89 | | - _lock: contextlib.nullcontext | threading.Lock |
| 104 | + _lock: threading.Lock |
90 | 105 |
|
91 | 106 | def __init__(self, file: str): |
92 | | - # SQLite3 backend supports concurrent access without a lock |
93 | | - if dbm.whichdb(file) == "dbm.sqlite3": |
94 | | - self._lock = contextlib.nullcontext() |
95 | | - else: |
96 | | - self._lock = threading.Lock() |
97 | 107 | self._file = file |
| 108 | + self._lock = threading.Lock() |
98 | 109 |
|
99 | 110 | def get(self, key: str) -> str | None: |
| 111 | + if not self._db_existed(): |
| 112 | + return None |
| 113 | + |
100 | 114 | s = time.perf_counter() |
101 | 115 |
|
102 | | - with self._lock: |
103 | | - with dbm.open(self._file, flag="c") as db: |
104 | | - value: bytes | None = db.get(key) |
| 116 | + with store.KeyValueStore(self._file, flag="r") as db: |
| 117 | + try: |
| 118 | + raw_payload: bytes | None = db.get(key) # data retrieved from db[key] |
| 119 | + except Exception as ex: |
| 120 | + if self._table_not_found(ex): |
| 121 | + return None |
| 122 | + raise ex |
105 | 123 |
|
106 | | - if value is None: |
| 124 | + if raw_payload is None: |
107 | 125 | return None |
108 | 126 |
|
109 | | - payload = self._decode(value) |
| 127 | + data: JSONDict = self._decode(raw_payload) # JSON dict decoded from db[key] |
110 | 128 |
|
111 | | - if self._is_expired(payload): |
| 129 | + if self._is_expired(data): |
112 | 130 | return None |
113 | 131 |
|
114 | | - file_handle = payload.get("file_handle") |
| 132 | + cached_value = data.get("value") # value in the JSON dict decoded from db[key] |
115 | 133 |
|
116 | 134 | LOG.debug( |
117 | 135 | f"Found file handle for {key} in cache ({(time.perf_counter() - s) * 1000:.0f} ms)" |
118 | 136 | ) |
119 | 137 |
|
120 | | - return T.cast(str, file_handle) |
| 138 | + return T.cast(str, cached_value) |
121 | 139 |
|
122 | | - def set(self, key: str, file_handle: str, expires_in: int = 3600 * 24 * 2) -> None: |
| 140 | + @_retry_on_database_lock_error |
| 141 | + def set(self, key: str, value: str, expires_in: int = 3600 * 24 * 2) -> None: |
123 | 142 | s = time.perf_counter() |
124 | 143 |
|
125 | | - payload = { |
| 144 | + data = { |
126 | 145 | "expires_at": time.time() + expires_in, |
127 | | - "file_handle": file_handle, |
| 146 | + "value": value, |
128 | 147 | } |
129 | 148 |
|
130 | | - value: bytes = json.dumps(payload).encode("utf-8") |
| 149 | + payload: bytes = json.dumps(data).encode("utf-8") |
131 | 150 |
|
132 | 151 | with self._lock: |
133 | | - with dbm.open(self._file, flag="c") as db: |
134 | | - db[key] = value |
| 152 | + with store.KeyValueStore(self._file, flag="c") as db: |
| 153 | + db[key] = payload |
135 | 154 |
|
136 | 155 | LOG.debug( |
137 | 156 | f"Cached file handle for {key} ({(time.perf_counter() - s) * 1000:.0f} ms)" |
138 | 157 | ) |
139 | 158 |
|
| 159 | + @_retry_on_database_lock_error |
140 | 160 | def clear_expired(self) -> list[str]: |
141 | | - s = time.perf_counter() |
142 | | - |
143 | 161 | expired_keys: list[str] = [] |
144 | 162 |
|
145 | | - with self._lock: |
146 | | - with dbm.open(self._file, flag="c") as db: |
147 | | - if hasattr(db, "items"): |
148 | | - items: T.Iterable[tuple[str | bytes, bytes]] = db.items() |
149 | | - else: |
150 | | - items = ((key, db[key]) for key in db.keys()) |
| 163 | + s = time.perf_counter() |
151 | 164 |
|
152 | | - for key, value in items: |
153 | | - payload = self._decode(value) |
154 | | - if self._is_expired(payload): |
| 165 | + with self._lock: |
| 166 | + with store.KeyValueStore(self._file, flag="c") as db: |
| 167 | + for key, raw_payload in db.items(): |
| 168 | + data = self._decode(raw_payload) |
| 169 | + if self._is_expired(data): |
155 | 170 | del db[key] |
156 | 171 | expired_keys.append(T.cast(str, key)) |
157 | 172 |
|
158 | | - if expired_keys: |
159 | | - LOG.debug( |
160 | | - f"Cleared {len(expired_keys)} expired entries from the cache ({(time.perf_counter() - s) * 1000:.0f} ms)" |
161 | | - ) |
| 173 | + LOG.debug( |
| 174 | + f"Cleared {len(expired_keys)} expired entries from the cache ({(time.perf_counter() - s) * 1000:.0f} ms)" |
| 175 | + ) |
162 | 176 |
|
163 | 177 | return expired_keys |
164 | 178 |
|
165 | | - def keys(self): |
166 | | - with self._lock: |
167 | | - with dbm.open(self._file, flag="c") as db: |
168 | | - return db.keys() |
| 179 | + def keys(self) -> list[str]: |
| 180 | + if not self._db_existed(): |
| 181 | + return [] |
169 | 182 |
|
170 | | - def _is_expired(self, payload: JSONDict) -> bool: |
171 | | - expires_at = payload.get("expires_at") |
| 183 | + try: |
| 184 | + with store.KeyValueStore(self._file, flag="r") as db: |
| 185 | + return [key.decode("utf-8") for key in db.keys()] |
| 186 | + except Exception as ex: |
| 187 | + if self._table_not_found(ex): |
| 188 | + return [] |
| 189 | + raise ex |
| 190 | + |
| 191 | + def _is_expired(self, data: JSONDict) -> bool: |
| 192 | + expires_at = data.get("expires_at") |
172 | 193 | if isinstance(expires_at, (int, float)): |
173 | 194 | return expires_at is None or expires_at <= time.time() |
174 | 195 | return False |
175 | 196 |
|
176 | | - def _decode(self, value: bytes) -> JSONDict: |
| 197 | + def _decode(self, raw_payload: bytes) -> JSONDict: |
177 | 198 | try: |
178 | | - payload = json.loads(value.decode("utf-8")) |
| 199 | + data = json.loads(raw_payload.decode("utf-8")) |
179 | 200 | except json.JSONDecodeError as ex: |
180 | 201 | LOG.warning(f"Failed to decode cache value: {ex}") |
181 | 202 | return {} |
182 | 203 |
|
183 | | - if not isinstance(payload, dict): |
184 | | - LOG.warning(f"Invalid cache value format: {payload}") |
| 204 | + if not isinstance(data, dict): |
| 205 | + LOG.warning(f"Invalid cache value format: {raw_payload!r}") |
185 | 206 | return {} |
186 | 207 |
|
187 | | - return payload |
| 208 | + return data |
| 209 | + |
| 210 | + def _db_existed(self) -> bool: |
| 211 | + return os.path.exists(self._file) |
| 212 | + |
| 213 | + def _table_not_found(self, ex: Exception) -> bool: |
| 214 | + if isinstance(ex, sqlite3.OperationalError): |
| 215 | + if "no such table" in str(ex): |
| 216 | + return True |
| 217 | + return False |
0 commit comments