Skip to content

Commit 476ab8c

Browse files
fchiricaaltendky
andauthored
Optimize DL functions for get keys/values. (#19609)
* Optimize DL functions for get keys/values. * Update chia/data_layer/data_store.py Co-authored-by: Kyle Altendorf <[email protected]> * Address review comment. * Address review comments. --------- Co-authored-by: Kyle Altendorf <[email protected]>
1 parent 213dc58 commit 476ab8c

File tree

2 files changed

+75
-30
lines changed

2 files changed

+75
-30
lines changed

chia/_tests/core/data_layer/test_data_rpc.py

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -3003,18 +3003,18 @@ async def test_pagination_rpcs(
30033003
"total_pages": 1,
30043004
"total_bytes": 8,
30053005
"diff": [
3006-
{"type": "DELETE", "key": key6.hex(), "value": value6.hex()},
30073006
{"type": "INSERT", "key": key6.hex(), "value": new_value.hex()},
3007+
{"type": "DELETE", "key": key6.hex(), "value": value6.hex()},
30083008
],
30093009
}
30103010
assert diff_res == diff_reference
30113011

3012-
with pytest.raises(Exception, match="Can't find keys"):
3012+
with pytest.raises(Exception, match="Cannot find merkle blob"):
30133013
await data_rpc_api.get_keys(
30143014
{"id": store_id.hex(), "page": 0, "max_page_size": 100, "root_hash": bytes32([0] * 31 + [1]).hex()}
30153015
)
30163016

3017-
with pytest.raises(Exception, match="Can't find keys and values"):
3017+
with pytest.raises(Exception, match="Cannot find merkle blob"):
30183018
await data_rpc_api.get_keys_values(
30193019
{"id": store_id.hex(), "page": 0, "max_page_size": 100, "root_hash": bytes32([0] * 31 + [1]).hex()}
30203020
)

chia/data_layer/data_store.py

Lines changed: 72 additions & 27 deletions
Original file line numberDiff line numberDiff line change
@@ -7,7 +7,7 @@
77
import shutil
88
import sqlite3
99
from collections import defaultdict
10-
from collections.abc import AsyncIterator, Awaitable, Sequence
10+
from collections.abc import AsyncIterator, Awaitable, Iterable, Sequence
1111
from contextlib import asynccontextmanager
1212
from dataclasses import dataclass, replace
1313
from hashlib import sha256
@@ -659,6 +659,33 @@ async def get_terminal_node_by_hash(
659659
kid, vid = merkle_blob.get_node_by_hash(node_hash)
660660
return await self.get_terminal_node(kid, vid, store_id)
661661

662+
async def get_terminal_nodes_by_hashes(
663+
self,
664+
node_hashes: list[bytes32],
665+
store_id: bytes32,
666+
root_hash: Union[bytes32, Unspecified] = unspecified,
667+
) -> list[TerminalNode]:
668+
resolved_root_hash: Optional[bytes32]
669+
if root_hash is unspecified:
670+
root = await self.get_tree_root(store_id=store_id)
671+
resolved_root_hash = root.node_hash
672+
else:
673+
resolved_root_hash = root_hash
674+
675+
merkle_blob = await self.get_merkle_blob(store_id=store_id, root_hash=resolved_root_hash)
676+
kv_ids: list[tuple[KeyId, ValueId]] = []
677+
for node_hash in node_hashes:
678+
kid, vid = merkle_blob.get_node_by_hash(node_hash)
679+
kv_ids.append((kid, vid))
680+
kv_ids_unpacked = (KeyOrValueId(id.raw) for kv_id in kv_ids for id in kv_id)
681+
table_blobs = await self.get_table_blobs(kv_ids_unpacked, store_id)
682+
683+
terminal_nodes: list[TerminalNode] = []
684+
for kid, vid in kv_ids:
685+
terminal_nodes.append(self.get_terminal_node_from_table_blobs(kid, vid, table_blobs, store_id))
686+
687+
return terminal_nodes
688+
662689
async def get_first_generation(self, node_hash: bytes32, store_id: bytes32) -> Optional[int]:
663690
async with self.db_wrapper.reader() as reader:
664691
cursor = await reader.execute(
@@ -1003,6 +1030,13 @@ async def get_internal_nodes(self, store_id: bytes32, root_hash: Optional[bytes3
10031030

10041031
return internal_nodes
10051032

1033+
def get_terminal_node_from_table_blobs(
1034+
self, kid: KeyId, vid: ValueId, table_blobs: dict[KeyOrValueId, bytes], store_id: bytes32
1035+
) -> TerminalNode:
1036+
key = self.get_blob_from_table_blob(table_blobs[KeyOrValueId(kid.raw)], store_id)
1037+
value = self.get_blob_from_table_blob(table_blobs[KeyOrValueId(vid.raw)], store_id)
1038+
return TerminalNode(hash=leaf_hash(key, value), key=key, value=value)
1039+
10061040
async def get_keys_values(
10071041
self,
10081042
store_id: bytes32,
@@ -1022,11 +1056,12 @@ async def get_keys_values(
10221056
return []
10231057

10241058
kv_ids = merkle_blob.get_keys_values()
1059+
kv_ids_unpacked = (KeyOrValueId(id.raw) for pair in kv_ids.items() for id in pair)
1060+
table_blobs = await self.get_table_blobs(kv_ids_unpacked, store_id)
10251061

10261062
terminal_nodes: list[TerminalNode] = []
10271063
for kid, vid in kv_ids.items():
1028-
terminal_node = await self.get_terminal_node(kid, vid, store_id)
1029-
terminal_nodes.append(terminal_node)
1064+
terminal_nodes.append(self.get_terminal_node_from_table_blobs(kid, vid, table_blobs, store_id))
10301065

10311066
return terminal_nodes
10321067

@@ -1053,9 +1088,11 @@ async def get_keys_values_compressed(
10531088
return KeysValuesCompressed({}, {}, {}, resolved_root_hash)
10541089

10551090
kv_ids = merkle_blob.get_keys_values()
1056-
for kid, vid in kv_ids.items():
1057-
node = await self.get_terminal_node(kid, vid, store_id)
1091+
kv_ids_unpacked = (KeyOrValueId(id.raw) for pair in kv_ids.items() for id in pair)
1092+
table_blobs = await self.get_table_blobs(kv_ids_unpacked, store_id)
10581093

1094+
for kid, vid in kv_ids.items():
1095+
node = self.get_terminal_node_from_table_blobs(kid, vid, table_blobs, store_id)
10591096
keys_values_hashed[key_hash(node.key)] = leaf_hash(node.key, node.value)
10601097
key_hash_to_length[key_hash(node.key)] = len(node.key)
10611098
leaf_hash_to_length[leaf_hash(node.key, node.value)] = len(node.key) + len(node.value)
@@ -1073,11 +1110,12 @@ async def get_keys_paginated(
10731110
pagination_data = get_hashes_for_page(page, keys_values_compressed.key_hash_to_length, max_page_size)
10741111

10751112
keys: list[bytes] = []
1113+
leaf_hashes: list[bytes32] = []
10761114
for hash in pagination_data.hashes:
10771115
leaf_hash = keys_values_compressed.keys_values_hashed[hash]
1078-
node = await self.get_terminal_node_by_hash(leaf_hash, store_id, root_hash)
1079-
assert isinstance(node, TerminalNode)
1080-
keys.append(node.key)
1116+
leaf_hashes.append(leaf_hash)
1117+
nodes = await self.get_terminal_nodes_by_hashes(leaf_hashes, store_id, root_hash)
1118+
keys = [node.key for node in nodes]
10811119

10821120
return KeysPaginationData(
10831121
pagination_data.total_pages,
@@ -1096,12 +1134,7 @@ async def get_keys_values_paginated(
10961134
keys_values_compressed = await self.get_keys_values_compressed(store_id, root_hash)
10971135
pagination_data = get_hashes_for_page(page, keys_values_compressed.leaf_hash_to_length, max_page_size)
10981136

1099-
keys_values: list[TerminalNode] = []
1100-
for hash in pagination_data.hashes:
1101-
node = await self.get_terminal_node_by_hash(hash, store_id, root_hash)
1102-
assert isinstance(node, TerminalNode)
1103-
keys_values.append(node)
1104-
1137+
keys_values = await self.get_terminal_nodes_by_hashes(pagination_data.hashes, store_id, root_hash)
11051138
return KeysValuesPaginationData(
11061139
pagination_data.total_pages,
11071140
pagination_data.total_bytes,
@@ -1138,15 +1171,25 @@ async def get_kv_diff_paginated(
11381171

11391172
pagination_data = get_hashes_for_page(page, lengths, max_page_size)
11401173
kv_diff: list[DiffData] = []
1141-
1174+
insertion_hashes: list[bytes32] = []
1175+
deletion_hashes: list[bytes32] = []
11421176
for hash in pagination_data.hashes:
1143-
root_hash = hash2 if hash in insertions else hash1
1144-
node = await self.get_terminal_node_by_hash(hash, store_id, root_hash)
1145-
assert isinstance(node, TerminalNode)
11461177
if hash in insertions:
1147-
kv_diff.append(DiffData(OperationType.INSERT, node.key, node.value))
1178+
insertion_hashes.append(hash)
11481179
else:
1149-
kv_diff.append(DiffData(OperationType.DELETE, node.key, node.value))
1180+
deletion_hashes.append(hash)
1181+
if hash2 != bytes32.zeros:
1182+
insertion_nodes = await self.get_terminal_nodes_by_hashes(insertion_hashes, store_id, hash2)
1183+
else:
1184+
insertion_nodes = []
1185+
if hash1 != bytes32.zeros:
1186+
deletion_nodes = await self.get_terminal_nodes_by_hashes(deletion_hashes, store_id, hash1)
1187+
else:
1188+
deletion_nodes = []
1189+
for node in insertion_nodes:
1190+
kv_diff.append(DiffData(OperationType.INSERT, node.key, node.value))
1191+
for node in deletion_nodes:
1192+
kv_diff.append(DiffData(OperationType.DELETE, node.key, node.value))
11501193

11511194
return KVDiffPaginationData(
11521195
pagination_data.total_pages,
@@ -1211,11 +1254,11 @@ async def get_keys(
12111254
return []
12121255

12131256
kv_ids = merkle_blob.get_keys_values()
1257+
raw_key_ids = (KeyOrValueId(id.raw) for id in kv_ids.keys())
1258+
table_blobs = await self.get_table_blobs(raw_key_ids, store_id)
12141259
keys: list[bytes] = []
12151260
for kid in kv_ids.keys():
1216-
key = await self.get_blob_from_kvid(kid.raw, store_id)
1217-
if key is None:
1218-
raise Exception(f"Unknown key corresponding to KeyId: {kid}")
1261+
key = self.get_blob_from_table_blob(table_blobs[KeyOrValueId(kid.raw)], store_id)
12191262
keys.append(key)
12201263

12211264
return keys
@@ -1589,10 +1632,12 @@ async def get_nodes_for_file(
15891632
else:
15901633
raise Exception(f"Node is neither InternalNode nor TerminalNode: {raw_node}")
15911634

1592-
async def get_table_blobs(self, kv_ids: list[KeyOrValueId], store_id: bytes32) -> dict[KeyOrValueId, bytes]:
1635+
async def get_table_blobs(
1636+
self, kv_ids_iter: Iterable[KeyOrValueId], store_id: bytes32
1637+
) -> dict[KeyOrValueId, bytes]:
15931638
result: dict[KeyOrValueId, bytes] = {}
15941639
batch_size = min(500, SQLITE_MAX_VARIABLE_NUMBER - 10)
1595-
kv_ids = list(set(kv_ids))
1640+
kv_ids = list(dict.fromkeys(kv_ids_iter))
15961641

15971642
async with self.db_wrapper.reader() as reader:
15981643
for i in range(0, len(kv_ids), batch_size):
@@ -1636,12 +1681,12 @@ async def write_tree_to_file(
16361681
await self.get_nodes_for_file(
16371682
root, node_hash, store_id, deltas_only, merkle_blob, hash_to_index, existing_hashes, tree_nodes
16381683
)
1639-
kv_ids = [
1684+
kv_ids = (
16401685
KeyOrValueId.from_bytes(raw_id)
16411686
for node in tree_nodes
16421687
if node.is_terminal
16431688
for raw_id in (node.value1, node.value2)
1644-
]
1689+
)
16451690
table_blobs = await self.get_table_blobs(kv_ids, store_id)
16461691

16471692
for node in tree_nodes:

0 commit comments

Comments
 (0)