@@ -21,9 +21,11 @@ import (
21
21
"encoding/binary"
22
22
"errors"
23
23
"fmt"
24
+ "maps"
24
25
"math"
25
26
"math/rand"
26
27
"reflect"
28
+ "slices"
27
29
"strings"
28
30
"sync"
29
31
"testing"
@@ -557,10 +559,14 @@ func forEachStorage(s *StateDB, addr common.Address, cb func(key, value common.H
557
559
if err != nil {
558
560
return err
559
561
}
560
- it := trie .NewIterator (trieIt )
562
+ var (
563
+ it = trie .NewIterator (trieIt )
564
+ visited = make (map [common.Hash ]bool )
565
+ )
561
566
562
567
for it .Next () {
563
568
key := common .BytesToHash (s .trie .GetKey (it .Key ))
569
+ visited [key ] = true
564
570
if value , dirty := so .dirtyStorage [key ]; dirty {
565
571
if ! cb (key , value ) {
566
572
return nil
@@ -600,6 +606,10 @@ func (test *snapshotTest) checkEqual(state, checkstate *StateDB) error {
600
606
checkeq ("GetCode" , state .GetCode (addr ), checkstate .GetCode (addr ))
601
607
checkeq ("GetCodeHash" , state .GetCodeHash (addr ), checkstate .GetCodeHash (addr ))
602
608
checkeq ("GetCodeSize" , state .GetCodeSize (addr ), checkstate .GetCodeSize (addr ))
609
+ // Check newContract-flag
610
+ if obj := state .getStateObject (addr ); obj != nil {
611
+ checkeq ("IsNewContract" , obj .newContract , checkstate .getStateObject (addr ).newContract )
612
+ }
603
613
// Check storage.
604
614
if obj := state .getStateObject (addr ); obj != nil {
605
615
forEachStorage (state , addr , func (key , value common.Hash ) bool {
@@ -608,12 +618,49 @@ func (test *snapshotTest) checkEqual(state, checkstate *StateDB) error {
608
618
forEachStorage (checkstate , addr , func (key , value common.Hash ) bool {
609
619
return checkeq ("GetState(" + key .Hex ()+ ")" , checkstate .GetState (addr , key ), value )
610
620
})
621
+ other := checkstate .getStateObject (addr )
622
+ // Check dirty storage which is not in trie
623
+ if ! maps .Equal (obj .dirtyStorage , other .dirtyStorage ) {
624
+ print := func (dirty map [common.Hash ]common.Hash ) string {
625
+ var keys []common.Hash
626
+ out := new (strings.Builder )
627
+ for key := range dirty {
628
+ keys = append (keys , key )
629
+ }
630
+ slices .SortFunc (keys , common .Hash .Cmp )
631
+ for i , key := range keys {
632
+ fmt .Fprintf (out , " %d. %v %v\n " , i , key , dirty [key ])
633
+ }
634
+ return out .String ()
635
+ }
636
+ return fmt .Errorf ("dirty storage err, have\n %v\n want\n %v" ,
637
+ print (obj .dirtyStorage ),
638
+ print (other .dirtyStorage ))
639
+ }
640
+ }
641
+ // Check transient storage.
642
+ {
643
+ have := state .transientStorage
644
+ want := checkstate .transientStorage
645
+ eq := maps .EqualFunc (have , want ,
646
+ func (a Storage , b Storage ) bool {
647
+ return maps .Equal (a , b )
648
+ })
649
+ if ! eq {
650
+ return fmt .Errorf ("transient storage differs ,have\n %v\n want\n %v" ,
651
+ have .PrettyPrint (),
652
+ want .PrettyPrint ())
653
+ }
611
654
}
612
655
if err != nil {
613
656
return err
614
657
}
615
658
}
616
-
659
+ if ! checkstate .accessList .Equal (state .accessList ) { // Check access lists
660
+ return fmt .Errorf ("AccessLists are wrong, have \n %v\n want\n %v" ,
661
+ checkstate .accessList .PrettyPrint (),
662
+ state .accessList .PrettyPrint ())
663
+ }
617
664
if state .GetRefund () != checkstate .GetRefund () {
618
665
return fmt .Errorf ("got GetRefund() == %d, want GetRefund() == %d" ,
619
666
state .GetRefund (), checkstate .GetRefund ())
@@ -622,6 +669,23 @@ func (test *snapshotTest) checkEqual(state, checkstate *StateDB) error {
622
669
return fmt .Errorf ("got GetLogs(common.Hash{}) == %v, want GetLogs(common.Hash{}) == %v" ,
623
670
state .GetLogs (common.Hash {}, 0 , common.Hash {}), checkstate .GetLogs (common.Hash {}, 0 , common.Hash {}))
624
671
}
672
+ if ! maps .Equal (state .journal .dirties , checkstate .journal .dirties ) {
673
+ getKeys := func (dirty map [common.Address ]int ) string {
674
+ var keys []common.Address
675
+ out := new (strings.Builder )
676
+ for key := range dirty {
677
+ keys = append (keys , key )
678
+ }
679
+ slices .SortFunc (keys , common .Address .Cmp )
680
+ for i , key := range keys {
681
+ fmt .Fprintf (out , " %d. %v\n " , i , key )
682
+ }
683
+ return out .String ()
684
+ }
685
+ have := getKeys (state .journal .dirties )
686
+ want := getKeys (checkstate .journal .dirties )
687
+ return fmt .Errorf ("dirty-journal set mismatch.\n have:\n %v\n want:\n %v\n " , have , want )
688
+ }
625
689
return nil
626
690
}
627
691
0 commit comments