Skip to content

Commit 4accd18

Browse files
committed
part: Store header instead of uintptr in node mutated
There's no need to store an uintptr. In StateDB v0.5.x this actually caused issues due to the compression change in which watch channels from old tree were retained. The extended validation shows no such for v0.4 since old watches are never retained when a node is cloned, but let's be on safe side and change this here as well. Signed-off-by: Jussi Maki <jussi.maki@isovalent.com>
1 parent 4181bc4 commit 4accd18

File tree

3 files changed

+90
-22
lines changed

3 files changed

+90
-22
lines changed

part/cache.go

Lines changed: 29 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -9,36 +9,56 @@ import (
99

1010
const nodeMutatedSize = 256 // must be power-of-two
1111

12-
type nodeMutated struct {
13-
ptrs [nodeMutatedSize]uintptr
12+
// nodeMutated is a probabilistic check for seeing if a node has
13+
// been cloned within a transaction and thus can be modified in-place
14+
// since it has not been seen outside. This significantly speeds up
15+
// writes within a single write transaction as inner nodes no longer
16+
// need to be cloned on every change, effectively making the immutable
17+
// radix tree perform as if it's a mutable one.
18+
//
19+
// Earlier versions of StateDB just used a map[*header[T]]struct{}, but
20+
// that was fairly costly and experiments showed that it's enough to most
21+
// of the time avoid the clone to perform well.
22+
//
23+
// The value for [nodeMutatedSize] is a balance between making Txn()
24+
// not too costly (due to e.g. clear()) and between giving a high likelyhood
25+
// that we mutate nodes in-place.
26+
type nodeMutated[T any] struct {
27+
ptrs [nodeMutatedSize]*header[T]
1428
used bool
1529
}
1630

17-
func nodeMutatedSet[T any](nm *nodeMutated, ptr *header[T]) {
31+
func (nm *nodeMutated[T]) set(n *header[T]) {
1832
if nm == nil {
1933
return
2034
}
21-
ptrInt := uintptr(unsafe.Pointer(ptr))
22-
nm.ptrs[slot(ptrInt)] = ptrInt
35+
ptrInt := uintptr(unsafe.Pointer(n))
36+
nm.ptrs[slot(ptrInt)] = n
2337
nm.used = true
2438
}
2539

26-
func nodeMutatedExists[T any](nm *nodeMutated, ptr *header[T]) bool {
40+
func (nm *nodeMutated[T]) exists(n *header[T]) bool {
2741
if nm == nil {
2842
return false
2943
}
30-
ptrInt := uintptr(unsafe.Pointer(ptr))
31-
return nm.ptrs[slot(ptrInt)] == ptrInt
44+
ptrInt := uintptr(unsafe.Pointer(n))
45+
return nm.ptrs[slot(ptrInt)] == n
3246
}
3347

48+
// slot returns the index in the [ptrs] array for a given pointer.
49+
// The Go spec allows objects to be moved so it may be that the same
50+
// instance of an object is assigned to a different memory location in
51+
// which case we'd no longer report that node as being in the cache.
52+
// This is fine though as we do compare the actual *header[T] pointers
53+
// and this is probabilistic anyway as this is a fixed size cache.
3454
func slot(p uintptr) int {
3555
p >>= 4 // ignore low order bits
3656
// use some relevant bits from the pointer
3757
slot := uint8(p) ^ uint8(p>>8) ^ uint8(p>>16)
3858
return int(slot & (nodeMutatedSize - 1))
3959
}
4060

41-
func (nm *nodeMutated) clear() {
61+
func (nm *nodeMutated[T]) clear() {
4262
if nm == nil {
4363
return
4464
}

part/tree.go

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -44,7 +44,7 @@ func newTxn[T any](o options) *Txn[T] {
4444
watches: make(map[chan struct{}]struct{}),
4545
}
4646
if !o.noCache() {
47-
txn.mutated = &nodeMutated{}
47+
txn.mutated = &nodeMutated[T]{}
4848
txn.deleteParentsCache = make([]deleteParent[T], 0, 32)
4949
}
5050
txn.opts = o

part/txn.go

Lines changed: 60 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -20,7 +20,7 @@ type Txn[T any] struct {
2020
// that we can keep mutating without cloning them again.
2121
// It is cleared if the transaction is cloned or iterated
2222
// upon.
23-
mutated *nodeMutated
23+
mutated *nodeMutated[T]
2424

2525
// watches contains the channels of cloned nodes that should be closed
2626
// when transaction is committed.
@@ -188,14 +188,14 @@ func (txn *Txn[T]) PrintTree() {
188188
}
189189

190190
func (txn *Txn[T]) cloneNode(n *header[T]) *header[T] {
191-
if nodeMutatedExists(txn.mutated, n) {
191+
if txn.mutated.exists(n) {
192192
return n
193193
}
194194
if n.watch != nil {
195195
txn.watches[n.watch] = struct{}{}
196196
}
197197
n = n.clone(!txn.opts.rootOnlyWatch() || n == txn.root)
198-
nodeMutatedSet(txn.mutated, n)
198+
txn.mutated.set(n)
199199
return n
200200
}
201201

@@ -232,7 +232,7 @@ func (txn *Txn[T]) modify(root *header[T], key []byte, mod func(T) T) (oldValue
232232
txn.watches[this.watch] = struct{}{}
233233
}
234234
this = this.promote(!txn.opts.rootOnlyWatch() || this == root)
235-
nodeMutatedSet(txn.mutated, this)
235+
txn.mutated.set(this)
236236
} else {
237237
// Node is big enough, clone it so we can mutate it
238238
this = txn.cloneNode(this)
@@ -490,41 +490,89 @@ func validateRemovedWatches[T any](oldRoot *header[T], newRoot *header[T]) {
490490
if !runValidation {
491491
return
492492
}
493-
var collectWatches func(depth int, watches map[<-chan struct{}]int, node *header[T])
494-
collectWatches = func(depth int, watches map[<-chan struct{}]int, node *header[T]) {
493+
494+
// nodeStruct memoizes the structure for a node.
495+
nodeStruct := map[*header[T]]string{}
496+
497+
// summarizeNodeStructure "lazily" summarizes the internal structure of a node,
498+
// e.g. it's watch channel, size, leaf and children. The lazy construction speeds
499+
// things up a lot as we only look at the structure in certain specific cases.
500+
var summarizeNodeStructure func(node *header[T]) func() string
501+
summarizeNodeStructure = func(node *header[T]) func() string {
502+
if node == nil {
503+
return func() string { return "" }
504+
}
505+
if s, found := nodeStruct[node]; found {
506+
return func() string { return s }
507+
}
508+
return func() string {
509+
var childS string
510+
for _, child := range node.children() {
511+
if child != nil {
512+
childS += summarizeNodeStructure(child)()
513+
}
514+
}
515+
var leafS string
516+
if leaf := node.getLeaf(); leaf != nil && !node.isLeaf() {
517+
leafS = summarizeNodeStructure(leaf.self())()
518+
}
519+
s := fmt.Sprintf("K:%d S:%d W:%p L:[%s] C:[%s]", node.kind(), node.size(), node.watch, leafS, childS)
520+
nodeStruct[node] = s
521+
return s
522+
}
523+
}
524+
525+
var collectWatches func(depth int, watches map[<-chan struct{}]func() string, node *header[T])
526+
collectWatches = func(depth int, watches map[<-chan struct{}]func() string, node *header[T]) {
495527
if node == nil {
496528
return
497529
}
498530
if node.watch == nil {
499531
panic("nil watch channel")
500532
}
501-
watches[node.watch] = depth
533+
watches[node.watch] = summarizeNodeStructure(node)
502534
if leaf := node.getLeaf(); leaf != nil && !node.isLeaf() {
503-
watches[leaf.watch] = depth
535+
watches[leaf.watch] = summarizeNodeStructure(leaf.self())
504536
}
505537
for _, child := range node.children() {
506538
if child != nil {
507539
collectWatches(depth+1, watches, child)
508540
}
509541
}
510542
}
511-
oldWatches := map[<-chan struct{}]int{}
543+
544+
oldWatches := map[<-chan struct{}]func() string{}
512545
collectWatches(0, oldWatches, oldRoot)
513-
newWatches := map[<-chan struct{}]int{}
546+
newWatches := map[<-chan struct{}]func() string{}
514547
collectWatches(0, newWatches, newRoot)
515548

549+
// Check that any nodes that kept the old watch channel have exactly
550+
// the same leaf and children structure.
551+
for watch, oldDescFn := range oldWatches {
552+
newDescFn, found := newWatches[watch]
553+
if found {
554+
oldDesc := oldDescFn()
555+
newDesc := newDescFn()
556+
if oldDesc != newDesc {
557+
panic(fmt.Sprintf("node with retained watch channel has different structure:\nexpected: %s\n got: %s", oldDesc, newDesc))
558+
}
559+
}
560+
561+
}
562+
516563
// Any nodes that are not part of the new tree must have their watch channels closed.
517564
for watch := range newWatches {
518565
delete(oldWatches, watch)
519566
}
520-
for watch, depth := range oldWatches {
567+
568+
for watch, desc := range oldWatches {
521569
select {
522570
case <-watch:
523571
default:
524572
oldRoot.printTree(0)
525573
fmt.Println("---")
526574
newRoot.printTree(0)
527-
panic(fmt.Sprintf("dropped watch channel %p at depth %d not closed", watch, depth))
575+
panic(fmt.Sprintf("dropped watch channel %p not closed %s", watch, desc()))
528576
}
529577
}
530578
}

0 commit comments

Comments
 (0)