Skip to content

Commit aee5be0

Browse files
committed
feat: types.HeaderHooks
1 parent bd44839 commit aee5be0

10 files changed

+252
-51
lines changed

core/state/state.libevm.go

Lines changed: 13 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -19,46 +19,47 @@ package state
1919
import (
2020
"github.com/ava-labs/libevm/common"
2121
"github.com/ava-labs/libevm/core/types"
22+
"github.com/ava-labs/libevm/libevm/pseudo"
2223
)
2324

2425
// GetExtra returns the extra payload from the [types.StateAccount] associated
25-
// with the address, or a zero-value `SA` if not found. The
26-
// [types.ExtraPayloads] MUST be sourced from [types.RegisterExtras].
27-
func GetExtra[SA any](s *StateDB, p types.ExtraPayloads[SA], addr common.Address) SA {
26+
// with the address, or a zero-value `SA` if not found. The [pseudo.Accessor]
27+
// MUST be sourced from [types.RegisterExtras].
28+
func GetExtra[SA any](s *StateDB, a pseudo.Accessor[types.StateOrSlimAccount, SA], addr common.Address) SA {
2829
stateObject := s.getStateObject(addr)
2930
if stateObject != nil {
30-
return p.StateAccount.Get(&stateObject.data)
31+
return a.Get(&stateObject.data)
3132
}
3233
var zero SA
3334
return zero
3435
}
3536

3637
// SetExtra sets the extra payload for the address. See [GetExtra] for details.
37-
func SetExtra[SA any](s *StateDB, p types.ExtraPayloads[SA], addr common.Address, extra SA) {
38+
func SetExtra[SA any](s *StateDB, a pseudo.Accessor[types.StateOrSlimAccount, SA], addr common.Address, extra SA) {
3839
stateObject := s.getOrNewStateObject(addr)
3940
if stateObject != nil {
40-
setExtraOnObject(stateObject, p, addr, extra)
41+
setExtraOnObject(stateObject, a, addr, extra)
4142
}
4243
}
4344

44-
func setExtraOnObject[SA any](s *stateObject, p types.ExtraPayloads[SA], addr common.Address, extra SA) {
45+
func setExtraOnObject[SA any](s *stateObject, a pseudo.Accessor[types.StateOrSlimAccount, SA], addr common.Address, extra SA) {
4546
s.db.journal.append(extraChange[SA]{
46-
payloads: p,
47+
accessor: a,
4748
account: &addr,
48-
prev: p.StateAccount.Get(&s.data),
49+
prev: a.Get(&s.data),
4950
})
50-
p.StateAccount.Set(&s.data, extra)
51+
a.Set(&s.data, extra)
5152
}
5253

5354
// extraChange is a [journalEntry] for [SetExtra] / [setExtraOnObject].
5455
type extraChange[SA any] struct {
55-
payloads types.ExtraPayloads[SA]
56+
accessor pseudo.Accessor[types.StateOrSlimAccount, SA]
5657
account *common.Address
5758
prev SA
5859
}
5960

6061
func (e extraChange[SA]) dirtied() *common.Address { return e.account }
6162

6263
func (e extraChange[SA]) revert(s *StateDB) {
63-
e.payloads.StateAccount.Set(&s.getStateObject(*e.account).data, e.prev)
64+
e.accessor.Set(&s.getStateObject(*e.account).data, e.prev)
6465
}

core/state/state.libevm_test.go

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -45,7 +45,7 @@ func TestGetSetExtra(t *testing.T) {
4545
t.Cleanup(types.TestOnlyClearRegisteredExtras)
4646
// Just as its Data field is a pointer, the registered type is a pointer to
4747
// test deep copying.
48-
payloads := types.RegisterExtras[*accountExtra]()
48+
payloads := types.RegisterExtras[types.NOOPHeaderHooks, *types.NOOPHeaderHooks, *accountExtra]().StateAccount
4949

5050
rng := ethtest.NewPseudoRand(42)
5151
addr := rng.Address()
@@ -87,7 +87,7 @@ func TestGetSetExtra(t *testing.T) {
8787
Root: types.EmptyRootHash,
8888
CodeHash: types.EmptyCodeHash[:],
8989
}
90-
payloads.StateAccount.Set(want, extra)
90+
payloads.Set(want, extra)
9191

9292
if diff := cmp.Diff(want, got); diff != "" {
9393
t.Errorf("types.FullAccount(%T.Account()) diff (-want +got):\n%s", iter, diff)

core/state/state_object.libevm_test.go

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -46,21 +46,21 @@ func TestStateObjectEmpty(t *testing.T) {
4646
{
4747
name: "explicit false bool",
4848
registerAndSet: func(acc *types.StateAccount) {
49-
types.RegisterExtras[bool]().StateAccount.Set(acc, false)
49+
types.RegisterExtras[types.NOOPHeaderHooks, *types.NOOPHeaderHooks, bool]().StateAccount.Set(acc, false)
5050
},
5151
wantEmpty: true,
5252
},
5353
{
5454
name: "implicit false bool",
5555
registerAndSet: func(*types.StateAccount) {
56-
types.RegisterExtras[bool]()
56+
types.RegisterExtras[types.NOOPHeaderHooks, *types.NOOPHeaderHooks, bool]()
5757
},
5858
wantEmpty: true,
5959
},
6060
{
6161
name: "true bool",
6262
registerAndSet: func(acc *types.StateAccount) {
63-
types.RegisterExtras[bool]().StateAccount.Set(acc, true)
63+
types.RegisterExtras[types.NOOPHeaderHooks, *types.NOOPHeaderHooks, bool]().StateAccount.Set(acc, true)
6464
},
6565
wantEmpty: false,
6666
},

core/types/block.go

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -28,6 +28,7 @@ import (
2828

2929
"github.com/ava-labs/libevm/common"
3030
"github.com/ava-labs/libevm/common/hexutil"
31+
"github.com/ava-labs/libevm/libevm/pseudo"
3132
"github.com/ava-labs/libevm/rlp"
3233
)
3334

@@ -93,6 +94,8 @@ type Header struct {
9394

9495
// ParentBeaconRoot was added by EIP-4788 and is ignored in legacy headers.
9596
ParentBeaconRoot *common.Hash `json:"parentBeaconBlockRoot" rlp:"optional"`
97+
98+
extra *pseudo.Type // See RegisterExtras()
9699
}
97100

98101
// field type overrides for gencodec

core/types/block.libevm.go

Lines changed: 62 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -17,13 +17,74 @@
1717
package types
1818

1919
import (
20+
"fmt"
2021
"io"
2122

23+
"github.com/ava-labs/libevm/libevm/pseudo"
2224
"github.com/ava-labs/libevm/rlp"
2325
)
2426

27+
// HeaderHooks are required for all types registered with [RegisterExtras] for
28+
// [Header] payloads.
29+
type HeaderHooks interface {
30+
EncodeRLP(*Header, io.Writer) error
31+
DecodeRLP(*Header, *rlp.Stream) error
32+
}
33+
34+
var _ interface {
35+
rlp.Encoder
36+
rlp.Decoder
37+
} = (*Header)(nil)
38+
39+
// EncodeRLP implements the [rlp.Encoder] interface.
2540
func (h *Header) EncodeRLP(w io.Writer) error {
41+
if r := &registeredExtras; r.Registered() {
42+
return r.Get().hooks.hooksFromHeader(h).EncodeRLP(h, w)
43+
}
2644
return h.encodeRLP(w)
2745
}
2846

29-
var _ rlp.Encoder = (*Header)(nil)
47+
// decodeHeaderRLPDirectly bypasses the [Header.DecodeRLP] method to avoid
48+
// infinite recursion.
49+
func decodeHeaderRLPDirectly(h *Header, s *rlp.Stream) error {
50+
type withoutMethods Header
51+
return s.Decode((*withoutMethods)(h))
52+
}
53+
54+
// DecodeRLP implements the [rlp.Decoder] interface.
55+
func (h *Header) DecodeRLP(s *rlp.Stream) error {
56+
if r := &registeredExtras; r.Registered() {
57+
return r.Get().hooks.hooksFromHeader(h).DecodeRLP(h, s)
58+
}
59+
return decodeHeaderRLPDirectly(h, s)
60+
}
61+
62+
func (e ExtraPayloads[HPtr, SA]) hooksFromHeader(h *Header) HeaderHooks {
63+
return e.Header.Get(h)
64+
}
65+
66+
func (h *Header) extraPayload() *pseudo.Type {
67+
r := &registeredExtras
68+
if !r.Registered() {
69+
// See params.ChainConfig.extraPayload() for panic rationale.
70+
panic(fmt.Sprintf("%T.extraPayload() called before RegisterExtras()", r))
71+
}
72+
if h.extra == nil {
73+
h.extra = r.Get().newHeader()
74+
}
75+
return h.extra
76+
}
77+
78+
// NOOPHeaderHooks implements [HeaderHooks] such that they are equivalent to
79+
// no type having been registered.
80+
type NOOPHeaderHooks struct{}
81+
82+
var _ HeaderHooks = (*NOOPHeaderHooks)(nil)
83+
84+
func (*NOOPHeaderHooks) EncodeRLP(h *Header, w io.Writer) error {
85+
return h.encodeRLP(w)
86+
}
87+
88+
func (*NOOPHeaderHooks) DecodeRLP(h *Header, s *rlp.Stream) error {
89+
return decodeHeaderRLPDirectly(h, s)
90+
}

core/types/block.libevm_test.go

Lines changed: 112 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,112 @@
1+
// Copyright 2024 the libevm authors.
2+
//
3+
// The libevm additions to go-ethereum are free software: you can redistribute
4+
// them and/or modify them under the terms of the GNU Lesser General Public License
5+
// as published by the Free Software Foundation, either version 3 of the License,
6+
// or (at your option) any later version.
7+
//
8+
// The libevm additions are distributed in the hope that they will be useful,
9+
// but WITHOUT ANY WARRANTY; without even the implied warranty of
10+
// MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the GNU Lesser
11+
// General Public License for more details.
12+
//
13+
// You should have received a copy of the GNU Lesser General Public License
14+
// along with the go-ethereum library. If not, see
15+
// <http://www.gnu.org/licenses/>.
16+
17+
package types_test
18+
19+
import (
20+
"errors"
21+
"io"
22+
"testing"
23+
24+
"github.com/stretchr/testify/assert"
25+
"github.com/stretchr/testify/require"
26+
27+
. "github.com/ava-labs/libevm/core/types"
28+
"github.com/ava-labs/libevm/crypto"
29+
"github.com/ava-labs/libevm/libevm/ethtest"
30+
"github.com/ava-labs/libevm/rlp"
31+
)
32+
33+
type stubHeaderHooks struct {
34+
rlpSuffix []byte
35+
gotRawRLPToDecode []byte
36+
setHeaderToOnDecode Header
37+
38+
errEncode, errDecode error
39+
}
40+
41+
func fakeHeaderRLP(h *Header, suffix []byte) []byte {
42+
return append(crypto.Keccak256(h.ParentHash[:]), suffix...)
43+
}
44+
45+
func (hh *stubHeaderHooks) EncodeRLP(h *Header, w io.Writer) error {
46+
if _, err := w.Write(fakeHeaderRLP(h, hh.rlpSuffix)); err != nil {
47+
return err
48+
}
49+
return hh.errEncode
50+
}
51+
52+
func (hh *stubHeaderHooks) DecodeRLP(h *Header, s *rlp.Stream) error {
53+
r, err := s.Raw()
54+
if err != nil {
55+
return err
56+
}
57+
hh.gotRawRLPToDecode = r
58+
*h = hh.setHeaderToOnDecode
59+
return hh.errDecode
60+
}
61+
62+
func TestHeaderHooks(t *testing.T) {
63+
TestOnlyClearRegisteredExtras()
64+
defer TestOnlyClearRegisteredExtras()
65+
66+
extras := RegisterExtras[stubHeaderHooks, *stubHeaderHooks, struct{}]()
67+
rng := ethtest.NewPseudoRand(13579)
68+
69+
t.Run("EncodeRLP", func(t *testing.T) {
70+
suffix := rng.Bytes(8)
71+
72+
hdr := &Header{
73+
ParentHash: rng.Hash(),
74+
}
75+
extras.Header.Get(hdr).rlpSuffix = append([]byte{}, suffix...)
76+
77+
got, err := rlp.EncodeToBytes(hdr)
78+
require.NoError(t, err, "rlp.EncodeToBytes(%T)", hdr)
79+
assert.Equal(t, fakeHeaderRLP(hdr, suffix), got)
80+
})
81+
82+
t.Run("DecodeRLP", func(t *testing.T) {
83+
input, err := rlp.EncodeToBytes(rng.Bytes(8))
84+
require.NoError(t, err)
85+
86+
hdr := new(Header)
87+
stub := &stubHeaderHooks{
88+
setHeaderToOnDecode: Header{
89+
Extra: []byte("arr4n was here"),
90+
},
91+
}
92+
extras.Header.Set(hdr, stub)
93+
require.NoErrorf(t, rlp.DecodeBytes(input, hdr), "rlp.DecodeBytes(%#x)", input)
94+
95+
assert.Equal(t, input, stub.gotRawRLPToDecode, "raw RLP received by hooks")
96+
assert.Equalf(t, &stub.setHeaderToOnDecode, hdr, "%T after RLP decoding with hook")
97+
})
98+
99+
t.Run("error propagation", func(t *testing.T) {
100+
errEncode := errors.New("uh oh")
101+
errDecode := errors.New("something bad happened")
102+
103+
hdr := new(Header)
104+
extras.Header.Set(hdr, &stubHeaderHooks{
105+
errEncode: errEncode,
106+
errDecode: errDecode,
107+
})
108+
109+
assert.Equal(t, errEncode, rlp.Encode(io.Discard, hdr), "via rlp.Encode()")
110+
assert.Equal(t, errDecode, rlp.DecodeBytes([]byte{0}, hdr), "via rlp.DecodeBytes()")
111+
})
112+
}

0 commit comments

Comments
 (0)