Skip to content
Merged
164 changes: 129 additions & 35 deletions core/state/journal.go
Original file line number Diff line number Diff line change
Expand Up @@ -17,12 +17,21 @@
package state

import (
"fmt"
"maps"
"slices"
"sort"

"github.com/ethereum/go-ethereum/common"
"github.com/ethereum/go-ethereum/core/types"
"github.com/holiman/uint256"
)

type revision struct {
id int
journalIndex int
}

// journalEntry is a modification entry in the state change journal that can be
// reverted on demand.
type journalEntry interface {
Expand All @@ -42,6 +51,9 @@ type journalEntry interface {
type journal struct {
entries []journalEntry // Current changes tracked by the journal
dirties map[common.Address]int // Dirty accounts and the number of changes

validRevisions []revision
nextRevisionId int
}

// newJournal creates a new initialized journal.
Expand All @@ -51,6 +63,40 @@ func newJournal() *journal {
}
}

// Reset clears the journal, after this operation the journal can be used
// anew. It is semantically similar to calling 'newJournal', but the underlying
// slices can be reused
func (j *journal) Reset() {
j.entries = j.entries[:0]
j.validRevisions = j.validRevisions[:0]
clear(j.dirties)
j.nextRevisionId = 0
}

// Snapshot returns an identifier for the current revision of the state.
func (j *journal) Snapshot() int {
id := j.nextRevisionId
j.nextRevisionId++
j.validRevisions = append(j.validRevisions, revision{id, j.length()})
return id
}

// RevertToSnapshot reverts all state changes made since the given revision.
func (j *journal) RevertToSnapshot(revid int, s *StateDB) {
// Find the snapshot in the stack of valid snapshots.
idx := sort.Search(len(j.validRevisions), func(i int) bool {
return j.validRevisions[i].id >= revid
})
if idx == len(j.validRevisions) || j.validRevisions[idx].id != revid {
panic(fmt.Errorf("revision id %v cannot be reverted", revid))
}
snapshot := j.validRevisions[idx].journalIndex

// Replay the journal to undo changes and remove invalidated snapshots
j.revert(s, snapshot)
j.validRevisions = j.validRevisions[:idx]
}

// append inserts a new modification entry to the end of the change journal.
func (j *journal) append(entry journalEntry) {
j.entries = append(j.entries, entry)
Expand Down Expand Up @@ -95,8 +141,83 @@ func (j *journal) copy() *journal {
entries = append(entries, j.entries[i].copy())
}
return &journal{
entries: entries,
dirties: maps.Clone(j.dirties),
entries: entries,
dirties: maps.Clone(j.dirties),
validRevisions: slices.Clone(j.validRevisions),
nextRevisionId: j.nextRevisionId,
}
}

func (j *journal) AccessListAddAccount(addr common.Address) {
j.append(accessListAddAccountChange{&addr})
}

func (j *journal) AccessListAddSlot(addr common.Address, slot common.Hash) {
j.append(accessListAddSlotChange{
address: &addr,
slot: &slot,
})
}

func (j *journal) Log(txHash common.Hash) {
j.append(addLogChange{txhash: txHash})
}

func (j *journal) Create(addr common.Address) {
j.append(createObjectChange{account: &addr})
}

func (j *journal) Destruct(addr common.Address) {
j.append(selfDestructChange{account: &addr})
}

func (j *journal) SetStorage(addr common.Address, key, prev, origin common.Hash) {
j.append(storageChange{
account: &addr,
key: key,
prevvalue: prev,
origvalue: origin,
})
}

func (j *journal) SetTransientState(addr common.Address, key, prev common.Hash) {
j.append(transientStorageChange{
account: &addr,
key: key,
prevalue: prev,
})
}

func (j *journal) RefundChange(previous uint64) {
j.append(refundChange{prev: previous})
}

func (j *journal) BalanceChange(addr common.Address, previous *uint256.Int) {
j.append(balanceChange{
account: &addr,
prev: previous.Clone(),
})
}

func (j *journal) SetCode(address common.Address) {
j.append(codeChange{account: &address})
}

func (j *journal) NonceChange(address common.Address, prev uint64) {
j.append(nonceChange{
account: &address,
prev: prev,
})
}

func (j *journal) Touch(address common.Address) {
j.append(touchChange{
account: &address,
})
if address == ripemd {
// Explicitly put it in the dirty-cache, which is otherwise generated from
// flattened journals.
j.dirty(address)
}
}

Expand All @@ -114,9 +235,7 @@ type (
}

selfDestructChange struct {
account *common.Address
prev bool // whether account had already self-destructed
prevbalance *uint256.Int
account *common.Address
}

// Changes to individual accounts.
Expand All @@ -135,8 +254,7 @@ type (
origvalue common.Hash
}
codeChange struct {
account *common.Address
prevcode, prevhash []byte
account *common.Address
}

// Changes to other state values.
Expand All @@ -146,9 +264,6 @@ type (
addLogChange struct {
txhash common.Hash
}
addPreimageChange struct {
hash common.Hash
}
touchChange struct {
account *common.Address
}
Expand Down Expand Up @@ -200,8 +315,7 @@ func (ch createContractChange) copy() journalEntry {
func (ch selfDestructChange) revert(s *StateDB) {
obj := s.getStateObject(*ch.account)
if obj != nil {
obj.selfDestructed = ch.prev
obj.setBalance(ch.prevbalance)
obj.selfDestructed = false
}
}

Expand All @@ -211,9 +325,7 @@ func (ch selfDestructChange) dirtied() *common.Address {

func (ch selfDestructChange) copy() journalEntry {
return selfDestructChange{
account: ch.account,
prev: ch.prev,
prevbalance: new(uint256.Int).Set(ch.prevbalance),
account: ch.account,
}
}

Expand Down Expand Up @@ -263,19 +375,15 @@ func (ch nonceChange) copy() journalEntry {
}

func (ch codeChange) revert(s *StateDB) {
s.getStateObject(*ch.account).setCode(common.BytesToHash(ch.prevhash), ch.prevcode)
s.getStateObject(*ch.account).setCode(types.EmptyCodeHash, nil)
}

func (ch codeChange) dirtied() *common.Address {
return ch.account
}

func (ch codeChange) copy() journalEntry {
return codeChange{
account: ch.account,
prevhash: common.CopyBytes(ch.prevhash),
prevcode: common.CopyBytes(ch.prevcode),
}
return codeChange{account: ch.account}
}

func (ch storageChange) revert(s *StateDB) {
Expand Down Expand Up @@ -344,20 +452,6 @@ func (ch addLogChange) copy() journalEntry {
}
}

func (ch addPreimageChange) revert(s *StateDB) {
delete(s.preimages, ch.hash)
}

func (ch addPreimageChange) dirtied() *common.Address {
return nil
}

func (ch addPreimageChange) copy() journalEntry {
return addPreimageChange{
hash: ch.hash,
}
}

func (ch accessListAddAccountChange) revert(s *StateDB) {
/*
One important invariant here, is that whenever a (addr, slot) is added, if the
Expand Down
38 changes: 8 additions & 30 deletions core/state/state_object.go
Original file line number Diff line number Diff line change
Expand Up @@ -114,14 +114,7 @@ func (s *stateObject) markSelfdestructed() {
}

func (s *stateObject) touch() {
s.db.journal.append(touchChange{
account: &s.address,
})
if s.address == ripemd {
// Explicitly put it in the dirty-cache, which is otherwise generated from
// flattened journals.
s.db.journal.dirty(s.address)
}
s.db.journal.Touch(s.address)
}

// getTrie returns the associated storage trie. The trie will be opened if it's
Expand Down Expand Up @@ -251,16 +244,11 @@ func (s *stateObject) SetState(key, value common.Hash) {
return
}
// New value is different, update and journal the change
s.db.journal.append(storageChange{
account: &s.address,
key: key,
prevvalue: prev,
origvalue: origin,
})
s.db.journal.SetStorage(s.address, key, prev, origin)
s.setState(key, value, origin)
if s.db.logger != nil && s.db.logger.OnStorageChange != nil {
s.db.logger.OnStorageChange(s.address, key, prev, value)
}
s.setState(key, value, origin)
}

// setState updates a value in account dirty storage. The dirtiness will be
Expand Down Expand Up @@ -510,10 +498,7 @@ func (s *stateObject) SubBalance(amount *uint256.Int, reason tracing.BalanceChan
}

func (s *stateObject) SetBalance(amount *uint256.Int, reason tracing.BalanceChangeReason) {
s.db.journal.append(balanceChange{
account: &s.address,
prev: new(uint256.Int).Set(s.data.Balance),
})
s.db.journal.BalanceChange(s.address, s.data.Balance)
if s.db.logger != nil && s.db.logger.OnBalanceChange != nil {
s.db.logger.OnBalanceChange(s.address, s.Balance().ToBig(), amount.ToBig(), reason)
}
Expand Down Expand Up @@ -589,14 +574,10 @@ func (s *stateObject) CodeSize() int {
}

func (s *stateObject) SetCode(codeHash common.Hash, code []byte) {
prevcode := s.Code()
s.db.journal.append(codeChange{
account: &s.address,
prevhash: s.CodeHash(),
prevcode: prevcode,
})
s.db.journal.SetCode(s.address)
if s.db.logger != nil && s.db.logger.OnCodeChange != nil {
s.db.logger.OnCodeChange(s.address, common.BytesToHash(s.CodeHash()), prevcode, codeHash, code)
// TODO remove prevcode from this callback
s.db.logger.OnCodeChange(s.address, common.BytesToHash(s.CodeHash()), nil, codeHash, code)
}
s.setCode(codeHash, code)
}
Expand All @@ -608,10 +589,7 @@ func (s *stateObject) setCode(codeHash common.Hash, code []byte) {
}

func (s *stateObject) SetNonce(nonce uint64) {
s.db.journal.append(nonceChange{
account: &s.address,
prev: s.data.Nonce,
})
s.db.journal.NonceChange(s.address, s.data.Nonce)
if s.db.logger != nil && s.db.logger.OnNonceChange != nil {
s.db.logger.OnNonceChange(s.address, s.data.Nonce, nonce)
}
Expand Down
Loading