7
7
import shutil
8
8
import sqlite3
9
9
from collections import defaultdict
10
- from collections .abc import AsyncIterator , Awaitable , Sequence
10
+ from collections .abc import AsyncIterator , Awaitable , Iterable , Sequence
11
11
from contextlib import asynccontextmanager
12
12
from dataclasses import dataclass , replace
13
13
from hashlib import sha256
@@ -659,6 +659,33 @@ async def get_terminal_node_by_hash(
659
659
kid , vid = merkle_blob .get_node_by_hash (node_hash )
660
660
return await self .get_terminal_node (kid , vid , store_id )
661
661
662
+ async def get_terminal_nodes_by_hashes (
663
+ self ,
664
+ node_hashes : list [bytes32 ],
665
+ store_id : bytes32 ,
666
+ root_hash : Union [bytes32 , Unspecified ] = unspecified ,
667
+ ) -> list [TerminalNode ]:
668
+ resolved_root_hash : Optional [bytes32 ]
669
+ if root_hash is unspecified :
670
+ root = await self .get_tree_root (store_id = store_id )
671
+ resolved_root_hash = root .node_hash
672
+ else :
673
+ resolved_root_hash = root_hash
674
+
675
+ merkle_blob = await self .get_merkle_blob (store_id = store_id , root_hash = resolved_root_hash )
676
+ kv_ids : list [tuple [KeyId , ValueId ]] = []
677
+ for node_hash in node_hashes :
678
+ kid , vid = merkle_blob .get_node_by_hash (node_hash )
679
+ kv_ids .append ((kid , vid ))
680
+ kv_ids_unpacked = (KeyOrValueId (id .raw ) for kv_id in kv_ids for id in kv_id )
681
+ table_blobs = await self .get_table_blobs (kv_ids_unpacked , store_id )
682
+
683
+ terminal_nodes : list [TerminalNode ] = []
684
+ for kid , vid in kv_ids :
685
+ terminal_nodes .append (self .get_terminal_node_from_table_blobs (kid , vid , table_blobs , store_id ))
686
+
687
+ return terminal_nodes
688
+
662
689
async def get_first_generation (self , node_hash : bytes32 , store_id : bytes32 ) -> Optional [int ]:
663
690
async with self .db_wrapper .reader () as reader :
664
691
cursor = await reader .execute (
@@ -1003,6 +1030,13 @@ async def get_internal_nodes(self, store_id: bytes32, root_hash: Optional[bytes3
1003
1030
1004
1031
return internal_nodes
1005
1032
1033
+ def get_terminal_node_from_table_blobs (
1034
+ self , kid : KeyId , vid : ValueId , table_blobs : dict [KeyOrValueId , bytes ], store_id : bytes32
1035
+ ) -> TerminalNode :
1036
+ key = self .get_blob_from_table_blob (table_blobs [KeyOrValueId (kid .raw )], store_id )
1037
+ value = self .get_blob_from_table_blob (table_blobs [KeyOrValueId (vid .raw )], store_id )
1038
+ return TerminalNode (hash = leaf_hash (key , value ), key = key , value = value )
1039
+
1006
1040
async def get_keys_values (
1007
1041
self ,
1008
1042
store_id : bytes32 ,
@@ -1022,11 +1056,12 @@ async def get_keys_values(
1022
1056
return []
1023
1057
1024
1058
kv_ids = merkle_blob .get_keys_values ()
1059
+ kv_ids_unpacked = (KeyOrValueId (id .raw ) for pair in kv_ids .items () for id in pair )
1060
+ table_blobs = await self .get_table_blobs (kv_ids_unpacked , store_id )
1025
1061
1026
1062
terminal_nodes : list [TerminalNode ] = []
1027
1063
for kid , vid in kv_ids .items ():
1028
- terminal_node = await self .get_terminal_node (kid , vid , store_id )
1029
- terminal_nodes .append (terminal_node )
1064
+ terminal_nodes .append (self .get_terminal_node_from_table_blobs (kid , vid , table_blobs , store_id ))
1030
1065
1031
1066
return terminal_nodes
1032
1067
@@ -1053,9 +1088,11 @@ async def get_keys_values_compressed(
1053
1088
return KeysValuesCompressed ({}, {}, {}, resolved_root_hash )
1054
1089
1055
1090
kv_ids = merkle_blob .get_keys_values ()
1056
- for kid , vid in kv_ids .items ():
1057
- node = await self .get_terminal_node ( kid , vid , store_id )
1091
+ kv_ids_unpacked = ( KeyOrValueId ( id . raw ) for pair in kv_ids .items () for id in pair )
1092
+ table_blobs = await self .get_table_blobs ( kv_ids_unpacked , store_id )
1058
1093
1094
+ for kid , vid in kv_ids .items ():
1095
+ node = self .get_terminal_node_from_table_blobs (kid , vid , table_blobs , store_id )
1059
1096
keys_values_hashed [key_hash (node .key )] = leaf_hash (node .key , node .value )
1060
1097
key_hash_to_length [key_hash (node .key )] = len (node .key )
1061
1098
leaf_hash_to_length [leaf_hash (node .key , node .value )] = len (node .key ) + len (node .value )
@@ -1073,11 +1110,12 @@ async def get_keys_paginated(
1073
1110
pagination_data = get_hashes_for_page (page , keys_values_compressed .key_hash_to_length , max_page_size )
1074
1111
1075
1112
keys : list [bytes ] = []
1113
+ leaf_hashes : list [bytes32 ] = []
1076
1114
for hash in pagination_data .hashes :
1077
1115
leaf_hash = keys_values_compressed .keys_values_hashed [hash ]
1078
- node = await self . get_terminal_node_by_hash (leaf_hash , store_id , root_hash )
1079
- assert isinstance ( node , TerminalNode )
1080
- keys . append ( node .key )
1116
+ leaf_hashes . append (leaf_hash )
1117
+ nodes = await self . get_terminal_nodes_by_hashes ( leaf_hashes , store_id , root_hash )
1118
+ keys = [ node .key for node in nodes ]
1081
1119
1082
1120
return KeysPaginationData (
1083
1121
pagination_data .total_pages ,
@@ -1096,12 +1134,7 @@ async def get_keys_values_paginated(
1096
1134
keys_values_compressed = await self .get_keys_values_compressed (store_id , root_hash )
1097
1135
pagination_data = get_hashes_for_page (page , keys_values_compressed .leaf_hash_to_length , max_page_size )
1098
1136
1099
- keys_values : list [TerminalNode ] = []
1100
- for hash in pagination_data .hashes :
1101
- node = await self .get_terminal_node_by_hash (hash , store_id , root_hash )
1102
- assert isinstance (node , TerminalNode )
1103
- keys_values .append (node )
1104
-
1137
+ keys_values = await self .get_terminal_nodes_by_hashes (pagination_data .hashes , store_id , root_hash )
1105
1138
return KeysValuesPaginationData (
1106
1139
pagination_data .total_pages ,
1107
1140
pagination_data .total_bytes ,
@@ -1138,15 +1171,25 @@ async def get_kv_diff_paginated(
1138
1171
1139
1172
pagination_data = get_hashes_for_page (page , lengths , max_page_size )
1140
1173
kv_diff : list [DiffData ] = []
1141
-
1174
+ insertion_hashes : list [bytes32 ] = []
1175
+ deletion_hashes : list [bytes32 ] = []
1142
1176
for hash in pagination_data .hashes :
1143
- root_hash = hash2 if hash in insertions else hash1
1144
- node = await self .get_terminal_node_by_hash (hash , store_id , root_hash )
1145
- assert isinstance (node , TerminalNode )
1146
1177
if hash in insertions :
1147
- kv_diff .append (DiffData ( OperationType . INSERT , node . key , node . value ) )
1178
+ insertion_hashes .append (hash )
1148
1179
else :
1149
- kv_diff .append (DiffData (OperationType .DELETE , node .key , node .value ))
1180
+ deletion_hashes .append (hash )
1181
+ if hash2 != bytes32 .zeros :
1182
+ insertion_nodes = await self .get_terminal_nodes_by_hashes (insertion_hashes , store_id , hash2 )
1183
+ else :
1184
+ insertion_nodes = []
1185
+ if hash1 != bytes32 .zeros :
1186
+ deletion_nodes = await self .get_terminal_nodes_by_hashes (deletion_hashes , store_id , hash1 )
1187
+ else :
1188
+ deletion_nodes = []
1189
+ for node in insertion_nodes :
1190
+ kv_diff .append (DiffData (OperationType .INSERT , node .key , node .value ))
1191
+ for node in deletion_nodes :
1192
+ kv_diff .append (DiffData (OperationType .DELETE , node .key , node .value ))
1150
1193
1151
1194
return KVDiffPaginationData (
1152
1195
pagination_data .total_pages ,
@@ -1211,11 +1254,11 @@ async def get_keys(
1211
1254
return []
1212
1255
1213
1256
kv_ids = merkle_blob .get_keys_values ()
1257
+ raw_key_ids = (KeyOrValueId (id .raw ) for id in kv_ids .keys ())
1258
+ table_blobs = await self .get_table_blobs (raw_key_ids , store_id )
1214
1259
keys : list [bytes ] = []
1215
1260
for kid in kv_ids .keys ():
1216
- key = await self .get_blob_from_kvid (kid .raw , store_id )
1217
- if key is None :
1218
- raise Exception (f"Unknown key corresponding to KeyId: { kid } " )
1261
+ key = self .get_blob_from_table_blob (table_blobs [KeyOrValueId (kid .raw )], store_id )
1219
1262
keys .append (key )
1220
1263
1221
1264
return keys
@@ -1589,10 +1632,12 @@ async def get_nodes_for_file(
1589
1632
else :
1590
1633
raise Exception (f"Node is neither InternalNode nor TerminalNode: { raw_node } " )
1591
1634
1592
- async def get_table_blobs (self , kv_ids : list [KeyOrValueId ], store_id : bytes32 ) -> dict [KeyOrValueId , bytes ]:
1635
+ async def get_table_blobs (
1636
+ self , kv_ids_iter : Iterable [KeyOrValueId ], store_id : bytes32
1637
+ ) -> dict [KeyOrValueId , bytes ]:
1593
1638
result : dict [KeyOrValueId , bytes ] = {}
1594
1639
batch_size = min (500 , SQLITE_MAX_VARIABLE_NUMBER - 10 )
1595
- kv_ids = list (set ( kv_ids ))
1640
+ kv_ids = list (dict . fromkeys ( kv_ids_iter ))
1596
1641
1597
1642
async with self .db_wrapper .reader () as reader :
1598
1643
for i in range (0 , len (kv_ids ), batch_size ):
@@ -1636,12 +1681,12 @@ async def write_tree_to_file(
1636
1681
await self .get_nodes_for_file (
1637
1682
root , node_hash , store_id , deltas_only , merkle_blob , hash_to_index , existing_hashes , tree_nodes
1638
1683
)
1639
- kv_ids = [
1684
+ kv_ids = (
1640
1685
KeyOrValueId .from_bytes (raw_id )
1641
1686
for node in tree_nodes
1642
1687
if node .is_terminal
1643
1688
for raw_id in (node .value1 , node .value2 )
1644
- ]
1689
+ )
1645
1690
table_blobs = await self .get_table_blobs (kv_ids , store_id )
1646
1691
1647
1692
for node in tree_nodes :
0 commit comments