Skip to content

Commit 68c0ec0

Browse files
royscrjl493456442
andauthored
trie: iterate values pre-order and fix seek behavior (#27838)
This pull request fixes the pre-order trie traversal by defining a more accurate iterator order and path comparison rule. Co-authored-by: Gary Rong <[email protected]>
1 parent adbbd8c commit 68c0ec0

File tree

2 files changed

+76
-23
lines changed

2 files changed

+76
-23
lines changed

trie/iterator.go

Lines changed: 64 additions & 17 deletions
Original file line numberDiff line numberDiff line change
@@ -135,7 +135,7 @@ type nodeIteratorState struct {
135135
node node // Trie node being iterated
136136
parent common.Hash // Hash of the first full ancestor node (nil if current is the root)
137137
index int // Child to be processed next
138-
pathlen int // Length of the path to this node
138+
pathlen int // Length of the path to the parent node
139139
}
140140

141141
type nodeIterator struct {
@@ -145,7 +145,7 @@ type nodeIterator struct {
145145
err error // Failure set in case of an internal error in the iterator
146146

147147
resolver NodeResolver // optional node resolver for avoiding disk hits
148-
pool []*nodeIteratorState // local pool for iteratorstates
148+
pool []*nodeIteratorState // local pool for iterator states
149149
}
150150

151151
// errIteratorEnd is stored in nodeIterator.err when iteration is done.
@@ -304,14 +304,15 @@ func (it *nodeIterator) seek(prefix []byte) error {
304304
// The path we're looking for is the hex encoded key without terminator.
305305
key := keybytesToHex(prefix)
306306
key = key[:len(key)-1]
307+
307308
// Move forward until we're just before the closest match to key.
308309
for {
309310
state, parentIndex, path, err := it.peekSeek(key)
310311
if err == errIteratorEnd {
311312
return errIteratorEnd
312313
} else if err != nil {
313314
return seekError{prefix, err}
314-
} else if bytes.Compare(path, key) >= 0 {
315+
} else if reachedPath(path, key) {
315316
return nil
316317
}
317318
it.push(state, parentIndex, path)
@@ -339,7 +340,6 @@ func (it *nodeIterator) peek(descend bool) (*nodeIteratorState, *int, []byte, er
339340
// If we're skipping children, pop the current node first
340341
it.pop()
341342
}
342-
343343
// Continue iteration to the next child
344344
for len(it.stack) > 0 {
345345
parent := it.stack[len(it.stack)-1]
@@ -372,7 +372,6 @@ func (it *nodeIterator) peekSeek(seekKey []byte) (*nodeIteratorState, *int, []by
372372
// If we're skipping children, pop the current node first
373373
it.pop()
374374
}
375-
376375
// Continue iteration to the next child
377376
for len(it.stack) > 0 {
378377
parent := it.stack[len(it.stack)-1]
@@ -449,16 +448,18 @@ func (it *nodeIterator) findChild(n *fullNode, index int, ancestor common.Hash)
449448
state *nodeIteratorState
450449
childPath []byte
451450
)
452-
for ; index < len(n.Children); index++ {
451+
for ; index < len(n.Children); index = nextChildIndex(index) {
453452
if n.Children[index] != nil {
454453
child = n.Children[index]
455454
hash, _ := child.cache()
455+
456456
state = it.getFromPool()
457457
state.hash = common.BytesToHash(hash)
458458
state.node = child
459459
state.parent = ancestor
460460
state.index = -1
461461
state.pathlen = len(path)
462+
462463
childPath = append(childPath, path...)
463464
childPath = append(childPath, byte(index))
464465
return child, state, childPath, index
@@ -471,8 +472,8 @@ func (it *nodeIterator) nextChild(parent *nodeIteratorState, ancestor common.Has
471472
switch node := parent.node.(type) {
472473
case *fullNode:
473474
// Full node, move to the first non-nil child.
474-
if child, state, path, index := it.findChild(node, parent.index+1, ancestor); child != nil {
475-
parent.index = index - 1
475+
if child, state, path, index := it.findChild(node, nextChildIndex(parent.index), ancestor); child != nil {
476+
parent.index = prevChildIndex(index)
476477
return state, path, true
477478
}
478479
case *shortNode:
@@ -498,23 +499,23 @@ func (it *nodeIterator) nextChildAt(parent *nodeIteratorState, ancestor common.H
498499
switch n := parent.node.(type) {
499500
case *fullNode:
500501
// Full node, move to the first non-nil child before the desired key position
501-
child, state, path, index := it.findChild(n, parent.index+1, ancestor)
502+
child, state, path, index := it.findChild(n, nextChildIndex(parent.index), ancestor)
502503
if child == nil {
503504
// No more children in this fullnode
504505
return parent, it.path, false
505506
}
506507
// If the child we found is already past the seek position, just return it.
507-
if bytes.Compare(path, key) >= 0 {
508-
parent.index = index - 1
508+
if reachedPath(path, key) {
509+
parent.index = prevChildIndex(index)
509510
return state, path, true
510511
}
511512
// The child is before the seek position. Try advancing
512513
for {
513-
nextChild, nextState, nextPath, nextIndex := it.findChild(n, index+1, ancestor)
514+
nextChild, nextState, nextPath, nextIndex := it.findChild(n, nextChildIndex(index), ancestor)
514515
// If we run out of children, or skipped past the target, return the
515516
// previous one
516-
if nextChild == nil || bytes.Compare(nextPath, key) >= 0 {
517-
parent.index = index - 1
517+
if nextChild == nil || reachedPath(nextPath, key) {
518+
parent.index = prevChildIndex(index)
518519
return state, path, true
519520
}
520521
// We found a better child closer to the target
@@ -541,7 +542,7 @@ func (it *nodeIterator) push(state *nodeIteratorState, parentIndex *int, path []
541542
it.path = path
542543
it.stack = append(it.stack, state)
543544
if parentIndex != nil {
544-
*parentIndex++
545+
*parentIndex = nextChildIndex(*parentIndex)
545546
}
546547
}
547548

@@ -550,8 +551,54 @@ func (it *nodeIterator) pop() {
550551
it.path = it.path[:last.pathlen]
551552
it.stack[len(it.stack)-1] = nil
552553
it.stack = it.stack[:len(it.stack)-1]
553-
// last is now unused
554-
it.putInPool(last)
554+
555+
it.putInPool(last) // last is now unused
556+
}
557+
558+
// reachedPath normalizes a path by truncating a terminator if present, and
559+
// returns true if it is greater than or equal to the target. Using this,
560+
// the path of a value node embedded a full node will compare less than the
561+
// full node's children.
562+
func reachedPath(path, target []byte) bool {
563+
if hasTerm(path) {
564+
path = path[:len(path)-1]
565+
}
566+
return bytes.Compare(path, target) >= 0
567+
}
568+
569+
// A value embedded in a full node occupies the last slot (16) of the array of
570+
// children. In order to produce a pre-order traversal when iterating children,
571+
// we jump to this last slot first, then go back iterate the child nodes (and
572+
// skip the last slot at the end):
573+
574+
// prevChildIndex returns the index of a child in a full node which precedes
575+
// the given index when performing a pre-order traversal.
576+
func prevChildIndex(index int) int {
577+
switch index {
578+
case 0: // We jumped back to iterate the children, from the value slot
579+
return 16
580+
case 16: // We jumped to the embedded value slot at the end, from the placeholder index
581+
return -1
582+
case 17: // We skipped the value slot after iterating all the children
583+
return 15
584+
default: // We are iterating the children in sequence
585+
return index - 1
586+
}
587+
}
588+
589+
// nextChildIndex returns the index of a child in a full node which follows
590+
// the given index when performing a pre-order traversal.
591+
func nextChildIndex(index int) int {
592+
switch index {
593+
case -1: // Jump from the placeholder index to the embedded value slot
594+
return 16
595+
case 15: // Skip the value slot after iterating the children
596+
return 17
597+
case 16: // From the embedded value slot, jump back to iterate the children
598+
return 0
599+
default: // Iterate children in sequence
600+
return index + 1
601+
}
555602
}
556603

557604
func compareNodes(a, b NodeIterator) int {

trie/iterator_test.go

Lines changed: 12 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -182,14 +182,14 @@ func testNodeIteratorCoverage(t *testing.T, scheme string) {
182182
type kvs struct{ k, v string }
183183

184184
var testdata1 = []kvs{
185+
{"bar", "b"},
185186
{"barb", "ba"},
186187
{"bard", "bc"},
187188
{"bars", "bb"},
188-
{"bar", "b"},
189189
{"fab", "z"},
190+
{"foo", "a"},
190191
{"food", "ab"},
191192
{"foos", "aa"},
192-
{"foo", "a"},
193193
}
194194

195195
var testdata2 = []kvs{
@@ -218,7 +218,7 @@ func TestIteratorSeek(t *testing.T) {
218218

219219
// Seek to a non-existent key.
220220
it = NewIterator(trie.MustNodeIterator([]byte("barc")))
221-
if err := checkIteratorOrder(testdata1[1:], it); err != nil {
221+
if err := checkIteratorOrder(testdata1[2:], it); err != nil {
222222
t.Fatal(err)
223223
}
224224

@@ -227,6 +227,12 @@ func TestIteratorSeek(t *testing.T) {
227227
if err := checkIteratorOrder(nil, it); err != nil {
228228
t.Fatal(err)
229229
}
230+
231+
// Seek to a key for which a prefixing key exists.
232+
it = NewIterator(trie.MustNodeIterator([]byte("food")))
233+
if err := checkIteratorOrder(testdata1[6:], it); err != nil {
234+
t.Fatal(err)
235+
}
230236
}
231237

232238
func checkIteratorOrder(want []kvs, it *Iterator) error {
@@ -311,16 +317,16 @@ func TestUnionIterator(t *testing.T) {
311317

312318
all := []struct{ k, v string }{
313319
{"aardvark", "c"},
320+
{"bar", "b"},
314321
{"barb", "ba"},
315322
{"barb", "bd"},
316323
{"bard", "bc"},
317324
{"bars", "bb"},
318325
{"bars", "be"},
319-
{"bar", "b"},
320326
{"fab", "z"},
327+
{"foo", "a"},
321328
{"food", "ab"},
322329
{"foos", "aa"},
323-
{"foo", "a"},
324330
{"jars", "d"},
325331
}
326332

@@ -512,7 +518,7 @@ func testIteratorContinueAfterSeekError(t *testing.T, memonly bool, scheme strin
512518
rawdb.WriteTrieNode(diskdb, common.Hash{}, barNodePath, barNodeHash, barNodeBlob, triedb.Scheme())
513519
}
514520
// Check that iteration produces the right set of values.
515-
if err := checkIteratorOrder(testdata1[2:], NewIterator(it)); err != nil {
521+
if err := checkIteratorOrder(testdata1[3:], NewIterator(it)); err != nil {
516522
t.Fatal(err)
517523
}
518524
}

0 commit comments

Comments
 (0)