Skip to content

Commit 0acc0a1

Browse files
authored
core/state: simplify storage trie update and commit (#28030)
This change improves function description and simplifies logic in statedb update and commit operations.
1 parent 53f3c2a commit 0acc0a1

File tree

5 files changed

+88
-129
lines changed

5 files changed

+88
-129
lines changed

core/state/state_object.go

Lines changed: 35 additions & 31 deletions
Original file line numberDiff line numberDiff line change
@@ -264,12 +264,17 @@ func (s *stateObject) finalise(prefetch bool) {
264264
}
265265
}
266266

267-
// updateTrie writes cached storage modifications into the object's storage trie.
268-
// It will return nil if the trie has not been loaded and no changes have been
269-
// made. An error will be returned if the trie can't be loaded/updated correctly.
267+
// updateTrie is responsible for persisting cached storage changes into the
268+
// object's storage trie. In case the storage trie is not yet loaded, this
269+
// function will load the trie automatically. If any issues arise during the
270+
// loading or updating of the trie, an error will be returned. Furthermore,
271+
// this function will return the mutated storage trie, or nil if there is no
272+
// storage change at all.
270273
func (s *stateObject) updateTrie() (Trie, error) {
271274
// Make sure all dirty slots are finalized into the pending storage area
272-
s.finalise(false) // Don't prefetch anymore, pull directly if need be
275+
s.finalise(false)
276+
277+
// Short circuit if nothing changed, don't bother with hashing anything
273278
if len(s.pendingStorage) == 0 {
274279
return s.trie, nil
275280
}
@@ -281,14 +286,13 @@ func (s *stateObject) updateTrie() (Trie, error) {
281286
var (
282287
storage map[common.Hash][]byte
283288
origin map[common.Hash][]byte
284-
hasher = s.db.hasher
285289
)
286290
tr, err := s.getTrie()
287291
if err != nil {
288292
s.db.setError(err)
289293
return nil, err
290294
}
291-
// Insert all the pending updates into the trie
295+
// Insert all the pending storage updates into the trie
292296
usedStorage := make([][]byte, 0, len(s.pendingStorage))
293297
for key, value := range s.pendingStorage {
294298
// Skip noop changes, persist actual changes
@@ -298,19 +302,18 @@ func (s *stateObject) updateTrie() (Trie, error) {
298302
prev := s.originStorage[key]
299303
s.originStorage[key] = value
300304

301-
// rlp-encoded value to be used by the snapshot
302-
var snapshotVal []byte
305+
var encoded []byte // rlp-encoded value to be used by the snapshot
303306
if (value == common.Hash{}) {
304307
if err := tr.DeleteStorage(s.address, key[:]); err != nil {
305308
s.db.setError(err)
306309
return nil, err
307310
}
308311
s.db.StorageDeleted += 1
309312
} else {
310-
trimmedVal := common.TrimLeftZeroes(value[:])
311313
// Encoding []byte cannot fail, ok to ignore the error.
312-
snapshotVal, _ = rlp.EncodeToBytes(trimmedVal)
313-
if err := tr.UpdateStorage(s.address, key[:], trimmedVal); err != nil {
314+
trimmed := common.TrimLeftZeroes(value[:])
315+
encoded, _ = rlp.EncodeToBytes(trimmed)
316+
if err := tr.UpdateStorage(s.address, key[:], trimmed); err != nil {
314317
s.db.setError(err)
315318
return nil, err
316319
}
@@ -323,8 +326,8 @@ func (s *stateObject) updateTrie() (Trie, error) {
323326
s.db.storages[s.addrHash] = storage
324327
}
325328
}
326-
khash := crypto.HashData(hasher, key[:])
327-
storage[khash] = snapshotVal // snapshotVal will be nil if it's deleted
329+
khash := crypto.HashData(s.db.hasher, key[:])
330+
storage[khash] = encoded // encoded will be nil if it's deleted
328331

329332
// Cache the original value of mutated storage slots
330333
if origin == nil {
@@ -349,21 +352,17 @@ func (s *stateObject) updateTrie() (Trie, error) {
349352
if s.db.prefetcher != nil {
350353
s.db.prefetcher.used(s.addrHash, s.data.Root, usedStorage)
351354
}
352-
if len(s.pendingStorage) > 0 {
353-
s.pendingStorage = make(Storage)
354-
}
355+
s.pendingStorage = make(Storage) // reset pending map
355356
return tr, nil
356357
}
357358

358-
// UpdateRoot sets the trie root to the current root hash of. An error
359-
// will be returned if trie root hash is not computed correctly.
359+
// updateRoot flushes all cached storage mutations to trie, recalculating the
360+
// new storage trie root.
360361
func (s *stateObject) updateRoot() {
362+
// Flush cached storage mutations into trie, short circuit if any error
363+
// is occurred or there is not change in the trie.
361364
tr, err := s.updateTrie()
362-
if err != nil {
363-
return
364-
}
365-
// If nothing changed, don't bother with hashing anything
366-
if tr == nil {
365+
if err != nil || tr == nil {
367366
return
368367
}
369368
// Track the amount of time wasted on hashing the storage trie
@@ -373,22 +372,23 @@ func (s *stateObject) updateRoot() {
373372
s.data.Root = tr.Hash()
374373
}
375374

376-
// commit returns the changes made in storage trie and updates the account data.
375+
// commit obtains a set of dirty storage trie nodes and updates the account data.
376+
// The returned set can be nil if nothing to commit. This function assumes all
377+
// storage mutations have already been flushed into trie by updateRoot.
377378
func (s *stateObject) commit() (*trienode.NodeSet, error) {
378-
tr, err := s.updateTrie()
379-
if err != nil {
380-
return nil, err
381-
}
382-
// If nothing changed, don't bother with committing anything
383-
if tr == nil {
379+
// Short circuit if trie is not even loaded, don't bother with committing anything
380+
if s.trie == nil {
384381
s.origin = s.data.Copy()
385382
return nil, nil
386383
}
387384
// Track the amount of time wasted on committing the storage trie
388385
if metrics.EnabledExpensive {
389386
defer func(start time.Time) { s.db.StorageCommits += time.Since(start) }(time.Now())
390387
}
391-
root, nodes, err := tr.Commit(false)
388+
// The trie is currently in an open state and could potentially contain
389+
// cached mutations. Call commit to acquire a set of nodes that have been
390+
// modified, the set can be nil if nothing to commit.
391+
root, nodes, err := s.trie.Commit(false)
392392
if err != nil {
393393
return nil, err
394394
}
@@ -536,3 +536,7 @@ func (s *stateObject) Balance() *big.Int {
536536
func (s *stateObject) Nonce() uint64 {
537537
return s.data.Nonce
538538
}
539+
540+
func (s *stateObject) Root() common.Hash {
541+
return s.data.Root
542+
}

core/state/statedb.go

Lines changed: 11 additions & 56 deletions
Original file line numberDiff line numberDiff line change
@@ -18,7 +18,6 @@
1818
package state
1919

2020
import (
21-
"errors"
2221
"fmt"
2322
"math/big"
2423
"sort"
@@ -48,17 +47,6 @@ type revision struct {
4847
journalIndex int
4948
}
5049

51-
type proofList [][]byte
52-
53-
func (n *proofList) Put(key []byte, value []byte) error {
54-
*n = append(*n, value)
55-
return nil
56-
}
57-
58-
func (n *proofList) Delete(key []byte) error {
59-
panic("not supported")
60-
}
61-
6250
// StateDB structs within the ethereum protocol are used to store anything
6351
// within the merkle trie. StateDBs take care of caching and storing
6452
// nested states. It's the general query interface to retrieve:
@@ -297,6 +285,7 @@ func (s *StateDB) GetBalance(addr common.Address) *big.Int {
297285
return common.Big0
298286
}
299287

288+
// GetNonce retrieves the nonce from the given address or 0 if object not found
300289
func (s *StateDB) GetNonce(addr common.Address) uint64 {
301290
stateObject := s.getStateObject(addr)
302291
if stateObject != nil {
@@ -306,6 +295,16 @@ func (s *StateDB) GetNonce(addr common.Address) uint64 {
306295
return 0
307296
}
308297

298+
// GetStorageRoot retrieves the storage root from the given address or empty
299+
// if object not found.
300+
func (s *StateDB) GetStorageRoot(addr common.Address) common.Hash {
301+
stateObject := s.getStateObject(addr)
302+
if stateObject != nil {
303+
return stateObject.Root()
304+
}
305+
return common.Hash{}
306+
}
307+
309308
// TxIndex returns the current transaction index set by Prepare.
310309
func (s *StateDB) TxIndex() int {
311310
return s.txIndex
@@ -344,35 +343,6 @@ func (s *StateDB) GetState(addr common.Address, hash common.Hash) common.Hash {
344343
return common.Hash{}
345344
}
346345

347-
// GetProof returns the Merkle proof for a given account.
348-
func (s *StateDB) GetProof(addr common.Address) ([][]byte, error) {
349-
return s.GetProofByHash(crypto.Keccak256Hash(addr.Bytes()))
350-
}
351-
352-
// GetProofByHash returns the Merkle proof for a given account.
353-
func (s *StateDB) GetProofByHash(addrHash common.Hash) ([][]byte, error) {
354-
var proof proofList
355-
err := s.trie.Prove(addrHash[:], &proof)
356-
return proof, err
357-
}
358-
359-
// GetStorageProof returns the Merkle proof for given storage slot.
360-
func (s *StateDB) GetStorageProof(a common.Address, key common.Hash) ([][]byte, error) {
361-
trie, err := s.StorageTrie(a)
362-
if err != nil {
363-
return nil, err
364-
}
365-
if trie == nil {
366-
return nil, errors.New("storage trie for requested address does not exist")
367-
}
368-
var proof proofList
369-
err = trie.Prove(crypto.Keccak256(key.Bytes()), &proof)
370-
if err != nil {
371-
return nil, err
372-
}
373-
return proof, nil
374-
}
375-
376346
// GetCommittedState retrieves a value from the given account's committed storage trie.
377347
func (s *StateDB) GetCommittedState(addr common.Address, hash common.Hash) common.Hash {
378348
stateObject := s.getStateObject(addr)
@@ -387,21 +357,6 @@ func (s *StateDB) Database() Database {
387357
return s.db
388358
}
389359

390-
// StorageTrie returns the storage trie of an account. The return value is a copy
391-
// and is nil for non-existent accounts. An error will be returned if storage trie
392-
// is existent but can't be loaded correctly.
393-
func (s *StateDB) StorageTrie(addr common.Address) (Trie, error) {
394-
stateObject := s.getStateObject(addr)
395-
if stateObject == nil {
396-
return nil, nil
397-
}
398-
cpy := stateObject.deepCopy(s)
399-
if _, err := cpy.updateTrie(); err != nil {
400-
return nil, err
401-
}
402-
return cpy.getTrie()
403-
}
404-
405360
func (s *StateDB) HasSelfDestructed(addr common.Address) bool {
406361
stateObject := s.getStateObject(addr)
407362
if stateObject != nil {

eth/api_debug.go

Lines changed: 13 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -27,6 +27,7 @@ import (
2727
"github.com/ethereum/go-ethereum/core/rawdb"
2828
"github.com/ethereum/go-ethereum/core/state"
2929
"github.com/ethereum/go-ethereum/core/types"
30+
"github.com/ethereum/go-ethereum/crypto"
3031
"github.com/ethereum/go-ethereum/internal/ethapi"
3132
"github.com/ethereum/go-ethereum/log"
3233
"github.com/ethereum/go-ethereum/rlp"
@@ -216,7 +217,6 @@ func (api *DebugAPI) StorageRangeAt(ctx context.Context, blockNrOrHash rpc.Block
216217
if err != nil {
217218
return StorageRangeResult{}, err
218219
}
219-
220220
if block == nil {
221221
return StorageRangeResult{}, fmt.Errorf("block %v not found", blockNrOrHash)
222222
}
@@ -226,18 +226,20 @@ func (api *DebugAPI) StorageRangeAt(ctx context.Context, blockNrOrHash rpc.Block
226226
}
227227
defer release()
228228

229-
st, err := statedb.StorageTrie(contractAddress)
229+
return storageRangeAt(statedb, block.Root(), contractAddress, keyStart, maxResult)
230+
}
231+
232+
func storageRangeAt(statedb *state.StateDB, root common.Hash, address common.Address, start []byte, maxResult int) (StorageRangeResult, error) {
233+
storageRoot := statedb.GetStorageRoot(address)
234+
if storageRoot == types.EmptyRootHash || storageRoot == (common.Hash{}) {
235+
return StorageRangeResult{}, nil // empty storage
236+
}
237+
id := trie.StorageTrieID(root, crypto.Keccak256Hash(address.Bytes()), storageRoot)
238+
tr, err := trie.NewStateTrie(id, statedb.Database().TrieDB())
230239
if err != nil {
231240
return StorageRangeResult{}, err
232241
}
233-
if st == nil {
234-
return StorageRangeResult{}, fmt.Errorf("account %x doesn't exist", contractAddress)
235-
}
236-
return storageRangeAt(st, keyStart, maxResult)
237-
}
238-
239-
func storageRangeAt(st state.Trie, start []byte, maxResult int) (StorageRangeResult, error) {
240-
trieIt, err := st.NodeIterator(start)
242+
trieIt, err := tr.NodeIterator(start)
241243
if err != nil {
242244
return StorageRangeResult{}, err
243245
}
@@ -249,7 +251,7 @@ func storageRangeAt(st state.Trie, start []byte, maxResult int) (StorageRangeRes
249251
return StorageRangeResult{}, err
250252
}
251253
e := storageEntry{Value: common.BytesToHash(content)}
252-
if preimage := st.GetKey(it.Key); preimage != nil {
254+
if preimage := tr.GetKey(it.Key); preimage != nil {
253255
preimage := common.BytesToHash(preimage)
254256
e.Key = &preimage
255257
}

eth/api_debug_test.go

Lines changed: 8 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -159,9 +159,10 @@ func TestStorageRangeAt(t *testing.T) {
159159

160160
// Create a state where account 0x010000... has a few storage entries.
161161
var (
162-
state, _ = state.New(types.EmptyRootHash, state.NewDatabase(rawdb.NewMemoryDatabase()), nil)
163-
addr = common.Address{0x01}
164-
keys = []common.Hash{ // hashes of Keys of storage
162+
db = state.NewDatabaseWithConfig(rawdb.NewMemoryDatabase(), &trie.Config{Preimages: true})
163+
sdb, _ = state.New(types.EmptyRootHash, db, nil)
164+
addr = common.Address{0x01}
165+
keys = []common.Hash{ // hashes of Keys of storage
165166
common.HexToHash("340dd630ad21bf010b4e676dbfa9ba9a02175262d1fa356232cfde6cb5b47ef2"),
166167
common.HexToHash("426fcb404ab2d5d8e61a3d918108006bbb0a9be65e92235bb10eefbdb6dcd053"),
167168
common.HexToHash("48078cfed56339ea54962e72c37c7f588fc4f8e5bc173827ba75cb10a63a96a5"),
@@ -175,8 +176,10 @@ func TestStorageRangeAt(t *testing.T) {
175176
}
176177
)
177178
for _, entry := range storage {
178-
state.SetState(addr, *entry.Key, entry.Value)
179+
sdb.SetState(addr, *entry.Key, entry.Value)
179180
}
181+
root, _ := sdb.Commit(0, false)
182+
sdb, _ = state.New(root, db, nil)
180183

181184
// Check a few combinations of limit and start/end.
182185
tests := []struct {
@@ -206,11 +209,7 @@ func TestStorageRangeAt(t *testing.T) {
206209
},
207210
}
208211
for _, test := range tests {
209-
tr, err := state.StorageTrie(addr)
210-
if err != nil {
211-
t.Error(err)
212-
}
213-
result, err := storageRangeAt(tr, test.start, test.limit)
212+
result, err := storageRangeAt(sdb, root, addr, test.start, test.limit)
214213
if err != nil {
215214
t.Error(err)
216215
}

0 commit comments

Comments
 (0)