Skip to content

Commit 926ded3

Browse files
committed
trie/bintrie: skip resolver for zero-hash children
1 parent 6452b7a commit 926ded3

File tree

2 files changed

+203
-42
lines changed

2 files changed

+203
-42
lines changed

trie/bintrie/internal_node.go

Lines changed: 58 additions & 42 deletions
Original file line numberDiff line numberDiff line change
@@ -51,6 +51,31 @@ func (bt *InternalNode) GetValuesAtStem(stem []byte, resolver NodeResolverFn) ([
5151
bit := stem[bt.depth/8] >> (7 - (bt.depth % 8)) & 1
5252
if bit == 0 {
5353
if hn, ok := bt.left.(HashedNode); ok {
54+
if common.Hash(hn) == (common.Hash{}) {
55+
bt.left = Empty{}
56+
} else {
57+
path, err := keyToPath(bt.depth, stem)
58+
if err != nil {
59+
return nil, fmt.Errorf("GetValuesAtStem resolve error: %w", err)
60+
}
61+
data, err := resolver(path, common.Hash(hn))
62+
if err != nil {
63+
return nil, fmt.Errorf("GetValuesAtStem resolve error: %w", err)
64+
}
65+
node, err := DeserializeNode(data, bt.depth+1)
66+
if err != nil {
67+
return nil, fmt.Errorf("GetValuesAtStem node deserialization error: %w", err)
68+
}
69+
bt.left = node
70+
}
71+
}
72+
return bt.left.GetValuesAtStem(stem, resolver)
73+
}
74+
75+
if hn, ok := bt.right.(HashedNode); ok {
76+
if common.Hash(hn) == (common.Hash{}) {
77+
bt.right = Empty{}
78+
} else {
5479
path, err := keyToPath(bt.depth, stem)
5580
if err != nil {
5681
return nil, fmt.Errorf("GetValuesAtStem resolve error: %w", err)
@@ -63,25 +88,8 @@ func (bt *InternalNode) GetValuesAtStem(stem []byte, resolver NodeResolverFn) ([
6388
if err != nil {
6489
return nil, fmt.Errorf("GetValuesAtStem node deserialization error: %w", err)
6590
}
66-
bt.left = node
91+
bt.right = node
6792
}
68-
return bt.left.GetValuesAtStem(stem, resolver)
69-
}
70-
71-
if hn, ok := bt.right.(HashedNode); ok {
72-
path, err := keyToPath(bt.depth, stem)
73-
if err != nil {
74-
return nil, fmt.Errorf("GetValuesAtStem resolve error: %w", err)
75-
}
76-
data, err := resolver(path, common.Hash(hn))
77-
if err != nil {
78-
return nil, fmt.Errorf("GetValuesAtStem resolve error: %w", err)
79-
}
80-
node, err := DeserializeNode(data, bt.depth+1)
81-
if err != nil {
82-
return nil, fmt.Errorf("GetValuesAtStem node deserialization error: %w", err)
83-
}
84-
bt.right = node
8593
}
8694
return bt.right.GetValuesAtStem(stem, resolver)
8795
}
@@ -141,19 +149,23 @@ func (bt *InternalNode) InsertValuesAtStem(stem []byte, values [][]byte, resolve
141149
}
142150

143151
if hn, ok := bt.left.(HashedNode); ok {
144-
path, err := keyToPath(bt.depth, stem)
145-
if err != nil {
146-
return nil, fmt.Errorf("InsertValuesAtStem resolve error: %w", err)
147-
}
148-
data, err := resolver(path, common.Hash(hn))
149-
if err != nil {
150-
return nil, fmt.Errorf("InsertValuesAtStem resolve error: %w", err)
151-
}
152-
node, err := DeserializeNode(data, bt.depth+1)
153-
if err != nil {
154-
return nil, fmt.Errorf("InsertValuesAtStem node deserialization error: %w", err)
152+
if common.Hash(hn) == (common.Hash{}) {
153+
bt.left = Empty{}
154+
} else {
155+
path, err := keyToPath(bt.depth, stem)
156+
if err != nil {
157+
return nil, fmt.Errorf("InsertValuesAtStem resolve error: %w", err)
158+
}
159+
data, err := resolver(path, common.Hash(hn))
160+
if err != nil {
161+
return nil, fmt.Errorf("InsertValuesAtStem resolve error: %w", err)
162+
}
163+
node, err := DeserializeNode(data, bt.depth+1)
164+
if err != nil {
165+
return nil, fmt.Errorf("InsertValuesAtStem node deserialization error: %w", err)
166+
}
167+
bt.left = node
155168
}
156-
bt.left = node
157169
}
158170

159171
bt.left, err = bt.left.InsertValuesAtStem(stem, values, resolver, depth+1)
@@ -165,19 +177,23 @@ func (bt *InternalNode) InsertValuesAtStem(stem []byte, values [][]byte, resolve
165177
}
166178

167179
if hn, ok := bt.right.(HashedNode); ok {
168-
path, err := keyToPath(bt.depth, stem)
169-
if err != nil {
170-
return nil, fmt.Errorf("InsertValuesAtStem resolve error: %w", err)
171-
}
172-
data, err := resolver(path, common.Hash(hn))
173-
if err != nil {
174-
return nil, fmt.Errorf("InsertValuesAtStem resolve error: %w", err)
175-
}
176-
node, err := DeserializeNode(data, bt.depth+1)
177-
if err != nil {
178-
return nil, fmt.Errorf("InsertValuesAtStem node deserialization error: %w", err)
180+
if common.Hash(hn) == (common.Hash{}) {
181+
bt.right = Empty{}
182+
} else {
183+
path, err := keyToPath(bt.depth, stem)
184+
if err != nil {
185+
return nil, fmt.Errorf("InsertValuesAtStem resolve error: %w", err)
186+
}
187+
data, err := resolver(path, common.Hash(hn))
188+
if err != nil {
189+
return nil, fmt.Errorf("InsertValuesAtStem resolve error: %w", err)
190+
}
191+
node, err := DeserializeNode(data, bt.depth+1)
192+
if err != nil {
193+
return nil, fmt.Errorf("InsertValuesAtStem node deserialization error: %w", err)
194+
}
195+
bt.right = node
179196
}
180-
bt.right = node
181197
}
182198

183199
bt.right, err = bt.right.InsertValuesAtStem(stem, values, resolver, depth+1)

trie/bintrie/zero_hash_fix_test.go

Lines changed: 145 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,145 @@
1+
// Copyright 2025 go-ethereum Authors
2+
// This file is part of the go-ethereum library.
3+
//
4+
// The go-ethereum library is free software: you can redistribute it and/or modify
5+
// it under the terms of the GNU Lesser General Public License as published by
6+
// the Free Software Foundation, either version 3 of the License, or
7+
// (at your option) any later version.
8+
//
9+
// The go-ethereum library is distributed in the hope that it will be useful,
10+
// but WITHOUT ANY WARRANTY; without even the implied warranty of
11+
// MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
12+
// GNU Lesser General Public License for more details.
13+
//
14+
// You should have received a copy of the GNU Lesser General Public License
15+
// along with the go-ethereum library. If not, see <http://www.gnu.org/licenses/>.
16+
17+
package bintrie
18+
19+
import (
20+
"errors"
21+
"testing"
22+
23+
"github.com/ethereum/go-ethereum/common"
24+
)
25+
26+
// TestZeroHashSkipsResolver tests that zero-hash HashedNodes don't trigger resolver calls
27+
func TestZeroHashSkipsResolver(t *testing.T) {
28+
// Create an InternalNode with one real child and one Empty child
29+
realHash := common.HexToHash("0x1234")
30+
31+
node := &InternalNode{
32+
depth: 0,
33+
left: HashedNode(realHash),
34+
right: Empty{},
35+
}
36+
37+
// Serialize and deserialize to create zero-hash HashedNode
38+
serialized := SerializeNode(node)
39+
deserialized, err := DeserializeNode(serialized, 0)
40+
if err != nil {
41+
t.Fatalf("Failed to deserialize: %v", err)
42+
}
43+
44+
deserializedInternal := deserialized.(*InternalNode)
45+
46+
// Verify that right child is a zero-hash HashedNode after deserialization
47+
if hn, ok := deserializedInternal.right.(HashedNode); ok {
48+
if common.Hash(hn) != (common.Hash{}) {
49+
t.Fatal("Expected right child to be zero-hash HashedNode")
50+
}
51+
} else {
52+
t.Fatalf("Expected right child to be HashedNode, got %T", deserializedInternal.right)
53+
}
54+
55+
// Track resolver calls
56+
resolverCalls := 0
57+
resolver := func(path []byte, hash common.Hash) ([]byte, error) {
58+
resolverCalls++
59+
60+
// Zero-hash should never reach resolver
61+
if hash == (common.Hash{}) {
62+
t.Error("BUG: Resolver called for zero hash")
63+
return nil, errors.New("zero hash should not be resolved")
64+
}
65+
66+
// Return valid data for real hash
67+
if hash == realHash {
68+
stem := make([]byte, 31)
69+
var values [256][]byte
70+
values[5] = common.HexToHash("0xabcd").Bytes()
71+
return SerializeNode(&StemNode{Stem: stem, Values: values[:], depth: 1}), nil
72+
}
73+
74+
return nil, errors.New("not found")
75+
}
76+
77+
// Access right child (zero-hash) - should not call resolver
78+
rightStem := make([]byte, 31)
79+
rightStem[0] = 0x80 // First bit is 1, routes to right child
80+
81+
values, err := deserializedInternal.GetValuesAtStem(rightStem, resolver)
82+
if err != nil {
83+
t.Fatalf("GetValuesAtStem failed: %v", err)
84+
}
85+
86+
// All values should be nil for empty node
87+
for i, v := range values {
88+
if v != nil {
89+
t.Errorf("Expected nil value at index %d, got %x", i, v)
90+
}
91+
}
92+
93+
// Verify resolver was not called for zero-hash
94+
if resolverCalls > 0 {
95+
t.Errorf("Resolver should not have been called for zero-hash child, but was called %d times", resolverCalls)
96+
}
97+
98+
// Now test left child (real hash) - should call resolver
99+
leftStem := make([]byte, 31)
100+
_, err = deserializedInternal.GetValuesAtStem(leftStem, resolver)
101+
if err != nil {
102+
t.Fatalf("GetValuesAtStem failed for left child: %v", err)
103+
}
104+
105+
if resolverCalls != 1 {
106+
t.Errorf("Expected resolver to be called once for real hash, called %d times", resolverCalls)
107+
}
108+
}
109+
110+
// TestZeroHashSkipsResolverOnInsert tests that InsertValuesAtStem also skips zero-hash resolver calls
111+
func TestZeroHashSkipsResolverOnInsert(t *testing.T) {
112+
// Create node after deserialization with zero-hash children
113+
node := &InternalNode{
114+
depth: 0,
115+
left: HashedNode(common.Hash{}), // Zero-hash
116+
right: HashedNode(common.Hash{}), // Zero-hash
117+
}
118+
119+
resolverCalls := 0
120+
resolver := func(path []byte, hash common.Hash) ([]byte, error) {
121+
resolverCalls++
122+
123+
if hash == (common.Hash{}) {
124+
t.Error("BUG: Resolver called for zero hash in InsertValuesAtStem")
125+
return nil, errors.New("zero hash should not be resolved")
126+
}
127+
128+
return nil, errors.New("not found")
129+
}
130+
131+
// Insert values into left subtree (zero-hash child)
132+
leftStem := make([]byte, 31)
133+
var values [256][]byte
134+
values[10] = common.HexToHash("0x5678").Bytes()
135+
136+
_, err := node.InsertValuesAtStem(leftStem, values[:], resolver, 0)
137+
if err != nil {
138+
t.Fatalf("InsertValuesAtStem failed: %v", err)
139+
}
140+
141+
// Verify resolver was not called
142+
if resolverCalls > 0 {
143+
t.Errorf("Resolver should not have been called for zero-hash child, but was called %d times", resolverCalls)
144+
}
145+
}

0 commit comments

Comments
 (0)