Skip to content

Commit f915f68

Browse files
authored
core/state/snapshot: fix data race in layer flattening (#23628)
* core/state/snapshot: fix data race in layer flattening * core/state/snapshot: fix typo
1 parent 08e782c commit f915f68

File tree

2 files changed

+75
-4
lines changed

2 files changed

+75
-4
lines changed

core/state/snapshot/snapshot.go

Lines changed: 13 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -163,6 +163,9 @@ type Tree struct {
163163
cache int // Megabytes permitted to use for read caches
164164
layers map[common.Hash]snapshot // Collection of all known layers
165165
lock sync.RWMutex
166+
167+
// Test hooks
168+
onFlatten func() // Hook invoked when the bottom most diff layers are flattened
166169
}
167170

168171
// New attempts to load an already existing snapshot from a persistent key-value
@@ -463,14 +466,21 @@ func (t *Tree) cap(diff *diffLayer, layers int) *diskLayer {
463466
return nil
464467

465468
case *diffLayer:
469+
// Hold the write lock until the flattened parent is linked correctly.
470+
// Otherwise, the stale layer may be accessed by external reads in the
471+
// meantime.
472+
diff.lock.Lock()
473+
defer diff.lock.Unlock()
474+
466475
// Flatten the parent into the grandparent. The flattening internally obtains a
467476
// write lock on grandparent.
468477
flattened := parent.flatten().(*diffLayer)
469478
t.layers[flattened.root] = flattened
470479

471-
diff.lock.Lock()
472-
defer diff.lock.Unlock()
473-
480+
// Invoke the hook if it's registered. Ugly hack.
481+
if t.onFlatten != nil {
482+
t.onFlatten()
483+
}
474484
diff.parent = flattened
475485
if flattened.memory < aggregatorMemoryLimit {
476486
// Accumulator layer is smaller than the limit, so we can abort, unless

core/state/snapshot/snapshot_test.go

Lines changed: 62 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -22,6 +22,7 @@ import (
2222
"math/big"
2323
"math/rand"
2424
"testing"
25+
"time"
2526

2627
"github.com/VictoriaMetrics/fastcache"
2728
"github.com/ethereum/go-ethereum/common"
@@ -324,7 +325,7 @@ func TestPostCapBasicDataAccess(t *testing.T) {
324325
}
325326
}
326327

327-
// TestSnaphots tests the functionality for retrieveing the snapshot
328+
// TestSnaphots tests the functionality for retrieving the snapshot
328329
// with given head root and the desired depth.
329330
func TestSnaphots(t *testing.T) {
330331
// setAccount is a helper to construct a random account entry and assign it to
@@ -423,3 +424,63 @@ func TestSnaphots(t *testing.T) {
423424
}
424425
}
425426
}
427+
428+
// TestReadStateDuringFlattening tests the scenario that, during the
429+
// bottom diff layers are merging which tags these as stale, the read
430+
// happens via a pre-created top snapshot layer which tries to access
431+
// the state in these stale layers. Ensure this read can retrieve the
432+
// right state back(block until the flattening is finished) instead of
433+
// an unexpected error(snapshot layer is stale).
434+
func TestReadStateDuringFlattening(t *testing.T) {
435+
// setAccount is a helper to construct a random account entry and assign it to
436+
// an account slot in a snapshot
437+
setAccount := func(accKey string) map[common.Hash][]byte {
438+
return map[common.Hash][]byte{
439+
common.HexToHash(accKey): randomAccount(),
440+
}
441+
}
442+
// Create a starting base layer and a snapshot tree out of it
443+
base := &diskLayer{
444+
diskdb: rawdb.NewMemoryDatabase(),
445+
root: common.HexToHash("0x01"),
446+
cache: fastcache.New(1024 * 500),
447+
}
448+
snaps := &Tree{
449+
layers: map[common.Hash]snapshot{
450+
base.root: base,
451+
},
452+
}
453+
// 4 layers in total, 3 diff layers and 1 disk layers
454+
snaps.Update(common.HexToHash("0xa1"), common.HexToHash("0x01"), nil, setAccount("0xa1"), nil)
455+
snaps.Update(common.HexToHash("0xa2"), common.HexToHash("0xa1"), nil, setAccount("0xa2"), nil)
456+
snaps.Update(common.HexToHash("0xa3"), common.HexToHash("0xa2"), nil, setAccount("0xa3"), nil)
457+
458+
// Obtain the topmost snapshot handler for state accessing
459+
snap := snaps.Snapshot(common.HexToHash("0xa3"))
460+
461+
// Register the testing hook to access the state after flattening
462+
var result = make(chan *Account)
463+
snaps.onFlatten = func() {
464+
// Spin up a thread to read the account from the pre-created
465+
// snapshot handler. It's expected to be blocked.
466+
go func() {
467+
account, _ := snap.Account(common.HexToHash("0xa1"))
468+
result <- account
469+
}()
470+
select {
471+
case res := <-result:
472+
t.Fatalf("Unexpected return %v", res)
473+
case <-time.NewTimer(time.Millisecond * 300).C:
474+
}
475+
}
476+
// Cap the snap tree, which will mark the bottom-most layer as stale.
477+
snaps.Cap(common.HexToHash("0xa3"), 1)
478+
select {
479+
case account := <-result:
480+
if account == nil {
481+
t.Fatal("Failed to retrieve account")
482+
}
483+
case <-time.NewTimer(time.Millisecond * 300).C:
484+
t.Fatal("Unexpected blocker")
485+
}
486+
}

0 commit comments

Comments
 (0)