@@ -392,3 +392,191 @@ func (t *CompactedTree) MerkleProof(ctx context.Context, key [hashSize]byte) (
392392
393393 return NewProof (proof ), nil
394394}
395+
396+ // collectLeavesRecursive is a recursive helper function that's used to traverse
397+ // down an MS-SMT tree and collect all leaf nodes. It returns a map of leaf
398+ // nodes indexed by their hash.
399+ func collectLeavesRecursive (ctx context.Context , tx TreeStoreViewTx , node Node ,
400+ depth int ) (map [[hashSize ]byte ]* LeafNode , error ) {
401+
402+ // Base case: If it's a compacted leaf node.
403+ if compactedLeaf , ok := node .(* CompactedLeafNode ); ok {
404+ if compactedLeaf .LeafNode .IsEmpty () {
405+ return make (map [[hashSize ]byte ]* LeafNode ), nil
406+ }
407+ return map [[hashSize ]byte ]* LeafNode {
408+ compactedLeaf .Key (): compactedLeaf .LeafNode ,
409+ }, nil
410+ }
411+
412+ // Recursive step: If it's a branch node.
413+ if branchNode , ok := node .(* BranchNode ); ok {
414+ // Optimization: if the branch is empty, return early.
415+ if depth < MaxTreeLevels &&
416+ IsEqualNode (branchNode , EmptyTree [depth ]) {
417+
418+ return make (map [[hashSize ]byte ]* LeafNode ), nil
419+ }
420+
421+ // Handle case where depth might exceed EmptyTree bounds if
422+ // logic error exists
423+ if depth >= MaxTreeLevels {
424+ // This shouldn't happen if called correctly, implies a
425+ // leaf.
426+ return nil , fmt .Errorf ("invalid depth %d for branch " +
427+ "node" , depth )
428+ }
429+
430+ left , right , err := tx .GetChildren (depth , branchNode .NodeHash ())
431+ if err != nil {
432+ // If children not found, it might be an empty branch
433+ // implicitly Check if the error indicates "not found"
434+ // or similar Depending on store impl, this might be how
435+ // empty is signaled For now, treat error as fatal.
436+ return nil , fmt .Errorf ("error getting children for " +
437+ "branch %s at depth %d: %w" ,
438+ branchNode .NodeHash (), depth , err )
439+ }
440+
441+ leftLeaves , err := collectLeavesRecursive (
442+ ctx , tx , left , depth + 1 ,
443+ )
444+ if err != nil {
445+ return nil , err
446+ }
447+
448+ rightLeaves , err := collectLeavesRecursive (
449+ ctx , tx , right , depth + 1 ,
450+ )
451+ if err != nil {
452+ return nil , err
453+ }
454+
455+ // Merge the results.
456+ for k , v := range rightLeaves {
457+ // Check for duplicate keys, although this shouldn't
458+ // happen in a valid SMT.
459+ if _ , exists := leftLeaves [k ]; exists {
460+ return nil , fmt .Errorf ("duplicate key %x " +
461+ "found during leaf collection" , k )
462+ }
463+ leftLeaves [k ] = v
464+ }
465+
466+ return leftLeaves , nil
467+ }
468+
469+ // Handle unexpected node types or implicit empty nodes. If node is nil
470+ // or explicitly an EmptyLeafNode representation
471+ if node == nil || IsEqualNode (node , EmptyLeafNode ) {
472+ return make (map [[hashSize ]byte ]* LeafNode ), nil
473+ }
474+
475+ // Check against EmptyTree branches if possible (requires depth)
476+ if depth < MaxTreeLevels && IsEqualNode (node , EmptyTree [depth ]) {
477+ return make (map [[hashSize ]byte ]* LeafNode ), nil
478+ }
479+
480+ return nil , fmt .Errorf ("unexpected node type %T encountered " +
481+ "during leaf collection at depth %d" , node , depth )
482+ }
483+
484+ // Copy copies all the key-value pairs from the source tree into the target
485+ // tree.
486+ func (t * CompactedTree ) Copy (ctx context.Context , targetTree Tree ) error {
487+ var leaves map [[hashSize ]byte ]* LeafNode
488+ err := t .store .View (ctx , func (tx TreeStoreViewTx ) error {
489+ root , err := tx .RootNode ()
490+ if err != nil {
491+ return fmt .Errorf ("error getting root node: %w" , err )
492+ }
493+
494+ // Optimization: If the source tree is empty, there's nothing to
495+ // copy.
496+ if IsEqualNode (root , EmptyTree [0 ]) {
497+ leaves = make (map [[hashSize ]byte ]* LeafNode )
498+ return nil
499+ }
500+
501+ // Start recursive collection from the root at depth 0.
502+ leaves , err = collectLeavesRecursive (ctx , tx , root , 0 )
503+ if err != nil {
504+ return fmt .Errorf ("error collecting leaves: %w" , err )
505+ }
506+
507+ return nil
508+ })
509+ if err != nil {
510+ return err
511+ }
512+
513+ // Insert all found leaves into the target tree using InsertMany for
514+ // efficiency.
515+ _ , err = targetTree .InsertMany (ctx , leaves )
516+ if err != nil {
517+ return fmt .Errorf ("error inserting leaves into " +
518+ "target tree: %w" , err )
519+ }
520+
521+ return nil
522+ }
523+
524+ // InsertMany inserts multiple leaf nodes provided in the leaves map within a
525+ // single database transaction.
526+ func (t * CompactedTree ) InsertMany (ctx context.Context ,
527+ leaves map [[hashSize ]byte ]* LeafNode ) (Tree , error ) {
528+
529+ if len (leaves ) == 0 {
530+ return t , nil
531+ }
532+
533+ dbErr := t .store .Update (ctx , func (tx TreeStoreUpdateTx ) error {
534+ currentRoot , err := tx .RootNode ()
535+ if err != nil {
536+ return err
537+ }
538+ rootBranch := currentRoot .(* BranchNode )
539+
540+ for key , leaf := range leaves {
541+ // Check for potential sum overflow before each
542+ // insertion.
543+ sumRoot := rootBranch .NodeSum ()
544+ sumLeaf := leaf .NodeSum ()
545+ err = CheckSumOverflowUint64 (sumRoot , sumLeaf )
546+ if err != nil {
547+ return fmt .Errorf ("compact tree leaf insert " +
548+ "sum overflow, root: %d, leaf: %d; %w" ,
549+ sumRoot , sumLeaf , err )
550+ }
551+
552+ // Insert the leaf using the internal helper.
553+ newRoot , err := t .insert (
554+ tx , & key , 0 , rootBranch , leaf ,
555+ )
556+ if err != nil {
557+ return fmt .Errorf ("error inserting leaf " +
558+ "with key %x: %w" , key , err )
559+ }
560+ rootBranch = newRoot
561+
562+ // Update the root within the transaction for
563+ // consistency, even though the insert logic passes the
564+ // root explicitly.
565+ err = tx .UpdateRoot (rootBranch )
566+ if err != nil {
567+ return fmt .Errorf ("error updating root " +
568+ "during InsertMany: %w" , err )
569+ }
570+ }
571+
572+ // The root is already updated by the last iteration of the
573+ // loop. No final update needed here, but returning nil error
574+ // signals success.
575+ return nil
576+ })
577+ if dbErr != nil {
578+ return nil , dbErr
579+ }
580+
581+ return t , nil
582+ }
0 commit comments