Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
18 changes: 18 additions & 0 deletions core/state/statedb.go
Original file line number Diff line number Diff line change
Expand Up @@ -130,6 +130,8 @@ type StateDB struct {

// Transient storage
transientStorage transientStorage
// Overrides to apply after Prepare()
pendingTransientOverrides map[common.Address]map[common.Hash]common.Hash

// Journal of state modifications. This is the backbone of
// Snapshot and RevertToSnapshot.
Expand Down Expand Up @@ -1430,6 +1432,22 @@ func (s *StateDB) Prepare(rules params.Rules, sender, coinbase common.Address, d
}
// Reset transient storage at the beginning of transaction execution
s.transientStorage = newTransientStorage()

// Apply any pending transient storage overrides after reset
if s.pendingTransientOverrides != nil {
for addr, storage := range s.pendingTransientOverrides {
for key, value := range storage {
s.transientStorage.Set(addr, key, value)
}
}
s.pendingTransientOverrides = nil // Clear after applying
}
}

// SetPendingTransientOverrides stores transient storage overrides to be applied
// after the next Prepare() call. This ensures overrides are applied to a clean state.
func (s *StateDB) SetPendingTransientOverrides(overrides map[common.Address]map[common.Hash]common.Hash) {
s.pendingTransientOverrides = overrides
}

// AddAddressToAccessList adds the given address to the access list
Expand Down
29 changes: 29 additions & 0 deletions ethclient/gethclient/gethclient.go
Original file line number Diff line number Diff line change
Expand Up @@ -167,6 +167,31 @@ func (ec *Client) CallContractWithBlockOverrides(ctx context.Context, msg ethere
return hex, err
}

// CallContractWithTransientOverrides executes a message call transaction, which is directly executed
// in the VM of the node, but never mined into the blockchain.
//
// blockNumber selects the block height at which the call runs. It can be nil, in which
// case the code is taken from the latest known block. Note that state from very old
// blocks might not be available.
//
// overrides specifies a map of contract states that should be overwritten before executing
// the message call.
//
// blockOverrides specifies block fields exposed to the EVM that can be overridden for the call.
//
// transientOverrides specifies transient storage slots that should be overwritten before executing
// the message call. Transient storage is reset at the beginning of each transaction.
//
// Please use ethclient.CallContract instead if you don't need the override functionality.
func (ec *Client) CallContractWithTransientOverrides(ctx context.Context, msg ethereum.CallMsg, blockNumber *big.Int, overrides *map[common.Address]OverrideAccount, blockOverrides *BlockOverrides, transientOverrides *TransientOverrides) ([]byte, error) {
var hex hexutil.Bytes
err := ec.c.CallContext(
ctx, &hex, "eth_call", toCallArg(msg),
toBlockNumArg(blockNumber), overrides, blockOverrides, transientOverrides,
)
return hex, err
}

// GCStats retrieves the current garbage collection stats from a geth node.
func (ec *Client) GCStats(ctx context.Context) (*debug.GCStats, error) {
var result debug.GCStats
Expand Down Expand Up @@ -348,6 +373,10 @@ type BlockOverrides struct {
BaseFee *big.Int
}

// TransientOverrides specifies transient storage slots to override for eth_call.
// The map key is the contract address, and the value is another map of slot to value.
type TransientOverrides map[common.Address]map[common.Hash]common.Hash

func (o BlockOverrides) MarshalJSON() ([]byte, error) {
type override struct {
Number *hexutil.Big `json:"number,omitempty"`
Expand Down
4 changes: 2 additions & 2 deletions graphql/graphql.go
Original file line number Diff line number Diff line change
Expand Up @@ -1189,7 +1189,7 @@ func (c *CallResult) Status() hexutil.Uint64 {
func (b *Block) Call(ctx context.Context, args struct {
Data ethapi.TransactionArgs
}) (*CallResult, error) {
result, err := ethapi.DoCall(ctx, b.r.backend, args.Data, *b.numberOrHash, nil, nil, b.r.backend.RPCEVMTimeout(), b.r.backend.RPCGasCap())
result, err := ethapi.DoCall(ctx, b.r.backend, args.Data, *b.numberOrHash, nil, nil, nil, b.r.backend.RPCEVMTimeout(), b.r.backend.RPCGasCap())
if err != nil {
return nil, err
}
Expand Down Expand Up @@ -1252,7 +1252,7 @@ func (p *Pending) Call(ctx context.Context, args struct {
Data ethapi.TransactionArgs
}) (*CallResult, error) {
pendingBlockNr := rpc.BlockNumberOrHashWithNumber(rpc.PendingBlockNumber)
result, err := ethapi.DoCall(ctx, p.r.backend, args.Data, pendingBlockNr, nil, nil, p.r.backend.RPCEVMTimeout(), p.r.backend.RPCGasCap())
result, err := ethapi.DoCall(ctx, p.r.backend, args.Data, pendingBlockNr, nil, nil, nil, p.r.backend.RPCEVMTimeout(), p.r.backend.RPCGasCap())
if err != nil {
return nil, err
}
Expand Down
13 changes: 8 additions & 5 deletions internal/ethapi/api.go
Original file line number Diff line number Diff line change
Expand Up @@ -669,7 +669,7 @@ func (context *ChainContext) Config() *params.ChainConfig {
return context.b.ChainConfig()
}

func doCall(ctx context.Context, b Backend, args TransactionArgs, state *state.StateDB, header *types.Header, overrides *override.StateOverride, blockOverrides *override.BlockOverrides, timeout time.Duration, globalGasCap uint64) (*core.ExecutionResult, error) {
func doCall(ctx context.Context, b Backend, args TransactionArgs, state *state.StateDB, header *types.Header, overrides *override.StateOverride, blockOverrides *override.BlockOverrides, transientOverrides *override.TransientStorageOverride, timeout time.Duration, globalGasCap uint64) (*core.ExecutionResult, error) {
blockCtx := core.NewEVMBlockContext(header, NewChainContext(ctx, b), nil)
if blockOverrides != nil {
if err := blockOverrides.Apply(&blockCtx); err != nil {
Expand All @@ -682,6 +682,9 @@ func doCall(ctx context.Context, b Backend, args TransactionArgs, state *state.S
return nil, err
}

// Apply transient storage overrides. These will be applied after Prepare() is called.
transientOverrides.Apply(state)

// Setup context so it may be cancelled the call has completed
// or, in case of unmetered gas, setup a context with a timeout.
var cancel context.CancelFunc
Expand Down Expand Up @@ -750,14 +753,14 @@ func applyMessageWithEVM(ctx context.Context, evm *vm.EVM, msg *core.Message, ti
return result, nil
}

func DoCall(ctx context.Context, b Backend, args TransactionArgs, blockNrOrHash rpc.BlockNumberOrHash, overrides *override.StateOverride, blockOverrides *override.BlockOverrides, timeout time.Duration, globalGasCap uint64) (*core.ExecutionResult, error) {
func DoCall(ctx context.Context, b Backend, args TransactionArgs, blockNrOrHash rpc.BlockNumberOrHash, overrides *override.StateOverride, blockOverrides *override.BlockOverrides, transientOverrides *override.TransientStorageOverride, timeout time.Duration, globalGasCap uint64) (*core.ExecutionResult, error) {
defer func(start time.Time) { log.Debug("Executing EVM call finished", "runtime", time.Since(start)) }(time.Now())

state, header, err := b.StateAndHeaderByNumberOrHash(ctx, blockNrOrHash)
if state == nil || err != nil {
return nil, err
}
return doCall(ctx, b, args, state, header, overrides, blockOverrides, timeout, globalGasCap)
return doCall(ctx, b, args, state, header, overrides, blockOverrides, transientOverrides, timeout, globalGasCap)
}

// Call executes the given transaction on the state for the given block number.
Expand All @@ -766,12 +769,12 @@ func DoCall(ctx context.Context, b Backend, args TransactionArgs, blockNrOrHash
//
// Note, this function doesn't make and changes in the state/blockchain and is
// useful to execute and retrieve values.
func (api *BlockChainAPI) Call(ctx context.Context, args TransactionArgs, blockNrOrHash *rpc.BlockNumberOrHash, overrides *override.StateOverride, blockOverrides *override.BlockOverrides) (hexutil.Bytes, error) {
func (api *BlockChainAPI) Call(ctx context.Context, args TransactionArgs, blockNrOrHash *rpc.BlockNumberOrHash, overrides *override.StateOverride, blockOverrides *override.BlockOverrides, transientOverrides *override.TransientStorageOverride) (hexutil.Bytes, error) {
if blockNrOrHash == nil {
latest := rpc.BlockNumberOrHashWithNumber(rpc.LatestBlockNumber)
blockNrOrHash = &latest
}
result, err := DoCall(ctx, api.b, args, *blockNrOrHash, overrides, blockOverrides, api.b.RPCEVMTimeout(), api.b.RPCGasCap())
result, err := DoCall(ctx, api.b, args, *blockNrOrHash, overrides, blockOverrides, transientOverrides, api.b.RPCEVMTimeout(), api.b.RPCGasCap())
if err != nil {
return nil, err
}
Expand Down
38 changes: 30 additions & 8 deletions internal/ethapi/api_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -967,13 +967,14 @@ func TestCall(t *testing.T) {
}))
randomAccounts := newAccounts(3)
var testSuite = []struct {
name string
blockNumber rpc.BlockNumber
overrides override.StateOverride
call TransactionArgs
blockOverrides override.BlockOverrides
expectErr error
want string
name string
blockNumber rpc.BlockNumber
overrides override.StateOverride
call TransactionArgs
blockOverrides override.BlockOverrides
transientOverrides override.TransientStorageOverride
expectErr error
want string
}{
// transfer on genesis
{
Expand Down Expand Up @@ -1226,9 +1227,30 @@ func TestCall(t *testing.T) {
},
expectErr: errors.New(`block override "withdrawals" is not supported for this RPC method`),
},
// Test transient storage override
{
name: "transient storage override takes effect",
blockNumber: rpc.LatestBlockNumber,
call: TransactionArgs{
From: &accounts[1].addr,
To: &randomAccounts[2].addr,
},
overrides: override.StateOverride{
randomAccounts[2].addr: override.OverrideAccount{
// PUSH1 0x00 TLOAD PUSH1 0x00 MSTORE PUSH1 0x20 PUSH1 0x00 RETURN
Code: hex2Bytes("0x60005c60005260206000f3"),
},
},
transientOverrides: override.TransientStorageOverride{
randomAccounts[2].addr: map[common.Hash]common.Hash{
common.Hash{}: common.HexToHash("0xabcd"),
},
},
want: "0x000000000000000000000000000000000000000000000000000000000000abcd",
},
}
for _, tc := range testSuite {
result, err := api.Call(context.Background(), tc.call, &rpc.BlockNumberOrHash{BlockNumber: &tc.blockNumber}, &tc.overrides, &tc.blockOverrides)
result, err := api.Call(context.Background(), tc.call, &rpc.BlockNumberOrHash{BlockNumber: &tc.blockNumber}, &tc.overrides, &tc.blockOverrides, &tc.transientOverrides)
if tc.expectErr != nil {
if err == nil {
t.Errorf("test %s: want error %v, have nothing", tc.name, tc.expectErr)
Expand Down
13 changes: 12 additions & 1 deletion internal/ethapi/override/override.go
Original file line number Diff line number Diff line change
Expand Up @@ -33,7 +33,7 @@ import (
// OverrideAccount indicates the overriding fields of account during the execution
// of a message call.
// Note, state and stateDiff can't be specified at the same time. If state is
// set, message execution will only use the data in the given state. Otherwise
// set, message execution will only use the data in the given state. Otherwise,
// if stateDiff is set, all diff will be applied first and then execute the call
// message.
type OverrideAccount struct {
Expand All @@ -48,6 +48,9 @@ type OverrideAccount struct {
// StateOverride is the collection of overridden accounts.
type StateOverride map[common.Address]OverrideAccount

// TransientStorageOverride is the collection of transient storage overrides.
type TransientStorageOverride map[common.Address]map[common.Hash]common.Hash

func (diff *StateOverride) has(address common.Address) bool {
_, ok := (*diff)[address]
return ok
Expand Down Expand Up @@ -119,6 +122,14 @@ func (diff *StateOverride) Apply(statedb *state.StateDB, precompiles vm.Precompi
return nil
}

// Apply stores transient storage overrides to be applied after Prepare().
func (diff *TransientStorageOverride) Apply(statedb *state.StateDB) {
if diff == nil || len(*diff) == 0 {
return
}
statedb.SetPendingTransientOverrides(*diff)
}

// BlockOverrides is a set of header fields to override.
type BlockOverrides struct {
Number *hexutil.Big
Expand Down
54 changes: 54 additions & 0 deletions internal/ethapi/override/override_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -26,6 +26,7 @@ import (
"github.com/ethereum/go-ethereum/core/state"
"github.com/ethereum/go-ethereum/core/types"
"github.com/ethereum/go-ethereum/core/vm"
"github.com/ethereum/go-ethereum/params"
"github.com/ethereum/go-ethereum/triedb"
)

Expand Down Expand Up @@ -128,3 +129,56 @@ func hex2Bytes(str string) *hexutil.Bytes {
rpcBytes := hexutil.Bytes(common.FromHex(str))
return &rpcBytes
}

func TestStateOverrideTransientStorage(t *testing.T) {
db := state.NewDatabase(triedb.NewDatabase(rawdb.NewMemoryDatabase(), nil), nil)
statedb, err := state.New(types.EmptyRootHash, db)
if err != nil {
t.Fatalf("failed to create statedb: %v", err)
}

addr := common.BytesToAddress([]byte{0x1})
key1 := common.HexToHash("0x0000000000000000000000000000000000000000000000000000000000000001")
key2 := common.HexToHash("0x0000000000000000000000000000000000000000000000000000000000000002")
value1 := common.HexToHash("0x1111111111111111111111111111111111111111111111111111111111111111")
value2 := common.HexToHash("0x2222222222222222222222222222222222222222222222222222222222222222")

// Verify initial state is empty
if got := statedb.GetTransientState(addr, key1); got != (common.Hash{}) {
t.Fatalf("expected initial transient state to be empty, got %s", got.Hex())
}
if got := statedb.GetTransientState(addr, key2); got != (common.Hash{}) {
t.Fatalf("expected initial transient state to be empty, got %s", got.Hex())
}

// Apply transient storage override
transientOverride := TransientStorageOverride{
addr: map[common.Hash]common.Hash{
key1: value1,
key2: value2,
},
}

transientOverride.Apply(statedb)

statedb.Prepare(params.Rules{}, common.Address{}, common.Address{}, nil, nil, nil)

// Verify transient storage was set
if got := statedb.GetTransientState(addr, key1); got != value1 {
t.Errorf("expected transient state for key1 to be %s, got %s", value1.Hex(), got.Hex())
}
if got := statedb.GetTransientState(addr, key2); got != value2 {
t.Errorf("expected transient state for key2 to be %s, got %s", value2.Hex(), got.Hex())
}

// Verify other addresses/keys remain empty
otherAddr := common.BytesToAddress([]byte{0x2})
if got := statedb.GetTransientState(otherAddr, key1); got != (common.Hash{}) {
t.Errorf("expected transient state for different address to be empty, got %s", got.Hex())
}

otherKey := common.HexToHash("0x0000000000000000000000000000000000000000000000000000000000000003")
if got := statedb.GetTransientState(addr, otherKey); got != (common.Hash{}) {
t.Errorf("expected transient state for different key to be empty, got %s", got.Hex())
}
}