Skip to content

Commit 58ded5b

Browse files
committed
feat: vm.OperationEnvironment for custom instructions
1 parent 7d054b9 commit 58ded5b

File tree

2 files changed

+39
-8
lines changed

2 files changed

+39
-8
lines changed

core/vm/jump_table.libevm.go

Lines changed: 23 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -18,7 +18,7 @@ func overrideJumpTable(r params.Rules, jt *JumpTable) *JumpTable {
1818
// An OperationBuilder is a factory for a new operations to include in a
1919
// [JumpTable].
2020
type OperationBuilder struct {
21-
Execute func(pc *uint64, interpreter *EVMInterpreter, callContext *ScopeContext) ([]byte, error)
21+
Execute OperationFunc
2222
ConstantGas uint64
2323
DynamicGas func(_ *EVM, _ *Contract, _ *Stack, _ *Memory, requestedMemorySize uint64) (uint64, error)
2424
MinStack, MaxStack int
@@ -28,7 +28,7 @@ type OperationBuilder struct {
2828
// Build constructs the operation.
2929
func (b OperationBuilder) Build() *operation {
3030
o := &operation{
31-
execute: b.Execute,
31+
execute: b.Execute.internal(),
3232
constantGas: b.ConstantGas,
3333
dynamicGas: b.DynamicGas,
3434
minStack: b.MinStack,
@@ -38,6 +38,27 @@ func (b OperationBuilder) Build() *operation {
3838
return o
3939
}
4040

41+
// An OperationFunc is the execution function of a custom instruction.
42+
type OperationFunc func(_ *OperationEnvironment, pc *uint64, _ *EVMInterpreter, _ *ScopeContext) ([]byte, error)
43+
44+
// An OperationEnvironment provides information about the context in which a
45+
// custom instruction is being executed.
46+
type OperationEnvironment struct {
47+
StateDB StateDB
48+
}
49+
50+
// internal converts an exported [OperationFunc] into an un-exported
51+
// [executionFunc] as required to build an [operation].
52+
func (fn OperationFunc) internal() executionFunc {
53+
return func(pc *uint64, interpreter *EVMInterpreter, scope *ScopeContext) ([]byte, error) {
54+
return fn(
55+
&OperationEnvironment{
56+
StateDB: interpreter.evm.StateDB,
57+
}, pc, interpreter, scope,
58+
)
59+
}
60+
}
61+
4162
// Hooks are arbitrary configuration functions to modify default VM behaviour.
4263
type Hooks interface {
4364
// OverrideJumpTable will only be called if

core/vm/jump_table.libevm_test.go

Lines changed: 16 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -10,6 +10,7 @@ import (
1010
"github.com/stretchr/testify/assert"
1111
"github.com/stretchr/testify/require"
1212

13+
"github.com/ethereum/go-ethereum/common"
1314
"github.com/ethereum/go-ethereum/core/vm"
1415
"github.com/ethereum/go-ethereum/libevm/ethtest"
1516
"github.com/ethereum/go-ethereum/libevm/hookstest"
@@ -35,6 +36,16 @@ func (s *vmHooksStub) OverrideJumpTable(_ params.Rules, jt *vm.JumpTable) *vm.Ju
3536
return jt
3637
}
3738

39+
// An opRecorder is an instruction that records its inputs.
40+
type opRecorder struct {
41+
stateVal common.Hash
42+
}
43+
44+
func (op *opRecorder) execute(env *vm.OperationEnvironment, pc *uint64, interpreter *vm.EVMInterpreter, scope *vm.ScopeContext) ([]byte, error) {
45+
op.stateVal = env.StateDB.GetState(scope.Contract.Address(), common.Hash{})
46+
return nil, nil
47+
}
48+
3849
func TestOverrideJumpTable(t *testing.T) {
3950
override := new(bool)
4051
hooks := &hookstest.Stub{
@@ -50,15 +61,12 @@ func TestOverrideJumpTable(t *testing.T) {
5061
)
5162
rng := ethtest.NewPseudoRand(142857)
5263
gasCost := 1 + rng.Uint64n(gasLimit)
53-
executed := false
64+
spy := &opRecorder{}
5465

5566
vmHooks := &vmHooksStub{
5667
replacement: &vm.JumpTable{
5768
opcode: vm.OperationBuilder{
58-
Execute: func(pc *uint64, interpreter *vm.EVMInterpreter, callContext *vm.ScopeContext) ([]byte, error) {
59-
executed = true
60-
return nil, nil
61-
},
69+
Execute: spy.execute,
6270
ConstantGas: gasCost,
6371
MemorySize: func(s *vm.Stack) (size uint64, overflow bool) {
6472
return 0, false
@@ -91,11 +99,13 @@ func TestOverrideJumpTable(t *testing.T) {
9199
contract := rng.Address()
92100
state.CreateAccount(contract)
93101
state.SetCode(contract, []byte{opcode})
102+
value := rng.Hash()
103+
state.SetState(contract, common.Hash{}, value)
94104

95105
_, gasRemaining, err := evm.Call(vm.AccountRef(rng.Address()), contract, []byte{}, gasLimit, uint256.NewInt(0))
96106
require.NoError(t, err, "evm.Call([contract with overridden opcode])")
97-
assert.True(t, executed, "executionFunc was called")
98107
assert.Equal(t, gasLimit-gasCost, gasRemaining, "gas remaining")
108+
assert.Equal(t, spy.stateVal, value, "StateDB propagated")
99109
})
100110
}
101111

0 commit comments

Comments
 (0)