diff --git a/core/state/journal.go b/core/state/journal.go index 137ec76395e..113a66c2a96 100644 --- a/core/state/journal.go +++ b/core/state/journal.go @@ -30,6 +30,9 @@ type journalEntry interface { // dirtied returns the Ethereum address modified by this journal entry. dirtied() *common.Address + + // copy returns a deep-copied journal entry. + copy() journalEntry } // journal contains the list of state modifications applied since the last state @@ -84,6 +87,22 @@ func (j *journal) length() int { return len(j.entries) } +// copy returns a deep-copied journal. +func (j *journal) copy() *journal { + entries := make([]journalEntry, 0, j.length()) + for i := 0; i < j.length(); i++ { + entries = append(entries, j.entries[i].copy()) + } + dirties := make(map[common.Address]int) + for addr, count := range j.dirties { + dirties[addr] = count + } + return &journal{ + entries: entries, + dirties: dirties, + } +} + type ( // Changes to the account trie. createObjectChange struct { @@ -137,6 +156,7 @@ type ( touchChange struct { account *common.Address } + // Changes to the access list accessListAddAccountChange struct { address *common.Address @@ -146,6 +166,7 @@ type ( slot *common.Hash } + // Changes to transient storage transientStorageChange struct { account *common.Address key, prevalue common.Hash @@ -154,13 +175,18 @@ type ( func (ch createObjectChange) revert(s *StateDB) { delete(s.stateObjects, *ch.account) - delete(s.stateObjectsDirty, *ch.account) } func (ch createObjectChange) dirtied() *common.Address { return ch.account } +func (ch createObjectChange) copy() journalEntry { + return createObjectChange{ + account: ch.account, + } +} + func (ch resetObjectChange) revert(s *StateDB) { s.setStateObject(ch.prev) if !ch.prevdestruct { @@ -184,6 +210,27 @@ func (ch resetObjectChange) dirtied() *common.Address { return ch.account } +func (ch resetObjectChange) copy() journalEntry { + prevStorage := make(map[common.Hash][]byte) + for key, slot := range ch.prevStorage { + prevStorage[key] = common.CopyBytes(slot) + } + prevStorageOrigin := make(map[common.Hash][]byte) + for key, slot := range ch.prevStorageOrigin { + prevStorageOrigin[key] = common.CopyBytes(slot) + } + return resetObjectChange{ + account: ch.account, + prev: ch.prev.deepCopy(), + prevdestruct: ch.prevdestruct, + prevAccount: common.CopyBytes(ch.prevAccount), + prevStorage: prevStorage, + prevAccountOriginExist: ch.prevAccountOriginExist, + prevAccountOrigin: common.CopyBytes(ch.prevAccountOrigin), + prevStorageOrigin: prevStorageOrigin, + } +} + func (ch selfDestructChange) revert(s *StateDB) { obj := s.getStateObject(*ch.account) if obj != nil { @@ -196,6 +243,14 @@ func (ch selfDestructChange) dirtied() *common.Address { return ch.account } +func (ch selfDestructChange) copy() journalEntry { + return selfDestructChange{ + account: ch.account, + prev: ch.prev, + prevbalance: new(big.Int).Set(ch.prevbalance), + } +} + var ripemd = common.HexToAddress("0000000000000000000000000000000000000003") func (ch touchChange) revert(s *StateDB) { @@ -205,6 +260,12 @@ func (ch touchChange) dirtied() *common.Address { return ch.account } +func (ch touchChange) copy() journalEntry { + return touchChange{ + account: ch.account, + } +} + func (ch balanceChange) revert(s *StateDB) { s.getStateObject(*ch.account).setBalance(ch.prev) } @@ -213,6 +274,13 @@ func (ch balanceChange) dirtied() *common.Address { return ch.account } +func (ch balanceChange) copy() journalEntry { + return balanceChange{ + account: ch.account, + prev: new(big.Int).Set(ch.prev), + } +} + func (ch nonceChange) revert(s *StateDB) { s.getStateObject(*ch.account).setNonce(ch.prev) } @@ -221,6 +289,13 @@ func (ch nonceChange) dirtied() *common.Address { return ch.account } +func (ch nonceChange) copy() journalEntry { + return nonceChange{ + account: ch.account, + prev: ch.prev, + } +} + func (ch codeChange) revert(s *StateDB) { s.getStateObject(*ch.account).setCode(common.BytesToHash(ch.prevhash), ch.prevcode) } @@ -229,6 +304,14 @@ func (ch codeChange) dirtied() *common.Address { return ch.account } +func (ch codeChange) copy() journalEntry { + return codeChange{ + account: ch.account, + prevhash: common.CopyBytes(ch.prevhash), + prevcode: common.CopyBytes(ch.prevcode), + } +} + func (ch storageChange) revert(s *StateDB) { s.getStateObject(*ch.account).setState(ch.key, ch.prevalue) } @@ -237,6 +320,14 @@ func (ch storageChange) dirtied() *common.Address { return ch.account } +func (ch storageChange) copy() journalEntry { + return storageChange{ + account: ch.account, + key: ch.key, + prevalue: ch.prevalue, + } +} + func (ch transientStorageChange) revert(s *StateDB) { s.setTransientState(*ch.account, ch.key, ch.prevalue) } @@ -245,6 +336,14 @@ func (ch transientStorageChange) dirtied() *common.Address { return nil } +func (ch transientStorageChange) copy() journalEntry { + return transientStorageChange{ + account: ch.account, + key: ch.key, + prevalue: ch.prevalue, + } +} + func (ch refundChange) revert(s *StateDB) { s.refund = ch.prev } @@ -253,6 +352,12 @@ func (ch refundChange) dirtied() *common.Address { return nil } +func (ch refundChange) copy() journalEntry { + return refundChange{ + prev: ch.prev, + } +} + func (ch addLogChange) revert(s *StateDB) { logs := s.logs[ch.txhash] if len(logs) == 1 { @@ -267,6 +372,12 @@ func (ch addLogChange) dirtied() *common.Address { return nil } +func (ch addLogChange) copy() journalEntry { + return addLogChange{ + txhash: ch.txhash, + } +} + func (ch addPreimageChange) revert(s *StateDB) { delete(s.preimages, ch.hash) } @@ -275,6 +386,12 @@ func (ch addPreimageChange) dirtied() *common.Address { return nil } +func (ch addPreimageChange) copy() journalEntry { + return addPreimageChange{ + hash: ch.hash, + } +} + func (ch accessListAddAccountChange) revert(s *StateDB) { /* One important invariant here, is that whenever a (addr, slot) is added, if the @@ -292,6 +409,12 @@ func (ch accessListAddAccountChange) dirtied() *common.Address { return nil } +func (ch accessListAddAccountChange) copy() journalEntry { + return accessListAddAccountChange{ + address: ch.address, + } +} + func (ch accessListAddSlotChange) revert(s *StateDB) { s.accessList.DeleteSlot(*ch.address, *ch.slot) } @@ -299,3 +422,10 @@ func (ch accessListAddSlotChange) revert(s *StateDB) { func (ch accessListAddSlotChange) dirtied() *common.Address { return nil } + +func (ch accessListAddSlotChange) copy() journalEntry { + return accessListAddSlotChange{ + address: ch.address, + slot: ch.slot, + } +} diff --git a/core/state/state_object.go b/core/state/state_object.go index 9383b98e449..0d67204299a 100644 --- a/core/state/state_object.go +++ b/core/state/state_object.go @@ -438,16 +438,16 @@ func (s *stateObject) setBalance(amount *big.Int) { s.data.Balance = amount } -func (s *stateObject) deepCopy(db *StateDB) *stateObject { +func (s *stateObject) deepCopy() *stateObject { obj := &stateObject{ - db: db, + db: s.db, address: s.address, addrHash: s.addrHash, origin: s.origin, data: s.data, } if s.trie != nil { - obj.trie = db.db.CopyTrie(s.trie) + obj.trie = s.db.db.CopyTrie(s.trie) } obj.code = s.code obj.dirtyStorage = s.dirtyStorage.Copy() diff --git a/core/state/state_test.go b/core/state/state_test.go index 2f45ba44b4e..daeb35b1f4d 100644 --- a/core/state/state_test.go +++ b/core/state/state_test.go @@ -296,3 +296,21 @@ func compareStateObjects(so0, so1 *stateObject, t *testing.T) { } } } + +func TestCreateObjectRevert(t *testing.T) { + state, _ := New(types.EmptyRootHash, NewDatabase(rawdb.NewMemoryDatabase()), nil) + addr := common.BytesToAddress([]byte("so0")) + snap := state.Snapshot() + + state.CreateAccount(addr) + so0 := state.getStateObject(addr) + so0.SetBalance(big.NewInt(42)) + so0.SetNonce(43) + so0.SetCode(crypto.Keccak256Hash([]byte{'c', 'a', 'f', 'e'}), []byte{'c', 'a', 'f', 'e'}) + state.setStateObject(so0) + + state.RevertToSnapshot(snap) + if state.Exist(addr) { + t.Error("Unexpected account after revert") + } +} diff --git a/core/state/statedb.go b/core/state/statedb.go index 905944cbb5b..396c4895507 100644 --- a/core/state/statedb.go +++ b/core/state/statedb.go @@ -461,7 +461,6 @@ func (s *StateDB) Selfdestruct6780(addr common.Address) { if stateObject == nil { return } - if stateObject.created { s.SelfDestruct(addr) } @@ -702,7 +701,7 @@ func (s *StateDB) Copy() *StateDB { logs: make(map[common.Hash][]*types.Log, len(s.logs)), logSize: s.logSize, preimages: make(map[common.Hash][]byte, len(s.preimages)), - journal: newJournal(), + journal: s.journal.copy(), hasher: crypto.NewKeccakState(), // In order for the block producer to be able to use and make additions @@ -712,36 +711,14 @@ func (s *StateDB) Copy() *StateDB { snaps: s.snaps, snap: s.snap, } - // Copy the dirty states, logs, and preimages - for addr := range s.journal.dirties { - // As documented [here](https://github.com/ethereum/go-ethereum/pull/16485#issuecomment-380438527), - // and in the Finalise-method, there is a case where an object is in the journal but not - // in the stateObjects: OOG after touch on ripeMD prior to Byzantium. Thus, we need to check for - // nil - if object, exist := s.stateObjects[addr]; exist { - // Even though the original object is dirty, we are not copying the journal, - // so we need to make sure that any side-effect the journal would have caused - // during a commit (or similar op) is already applied to the copy. - state.stateObjects[addr] = object.deepCopy(state) - - state.stateObjectsDirty[addr] = struct{}{} // Mark the copy dirty to force internal (code/state) commits - state.stateObjectsPending[addr] = struct{}{} // Mark the copy pending to force external (account) commits - } + // Deep copy cached state objects along with the pending and dirty markers. + for addr, obj := range s.stateObjects { + state.stateObjects[addr] = obj.deepCopy() } - // Above, we don't copy the actual journal. This means that if the copy - // is copied, the loop above will be a no-op, since the copy's journal - // is empty. Thus, here we iterate over stateObjects, to enable copies - // of copies. for addr := range s.stateObjectsPending { - if _, exist := state.stateObjects[addr]; !exist { - state.stateObjects[addr] = s.stateObjects[addr].deepCopy(state) - } state.stateObjectsPending[addr] = struct{}{} } for addr := range s.stateObjectsDirty { - if _, exist := state.stateObjects[addr]; !exist { - state.stateObjects[addr] = s.stateObjects[addr].deepCopy(state) - } state.stateObjectsDirty[addr] = struct{}{} } // Deep copy the destruction markers. diff --git a/core/state/statedb_test.go b/core/state/statedb_test.go index df1cd5547d3..c943ececdc8 100644 --- a/core/state/statedb_test.go +++ b/core/state/statedb_test.go @@ -224,6 +224,49 @@ func TestCopy(t *testing.T) { } } +// TestCopyWithDirtyJournal tests if Copy can correct create a equal copied +// stateDB with dirty journal present. +func TestCopyWithDirtyJournal(t *testing.T) { + db := NewDatabase(rawdb.NewMemoryDatabase()) + orig, _ := New(types.EmptyRootHash, db, nil) + + // Fill up the initial states + for i := byte(0); i < 255; i++ { + obj := orig.GetOrNewStateObject(common.BytesToAddress([]byte{i})) + obj.AddBalance(big.NewInt(int64(i))) + obj.data.Root = common.HexToHash("0xdeadbeef") + orig.updateStateObject(obj) + } + root, _ := orig.Commit(0, true) + orig, _ = New(root, db, nil) + + // modify all in memory without finalizing + for i := byte(0); i < 255; i++ { + obj := orig.GetOrNewStateObject(common.BytesToAddress([]byte{i})) + obj.SubBalance(big.NewInt(int64(i))) + orig.updateStateObject(obj) + } + cpy := orig.Copy() + + orig.Finalise(true) + for i := byte(0); i < 255; i++ { + root := orig.GetStorageRoot(common.BytesToAddress([]byte{i})) + if root != (common.Hash{}) { + t.Errorf("Unexpected storage root %x", root) + } + } + cpy.Finalise(true) + for i := byte(0); i < 255; i++ { + root := cpy.GetStorageRoot(common.BytesToAddress([]byte{i})) + if root != (common.Hash{}) { + t.Errorf("Unexpected storage root %x", root) + } + } + if cpy.IntermediateRoot(true) != orig.IntermediateRoot(true) { + t.Error("State is not equal after copy") + } +} + func TestSnapshotRandom(t *testing.T) { config := &quick.Config{MaxCount: 1000} err := quick.Check((*snapshotTest).run, config) @@ -708,18 +751,19 @@ func TestCopyCopyCommitCopy(t *testing.T) { } } -// TestCommitCopy tests the copy from a committed state is not functional. +// TestCommitCopy tests the copy from a committed state is not fully functional. func TestCommitCopy(t *testing.T) { - state, _ := New(types.EmptyRootHash, NewDatabase(rawdb.NewMemoryDatabase()), nil) + db := NewDatabase(rawdb.NewMemoryDatabase()) + state, _ := New(types.EmptyRootHash, db, nil) // Create an account and check if the retrieved balance is correct addr := common.HexToAddress("0xaffeaffeaffeaffeaffeaffeaffeaffeaffeaffe") - skey := common.HexToHash("aaa") - sval := common.HexToHash("bbb") + skey1, skey2 := common.HexToHash("a1"), common.HexToHash("a2") + sval1, sval2 := common.HexToHash("b1"), common.HexToHash("b2") state.SetBalance(addr, big.NewInt(42)) // Change the account trie state.SetCode(addr, []byte("hello")) // Change an external metadata - state.SetState(addr, skey, sval) // Change the storage trie + state.SetState(addr, skey1, sval1) // Change the storage trie if balance := state.GetBalance(addr); balance.Cmp(big.NewInt(42)) != 0 { t.Fatalf("initial balance mismatch: have %v, want %v", balance, 42) @@ -727,25 +771,37 @@ func TestCommitCopy(t *testing.T) { if code := state.GetCode(addr); !bytes.Equal(code, []byte("hello")) { t.Fatalf("initial code mismatch: have %x, want %x", code, []byte("hello")) } - if val := state.GetState(addr, skey); val != sval { - t.Fatalf("initial non-committed storage slot mismatch: have %x, want %x", val, sval) + if val := state.GetState(addr, skey1); val != sval1 { + t.Fatalf("initial non-committed storage slot mismatch: have %x, want %x", val, sval1) } - if val := state.GetCommittedState(addr, skey); val != (common.Hash{}) { + if val := state.GetCommittedState(addr, skey1); val != (common.Hash{}) { t.Fatalf("initial committed storage slot mismatch: have %x, want %x", val, common.Hash{}) } - // Copy the committed state database, the copied one is not functional. - state.Commit(0, true) + root, _ := state.Commit(0, true) + state, _ = New(root, db, nil) + state.SetState(addr, skey2, sval2) + state.Commit(1, true) + + // Copy the committed state database, the copied one is not fully functional. copied := state.Copy() - if balance := copied.GetBalance(addr); balance.Cmp(big.NewInt(0)) != 0 { + if balance := copied.GetBalance(addr); balance.Cmp(big.NewInt(42)) != 0 { t.Fatalf("unexpected balance: have %v", balance) } - if code := copied.GetCode(addr); code != nil { + if code := copied.GetCode(addr); !bytes.Equal(code, []byte("hello")) { t.Fatalf("unexpected code: have %x", code) } - if val := copied.GetState(addr, skey); val != (common.Hash{}) { + // Miss slots because of non-functional trie after commit + if val := copied.GetState(addr, skey1); val != (common.Hash{}) { + t.Fatalf("unexpected storage slot: have %x", sval1) + } + if val := copied.GetCommittedState(addr, skey1); val != (common.Hash{}) { t.Fatalf("unexpected storage slot: have %x", val) } - if val := copied.GetCommittedState(addr, skey); val != (common.Hash{}) { + // Slots cached in the stateDB, available after commit + if val := copied.GetState(addr, skey2); val != sval2 { + t.Fatalf("unexpected storage slot: have %x", sval1) + } + if val := copied.GetCommittedState(addr, skey2); val != sval2 { t.Fatalf("unexpected storage slot: have %x", val) } if !errors.Is(copied.Error(), trie.ErrCommitted) {