From 926ded3385bf2dbc69b9604f33de54275dd45a81 Mon Sep 17 00:00:00 2001 From: rizkyikiw42 Date: Thu, 27 Nov 2025 11:10:45 +0800 Subject: [PATCH] trie/bintrie: skip resolver for zero-hash children --- trie/bintrie/internal_node.go | 100 +++++++++++--------- trie/bintrie/zero_hash_fix_test.go | 145 +++++++++++++++++++++++++++++ 2 files changed, 203 insertions(+), 42 deletions(-) create mode 100644 trie/bintrie/zero_hash_fix_test.go diff --git a/trie/bintrie/internal_node.go b/trie/bintrie/internal_node.go index 0a7bece521fd..d4334f743a18 100644 --- a/trie/bintrie/internal_node.go +++ b/trie/bintrie/internal_node.go @@ -51,6 +51,31 @@ func (bt *InternalNode) GetValuesAtStem(stem []byte, resolver NodeResolverFn) ([ bit := stem[bt.depth/8] >> (7 - (bt.depth % 8)) & 1 if bit == 0 { if hn, ok := bt.left.(HashedNode); ok { + if common.Hash(hn) == (common.Hash{}) { + bt.left = Empty{} + } else { + path, err := keyToPath(bt.depth, stem) + if err != nil { + return nil, fmt.Errorf("GetValuesAtStem resolve error: %w", err) + } + data, err := resolver(path, common.Hash(hn)) + if err != nil { + return nil, fmt.Errorf("GetValuesAtStem resolve error: %w", err) + } + node, err := DeserializeNode(data, bt.depth+1) + if err != nil { + return nil, fmt.Errorf("GetValuesAtStem node deserialization error: %w", err) + } + bt.left = node + } + } + return bt.left.GetValuesAtStem(stem, resolver) + } + + if hn, ok := bt.right.(HashedNode); ok { + if common.Hash(hn) == (common.Hash{}) { + bt.right = Empty{} + } else { path, err := keyToPath(bt.depth, stem) if err != nil { return nil, fmt.Errorf("GetValuesAtStem resolve error: %w", err) @@ -63,25 +88,8 @@ func (bt *InternalNode) GetValuesAtStem(stem []byte, resolver NodeResolverFn) ([ if err != nil { return nil, fmt.Errorf("GetValuesAtStem node deserialization error: %w", err) } - bt.left = node + bt.right = node } - return bt.left.GetValuesAtStem(stem, resolver) - } - - if hn, ok := bt.right.(HashedNode); ok { - path, err := keyToPath(bt.depth, stem) - if err != nil { - return nil, fmt.Errorf("GetValuesAtStem resolve error: %w", err) - } - data, err := resolver(path, common.Hash(hn)) - if err != nil { - return nil, fmt.Errorf("GetValuesAtStem resolve error: %w", err) - } - node, err := DeserializeNode(data, bt.depth+1) - if err != nil { - return nil, fmt.Errorf("GetValuesAtStem node deserialization error: %w", err) - } - bt.right = node } return bt.right.GetValuesAtStem(stem, resolver) } @@ -141,19 +149,23 @@ func (bt *InternalNode) InsertValuesAtStem(stem []byte, values [][]byte, resolve } if hn, ok := bt.left.(HashedNode); ok { - path, err := keyToPath(bt.depth, stem) - if err != nil { - return nil, fmt.Errorf("InsertValuesAtStem resolve error: %w", err) - } - data, err := resolver(path, common.Hash(hn)) - if err != nil { - return nil, fmt.Errorf("InsertValuesAtStem resolve error: %w", err) - } - node, err := DeserializeNode(data, bt.depth+1) - if err != nil { - return nil, fmt.Errorf("InsertValuesAtStem node deserialization error: %w", err) + if common.Hash(hn) == (common.Hash{}) { + bt.left = Empty{} + } else { + path, err := keyToPath(bt.depth, stem) + if err != nil { + return nil, fmt.Errorf("InsertValuesAtStem resolve error: %w", err) + } + data, err := resolver(path, common.Hash(hn)) + if err != nil { + return nil, fmt.Errorf("InsertValuesAtStem resolve error: %w", err) + } + node, err := DeserializeNode(data, bt.depth+1) + if err != nil { + return nil, fmt.Errorf("InsertValuesAtStem node deserialization error: %w", err) + } + bt.left = node } - bt.left = node } bt.left, err = bt.left.InsertValuesAtStem(stem, values, resolver, depth+1) @@ -165,19 +177,23 @@ func (bt *InternalNode) InsertValuesAtStem(stem []byte, values [][]byte, resolve } if hn, ok := bt.right.(HashedNode); ok { - path, err := keyToPath(bt.depth, stem) - if err != nil { - return nil, fmt.Errorf("InsertValuesAtStem resolve error: %w", err) - } - data, err := resolver(path, common.Hash(hn)) - if err != nil { - return nil, fmt.Errorf("InsertValuesAtStem resolve error: %w", err) - } - node, err := DeserializeNode(data, bt.depth+1) - if err != nil { - return nil, fmt.Errorf("InsertValuesAtStem node deserialization error: %w", err) + if common.Hash(hn) == (common.Hash{}) { + bt.right = Empty{} + } else { + path, err := keyToPath(bt.depth, stem) + if err != nil { + return nil, fmt.Errorf("InsertValuesAtStem resolve error: %w", err) + } + data, err := resolver(path, common.Hash(hn)) + if err != nil { + return nil, fmt.Errorf("InsertValuesAtStem resolve error: %w", err) + } + node, err := DeserializeNode(data, bt.depth+1) + if err != nil { + return nil, fmt.Errorf("InsertValuesAtStem node deserialization error: %w", err) + } + bt.right = node } - bt.right = node } bt.right, err = bt.right.InsertValuesAtStem(stem, values, resolver, depth+1) diff --git a/trie/bintrie/zero_hash_fix_test.go b/trie/bintrie/zero_hash_fix_test.go new file mode 100644 index 000000000000..a90ba9a5b822 --- /dev/null +++ b/trie/bintrie/zero_hash_fix_test.go @@ -0,0 +1,145 @@ +// Copyright 2025 go-ethereum Authors +// This file is part of the go-ethereum library. +// +// The go-ethereum library is free software: you can redistribute it and/or modify +// it under the terms of the GNU Lesser General Public License as published by +// the Free Software Foundation, either version 3 of the License, or +// (at your option) any later version. +// +// The go-ethereum library is distributed in the hope that it will be useful, +// but WITHOUT ANY WARRANTY; without even the implied warranty of +// MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the +// GNU Lesser General Public License for more details. +// +// You should have received a copy of the GNU Lesser General Public License +// along with the go-ethereum library. If not, see . + +package bintrie + +import ( + "errors" + "testing" + + "github.com/ethereum/go-ethereum/common" +) + +// TestZeroHashSkipsResolver tests that zero-hash HashedNodes don't trigger resolver calls +func TestZeroHashSkipsResolver(t *testing.T) { + // Create an InternalNode with one real child and one Empty child + realHash := common.HexToHash("0x1234") + + node := &InternalNode{ + depth: 0, + left: HashedNode(realHash), + right: Empty{}, + } + + // Serialize and deserialize to create zero-hash HashedNode + serialized := SerializeNode(node) + deserialized, err := DeserializeNode(serialized, 0) + if err != nil { + t.Fatalf("Failed to deserialize: %v", err) + } + + deserializedInternal := deserialized.(*InternalNode) + + // Verify that right child is a zero-hash HashedNode after deserialization + if hn, ok := deserializedInternal.right.(HashedNode); ok { + if common.Hash(hn) != (common.Hash{}) { + t.Fatal("Expected right child to be zero-hash HashedNode") + } + } else { + t.Fatalf("Expected right child to be HashedNode, got %T", deserializedInternal.right) + } + + // Track resolver calls + resolverCalls := 0 + resolver := func(path []byte, hash common.Hash) ([]byte, error) { + resolverCalls++ + + // Zero-hash should never reach resolver + if hash == (common.Hash{}) { + t.Error("BUG: Resolver called for zero hash") + return nil, errors.New("zero hash should not be resolved") + } + + // Return valid data for real hash + if hash == realHash { + stem := make([]byte, 31) + var values [256][]byte + values[5] = common.HexToHash("0xabcd").Bytes() + return SerializeNode(&StemNode{Stem: stem, Values: values[:], depth: 1}), nil + } + + return nil, errors.New("not found") + } + + // Access right child (zero-hash) - should not call resolver + rightStem := make([]byte, 31) + rightStem[0] = 0x80 // First bit is 1, routes to right child + + values, err := deserializedInternal.GetValuesAtStem(rightStem, resolver) + if err != nil { + t.Fatalf("GetValuesAtStem failed: %v", err) + } + + // All values should be nil for empty node + for i, v := range values { + if v != nil { + t.Errorf("Expected nil value at index %d, got %x", i, v) + } + } + + // Verify resolver was not called for zero-hash + if resolverCalls > 0 { + t.Errorf("Resolver should not have been called for zero-hash child, but was called %d times", resolverCalls) + } + + // Now test left child (real hash) - should call resolver + leftStem := make([]byte, 31) + _, err = deserializedInternal.GetValuesAtStem(leftStem, resolver) + if err != nil { + t.Fatalf("GetValuesAtStem failed for left child: %v", err) + } + + if resolverCalls != 1 { + t.Errorf("Expected resolver to be called once for real hash, called %d times", resolverCalls) + } +} + +// TestZeroHashSkipsResolverOnInsert tests that InsertValuesAtStem also skips zero-hash resolver calls +func TestZeroHashSkipsResolverOnInsert(t *testing.T) { + // Create node after deserialization with zero-hash children + node := &InternalNode{ + depth: 0, + left: HashedNode(common.Hash{}), // Zero-hash + right: HashedNode(common.Hash{}), // Zero-hash + } + + resolverCalls := 0 + resolver := func(path []byte, hash common.Hash) ([]byte, error) { + resolverCalls++ + + if hash == (common.Hash{}) { + t.Error("BUG: Resolver called for zero hash in InsertValuesAtStem") + return nil, errors.New("zero hash should not be resolved") + } + + return nil, errors.New("not found") + } + + // Insert values into left subtree (zero-hash child) + leftStem := make([]byte, 31) + var values [256][]byte + values[10] = common.HexToHash("0x5678").Bytes() + + _, err := node.InsertValuesAtStem(leftStem, values[:], resolver, 0) + if err != nil { + t.Fatalf("InsertValuesAtStem failed: %v", err) + } + + // Verify resolver was not called + if resolverCalls > 0 { + t.Errorf("Resolver should not have been called for zero-hash child, but was called %d times", resolverCalls) + } +}