@@ -20,7 +20,7 @@ type Txn[T any] struct {
2020 // that we can keep mutating without cloning them again.
2121 // It is cleared if the transaction is cloned or iterated
2222 // upon.
23- mutated * nodeMutated
23+ mutated * nodeMutated [ T ]
2424
2525 // watches contains the channels of cloned nodes that should be closed
2626 // when transaction is committed.
@@ -188,14 +188,14 @@ func (txn *Txn[T]) PrintTree() {
188188}
189189
190190func (txn * Txn [T ]) cloneNode (n * header [T ]) * header [T ] {
191- if nodeMutatedExists ( txn .mutated , n ) {
191+ if txn .mutated . exists ( n ) {
192192 return n
193193 }
194194 if n .watch != nil {
195195 txn .watches [n .watch ] = struct {}{}
196196 }
197197 n = n .clone (! txn .opts .rootOnlyWatch () || n == txn .root )
198- nodeMutatedSet ( txn .mutated , n )
198+ txn .mutated . set ( n )
199199 return n
200200}
201201
@@ -232,7 +232,7 @@ func (txn *Txn[T]) modify(root *header[T], key []byte, mod func(T) T) (oldValue
232232 txn .watches [this .watch ] = struct {}{}
233233 }
234234 this = this .promote (! txn .opts .rootOnlyWatch () || this == root )
235- nodeMutatedSet ( txn .mutated , this )
235+ txn .mutated . set ( this )
236236 } else {
237237 // Node is big enough, clone it so we can mutate it
238238 this = txn .cloneNode (this )
@@ -490,41 +490,89 @@ func validateRemovedWatches[T any](oldRoot *header[T], newRoot *header[T]) {
490490 if ! runValidation {
491491 return
492492 }
493- var collectWatches func (depth int , watches map [<- chan struct {}]int , node * header [T ])
494- collectWatches = func (depth int , watches map [<- chan struct {}]int , node * header [T ]) {
493+
494+ // nodeStruct memoizes the structure for a node.
495+ nodeStruct := map [* header [T ]]string {}
496+
497+ // summarizeNodeStructure "lazily" summarizes the internal structure of a node,
498+ // e.g. it's watch channel, size, leaf and children. The lazy construction speeds
499+ // things up a lot as we only look at the structure in certain specific cases.
500+ var summarizeNodeStructure func (node * header [T ]) func () string
501+ summarizeNodeStructure = func (node * header [T ]) func () string {
502+ if node == nil {
503+ return func () string { return "" }
504+ }
505+ if s , found := nodeStruct [node ]; found {
506+ return func () string { return s }
507+ }
508+ return func () string {
509+ var childS string
510+ for _ , child := range node .children () {
511+ if child != nil {
512+ childS += summarizeNodeStructure (child )()
513+ }
514+ }
515+ var leafS string
516+ if leaf := node .getLeaf (); leaf != nil && ! node .isLeaf () {
517+ leafS = summarizeNodeStructure (leaf .self ())()
518+ }
519+ s := fmt .Sprintf ("K:%d S:%d W:%p L:[%s] C:[%s]" , node .kind (), node .size (), node .watch , leafS , childS )
520+ nodeStruct [node ] = s
521+ return s
522+ }
523+ }
524+
525+ var collectWatches func (depth int , watches map [<- chan struct {}]func () string , node * header [T ])
526+ collectWatches = func (depth int , watches map [<- chan struct {}]func () string , node * header [T ]) {
495527 if node == nil {
496528 return
497529 }
498530 if node .watch == nil {
499531 panic ("nil watch channel" )
500532 }
501- watches [node .watch ] = depth
533+ watches [node .watch ] = summarizeNodeStructure ( node )
502534 if leaf := node .getLeaf (); leaf != nil && ! node .isLeaf () {
503- watches [leaf .watch ] = depth
535+ watches [leaf .watch ] = summarizeNodeStructure ( leaf . self ())
504536 }
505537 for _ , child := range node .children () {
506538 if child != nil {
507539 collectWatches (depth + 1 , watches , child )
508540 }
509541 }
510542 }
511- oldWatches := map [<- chan struct {}]int {}
543+
544+ oldWatches := map [<- chan struct {}]func () string {}
512545 collectWatches (0 , oldWatches , oldRoot )
513- newWatches := map [<- chan struct {}]int {}
546+ newWatches := map [<- chan struct {}]func () string {}
514547 collectWatches (0 , newWatches , newRoot )
515548
549+ // Check that any nodes that kept the old watch channel have exactly
550+ // the same leaf and children structure.
551+ for watch , oldDescFn := range oldWatches {
552+ newDescFn , found := newWatches [watch ]
553+ if found {
554+ oldDesc := oldDescFn ()
555+ newDesc := newDescFn ()
556+ if oldDesc != newDesc {
557+ panic (fmt .Sprintf ("node with retained watch channel has different structure:\n expected: %s\n got: %s" , oldDesc , newDesc ))
558+ }
559+ }
560+
561+ }
562+
516563 // Any nodes that are not part of the new tree must have their watch channels closed.
517564 for watch := range newWatches {
518565 delete (oldWatches , watch )
519566 }
520- for watch , depth := range oldWatches {
567+
568+ for watch , desc := range oldWatches {
521569 select {
522570 case <- watch :
523571 default :
524572 oldRoot .printTree (0 )
525573 fmt .Println ("---" )
526574 newRoot .printTree (0 )
527- panic (fmt .Sprintf ("dropped watch channel %p at depth %d not closed" , watch , depth ))
575+ panic (fmt .Sprintf ("dropped watch channel %p not closed %s " , watch , desc () ))
528576 }
529577 }
530578}
0 commit comments