Skip to content

Commit ddb410d

Browse files
committed
mssmt: add tree copy functionality for full and compacted trees
This commit introduces a new `Copy` method to both the `FullTree` and `CompactedTree` implementations of the MS-SMT. This method allows copying all key-value pairs from a source tree to a target tree, assuming the target tree is initially empty. The `Copy` method is implemented differently for each tree type: - For `FullTree`, the method recursively traverses the tree, collecting all non-empty leaf nodes along with their keys. It then inserts these leaves into the target tree. - For `CompactedTree`, the method similarly traverses the tree, collecting all non-empty compacted leaf nodes along with their keys. It then inserts these leaves into the target tree. A new test case, `TestTreeCopy`, is added to verify the correctness of the `Copy` method for both tree types, including copying between different tree types (FullTree to CompactedTree and vice versa). The test case generates a set of random leaves, inserts them into a source tree, copies the source tree to a target tree, and then verifies that the target tree contains the same leaves as the source tree.
1 parent 581c881 commit ddb410d

File tree

4 files changed

+342
-0
lines changed

4 files changed

+342
-0
lines changed

mssmt/compacted_tree.go

Lines changed: 130 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -392,3 +392,133 @@ func (t *CompactedTree) MerkleProof(ctx context.Context, key [hashSize]byte) (
392392

393393
return NewProof(proof), nil
394394
}
395+
396+
// collectLeavesRecursive is a recursive helper function that's used to traverse
397+
// down an MS-SMT tree and collect all leaf nodes. It returns a map of leaf
398+
// nodes indexed by their hash.
399+
func collectLeavesRecursive(ctx context.Context, tx TreeStoreViewTx, node Node,
400+
depth int) (map[[hashSize]byte]*LeafNode, error) {
401+
402+
// Base case: If it's a compacted leaf node.
403+
if compactedLeaf, ok := node.(*CompactedLeafNode); ok {
404+
if compactedLeaf.LeafNode.IsEmpty() {
405+
return make(map[[hashSize]byte]*LeafNode), nil
406+
}
407+
return map[[hashSize]byte]*LeafNode{
408+
compactedLeaf.Key(): compactedLeaf.LeafNode,
409+
}, nil
410+
}
411+
412+
// Recursive step: If it's a branch node.
413+
if branchNode, ok := node.(*BranchNode); ok {
414+
// Optimization: if the branch is empty, return early.
415+
if depth < MaxTreeLevels &&
416+
IsEqualNode(branchNode, EmptyTree[depth]) {
417+
418+
return make(map[[hashSize]byte]*LeafNode), nil
419+
}
420+
421+
// Handle case where depth might exceed EmptyTree bounds if
422+
// logic error exists
423+
if depth >= MaxTreeLevels {
424+
// This shouldn't happen if called correctly, implies a
425+
// leaf.
426+
return nil, fmt.Errorf("invalid depth %d for branch "+
427+
"node", depth)
428+
}
429+
430+
left, right, err := tx.GetChildren(depth, branchNode.NodeHash())
431+
if err != nil {
432+
// If children not found, it might be an empty branch
433+
// implicitly Check if the error indicates "not found"
434+
// or similar Depending on store impl, this might be how
435+
// empty is signaled For now, treat error as fatal.
436+
return nil, fmt.Errorf("error getting children for "+
437+
"branch %s at depth %d: %w",
438+
branchNode.NodeHash(), depth, err)
439+
}
440+
441+
leftLeaves, err := collectLeavesRecursive(
442+
ctx, tx, left, depth+1,
443+
)
444+
if err != nil {
445+
return nil, err
446+
}
447+
448+
rightLeaves, err := collectLeavesRecursive(
449+
ctx, tx, right, depth+1,
450+
)
451+
if err != nil {
452+
return nil, err
453+
}
454+
455+
// Merge the results.
456+
for k, v := range rightLeaves {
457+
// Check for duplicate keys, although this shouldn't
458+
// happen in a valid SMT.
459+
if _, exists := leftLeaves[k]; exists {
460+
return nil, fmt.Errorf("duplicate key %x "+
461+
"found during leaf collection", k)
462+
}
463+
leftLeaves[k] = v
464+
}
465+
466+
return leftLeaves, nil
467+
}
468+
469+
// Handle unexpected node types or implicit empty nodes. If node is nil
470+
// or explicitly an EmptyLeafNode representation
471+
if node == nil || IsEqualNode(node, EmptyLeafNode) {
472+
return make(map[[hashSize]byte]*LeafNode), nil
473+
}
474+
475+
// Check against EmptyTree branches if possible (requires depth)
476+
if depth < MaxTreeLevels && IsEqualNode(node, EmptyTree[depth]) {
477+
return make(map[[hashSize]byte]*LeafNode), nil
478+
}
479+
480+
return nil, fmt.Errorf("unexpected node type %T encountered "+
481+
"during leaf collection at depth %d", node, depth)
482+
}
483+
484+
// Copy copies all the key-value pairs from the source tree into the target
485+
// tree.
486+
func (t *CompactedTree) Copy(ctx context.Context, targetTree Tree) error {
487+
var leaves map[[hashSize]byte]*LeafNode
488+
err := t.store.View(ctx, func(tx TreeStoreViewTx) error {
489+
root, err := tx.RootNode()
490+
if err != nil {
491+
return fmt.Errorf("error getting root node: %w", err)
492+
}
493+
494+
// Optimization: If the source tree is empty, there's nothing to
495+
// copy.
496+
if IsEqualNode(root, EmptyTree[0]) {
497+
leaves = make(map[[hashSize]byte]*LeafNode)
498+
return nil
499+
}
500+
501+
// Start recursive collection from the root at depth 0.
502+
leaves, err = collectLeavesRecursive(ctx, tx, root, 0)
503+
if err != nil {
504+
return fmt.Errorf("error collecting leaves: %w", err)
505+
}
506+
507+
return nil
508+
})
509+
if err != nil {
510+
return err
511+
}
512+
513+
// Insert all found leaves into the target tree.
514+
for key, leaf := range leaves {
515+
// Use the target tree's Insert method.
516+
_, err := targetTree.Insert(ctx, key, leaf)
517+
if err != nil {
518+
return fmt.Errorf("error inserting leaf with key %x "+
519+
"into target tree: %w", key, err)
520+
}
521+
}
522+
523+
return nil
524+
}

mssmt/interface.go

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -30,4 +30,8 @@ type Tree interface {
3030
// proof. This is noted by the returned `Proof` containing an empty
3131
// leaf.
3232
MerkleProof(ctx context.Context, key [hashSize]byte) (*Proof, error)
33+
34+
// Copy copies all the key-value pairs from the source tree into the
35+
// target tree.
36+
Copy(ctx context.Context, targetTree Tree) error
3337
}

mssmt/tree.go

Lines changed: 111 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -97,6 +97,14 @@ func bitIndex(idx uint8, key *[hashSize]byte) byte {
9797
return (byteVal >> (idx % 8)) & 1
9898
}
9999

100+
// setBit returns a copy of the key with the bit at the given depth set to 1.
101+
func setBit(key [hashSize]byte, depth int) [hashSize]byte {
102+
byteIndex := depth / 8
103+
bitIndex := depth % 8
104+
key[byteIndex] |= (1 << bitIndex)
105+
return key
106+
}
107+
100108
// iterFunc is a type alias for closures to be invoked at every iteration of
101109
// walking through a tree.
102110
type iterFunc = func(height int, current, sibling, parent Node) error
@@ -333,6 +341,109 @@ func (t *FullTree) MerkleProof(ctx context.Context, key [hashSize]byte) (
333341
return NewProof(proof), nil
334342
}
335343

344+
// findLeaves recursively traverses the tree represented by the given node and
345+
// collects all non-empty leaf nodes along with their reconstructed keys.
346+
func findLeaves(ctx context.Context, tx TreeStoreViewTx, node Node,
347+
keyPrefix [hashSize]byte, depth int) (map[[hashSize]byte]*LeafNode, error) {
348+
349+
// Base case: If it's a leaf node.
350+
if leafNode, ok := node.(*LeafNode); ok {
351+
if leafNode.IsEmpty() {
352+
return make(map[[hashSize]byte]*LeafNode), nil
353+
}
354+
return map[[hashSize]byte]*LeafNode{keyPrefix: leafNode}, nil
355+
}
356+
357+
// Recursive step: If it's a branch node.
358+
if branchNode, ok := node.(*BranchNode); ok {
359+
// Optimization: if the branch is empty, return early.
360+
if IsEqualNode(branchNode, EmptyTree[depth]) {
361+
return make(map[[hashSize]byte]*LeafNode), nil
362+
}
363+
364+
left, right, err := tx.GetChildren(depth, branchNode.NodeHash())
365+
if err != nil {
366+
return nil, fmt.Errorf("error getting children for "+
367+
"branch %s at depth %d: %w",
368+
branchNode.NodeHash(), depth, err)
369+
}
370+
371+
// Recursively find leaves in the left subtree. The key prefix
372+
// remains the same as the 0 bit is implicitly handled by the
373+
// initial keyPrefix state.
374+
leftLeaves, err := findLeaves(
375+
ctx, tx, left, keyPrefix, depth+1,
376+
)
377+
if err != nil {
378+
return nil, err
379+
}
380+
381+
// Recursively find leaves in the right subtree. Set the bit
382+
// corresponding to the current depth in the key prefix.
383+
rightKeyPrefix := setBit(keyPrefix, depth)
384+
385+
rightLeaves, err := findLeaves(
386+
ctx, tx, right, rightKeyPrefix, depth+1,
387+
)
388+
if err != nil {
389+
return nil, err
390+
}
391+
392+
// Merge the results.
393+
for k, v := range rightLeaves {
394+
leftLeaves[k] = v
395+
}
396+
return leftLeaves, nil
397+
}
398+
399+
// Handle unexpected node types.
400+
return nil, fmt.Errorf("unexpected node type %T encountered "+
401+
"during leaf collection", node)
402+
}
403+
404+
// Copy copies all the key-value pairs from the source tree into the target
405+
// tree.
406+
func (t *FullTree) Copy(ctx context.Context, targetTree Tree) error {
407+
var leaves map[[hashSize]byte]*LeafNode
408+
err := t.store.View(ctx, func(tx TreeStoreViewTx) error {
409+
root, err := tx.RootNode()
410+
if err != nil {
411+
return fmt.Errorf("error getting root node: %w", err)
412+
}
413+
414+
// Optimization: If the source tree is empty, there's nothing
415+
// to copy.
416+
if IsEqualNode(root, EmptyTree[0]) {
417+
leaves = make(map[[hashSize]byte]*LeafNode)
418+
return nil
419+
}
420+
421+
leaves, err = findLeaves(ctx, tx, root, [hashSize]byte{}, 0)
422+
if err != nil {
423+
return fmt.Errorf("error finding leaves: %w", err)
424+
}
425+
return nil
426+
})
427+
if err != nil {
428+
return err
429+
}
430+
431+
// Insert all found leaves into the target tree. We assume the target
432+
// tree handles batching or individual inserts efficiently.
433+
for key, leaf := range leaves {
434+
// Use the target tree's Insert method. We ignore the returned
435+
// tree as we are modifying the targetTree in place via its
436+
// store.
437+
_, err := targetTree.Insert(ctx, key, leaf)
438+
if err != nil {
439+
return fmt.Errorf("error inserting leaf with key %x "+
440+
"into target tree: %w", key, err)
441+
}
442+
}
443+
444+
return nil
445+
}
446+
336447
// VerifyMerkleProof determines whether a merkle proof for the leaf found at the
337448
// given key is valid.
338449
func VerifyMerkleProof(key [hashSize]byte, leaf *LeafNode, proof *Proof,

mssmt/tree_test.go

Lines changed: 97 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -822,6 +822,103 @@ func TestBIPTestVectors(t *testing.T) {
822822
}
823823
}
824824

825+
// TestTreeCopy tests the Copy method for both FullTree and CompactedTree,
826+
// including copying between different tree types.
827+
func TestTreeCopy(t *testing.T) {
828+
t.Parallel()
829+
830+
leaves := randTree(50) // Use a smaller number for faster testing
831+
832+
// Prepare source trees (Full and Compacted)
833+
ctx := context.Background()
834+
sourceFullStore := mssmt.NewDefaultStore()
835+
sourceFullTree := mssmt.NewFullTree(sourceFullStore)
836+
sourceCompactedStore := mssmt.NewDefaultStore()
837+
sourceCompactedTree := mssmt.NewCompactedTree(sourceCompactedStore)
838+
839+
for _, item := range leaves {
840+
_, err := sourceFullTree.Insert(ctx, item.key, item.leaf)
841+
require.NoError(t, err)
842+
_, err = sourceCompactedTree.Insert(ctx, item.key, item.leaf)
843+
require.NoError(t, err)
844+
}
845+
846+
sourceFullRoot, err := sourceFullTree.Root(ctx)
847+
require.NoError(t, err)
848+
sourceCompactedRoot, err := sourceCompactedTree.Root(ctx)
849+
require.NoError(t, err)
850+
require.True(t, mssmt.IsEqualNode(sourceFullRoot, sourceCompactedRoot))
851+
852+
// Define test cases
853+
testCases := []struct {
854+
name string
855+
sourceTree mssmt.Tree
856+
makeTarget func() mssmt.Tree
857+
}{
858+
{
859+
name: "Full -> Full",
860+
sourceTree: sourceFullTree,
861+
makeTarget: func() mssmt.Tree {
862+
return mssmt.NewFullTree(mssmt.NewDefaultStore())
863+
},
864+
},
865+
{
866+
name: "Full -> Compacted",
867+
sourceTree: sourceFullTree,
868+
makeTarget: func() mssmt.Tree {
869+
return mssmt.NewCompactedTree(mssmt.NewDefaultStore())
870+
},
871+
},
872+
{
873+
name: "Compacted -> Full",
874+
sourceTree: sourceCompactedTree,
875+
makeTarget: func() mssmt.Tree {
876+
return mssmt.NewFullTree(mssmt.NewDefaultStore())
877+
},
878+
},
879+
{
880+
name: "Compacted -> Compacted",
881+
sourceTree: sourceCompactedTree,
882+
makeTarget: func() mssmt.Tree {
883+
return mssmt.NewCompactedTree(mssmt.NewDefaultStore())
884+
},
885+
},
886+
}
887+
888+
for _, tc := range testCases {
889+
tc := tc
890+
t.Run(tc.name, func(t *testing.T) {
891+
t.Parallel()
892+
893+
targetTree := tc.makeTarget()
894+
895+
// Perform the copy
896+
err := tc.sourceTree.Copy(ctx, targetTree)
897+
require.NoError(t, err)
898+
899+
// Verify the target tree root
900+
targetRoot, err := targetTree.Root(ctx)
901+
require.NoError(t, err)
902+
require.True(t, mssmt.IsEqualNode(sourceFullRoot, targetRoot),
903+
"Root mismatch after copy")
904+
905+
// Verify individual leaves in the target tree
906+
for _, item := range leaves {
907+
targetLeaf, err := targetTree.Get(ctx, item.key)
908+
require.NoError(t, err)
909+
require.Equal(t, item.leaf, targetLeaf,
910+
"Leaf mismatch for key %x", item.key)
911+
}
912+
913+
// Verify a non-existent key is still empty
914+
emptyLeaf, err := targetTree.Get(ctx, test.RandHash())
915+
require.NoError(t, err)
916+
require.True(t, emptyLeaf.IsEmpty(), "Non-existent key found")
917+
})
918+
}
919+
}
920+
921+
825922
// runBIPTestVector runs the tests in a single BIP test vector file.
826923
func runBIPTestVector(t *testing.T, testVectors *mssmt.TestVectors) {
827924
for _, validCase := range testVectors.ValidTestCases {

0 commit comments

Comments
 (0)