Skip to content

Commit c7a2bdb

Browse files
committed
mist: Don't collect all leaves to add above root when not needed
1 parent b0fa75d commit c7a2bdb

File tree

2 files changed

+110
-36
lines changed

2 files changed

+110
-36
lines changed

mist/lib/mst.ml

Lines changed: 70 additions & 35 deletions
Original file line numberDiff line numberDiff line change
@@ -235,6 +235,10 @@ module type Intf = sig
235235

236236
val leaves_of_root : t -> (string * Cid.t) list Lwt.t
237237

238+
val get_min_key : t -> Cid.t -> string option Lwt.t
239+
240+
val get_max_key : t -> Cid.t -> string option Lwt.t
241+
238242
val equal : t -> t -> bool Lwt.t
239243
end
240244

@@ -424,11 +428,7 @@ struct
424428
(r, full_key) :: acc
425429
| None ->
426430
acc )
427-
( match raw.l with
428-
| Some l ->
429-
[(l, prefix)]
430-
| None ->
431-
[] )
431+
(match raw.l with Some l -> [(l, prefix)] | None -> [])
432432
raw.e
433433
in
434434
(Cid.Set.add cid visited, List.rev_append next_pairs queue) )
@@ -496,8 +496,7 @@ struct
496496
(List.rev acc, seen)
497497
| `Node cid :: rest ->
498498
if
499-
Cid.Set.mem cid missing
500-
|| Block_map.has cid cache
499+
Cid.Set.mem cid missing || Block_map.has cid cache
501500
|| Cid.Set.mem cid seen
502501
then collect acc seen remaining rest
503502
else
@@ -512,12 +511,11 @@ struct
512511
let cache' =
513512
List.fold_left
514513
(fun acc (cid, bytes) -> Block_map.set cid bytes acc)
515-
cache (Block_map.entries bm.blocks)
514+
cache
515+
(Block_map.entries bm.blocks)
516516
in
517517
let missing' =
518-
List.fold_left
519-
(fun acc cid -> Cid.Set.add cid acc)
520-
missing bm.missing
518+
List.fold_left (fun acc cid -> Cid.Set.add cid acc) missing bm.missing
521519
in
522520
Lwt.return (cache', missing')
523521
in
@@ -527,10 +525,10 @@ struct
527525
Lwt.return_none
528526
| `Leaf cid :: rest ->
529527
Lwt.return_some ((Leaf cid : ordered_item), (rest, cache, missing))
530-
| `Node cid :: rest ->
528+
| `Node cid :: rest -> (
531529
if Cid.Set.mem cid missing then step (rest, cache, missing)
532530
else
533-
( match Block_map.get cid cache with
531+
match Block_map.get cid cache with
534532
| None ->
535533
let%lwt cache', missing' = prefetch queue cache missing in
536534
if cache' == cache && Cid.Set.mem cid missing' then
@@ -554,8 +552,8 @@ struct
554552
let new_queue = left_queue @ entries_queue @ rest in
555553
let cache' = Block_map.remove cid cache in
556554
Lwt.return_some
557-
((Node (cid, bytes) : ordered_item), (new_queue, cache', missing))
558-
)
555+
( (Node (cid, bytes) : ordered_item)
556+
, (new_queue, cache', missing) ) )
559557
in
560558
Lwt_seq.unfold_lwt step ([`Node t.root], Block_map.empty, Cid.Set.empty)
561559

@@ -565,7 +563,8 @@ struct
565563
let entries =
566564
if entries_are_sorted node.entries then node.entries
567565
else
568-
List.sort (fun (a : entry) b -> String.compare a.key b.key)
566+
List.sort
567+
(fun (a : entry) b -> String.compare a.key b.key)
569568
node.entries
570569
in
571570
let%lwt left =
@@ -612,8 +611,7 @@ struct
612611
| Error e ->
613612
raise e
614613
in
615-
try%lwt Lwt.map Result.ok (aux node)
616-
with e -> Lwt.return_error e
614+
try%lwt Lwt.map Result.ok (aux node) with e -> Lwt.return_error e
617615

618616
(* raw-node helpers for covering proofs: operate on stored bytes, not re-serialization *)
619617
type interleaved_entry =
@@ -769,8 +767,7 @@ struct
769767
let missing = ref Cid.Set.empty in
770768
let acc = ref Block_map.empty in
771769
let add_block cid bytes =
772-
if not (Block_map.has cid !acc) then
773-
acc := Block_map.set cid bytes !acc
770+
if not (Block_map.has cid !acc) then acc := Block_map.set cid bytes !acc
774771
in
775772
let get_bytes_cached cid =
776773
match Block_map.get cid !cache with
@@ -792,24 +789,23 @@ struct
792789
| None ->
793790
Lwt.return_unit
794791
| Some leaf_cid -> (
795-
match%lwt get_bytes_cached leaf_cid with
796-
| Some bytes ->
797-
add_block leaf_cid bytes ;
798-
Lwt.return_unit
799-
| None ->
800-
Lwt.return_unit )
792+
match%lwt get_bytes_cached leaf_cid with
793+
| Some bytes ->
794+
add_block leaf_cid bytes ; Lwt.return_unit
795+
| None ->
796+
Lwt.return_unit )
801797
in
802798
let rec proof_for_key_cached cid key =
803799
match%lwt get_bytes_cached cid with
804800
| None ->
805801
Lwt.return_unit
806-
| Some bytes ->
802+
| Some bytes -> (
807803
add_block cid bytes ;
808804
let raw = decode_block_raw bytes in
809805
let keys = node_entry_keys raw in
810806
let seq = interleave_raw raw keys in
811807
let index = find_gte_leaf_index key seq in
812-
( match List.nth_opt seq index with
808+
match List.nth_opt seq index with
813809
| Some (Leaf (k, _, _)) when k = key ->
814810
Lwt.return_unit
815811
| Some (Leaf (_k, v_right, _)) -> (
@@ -896,7 +892,8 @@ struct
896892
, (cid, bytes) :: nodes
897893
, leaves'
898894
, List.rev_append next_cids queue ) )
899-
(visited, nodes, leaves, rest) batch
895+
(visited, nodes, leaves, rest)
896+
batch
900897
in
901898
loop next_queue visited' nodes' leaves'
902899
in
@@ -1212,6 +1209,43 @@ struct
12121209
| None ->
12131210
Lwt.return []
12141211

1212+
(* returns the minimum key in a subtree by following the leftmost path *)
1213+
let rec get_min_key (t : t) (cid : Cid.t) : string option Lwt.t =
1214+
match%lwt retrieve_node_raw t cid with
1215+
| None ->
1216+
Lwt.return_none
1217+
| Some raw -> (
1218+
match raw.l with
1219+
| Some left_cid ->
1220+
get_min_key t left_cid
1221+
| None -> (
1222+
match raw.e with
1223+
| [] ->
1224+
Lwt.return_none
1225+
| first :: _ ->
1226+
Lwt.return_some (Bytes.to_string first.k) ) )
1227+
1228+
(* returns the maximum key in a subtree by following the rightmost path *)
1229+
let rec get_max_key (t : t) (cid : Cid.t) : string option Lwt.t =
1230+
match%lwt retrieve_node_raw t cid with
1231+
| None ->
1232+
Lwt.return_none
1233+
| Some raw -> (
1234+
let keys = decompress_keys raw in
1235+
match List.rev (List.combine keys raw.e) with
1236+
| [] -> (
1237+
match raw.l with
1238+
| Some left_cid ->
1239+
get_max_key t left_cid
1240+
| None ->
1241+
Lwt.return_none )
1242+
| (last_key, last_entry) :: _ -> (
1243+
match last_entry.t with
1244+
| Some right_cid ->
1245+
get_max_key t right_cid
1246+
| None ->
1247+
Lwt.return_some last_key ) )
1248+
12151249
(* rebuild a subtree from leaves
12161250
returns (root_cid option, actual_layer) *)
12171251
let rebuild_subtree (blockstore : bs) (leaves : (string * Cid.t) list) :
@@ -1265,11 +1299,11 @@ struct
12651299
let%lwt wrapped_old =
12661300
wrap_to_layer t.blockstore old_root_cid old_root_layer (key_layer - 1)
12671301
in
1268-
(* get all keys from old tree to determine position *)
1269-
let%lwt old_leaves = collect_subtree_leaves t old_root_cid in
1270-
let old_keys = List.map fst old_leaves in
1271-
let all_less = List.for_all (fun k -> k < key) old_keys in
1272-
let all_greater = List.for_all (fun k -> k > key) old_keys in
1302+
(* check boundary keys to determine position *)
1303+
let%lwt min_key = get_min_key t old_root_cid in
1304+
let%lwt max_key = get_max_key t old_root_cid in
1305+
let all_less = match max_key with Some mx -> mx < key | None -> true in
1306+
let all_greater = match min_key with Some mn -> mn > key | None -> true in
12731307
if all_less then
12741308
(* all old keys < new key: old tree is left, new entry has no right *)
12751309
let entries = compress_entries [(key, value, None)] in
@@ -1279,7 +1313,8 @@ struct
12791313
let entries = compress_entries [(key, value, Some wrapped_old)] in
12801314
persist_node_raw t.blockstore {l= None; e= entries}
12811315
else
1282-
(* key is in the middle: need to split *)
1316+
(* key is in the middle: need to split; collect all leaves *)
1317+
let%lwt old_leaves = collect_subtree_leaves t old_root_cid in
12831318
let left_leaves = List.filter (fun (k, _) -> k < key) old_leaves in
12841319
let right_leaves = List.filter (fun (k, _) -> k > key) old_leaves in
12851320
let%lwt left_cid, left_layer = rebuild_subtree t.blockstore left_leaves in

mist/test/test_mst.ml

Lines changed: 40 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -860,6 +860,42 @@ let test_incremental_edge_cases () =
860860
Alcotest.fail "key should exist after update" ) ;
861861
Lwt.return_ok ()
862862

863+
let test_get_min_max_keys () =
864+
let store = Storage.Memory_blockstore.create () in
865+
let cid1 =
866+
cid_of_string_exn
867+
"bafyreie5cvv4h45feadgeuwhbcutmh6t2ceseocckahdoe6uat64zmz454"
868+
in
869+
let* mst = Mem_mst.create_empty store in
870+
(* empty tree *)
871+
let%lwt min_empty = Mem_mst.get_min_key mst mst.root in
872+
let%lwt max_empty = Mem_mst.get_max_key mst mst.root in
873+
Alcotest.(check (option string)) "empty min" None min_empty ;
874+
Alcotest.(check (option string)) "empty max" None max_empty ;
875+
(* single entry *)
876+
let%lwt mst = Mem_mst.add mst "com.example/mmm" cid1 in
877+
let%lwt min_single = Mem_mst.get_min_key mst mst.root in
878+
let%lwt max_single = Mem_mst.get_max_key mst mst.root in
879+
Alcotest.(check (option string)) "single min" (Some "com.example/mmm") min_single ;
880+
Alcotest.(check (option string)) "single max" (Some "com.example/mmm") max_single ;
881+
(* multiple entries at different layers *)
882+
let%lwt mst = Mem_mst.add mst "com.example/aaa" cid1 in
883+
let%lwt mst = Mem_mst.add mst "com.example/zzz" cid1 in
884+
let%lwt mst = Mem_mst.add mst "com.example/bbb" cid1 in
885+
let%lwt mst = Mem_mst.add mst "com.example/yyy" cid1 in
886+
let%lwt min_key = Mem_mst.get_min_key mst mst.root in
887+
let%lwt max_key = Mem_mst.get_max_key mst mst.root in
888+
Alcotest.(check (option string)) "multi min" (Some "com.example/aaa") min_key ;
889+
Alcotest.(check (option string)) "multi max" (Some "com.example/zzz") max_key ;
890+
(* add keys with high layer values to exercise deeper tree structure *)
891+
let%lwt mst = Mem_mst.add mst "com.example.record/3jqfcqzm3fs2j" cid1 in
892+
let%lwt mst = Mem_mst.add mst "com.example.record/3jqfcqzm3fn2j" cid1 in
893+
let%lwt min_deep = Mem_mst.get_min_key mst mst.root in
894+
let%lwt max_deep = Mem_mst.get_max_key mst mst.root in
895+
Alcotest.(check (option string)) "deep min" (Some "com.example.record/3jqfcqzm3fn2j") min_deep ;
896+
Alcotest.(check (option string)) "deep max" (Some "com.example/zzz") max_deep ;
897+
Lwt.return_ok ()
898+
863899
let () =
864900
let open Alcotest in
865901
let run_test test =
@@ -909,4 +945,7 @@ let () =
909945
; test_case "mixed incremental ops" `Quick (fun () ->
910946
run_test test_incremental_mixed_ops_canonicity )
911947
; test_case "incremental edge cases" `Quick (fun () ->
912-
run_test test_incremental_edge_cases ) ] ) ]
948+
run_test test_incremental_edge_cases ) ] )
949+
; ( "boundary functions"
950+
, [ test_case "get_min_key and get_max_key" `Quick (fun () ->
951+
run_test test_get_min_max_keys ) ] ) ]

0 commit comments

Comments
 (0)