Skip to content

Commit 25a08de

Browse files
authored
fix: database lock (#781)
* first try * fix * add more tests * tests * add store.py * use key value store * tests * refactor * tests * types
1 parent 9ef9736 commit 25a08de

File tree

5 files changed

+1083
-165
lines changed

5 files changed

+1083
-165
lines changed

mapillary_tools/history.py

Lines changed: 87 additions & 57 deletions
Original file line numberDiff line numberDiff line change
@@ -1,24 +1,17 @@
11
from __future__ import annotations
22

3-
import contextlib
4-
import dbm
53
import json
64
import logging
5+
import os
6+
import sqlite3
77
import string
88
import threading
99
import time
1010
import typing as T
11+
from functools import wraps
1112
from pathlib import Path
1213

13-
# dbm modules are dynamically imported, so here we explicitly import dbm.sqlite3 to make sure pyinstaller include it
14-
# Otherwise you will see: ImportError: no dbm clone found; tried ['dbm.sqlite3', 'dbm.gnu', 'dbm.ndbm', 'dbm.dumb']
15-
try:
16-
import dbm.sqlite3 # type: ignore
17-
except ImportError:
18-
pass
19-
20-
21-
from . import constants, types
14+
from . import constants, store, types
2215
from .serializer.description import DescriptionJSONSerializer
2316

2417
JSONDict = T.Dict[str, T.Union[str, int, float, None]]
@@ -85,103 +78,140 @@ def write_history(
8578
fp.write(json.dumps(history))
8679

8780

81+
def _retry_on_database_lock_error(fn):
82+
"""
83+
Decorator to retry a function if it raises a sqlite3.OperationalError with
84+
"database is locked" in the message.
85+
"""
86+
87+
@wraps(fn)
88+
def wrapper(*args, **kwargs):
89+
while True:
90+
try:
91+
return fn(*args, **kwargs)
92+
except sqlite3.OperationalError as ex:
93+
if "database is locked" in str(ex).lower():
94+
LOG.warning(f"{str(ex)}")
95+
LOG.info("Retrying in 1 second...")
96+
time.sleep(1)
97+
else:
98+
raise ex
99+
100+
return wrapper
101+
102+
88103
class PersistentCache:
89-
_lock: contextlib.nullcontext | threading.Lock
104+
_lock: threading.Lock
90105

91106
def __init__(self, file: str):
92-
# SQLite3 backend supports concurrent access without a lock
93-
if dbm.whichdb(file) == "dbm.sqlite3":
94-
self._lock = contextlib.nullcontext()
95-
else:
96-
self._lock = threading.Lock()
97107
self._file = file
108+
self._lock = threading.Lock()
98109

99110
def get(self, key: str) -> str | None:
111+
if not self._db_existed():
112+
return None
113+
100114
s = time.perf_counter()
101115

102-
with self._lock:
103-
with dbm.open(self._file, flag="c") as db:
104-
value: bytes | None = db.get(key)
116+
with store.KeyValueStore(self._file, flag="r") as db:
117+
try:
118+
raw_payload: bytes | None = db.get(key) # data retrieved from db[key]
119+
except Exception as ex:
120+
if self._table_not_found(ex):
121+
return None
122+
raise ex
105123

106-
if value is None:
124+
if raw_payload is None:
107125
return None
108126

109-
payload = self._decode(value)
127+
data: JSONDict = self._decode(raw_payload) # JSON dict decoded from db[key]
110128

111-
if self._is_expired(payload):
129+
if self._is_expired(data):
112130
return None
113131

114-
file_handle = payload.get("file_handle")
132+
cached_value = data.get("value") # value in the JSON dict decoded from db[key]
115133

116134
LOG.debug(
117135
f"Found file handle for {key} in cache ({(time.perf_counter() - s) * 1000:.0f} ms)"
118136
)
119137

120-
return T.cast(str, file_handle)
138+
return T.cast(str, cached_value)
121139

122-
def set(self, key: str, file_handle: str, expires_in: int = 3600 * 24 * 2) -> None:
140+
@_retry_on_database_lock_error
141+
def set(self, key: str, value: str, expires_in: int = 3600 * 24 * 2) -> None:
123142
s = time.perf_counter()
124143

125-
payload = {
144+
data = {
126145
"expires_at": time.time() + expires_in,
127-
"file_handle": file_handle,
146+
"value": value,
128147
}
129148

130-
value: bytes = json.dumps(payload).encode("utf-8")
149+
payload: bytes = json.dumps(data).encode("utf-8")
131150

132151
with self._lock:
133-
with dbm.open(self._file, flag="c") as db:
134-
db[key] = value
152+
with store.KeyValueStore(self._file, flag="c") as db:
153+
db[key] = payload
135154

136155
LOG.debug(
137156
f"Cached file handle for {key} ({(time.perf_counter() - s) * 1000:.0f} ms)"
138157
)
139158

159+
@_retry_on_database_lock_error
140160
def clear_expired(self) -> list[str]:
141-
s = time.perf_counter()
142-
143161
expired_keys: list[str] = []
144162

145-
with self._lock:
146-
with dbm.open(self._file, flag="c") as db:
147-
if hasattr(db, "items"):
148-
items: T.Iterable[tuple[str | bytes, bytes]] = db.items()
149-
else:
150-
items = ((key, db[key]) for key in db.keys())
163+
s = time.perf_counter()
151164

152-
for key, value in items:
153-
payload = self._decode(value)
154-
if self._is_expired(payload):
165+
with self._lock:
166+
with store.KeyValueStore(self._file, flag="c") as db:
167+
for key, raw_payload in db.items():
168+
data = self._decode(raw_payload)
169+
if self._is_expired(data):
155170
del db[key]
156171
expired_keys.append(T.cast(str, key))
157172

158-
if expired_keys:
159-
LOG.debug(
160-
f"Cleared {len(expired_keys)} expired entries from the cache ({(time.perf_counter() - s) * 1000:.0f} ms)"
161-
)
173+
LOG.debug(
174+
f"Cleared {len(expired_keys)} expired entries from the cache ({(time.perf_counter() - s) * 1000:.0f} ms)"
175+
)
162176

163177
return expired_keys
164178

165-
def keys(self):
166-
with self._lock:
167-
with dbm.open(self._file, flag="c") as db:
168-
return db.keys()
179+
def keys(self) -> list[str]:
180+
if not self._db_existed():
181+
return []
169182

170-
def _is_expired(self, payload: JSONDict) -> bool:
171-
expires_at = payload.get("expires_at")
183+
try:
184+
with store.KeyValueStore(self._file, flag="r") as db:
185+
return [key.decode("utf-8") for key in db.keys()]
186+
except Exception as ex:
187+
if self._table_not_found(ex):
188+
return []
189+
raise ex
190+
191+
def _is_expired(self, data: JSONDict) -> bool:
192+
expires_at = data.get("expires_at")
172193
if isinstance(expires_at, (int, float)):
173194
return expires_at is None or expires_at <= time.time()
174195
return False
175196

176-
def _decode(self, value: bytes) -> JSONDict:
197+
def _decode(self, raw_payload: bytes) -> JSONDict:
177198
try:
178-
payload = json.loads(value.decode("utf-8"))
199+
data = json.loads(raw_payload.decode("utf-8"))
179200
except json.JSONDecodeError as ex:
180201
LOG.warning(f"Failed to decode cache value: {ex}")
181202
return {}
182203

183-
if not isinstance(payload, dict):
184-
LOG.warning(f"Invalid cache value format: {payload}")
204+
if not isinstance(data, dict):
205+
LOG.warning(f"Invalid cache value format: {raw_payload!r}")
185206
return {}
186207

187-
return payload
208+
return data
209+
210+
def _db_existed(self) -> bool:
211+
return os.path.exists(self._file)
212+
213+
def _table_not_found(self, ex: Exception) -> bool:
214+
if isinstance(ex, sqlite3.OperationalError):
215+
if "no such table" in str(ex):
216+
return True
217+
return False

mapillary_tools/store.py

Lines changed: 128 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,128 @@
1+
"""
2+
This module provides a persistent key-value store based on SQLite.
3+
4+
This implementation is mostly copied from dbm.sqlite3 in the Python standard library,
5+
but works for Python >= 3.9, whereas dbm.sqlite3 is only available for Python 3.13.
6+
7+
Source: https://github.com/python/cpython/blob/3.13/Lib/dbm/sqlite3.py
8+
"""
9+
10+
import os
11+
import sqlite3
12+
import sys
13+
from collections.abc import MutableMapping
14+
from contextlib import closing, suppress
15+
from pathlib import Path
16+
17+
BUILD_TABLE = """
18+
CREATE TABLE IF NOT EXISTS Dict (
19+
key BLOB UNIQUE NOT NULL,
20+
value BLOB NOT NULL
21+
)
22+
"""
23+
GET_SIZE = "SELECT COUNT (key) FROM Dict"
24+
LOOKUP_KEY = "SELECT value FROM Dict WHERE key = CAST(? AS BLOB)"
25+
STORE_KV = "REPLACE INTO Dict (key, value) VALUES (CAST(? AS BLOB), CAST(? AS BLOB))"
26+
DELETE_KEY = "DELETE FROM Dict WHERE key = CAST(? AS BLOB)"
27+
ITER_KEYS = "SELECT key FROM Dict"
28+
29+
30+
def _normalize_uri(path):
31+
path = Path(path)
32+
uri = path.absolute().as_uri()
33+
while "//" in uri:
34+
uri = uri.replace("//", "/")
35+
return uri
36+
37+
38+
class KeyValueStore(MutableMapping):
39+
def __init__(self, path, /, *, flag="r", mode=0o666):
40+
"""Open a key-value database and return the object.
41+
42+
The 'path' parameter is the name of the database file.
43+
44+
The optional 'flag' parameter can be one of ...:
45+
'r' (default): open an existing database for read only access
46+
'w': open an existing database for read/write access
47+
'c': create a database if it does not exist; open for read/write access
48+
'n': always create a new, empty database; open for read/write access
49+
50+
The optional 'mode' parameter is the Unix file access mode of the database;
51+
only used when creating a new database. Default: 0o666.
52+
"""
53+
path = os.fsdecode(path)
54+
if flag == "r":
55+
flag = "ro"
56+
elif flag == "w":
57+
flag = "rw"
58+
elif flag == "c":
59+
flag = "rwc"
60+
Path(path).touch(mode=mode, exist_ok=True)
61+
elif flag == "n":
62+
flag = "rwc"
63+
Path(path).unlink(missing_ok=True)
64+
Path(path).touch(mode=mode)
65+
else:
66+
raise ValueError(f"Flag must be one of 'r', 'w', 'c', or 'n', not {flag!r}")
67+
68+
# We use the URI format when opening the database.
69+
uri = _normalize_uri(path)
70+
uri = f"{uri}?mode={flag}"
71+
72+
if sys.version_info >= (3, 12):
73+
# This is the preferred way, but only available in Python 3.10 and newer.
74+
self._cx = sqlite3.connect(uri, autocommit=True, uri=True)
75+
else:
76+
self._cx = sqlite3.connect(uri, uri=True)
77+
78+
# This is an optimization only; it's ok if it fails.
79+
with suppress(sqlite3.OperationalError):
80+
self._cx.execute("PRAGMA journal_mode = wal")
81+
82+
if flag == "rwc":
83+
self._execute(BUILD_TABLE)
84+
85+
def _execute(self, *args, **kwargs):
86+
if sys.version_info >= (3, 12):
87+
return closing(self._cx.execute(*args, **kwargs))
88+
else:
89+
# Use a context manager to commit the changes
90+
with self._cx:
91+
return closing(self._cx.execute(*args, **kwargs))
92+
93+
def __len__(self):
94+
with self._execute(GET_SIZE) as cu:
95+
row = cu.fetchone()
96+
return row[0]
97+
98+
def __getitem__(self, key):
99+
with self._execute(LOOKUP_KEY, (key,)) as cu:
100+
row = cu.fetchone()
101+
if not row:
102+
raise KeyError(key)
103+
return row[0]
104+
105+
def __setitem__(self, key, value):
106+
self._execute(STORE_KV, (key, value))
107+
108+
def __delitem__(self, key):
109+
with self._execute(DELETE_KEY, (key,)) as cu:
110+
if not cu.rowcount:
111+
raise KeyError(key)
112+
113+
def __iter__(self):
114+
with self._execute(ITER_KEYS) as cu:
115+
for row in cu:
116+
yield row[0]
117+
118+
def close(self):
119+
self._cx.close()
120+
121+
def keys(self):
122+
return list(super().keys())
123+
124+
def __enter__(self):
125+
return self
126+
127+
def __exit__(self, *args):
128+
self.close()

mapillary_tools/uploader.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1311,7 +1311,7 @@ def _is_uuid(key: str) -> bool:
13111311

13121312

13131313
def _build_upload_cache_path(upload_options: UploadOptions) -> Path:
1314-
# Different python/CLI versions use different cache (dbm) formats.
1314+
# Different python/CLI versions use different cache formats.
13151315
# Separate them to avoid conflicts
13161316
py_version_parts = [str(part) for part in sys.version_info[:3]]
13171317
version = f"py_{'_'.join(py_version_parts)}_{VERSION}"

0 commit comments

Comments
 (0)