Skip to content

Commit 46a527d

Browse files
committed
[release/1.4.16] core/state: implement reverts by journaling all changes
This commit replaces the deep-copy based state revert mechanism with a linear complexity journal. This commit also hides several internal StateDB methods to limit the number of ways in which calling code can use the journal incorrectly. As usual consultation and bug fixes to the initial implementation were provided by @karalabe, @obscuren and @Arachnid. Thank you! (cherry picked from commit 1f1ea18)
1 parent e97b301 commit 46a527d

22 files changed

+658
-239
lines changed

accounts/abi/bind/backends/simulated.go

Lines changed: 6 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -97,7 +97,8 @@ func (b *SimulatedBackend) ContractCall(contract common.Address, data []byte, pe
9797
statedb *state.StateDB
9898
)
9999
if pending {
100-
block, statedb = b.pendingBlock, b.pendingState.Copy()
100+
block, statedb = b.pendingBlock, b.pendingState
101+
defer statedb.RevertToSnapshot(statedb.Snapshot())
101102
} else {
102103
block = b.blockchain.CurrentBlock()
103104
statedb, _ = b.blockchain.State()
@@ -119,6 +120,7 @@ func (b *SimulatedBackend) ContractCall(contract common.Address, data []byte, pe
119120
value: new(big.Int),
120121
data: data,
121122
}
123+
122124
// Execute the call and return
123125
vmenv := core.NewEnv(statedb, chainConfig, b.blockchain, msg, block.Header(), vm.Config{})
124126
gaspool := new(core.GasPool).AddGas(common.MaxBig)
@@ -146,8 +148,10 @@ func (b *SimulatedBackend) EstimateGasLimit(sender common.Address, contract *com
146148
// Create a copy of the currently pending state db to screw around with
147149
var (
148150
block = b.pendingBlock
149-
statedb = b.pendingState.Copy()
151+
statedb = b.pendingState
150152
)
153+
defer statedb.RevertToSnapshot(statedb.Snapshot())
154+
151155
// If there's no code to interact with, respond with an appropriate error
152156
if contract != nil {
153157
if code := statedb.GetCode(*contract); len(code) == 0 {

cmd/evm/main.go

Lines changed: 16 additions & 16 deletions
Original file line numberDiff line numberDiff line change
@@ -223,22 +223,22 @@ type ruleSet struct{}
223223

224224
func (ruleSet) IsHomestead(*big.Int) bool { return true }
225225

226-
func (self *VMEnv) RuleSet() vm.RuleSet { return ruleSet{} }
227-
func (self *VMEnv) Vm() vm.Vm { return self.evm }
228-
func (self *VMEnv) Db() vm.Database { return self.state }
229-
func (self *VMEnv) MakeSnapshot() vm.Database { return self.state.Copy() }
230-
func (self *VMEnv) SetSnapshot(db vm.Database) { self.state.Set(db.(*state.StateDB)) }
231-
func (self *VMEnv) Origin() common.Address { return *self.transactor }
232-
func (self *VMEnv) BlockNumber() *big.Int { return common.Big0 }
233-
func (self *VMEnv) Coinbase() common.Address { return *self.transactor }
234-
func (self *VMEnv) Time() *big.Int { return self.time }
235-
func (self *VMEnv) Difficulty() *big.Int { return common.Big1 }
236-
func (self *VMEnv) BlockHash() []byte { return make([]byte, 32) }
237-
func (self *VMEnv) Value() *big.Int { return self.value }
238-
func (self *VMEnv) GasLimit() *big.Int { return big.NewInt(1000000000) }
239-
func (self *VMEnv) VmType() vm.Type { return vm.StdVmTy }
240-
func (self *VMEnv) Depth() int { return 0 }
241-
func (self *VMEnv) SetDepth(i int) { self.depth = i }
226+
func (self *VMEnv) RuleSet() vm.RuleSet { return ruleSet{} }
227+
func (self *VMEnv) Vm() vm.Vm { return self.evm }
228+
func (self *VMEnv) Db() vm.Database { return self.state }
229+
func (self *VMEnv) SnapshotDatabase() int { return self.state.Snapshot() }
230+
func (self *VMEnv) RevertToSnapshot(snap int) { self.state.RevertToSnapshot(snap) }
231+
func (self *VMEnv) Origin() common.Address { return *self.transactor }
232+
func (self *VMEnv) BlockNumber() *big.Int { return common.Big0 }
233+
func (self *VMEnv) Coinbase() common.Address { return *self.transactor }
234+
func (self *VMEnv) Time() *big.Int { return self.time }
235+
func (self *VMEnv) Difficulty() *big.Int { return common.Big1 }
236+
func (self *VMEnv) BlockHash() []byte { return make([]byte, 32) }
237+
func (self *VMEnv) Value() *big.Int { return self.value }
238+
func (self *VMEnv) GasLimit() *big.Int { return big.NewInt(1000000000) }
239+
func (self *VMEnv) VmType() vm.Type { return vm.StdVmTy }
240+
func (self *VMEnv) Depth() int { return 0 }
241+
func (self *VMEnv) SetDepth(i int) { self.depth = i }
242242
func (self *VMEnv) GetHash(n uint64) common.Hash {
243243
if self.block.Number().Cmp(big.NewInt(int64(n))) == 0 {
244244
return self.block.Hash()

core/chain_makers.go

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -131,7 +131,7 @@ func (b *BlockGen) AddUncheckedReceipt(receipt *types.Receipt) {
131131
// TxNonce returns the next valid transaction nonce for the
132132
// account at addr. It panics if the account does not exist.
133133
func (b *BlockGen) TxNonce(addr common.Address) uint64 {
134-
if !b.statedb.HasAccount(addr) {
134+
if !b.statedb.Exist(addr) {
135135
panic("account does not exist")
136136
}
137137
return b.statedb.GetNonce(addr)

core/execution.go

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -85,7 +85,7 @@ func exec(env vm.Environment, caller vm.ContractRef, address, codeAddr *common.A
8585
createAccount = true
8686
}
8787

88-
snapshotPreTransfer := env.MakeSnapshot()
88+
snapshotPreTransfer := env.SnapshotDatabase()
8989
var (
9090
from = env.Db().GetAccount(caller.Address())
9191
to vm.Account
@@ -129,7 +129,7 @@ func exec(env vm.Environment, caller vm.ContractRef, address, codeAddr *common.A
129129
if err != nil && (env.RuleSet().IsHomestead(env.BlockNumber()) || err != vm.CodeStoreOutOfGasError) {
130130
contract.UseGas(contract.Gas)
131131

132-
env.SetSnapshot(snapshotPreTransfer)
132+
env.RevertToSnapshot(snapshotPreTransfer)
133133
}
134134

135135
return ret, addr, err
@@ -144,7 +144,7 @@ func execDelegateCall(env vm.Environment, caller vm.ContractRef, originAddr, toA
144144
return nil, common.Address{}, vm.DepthError
145145
}
146146

147-
snapshot := env.MakeSnapshot()
147+
snapshot := env.SnapshotDatabase()
148148

149149
var to vm.Account
150150
if !env.Db().Exist(*toAddr) {
@@ -162,7 +162,7 @@ func execDelegateCall(env vm.Environment, caller vm.ContractRef, originAddr, toA
162162
if err != nil {
163163
contract.UseGas(contract.Gas)
164164

165-
env.SetSnapshot(snapshot)
165+
env.RevertToSnapshot(snapshot)
166166
}
167167

168168
return ret, addr, err

core/state/dump.go

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -52,7 +52,7 @@ func (self *StateDB) RawDump() Dump {
5252
panic(err)
5353
}
5454

55-
obj := NewObject(common.BytesToAddress(addr), data, nil)
55+
obj := newObject(nil, common.BytesToAddress(addr), data, nil)
5656
account := DumpAccount{
5757
Balance: data.Balance.String(),
5858
Nonce: data.Nonce,

core/state/journal.go

Lines changed: 117 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,117 @@
1+
// Copyright 2016 The go-ethereum Authors
2+
// This file is part of the go-ethereum library.
3+
//
4+
// The go-ethereum library is free software: you can redistribute it and/or modify
5+
// it under the terms of the GNU Lesser General Public License as published by
6+
// the Free Software Foundation, either version 3 of the License, or
7+
// (at your option) any later version.
8+
//
9+
// The go-ethereum library is distributed in the hope that it will be useful,
10+
// but WITHOUT ANY WARRANTY; without even the implied warranty of
11+
// MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
12+
// GNU Lesser General Public License for more details.
13+
//
14+
// You should have received a copy of the GNU Lesser General Public License
15+
// along with the go-ethereum library. If not, see <http://www.gnu.org/licenses/>.
16+
17+
package state
18+
19+
import (
20+
"math/big"
21+
22+
"github.com/ethereum/go-ethereum/common"
23+
)
24+
25+
type journalEntry interface {
26+
undo(*StateDB)
27+
}
28+
29+
type journal []journalEntry
30+
31+
type (
32+
// Changes to the account trie.
33+
createObjectChange struct {
34+
account *common.Address
35+
}
36+
resetObjectChange struct {
37+
prev *StateObject
38+
}
39+
deleteAccountChange struct {
40+
account *common.Address
41+
prev bool // whether account had already suicided
42+
prevbalance *big.Int
43+
}
44+
45+
// Changes to individual accounts.
46+
balanceChange struct {
47+
account *common.Address
48+
prev *big.Int
49+
}
50+
nonceChange struct {
51+
account *common.Address
52+
prev uint64
53+
}
54+
storageChange struct {
55+
account *common.Address
56+
key, prevalue common.Hash
57+
}
58+
codeChange struct {
59+
account *common.Address
60+
prevcode, prevhash []byte
61+
}
62+
63+
// Changes to other state values.
64+
refundChange struct {
65+
prev *big.Int
66+
}
67+
addLogChange struct {
68+
txhash common.Hash
69+
}
70+
)
71+
72+
func (ch createObjectChange) undo(s *StateDB) {
73+
s.GetStateObject(*ch.account).deleted = true
74+
delete(s.stateObjects, *ch.account)
75+
delete(s.stateObjectsDirty, *ch.account)
76+
}
77+
78+
func (ch resetObjectChange) undo(s *StateDB) {
79+
s.setStateObject(ch.prev)
80+
}
81+
82+
func (ch deleteAccountChange) undo(s *StateDB) {
83+
obj := s.GetStateObject(*ch.account)
84+
if obj != nil {
85+
obj.remove = ch.prev
86+
obj.setBalance(ch.prevbalance)
87+
}
88+
}
89+
90+
func (ch balanceChange) undo(s *StateDB) {
91+
s.GetStateObject(*ch.account).setBalance(ch.prev)
92+
}
93+
94+
func (ch nonceChange) undo(s *StateDB) {
95+
s.GetStateObject(*ch.account).setNonce(ch.prev)
96+
}
97+
98+
func (ch codeChange) undo(s *StateDB) {
99+
s.GetStateObject(*ch.account).setCode(common.BytesToHash(ch.prevhash), ch.prevcode)
100+
}
101+
102+
func (ch storageChange) undo(s *StateDB) {
103+
s.GetStateObject(*ch.account).setState(ch.key, ch.prevalue)
104+
}
105+
106+
func (ch refundChange) undo(s *StateDB) {
107+
s.refund = ch.prev
108+
}
109+
110+
func (ch addLogChange) undo(s *StateDB) {
111+
logs := s.logs[ch.txhash]
112+
if len(logs) == 1 {
113+
delete(s.logs, ch.txhash)
114+
} else {
115+
s.logs[ch.txhash] = logs[:len(logs)-1]
116+
}
117+
}

core/state/managed_state_test.go

Lines changed: 2 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -29,11 +29,8 @@ func create() (*ManagedState, *account) {
2929
db, _ := ethdb.NewMemDatabase()
3030
statedb, _ := New(common.Hash{}, db)
3131
ms := ManageState(statedb)
32-
so := &StateObject{address: addr}
33-
so.SetNonce(100)
34-
ms.StateDB.stateObjects[addr] = so
35-
ms.accounts[addr] = newAccount(so)
36-
32+
ms.StateDB.SetNonce(addr, 100)
33+
ms.accounts[addr] = newAccount(ms.StateDB.GetStateObject(addr))
3734
return ms, ms.accounts[addr]
3835
}
3936

core/state/state_object.go

Lines changed: 45 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -66,6 +66,7 @@ func (self Storage) Copy() Storage {
6666
type StateObject struct {
6767
address common.Address // Ethereum address of this account
6868
data Account
69+
db *StateDB
6970

7071
// DB error.
7172
// State objects are used by the consensus core and VM which are
@@ -99,15 +100,15 @@ type Account struct {
99100
CodeHash []byte
100101
}
101102

102-
// NewObject creates a state object.
103-
func NewObject(address common.Address, data Account, onDirty func(addr common.Address)) *StateObject {
103+
// newObject creates a state object.
104+
func newObject(db *StateDB, address common.Address, data Account, onDirty func(addr common.Address)) *StateObject {
104105
if data.Balance == nil {
105106
data.Balance = new(big.Int)
106107
}
107108
if data.CodeHash == nil {
108109
data.CodeHash = emptyCodeHash
109110
}
110-
return &StateObject{address: address, data: data, cachedStorage: make(Storage), dirtyStorage: make(Storage), onDirty: onDirty}
111+
return &StateObject{db: db, address: address, data: data, cachedStorage: make(Storage), dirtyStorage: make(Storage), onDirty: onDirty}
111112
}
112113

113114
// EncodeRLP implements rlp.Encoder.
@@ -122,7 +123,7 @@ func (self *StateObject) setError(err error) {
122123
}
123124
}
124125

125-
func (self *StateObject) MarkForDeletion() {
126+
func (self *StateObject) markForDeletion() {
126127
self.remove = true
127128
if self.onDirty != nil {
128129
self.onDirty(self.Address())
@@ -163,7 +164,16 @@ func (self *StateObject) GetState(db trie.Database, key common.Hash) common.Hash
163164
}
164165

165166
// SetState updates a value in account storage.
166-
func (self *StateObject) SetState(key, value common.Hash) {
167+
func (self *StateObject) SetState(db trie.Database, key, value common.Hash) {
168+
self.db.journal = append(self.db.journal, storageChange{
169+
account: &self.address,
170+
key: key,
171+
prevalue: self.GetState(db, key),
172+
})
173+
self.setState(key, value)
174+
}
175+
176+
func (self *StateObject) setState(key, value common.Hash) {
167177
self.cachedStorage[key] = value
168178
self.dirtyStorage[key] = value
169179

@@ -189,7 +199,7 @@ func (self *StateObject) updateTrie(db trie.Database) {
189199
}
190200

191201
// UpdateRoot sets the trie root to the current root hash of
192-
func (self *StateObject) UpdateRoot(db trie.Database) {
202+
func (self *StateObject) updateRoot(db trie.Database) {
193203
self.updateTrie(db)
194204
self.data.Root = self.trie.Hash()
195205
}
@@ -232,6 +242,14 @@ func (c *StateObject) SubBalance(amount *big.Int) {
232242
}
233243

234244
func (self *StateObject) SetBalance(amount *big.Int) {
245+
self.db.journal = append(self.db.journal, balanceChange{
246+
account: &self.address,
247+
prev: new(big.Int).Set(self.data.Balance),
248+
})
249+
self.setBalance(amount)
250+
}
251+
252+
func (self *StateObject) setBalance(amount *big.Int) {
235253
self.data.Balance = amount
236254
if self.onDirty != nil {
237255
self.onDirty(self.Address())
@@ -242,8 +260,8 @@ func (self *StateObject) SetBalance(amount *big.Int) {
242260
// Return the gas back to the origin. Used by the Virtual machine or Closures
243261
func (c *StateObject) ReturnGas(gas, price *big.Int) {}
244262

245-
func (self *StateObject) Copy(db trie.Database, onDirty func(addr common.Address)) *StateObject {
246-
stateObject := NewObject(self.address, self.data, onDirty)
263+
func (self *StateObject) deepCopy(db *StateDB, onDirty func(addr common.Address)) *StateObject {
264+
stateObject := newObject(db, self.address, self.data, onDirty)
247265
stateObject.trie = self.trie
248266
stateObject.code = self.code
249267
stateObject.dirtyStorage = self.dirtyStorage.Copy()
@@ -280,6 +298,16 @@ func (self *StateObject) Code(db trie.Database) []byte {
280298
}
281299

282300
func (self *StateObject) SetCode(codeHash common.Hash, code []byte) {
301+
prevcode := self.Code(self.db.db)
302+
self.db.journal = append(self.db.journal, codeChange{
303+
account: &self.address,
304+
prevhash: self.CodeHash(),
305+
prevcode: prevcode,
306+
})
307+
self.setCode(codeHash, code)
308+
}
309+
310+
func (self *StateObject) setCode(codeHash common.Hash, code []byte) {
283311
self.code = code
284312
self.data.CodeHash = codeHash[:]
285313
self.dirtyCode = true
@@ -290,6 +318,14 @@ func (self *StateObject) SetCode(codeHash common.Hash, code []byte) {
290318
}
291319

292320
func (self *StateObject) SetNonce(nonce uint64) {
321+
self.db.journal = append(self.db.journal, nonceChange{
322+
account: &self.address,
323+
prev: self.data.Nonce,
324+
})
325+
self.setNonce(nonce)
326+
}
327+
328+
func (self *StateObject) setNonce(nonce uint64) {
293329
self.data.Nonce = nonce
294330
if self.onDirty != nil {
295331
self.onDirty(self.Address())
@@ -322,7 +358,7 @@ func (self *StateObject) ForEachStorage(cb func(key, value common.Hash) bool) {
322358
cb(h, value)
323359
}
324360

325-
it := self.trie.Iterator()
361+
it := self.getTrie(self.db.db).Iterator()
326362
for it.Next() {
327363
// ignore cached values
328364
key := common.BytesToHash(self.trie.GetKey(it.Key))

0 commit comments

Comments
 (0)