Skip to content
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
100 changes: 58 additions & 42 deletions trie/bintrie/internal_node.go
Original file line number Diff line number Diff line change
Expand Up @@ -51,6 +51,31 @@ func (bt *InternalNode) GetValuesAtStem(stem []byte, resolver NodeResolverFn) ([
bit := stem[bt.depth/8] >> (7 - (bt.depth % 8)) & 1
if bit == 0 {
if hn, ok := bt.left.(HashedNode); ok {
if common.Hash(hn) == (common.Hash{}) {
bt.left = Empty{}
} else {
path, err := keyToPath(bt.depth, stem)
if err != nil {
return nil, fmt.Errorf("GetValuesAtStem resolve error: %w", err)
}
data, err := resolver(path, common.Hash(hn))
if err != nil {
return nil, fmt.Errorf("GetValuesAtStem resolve error: %w", err)
}
node, err := DeserializeNode(data, bt.depth+1)
if err != nil {
return nil, fmt.Errorf("GetValuesAtStem node deserialization error: %w", err)
}
bt.left = node
}
}
return bt.left.GetValuesAtStem(stem, resolver)
}

if hn, ok := bt.right.(HashedNode); ok {
if common.Hash(hn) == (common.Hash{}) {
bt.right = Empty{}
} else {
path, err := keyToPath(bt.depth, stem)
if err != nil {
return nil, fmt.Errorf("GetValuesAtStem resolve error: %w", err)
Expand All @@ -63,25 +88,8 @@ func (bt *InternalNode) GetValuesAtStem(stem []byte, resolver NodeResolverFn) ([
if err != nil {
return nil, fmt.Errorf("GetValuesAtStem node deserialization error: %w", err)
}
bt.left = node
bt.right = node
}
return bt.left.GetValuesAtStem(stem, resolver)
}

if hn, ok := bt.right.(HashedNode); ok {
path, err := keyToPath(bt.depth, stem)
if err != nil {
return nil, fmt.Errorf("GetValuesAtStem resolve error: %w", err)
}
data, err := resolver(path, common.Hash(hn))
if err != nil {
return nil, fmt.Errorf("GetValuesAtStem resolve error: %w", err)
}
node, err := DeserializeNode(data, bt.depth+1)
if err != nil {
return nil, fmt.Errorf("GetValuesAtStem node deserialization error: %w", err)
}
bt.right = node
}
return bt.right.GetValuesAtStem(stem, resolver)
}
Expand Down Expand Up @@ -141,19 +149,23 @@ func (bt *InternalNode) InsertValuesAtStem(stem []byte, values [][]byte, resolve
}

if hn, ok := bt.left.(HashedNode); ok {
path, err := keyToPath(bt.depth, stem)
if err != nil {
return nil, fmt.Errorf("InsertValuesAtStem resolve error: %w", err)
}
data, err := resolver(path, common.Hash(hn))
if err != nil {
return nil, fmt.Errorf("InsertValuesAtStem resolve error: %w", err)
}
node, err := DeserializeNode(data, bt.depth+1)
if err != nil {
return nil, fmt.Errorf("InsertValuesAtStem node deserialization error: %w", err)
if common.Hash(hn) == (common.Hash{}) {
bt.left = Empty{}
} else {
path, err := keyToPath(bt.depth, stem)
if err != nil {
return nil, fmt.Errorf("InsertValuesAtStem resolve error: %w", err)
}
data, err := resolver(path, common.Hash(hn))
if err != nil {
return nil, fmt.Errorf("InsertValuesAtStem resolve error: %w", err)
}
node, err := DeserializeNode(data, bt.depth+1)
if err != nil {
return nil, fmt.Errorf("InsertValuesAtStem node deserialization error: %w", err)
}
bt.left = node
}
bt.left = node
}

bt.left, err = bt.left.InsertValuesAtStem(stem, values, resolver, depth+1)
Expand All @@ -165,19 +177,23 @@ func (bt *InternalNode) InsertValuesAtStem(stem []byte, values [][]byte, resolve
}

if hn, ok := bt.right.(HashedNode); ok {
path, err := keyToPath(bt.depth, stem)
if err != nil {
return nil, fmt.Errorf("InsertValuesAtStem resolve error: %w", err)
}
data, err := resolver(path, common.Hash(hn))
if err != nil {
return nil, fmt.Errorf("InsertValuesAtStem resolve error: %w", err)
}
node, err := DeserializeNode(data, bt.depth+1)
if err != nil {
return nil, fmt.Errorf("InsertValuesAtStem node deserialization error: %w", err)
if common.Hash(hn) == (common.Hash{}) {
bt.right = Empty{}
} else {
path, err := keyToPath(bt.depth, stem)
if err != nil {
return nil, fmt.Errorf("InsertValuesAtStem resolve error: %w", err)
}
data, err := resolver(path, common.Hash(hn))
if err != nil {
return nil, fmt.Errorf("InsertValuesAtStem resolve error: %w", err)
}
node, err := DeserializeNode(data, bt.depth+1)
if err != nil {
return nil, fmt.Errorf("InsertValuesAtStem node deserialization error: %w", err)
}
bt.right = node
}
bt.right = node
}

bt.right, err = bt.right.InsertValuesAtStem(stem, values, resolver, depth+1)
Expand Down
145 changes: 145 additions & 0 deletions trie/bintrie/zero_hash_fix_test.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,145 @@
// Copyright 2025 go-ethereum Authors
// This file is part of the go-ethereum library.
//
// The go-ethereum library is free software: you can redistribute it and/or modify
// it under the terms of the GNU Lesser General Public License as published by
// the Free Software Foundation, either version 3 of the License, or
// (at your option) any later version.
//
// The go-ethereum library is distributed in the hope that it will be useful,
// but WITHOUT ANY WARRANTY; without even the implied warranty of
// MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
// GNU Lesser General Public License for more details.
//
// You should have received a copy of the GNU Lesser General Public License
// along with the go-ethereum library. If not, see <http://www.gnu.org/licenses/>.

package bintrie

import (
"errors"
"testing"

"github.com/ethereum/go-ethereum/common"
)

// TestZeroHashSkipsResolver tests that zero-hash HashedNodes don't trigger resolver calls
func TestZeroHashSkipsResolver(t *testing.T) {
// Create an InternalNode with one real child and one Empty child
realHash := common.HexToHash("0x1234")

node := &InternalNode{
depth: 0,
left: HashedNode(realHash),
right: Empty{},
}

// Serialize and deserialize to create zero-hash HashedNode
serialized := SerializeNode(node)
deserialized, err := DeserializeNode(serialized, 0)
if err != nil {
t.Fatalf("Failed to deserialize: %v", err)
}

deserializedInternal := deserialized.(*InternalNode)

// Verify that right child is a zero-hash HashedNode after deserialization
if hn, ok := deserializedInternal.right.(HashedNode); ok {
if common.Hash(hn) != (common.Hash{}) {
t.Fatal("Expected right child to be zero-hash HashedNode")
}
} else {
t.Fatalf("Expected right child to be HashedNode, got %T", deserializedInternal.right)
}

// Track resolver calls
resolverCalls := 0
resolver := func(path []byte, hash common.Hash) ([]byte, error) {
resolverCalls++

// Zero-hash should never reach resolver
if hash == (common.Hash{}) {
t.Error("BUG: Resolver called for zero hash")
return nil, errors.New("zero hash should not be resolved")
}

// Return valid data for real hash
if hash == realHash {
stem := make([]byte, 31)
var values [256][]byte
values[5] = common.HexToHash("0xabcd").Bytes()
return SerializeNode(&StemNode{Stem: stem, Values: values[:], depth: 1}), nil
}

return nil, errors.New("not found")
}

// Access right child (zero-hash) - should not call resolver
rightStem := make([]byte, 31)
rightStem[0] = 0x80 // First bit is 1, routes to right child

values, err := deserializedInternal.GetValuesAtStem(rightStem, resolver)
if err != nil {
t.Fatalf("GetValuesAtStem failed: %v", err)
}

// All values should be nil for empty node
for i, v := range values {
if v != nil {
t.Errorf("Expected nil value at index %d, got %x", i, v)
}
}

// Verify resolver was not called for zero-hash
if resolverCalls > 0 {
t.Errorf("Resolver should not have been called for zero-hash child, but was called %d times", resolverCalls)
}

// Now test left child (real hash) - should call resolver
leftStem := make([]byte, 31)
_, err = deserializedInternal.GetValuesAtStem(leftStem, resolver)
if err != nil {
t.Fatalf("GetValuesAtStem failed for left child: %v", err)
}

if resolverCalls != 1 {
t.Errorf("Expected resolver to be called once for real hash, called %d times", resolverCalls)
}
}

// TestZeroHashSkipsResolverOnInsert tests that InsertValuesAtStem also skips zero-hash resolver calls
func TestZeroHashSkipsResolverOnInsert(t *testing.T) {
// Create node after deserialization with zero-hash children
node := &InternalNode{
depth: 0,
left: HashedNode(common.Hash{}), // Zero-hash
right: HashedNode(common.Hash{}), // Zero-hash
}

resolverCalls := 0
resolver := func(path []byte, hash common.Hash) ([]byte, error) {
resolverCalls++

if hash == (common.Hash{}) {
t.Error("BUG: Resolver called for zero hash in InsertValuesAtStem")
return nil, errors.New("zero hash should not be resolved")
}

return nil, errors.New("not found")
}

// Insert values into left subtree (zero-hash child)
leftStem := make([]byte, 31)
var values [256][]byte
values[10] = common.HexToHash("0x5678").Bytes()

_, err := node.InsertValuesAtStem(leftStem, values[:], resolver, 0)
if err != nil {
t.Fatalf("InsertValuesAtStem failed: %v", err)
}

// Verify resolver was not called
if resolverCalls > 0 {
t.Errorf("Resolver should not have been called for zero-hash child, but was called %d times", resolverCalls)
}
}
Loading