Skip to content

Commit 6136227

Browse files
holimanjorgemmsilva
authored andcommitted
core/state: better randomized testing (postcheck) on journalling (ethereum#29627)
This PR fixes some flaws with the existing tests. The randomized testing (TestSnapshotRandom) executes a series of steps which modify the state and create journal-events. Later on, we compare the forward-going-states against the backwards-unrolling-journal-states, and check that they are identical. The "identical" check is performed using various accessors. It turned out that we failed to check some things: - the accesslist contents - the transient storage contents - the 'newContract' flag - the dirty storage map This change adds these new checks
1 parent f1a2580 commit 6136227

File tree

4 files changed

+153
-17
lines changed

4 files changed

+153
-17
lines changed

core/state/access_list.go

Lines changed: 35 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -17,7 +17,10 @@
1717
package state
1818

1919
import (
20+
"fmt"
2021
"maps"
22+
"slices"
23+
"strings"
2124

2225
"github.com/ethereum/go-ethereum/common"
2326
)
@@ -130,3 +133,35 @@ func (al *accessList) DeleteSlot(address common.Address, slot common.Hash) {
130133
func (al *accessList) DeleteAddress(address common.Address) {
131134
delete(al.addresses, address)
132135
}
136+
137+
// Equal returns true if the two access lists are identical
138+
func (al *accessList) Equal(other *accessList) bool {
139+
if !maps.Equal(al.addresses, other.addresses) {
140+
return false
141+
}
142+
return slices.EqualFunc(al.slots, other.slots,
143+
func(m map[common.Hash]struct{}, m2 map[common.Hash]struct{}) bool {
144+
return maps.Equal(m, m2)
145+
})
146+
}
147+
148+
// PrettyPrint prints the contents of the access list in a human-readable form
149+
func (al *accessList) PrettyPrint() string {
150+
out := new(strings.Builder)
151+
var sortedAddrs []common.Address
152+
for addr := range al.addresses {
153+
sortedAddrs = append(sortedAddrs, addr)
154+
}
155+
slices.SortFunc(sortedAddrs, common.Address.Cmp)
156+
for _, addr := range sortedAddrs {
157+
idx := al.addresses[addr]
158+
fmt.Fprintf(out, "%#x : (idx %d)\n", addr, idx)
159+
if idx >= 0 {
160+
slotmap := al.slots[idx]
161+
for h := range slotmap {
162+
fmt.Fprintf(out, " %#x\n", h)
163+
}
164+
}
165+
}
166+
return out.String()
167+
}

core/state/state_object.go

Lines changed: 12 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -459,22 +459,22 @@ func (s *stateObject) setBalance(amount *uint256.Int) {
459459

460460
func (s *stateObject) deepCopy(db *StateDB) *stateObject {
461461
obj := &stateObject{
462-
db: db,
463-
address: s.address,
464-
addrHash: s.addrHash,
465-
origin: s.origin,
466-
data: s.data,
462+
db: db,
463+
address: s.address,
464+
addrHash: s.addrHash,
465+
origin: s.origin,
466+
data: s.data,
467+
code: s.code,
468+
originStorage: s.originStorage.Copy(),
469+
pendingStorage: s.pendingStorage.Copy(),
470+
dirtyStorage: s.dirtyStorage.Copy(),
471+
dirtyCode: s.dirtyCode,
472+
selfDestructed: s.selfDestructed,
473+
newContract: s.newContract,
467474
}
468475
if s.trie != nil {
469476
obj.trie = db.db.CopyTrie(s.trie)
470477
}
471-
obj.code = s.code
472-
obj.originStorage = s.originStorage.Copy()
473-
obj.pendingStorage = s.pendingStorage.Copy()
474-
obj.dirtyStorage = s.dirtyStorage.Copy()
475-
obj.dirtyCode = s.dirtyCode
476-
obj.selfDestructed = s.selfDestructed
477-
obj.newContract = s.newContract
478478
return obj
479479
}
480480

core/state/statedb_test.go

Lines changed: 66 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -21,9 +21,11 @@ import (
2121
"encoding/binary"
2222
"errors"
2323
"fmt"
24+
"maps"
2425
"math"
2526
"math/rand"
2627
"reflect"
28+
"slices"
2729
"strings"
2830
"sync"
2931
"testing"
@@ -557,10 +559,14 @@ func forEachStorage(s *StateDB, addr common.Address, cb func(key, value common.H
557559
if err != nil {
558560
return err
559561
}
560-
it := trie.NewIterator(trieIt)
562+
var (
563+
it = trie.NewIterator(trieIt)
564+
visited = make(map[common.Hash]bool)
565+
)
561566

562567
for it.Next() {
563568
key := common.BytesToHash(s.trie.GetKey(it.Key))
569+
visited[key] = true
564570
if value, dirty := so.dirtyStorage[key]; dirty {
565571
if !cb(key, value) {
566572
return nil
@@ -600,6 +606,10 @@ func (test *snapshotTest) checkEqual(state, checkstate *StateDB) error {
600606
checkeq("GetCode", state.GetCode(addr), checkstate.GetCode(addr))
601607
checkeq("GetCodeHash", state.GetCodeHash(addr), checkstate.GetCodeHash(addr))
602608
checkeq("GetCodeSize", state.GetCodeSize(addr), checkstate.GetCodeSize(addr))
609+
// Check newContract-flag
610+
if obj := state.getStateObject(addr); obj != nil {
611+
checkeq("IsNewContract", obj.newContract, checkstate.getStateObject(addr).newContract)
612+
}
603613
// Check storage.
604614
if obj := state.getStateObject(addr); obj != nil {
605615
forEachStorage(state, addr, func(key, value common.Hash) bool {
@@ -608,12 +618,49 @@ func (test *snapshotTest) checkEqual(state, checkstate *StateDB) error {
608618
forEachStorage(checkstate, addr, func(key, value common.Hash) bool {
609619
return checkeq("GetState("+key.Hex()+")", checkstate.GetState(addr, key), value)
610620
})
621+
other := checkstate.getStateObject(addr)
622+
// Check dirty storage which is not in trie
623+
if !maps.Equal(obj.dirtyStorage, other.dirtyStorage) {
624+
print := func(dirty map[common.Hash]common.Hash) string {
625+
var keys []common.Hash
626+
out := new(strings.Builder)
627+
for key := range dirty {
628+
keys = append(keys, key)
629+
}
630+
slices.SortFunc(keys, common.Hash.Cmp)
631+
for i, key := range keys {
632+
fmt.Fprintf(out, " %d. %v %v\n", i, key, dirty[key])
633+
}
634+
return out.String()
635+
}
636+
return fmt.Errorf("dirty storage err, have\n%v\nwant\n%v",
637+
print(obj.dirtyStorage),
638+
print(other.dirtyStorage))
639+
}
640+
}
641+
// Check transient storage.
642+
{
643+
have := state.transientStorage
644+
want := checkstate.transientStorage
645+
eq := maps.EqualFunc(have, want,
646+
func(a Storage, b Storage) bool {
647+
return maps.Equal(a, b)
648+
})
649+
if !eq {
650+
return fmt.Errorf("transient storage differs ,have\n%v\nwant\n%v",
651+
have.PrettyPrint(),
652+
want.PrettyPrint())
653+
}
611654
}
612655
if err != nil {
613656
return err
614657
}
615658
}
616-
659+
if !checkstate.accessList.Equal(state.accessList) { // Check access lists
660+
return fmt.Errorf("AccessLists are wrong, have \n%v\nwant\n%v",
661+
checkstate.accessList.PrettyPrint(),
662+
state.accessList.PrettyPrint())
663+
}
617664
if state.GetRefund() != checkstate.GetRefund() {
618665
return fmt.Errorf("got GetRefund() == %d, want GetRefund() == %d",
619666
state.GetRefund(), checkstate.GetRefund())
@@ -622,6 +669,23 @@ func (test *snapshotTest) checkEqual(state, checkstate *StateDB) error {
622669
return fmt.Errorf("got GetLogs(common.Hash{}) == %v, want GetLogs(common.Hash{}) == %v",
623670
state.GetLogs(common.Hash{}, 0, common.Hash{}), checkstate.GetLogs(common.Hash{}, 0, common.Hash{}))
624671
}
672+
if !maps.Equal(state.journal.dirties, checkstate.journal.dirties) {
673+
getKeys := func(dirty map[common.Address]int) string {
674+
var keys []common.Address
675+
out := new(strings.Builder)
676+
for key := range dirty {
677+
keys = append(keys, key)
678+
}
679+
slices.SortFunc(keys, common.Address.Cmp)
680+
for i, key := range keys {
681+
fmt.Fprintf(out, " %d. %v\n", i, key)
682+
}
683+
return out.String()
684+
}
685+
have := getKeys(state.journal.dirties)
686+
want := getKeys(checkstate.journal.dirties)
687+
return fmt.Errorf("dirty-journal set mismatch.\nhave:\n%v\nwant:\n%v\n", have, want)
688+
}
625689
return nil
626690
}
627691

core/state/transient_storage.go

Lines changed: 40 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -17,6 +17,10 @@
1717
package state
1818

1919
import (
20+
"fmt"
21+
"slices"
22+
"strings"
23+
2024
"github.com/ethereum/go-ethereum/common"
2125
)
2226

@@ -30,10 +34,19 @@ func newTransientStorage() transientStorage {
3034

3135
// Set sets the transient-storage `value` for `key` at the given `addr`.
3236
func (t transientStorage) Set(addr common.Address, key, value common.Hash) {
33-
if _, ok := t[addr]; !ok {
34-
t[addr] = make(Storage)
37+
if value == (common.Hash{}) { // this is a 'delete'
38+
if _, ok := t[addr]; ok {
39+
delete(t[addr], key)
40+
if len(t[addr]) == 0 {
41+
delete(t, addr)
42+
}
43+
}
44+
} else {
45+
if _, ok := t[addr]; !ok {
46+
t[addr] = make(Storage)
47+
}
48+
t[addr][key] = value
3549
}
36-
t[addr][key] = value
3750
}
3851

3952
// Get gets the transient storage for `key` at the given `addr`.
@@ -53,3 +66,27 @@ func (t transientStorage) Copy() transientStorage {
5366
}
5467
return storage
5568
}
69+
70+
// PrettyPrint prints the contents of the access list in a human-readable form
71+
func (t transientStorage) PrettyPrint() string {
72+
out := new(strings.Builder)
73+
var sortedAddrs []common.Address
74+
for addr := range t {
75+
sortedAddrs = append(sortedAddrs, addr)
76+
slices.SortFunc(sortedAddrs, common.Address.Cmp)
77+
}
78+
79+
for _, addr := range sortedAddrs {
80+
fmt.Fprintf(out, "%#x:", addr)
81+
var sortedKeys []common.Hash
82+
storage := t[addr]
83+
for key := range storage {
84+
sortedKeys = append(sortedKeys, key)
85+
}
86+
slices.SortFunc(sortedKeys, common.Hash.Cmp)
87+
for _, key := range sortedKeys {
88+
fmt.Fprintf(out, " %X : %X\n", key, storage[key])
89+
}
90+
}
91+
return out.String()
92+
}

0 commit comments

Comments
 (0)