Skip to content

Commit 8c96c1a

Browse files
authored
changeable key/value in-db blob size limit (#19669)
* changeable key/value in-db blob size limit * tests * test_get_keys_both_disk_and_db * test_get_keys_values_both_disk_and_db * fixup schema, like moving the index from the blob to the hash
1 parent 9944c39 commit 8c96c1a

File tree

2 files changed

+155
-38
lines changed

2 files changed

+155
-38
lines changed

chia/_tests/core/data_layer/test_data_store.py

Lines changed: 99 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -12,6 +12,7 @@
1212
import time
1313
from collections.abc import Awaitable
1414
from dataclasses import dataclass
15+
from hashlib import sha256
1516
from pathlib import Path
1617
from random import Random
1718
from typing import Any, BinaryIO, Callable, Optional
@@ -60,7 +61,7 @@
6061
"root": ["tree_id", "generation", "node_hash", "status"],
6162
"subscriptions": ["tree_id", "url", "ignore_till", "num_consecutive_failures", "from_wallet"],
6263
"schema": ["version_id", "applied_at"],
63-
"ids": ["kv_id", "blob", "store_id"],
64+
"ids": ["kv_id", "hash", "blob", "store_id"],
6465
"nodes": ["store_id", "hash", "root_hash", "generation", "idx"],
6566
}
6667

@@ -2173,3 +2174,100 @@ async def test_get_existing_hashes(
21732174
not_existing_hashes = [bytes32(i.to_bytes(32, byteorder="big")) for i in range(num_keys)]
21742175
result = await data_store.get_existing_hashes(existing_hashes + not_existing_hashes, store_id)
21752176
assert result == set(existing_hashes)
2177+
2178+
2179+
@pytest.mark.anyio
2180+
@pytest.mark.parametrize(argnames="size_offset", argvalues=[-1, 0, 1])
2181+
async def test_basic_key_value_db_vs_disk_cutoff(
2182+
data_store: DataStore,
2183+
store_id: bytes32,
2184+
seeded_random: random.Random,
2185+
size_offset: int,
2186+
) -> None:
2187+
size = data_store.prefer_file_kv_blob_length + size_offset
2188+
2189+
blob = bytes(seeded_random.getrandbits(8) for _ in range(size))
2190+
blob_hash = bytes32(sha256(blob).digest())
2191+
async with data_store.db_wrapper.writer() as writer:
2192+
await data_store.add_kvid(blob=blob, store_id=store_id, writer=writer)
2193+
2194+
file_exists = data_store.get_key_value_path(store_id=store_id, blob_hash=blob_hash).exists()
2195+
async with data_store.db_wrapper.writer() as writer:
2196+
async with writer.execute(
2197+
"SELECT blob FROM ids WHERE hash = :blob_hash",
2198+
{"blob_hash": blob_hash},
2199+
) as cursor:
2200+
row = await cursor.fetchone()
2201+
assert row is not None
2202+
db_blob: Optional[bytes] = row["blob"]
2203+
2204+
if size_offset <= 0:
2205+
assert not file_exists
2206+
assert db_blob == blob
2207+
else:
2208+
assert file_exists
2209+
assert db_blob is None
2210+
2211+
2212+
@pytest.mark.anyio
2213+
@pytest.mark.parametrize(argnames="size_offset", argvalues=[-1, 0, 1])
2214+
@pytest.mark.parametrize(argnames="limit_change", argvalues=[-2, -1, 1, 2])
2215+
async def test_changing_key_value_db_vs_disk_cutoff(
2216+
data_store: DataStore,
2217+
store_id: bytes32,
2218+
seeded_random: random.Random,
2219+
size_offset: int,
2220+
limit_change: int,
2221+
) -> None:
2222+
size = data_store.prefer_file_kv_blob_length + size_offset
2223+
2224+
blob = bytes(seeded_random.getrandbits(8) for _ in range(size))
2225+
async with data_store.db_wrapper.writer() as writer:
2226+
kv_id = await data_store.add_kvid(blob=blob, store_id=store_id, writer=writer)
2227+
2228+
data_store.prefer_file_kv_blob_length += limit_change
2229+
retrieved_blob = await data_store.get_blob_from_kvid(kv_id=kv_id, store_id=store_id)
2230+
2231+
assert blob == retrieved_blob
2232+
2233+
2234+
@pytest.mark.anyio
2235+
async def test_get_keys_both_disk_and_db(
2236+
data_store: DataStore,
2237+
store_id: bytes32,
2238+
seeded_random: random.Random,
2239+
) -> None:
2240+
inserted_keys: set[bytes] = set()
2241+
2242+
for size_offset in [-1, 0, 1]:
2243+
size = data_store.prefer_file_kv_blob_length + size_offset
2244+
2245+
blob = bytes(seeded_random.getrandbits(8) for _ in range(size))
2246+
await data_store.insert(key=blob, value=b"", store_id=store_id, status=Status.COMMITTED)
2247+
inserted_keys.add(blob)
2248+
2249+
retrieved_keys = set(await data_store.get_keys(store_id=store_id))
2250+
2251+
assert retrieved_keys == inserted_keys
2252+
2253+
2254+
@pytest.mark.anyio
2255+
async def test_get_keys_values_both_disk_and_db(
2256+
data_store: DataStore,
2257+
store_id: bytes32,
2258+
seeded_random: random.Random,
2259+
) -> None:
2260+
inserted_keys_values: dict[bytes, bytes] = {}
2261+
2262+
for size_offset in [-1, 0, 1]:
2263+
size = data_store.prefer_file_kv_blob_length + size_offset
2264+
2265+
key = bytes(seeded_random.getrandbits(8) for _ in range(size))
2266+
value = bytes(seeded_random.getrandbits(8) for _ in range(size))
2267+
await data_store.insert(key=key, value=value, store_id=store_id, status=Status.COMMITTED)
2268+
inserted_keys_values[key] = value
2269+
2270+
terminal_nodes = await data_store.get_keys_values(store_id=store_id)
2271+
retrieved_keys_values = {node.key: node.value for node in terminal_nodes}
2272+
2273+
assert retrieved_keys_values == inserted_keys_values

chia/data_layer/data_store.py

Lines changed: 56 additions & 37 deletions
Original file line numberDiff line numberDiff line change
@@ -80,6 +80,7 @@ class DataStore:
8080
recent_merkle_blobs: LRUCache[bytes32, MerkleBlob]
8181
merkle_blobs_path: Path
8282
key_value_blobs_path: Path
83+
prefer_file_kv_blob_length: int = 4096
8384

8485
@classmethod
8586
@contextlib.asynccontextmanager
@@ -156,6 +157,7 @@ async def managed(
156157
"""
157158
CREATE TABLE IF NOT EXISTS ids(
158159
kv_id INTEGER PRIMARY KEY,
160+
hash BLOB NOT NULL CHECK(length(store_id) == 32),
159161
blob BLOB,
160162
store_id BLOB NOT NULL CHECK(length(store_id) == 32)
161163
)
@@ -175,7 +177,7 @@ async def managed(
175177
)
176178
await writer.execute(
177179
"""
178-
CREATE UNIQUE INDEX IF NOT EXISTS ids_blob_index ON ids(blob, store_id)
180+
CREATE UNIQUE INDEX IF NOT EXISTS ids_hash_index ON ids(hash, store_id)
179181
"""
180182
)
181183
await writer.execute(
@@ -562,20 +564,17 @@ async def insert_root_from_merkle_blob(
562564

563565
return await self._insert_root(store_id, root_hash, status)
564566

565-
def _kvid_blob_is_file(self, blob: bytes) -> bool:
566-
return len(blob) >= len(bytes32.zeros)
567+
def _use_file_for_new_kv_blob(self, blob: bytes) -> bool:
568+
return len(blob) > self.prefer_file_kv_blob_length
567569

568570
async def get_kvid(self, blob: bytes, store_id: bytes32) -> Optional[KeyOrValueId]:
569-
if self._kvid_blob_is_file(blob):
570-
table_blob = sha256(blob).digest()
571-
else:
572-
table_blob = blob
571+
blob_hash = bytes32(sha256(blob).digest())
573572

574573
async with self.db_wrapper.reader() as reader:
575574
cursor = await reader.execute(
576-
"SELECT kv_id FROM ids WHERE blob = ? AND store_id = ?",
575+
"SELECT kv_id FROM ids WHERE hash = ? AND store_id = ?",
577576
(
578-
table_blob,
577+
blob_hash,
579578
store_id,
580579
),
581580
)
@@ -586,19 +585,15 @@ async def get_kvid(self, blob: bytes, store_id: bytes32) -> Optional[KeyOrValueI
586585

587586
return KeyOrValueId(row[0])
588587

589-
def get_blob_from_table_blob(self, table_blob: bytes, store_id: bytes32) -> bytes:
590-
if not self._kvid_blob_is_file(table_blob):
591-
return table_blob
592-
593-
blob_hash = bytes32(table_blob)
588+
def get_blob_from_file(self, blob_hash: bytes32, store_id: bytes32) -> bytes:
594589
# TODO: seems that zstd needs hinting
595590
# TODO: consider file-system based locking of either the file or the store directory
596591
return zstd.decompress(self.get_key_value_path(store_id=store_id, blob_hash=blob_hash).read_bytes()) # type: ignore[no-any-return]
597592

598593
async def get_blob_from_kvid(self, kv_id: KeyOrValueId, store_id: bytes32) -> Optional[bytes]:
599594
async with self.db_wrapper.reader() as reader:
600595
cursor = await reader.execute(
601-
"SELECT blob FROM ids WHERE kv_id = ? AND store_id = ?",
596+
"SELECT hash, blob FROM ids WHERE kv_id = ? AND store_id = ?",
602597
(
603598
kv_id,
604599
store_id,
@@ -609,7 +604,12 @@ async def get_blob_from_kvid(self, kv_id: KeyOrValueId, store_id: bytes32) -> Op
609604
if row is None:
610605
return None
611606

612-
return self.get_blob_from_table_blob(bytes(row[0]), store_id)
607+
blob: bytes = row["blob"]
608+
if blob is not None:
609+
return blob
610+
611+
blob_hash = bytes32(row["hash"])
612+
return self.get_blob_from_file(blob_hash, store_id)
613613

614614
async def get_terminal_node(self, kid: KeyId, vid: ValueId, store_id: bytes32) -> TerminalNode:
615615
key = await self.get_blob_from_kvid(kid.raw, store_id)
@@ -620,15 +620,17 @@ async def get_terminal_node(self, kid: KeyId, vid: ValueId, store_id: bytes32) -
620620
return TerminalNode(hash=leaf_hash(key, value), key=key, value=value)
621621

622622
async def add_kvid(self, blob: bytes, store_id: bytes32, writer: aiosqlite.Connection) -> KeyOrValueId:
623-
is_file = self._kvid_blob_is_file(blob)
624-
if is_file:
625-
table_blob = sha256(blob).digest()
623+
use_file = self._use_file_for_new_kv_blob(blob)
624+
blob_hash = bytes32(sha256(blob).digest())
625+
if use_file:
626+
table_blob = None
626627
else:
627628
table_blob = blob
628629
try:
629630
row = await writer.execute_insert(
630-
"INSERT INTO ids (blob, store_id) VALUES (?, ?)",
631+
"INSERT INTO ids (hash, blob, store_id) VALUES (?, ?, ?)",
631632
(
633+
blob_hash,
632634
table_blob,
633635
store_id,
634636
),
@@ -644,14 +646,12 @@ async def add_kvid(self, blob: bytes, store_id: bytes32, writer: aiosqlite.Conne
644646

645647
if row is None:
646648
raise Exception("Internal error")
647-
kv_id = KeyOrValueId(row[0])
648-
if is_file:
649-
blob_hash = bytes32(table_blob)
649+
if use_file:
650650
path = self.get_key_value_path(store_id=store_id, blob_hash=blob_hash)
651651
path.parent.mkdir(parents=True, exist_ok=True)
652652
# TODO: consider file-system based locking of either the file or the store directory
653653
path.write_bytes(zstd.compress(blob))
654-
return kv_id
654+
return KeyOrValueId(row[0])
655655

656656
async def add_key_value(
657657
self, key: bytes, value: bytes, store_id: bytes32, writer: aiosqlite.Connection
@@ -1050,10 +1050,22 @@ async def get_internal_nodes(self, store_id: bytes32, root_hash: Optional[bytes3
10501050
return internal_nodes
10511051

10521052
def get_terminal_node_from_table_blobs(
1053-
self, kid: KeyId, vid: ValueId, table_blobs: dict[KeyOrValueId, bytes], store_id: bytes32
1053+
self,
1054+
kid: KeyId,
1055+
vid: ValueId,
1056+
table_blobs: dict[KeyOrValueId, tuple[bytes32, Optional[bytes]]],
1057+
store_id: bytes32,
10541058
) -> TerminalNode:
1055-
key = self.get_blob_from_table_blob(table_blobs[KeyOrValueId(kid.raw)], store_id)
1056-
value = self.get_blob_from_table_blob(table_blobs[KeyOrValueId(vid.raw)], store_id)
1059+
key = table_blobs[KeyOrValueId(kid.raw)][1]
1060+
if key is None:
1061+
key_hash = table_blobs[KeyOrValueId(kid.raw)][0]
1062+
key = self.get_blob_from_file(key_hash, store_id)
1063+
1064+
value = table_blobs[KeyOrValueId(vid.raw)][1]
1065+
if value is None:
1066+
value_hash = table_blobs[KeyOrValueId(vid.raw)][0]
1067+
value = self.get_blob_from_file(value_hash, store_id)
1068+
10571069
return TerminalNode(hash=leaf_hash(key, value), key=key, value=value)
10581070

10591071
async def get_keys_values(
@@ -1277,8 +1289,10 @@ async def get_keys(
12771289
table_blobs = await self.get_table_blobs(raw_key_ids, store_id)
12781290
keys: list[bytes] = []
12791291
for kid in kv_ids.keys():
1280-
key = self.get_blob_from_table_blob(table_blobs[KeyOrValueId(kid.raw)], store_id)
1281-
keys.append(key)
1292+
blob_hash, blob = table_blobs[KeyOrValueId(kid.raw)]
1293+
if blob is None:
1294+
blob = self.get_blob_from_file(blob_hash, store_id)
1295+
keys.append(blob)
12821296

12831297
return keys
12841298

@@ -1653,22 +1667,25 @@ async def get_nodes_for_file(
16531667

16541668
async def get_table_blobs(
16551669
self, kv_ids_iter: Iterable[KeyOrValueId], store_id: bytes32
1656-
) -> dict[KeyOrValueId, bytes]:
1657-
result: dict[KeyOrValueId, bytes] = {}
1670+
) -> dict[KeyOrValueId, tuple[bytes32, Optional[bytes]]]:
1671+
result: dict[KeyOrValueId, tuple[bytes32, Optional[bytes]]] = {}
16581672
batch_size = min(500, SQLITE_MAX_VARIABLE_NUMBER - 10)
16591673
kv_ids = list(dict.fromkeys(kv_ids_iter))
16601674

16611675
async with self.db_wrapper.reader() as reader:
16621676
for i in range(0, len(kv_ids), batch_size):
16631677
chunk = kv_ids[i : i + batch_size]
16641678
placeholders = ",".join(["?"] * len(chunk))
1665-
query = (
1666-
f"SELECT blob, kv_id FROM ids WHERE store_id = ? AND kv_id IN ({placeholders}) LIMIT {len(chunk)}"
1667-
)
1679+
query = f"""
1680+
SELECT hash, blob, kv_id
1681+
FROM ids
1682+
WHERE store_id = ? AND kv_id IN ({placeholders})
1683+
LIMIT {len(chunk)}
1684+
"""
16681685

16691686
async with reader.execute(query, (store_id, *chunk)) as cursor:
16701687
rows = await cursor.fetchall()
1671-
result.update({row["kv_id"]: row["blob"] for row in rows})
1688+
result.update({row["kv_id"]: (row["hash"], row["blob"]) for row in rows})
16721689

16731690
if len(result) != len(kv_ids):
16741691
raise Exception("Cannot retrieve all the requested kv_ids")
@@ -1713,8 +1730,10 @@ async def write_tree_to_file(
17131730
blobs = []
17141731
for raw_id in (node.value1, node.value2):
17151732
id = KeyOrValueId.from_bytes(raw_id)
1716-
id_table_blob = table_blobs[id]
1717-
blobs.append(self.get_blob_from_table_blob(id_table_blob, store_id))
1733+
blob_hash, blob = table_blobs[id]
1734+
if blob is None:
1735+
blob = self.get_blob_from_file(blob_hash, store_id)
1736+
blobs.append(blob)
17181737
to_write = bytes(SerializedNode(True, blobs[0], blobs[1]))
17191738
else:
17201739
to_write = bytes(node)

0 commit comments

Comments
 (0)