diff --git a/core/types/rlp_payload.libevm.go b/core/types/rlp_payload.libevm.go index 3d3346c028e..f6d13adfa52 100644 --- a/core/types/rlp_payload.libevm.go +++ b/core/types/rlp_payload.libevm.go @@ -21,6 +21,7 @@ import ( "io" "github.com/ava-labs/libevm/libevm/pseudo" + "github.com/ava-labs/libevm/libevm/register" "github.com/ava-labs/libevm/libevm/testonly" "github.com/ava-labs/libevm/rlp" ) @@ -37,18 +38,15 @@ import ( // The payload can be accessed via the [ExtraPayloads.FromPayloadCarrier] method // of the accessor returned by RegisterExtras. func RegisterExtras[SA any]() ExtraPayloads[SA] { - if registeredExtras != nil { - panic("re-registration of Extras") - } var extra ExtraPayloads[SA] - registeredExtras = &extraConstructors{ + registeredExtras.MustRegister(&extraConstructors{ stateAccountType: func() string { var x SA return fmt.Sprintf("%T", x) }(), newStateAccount: pseudo.NewConstructor[SA]().Zero, cloneStateAccount: extra.cloneStateAccount, - } + }) return extra } @@ -59,12 +57,10 @@ func RegisterExtras[SA any]() ExtraPayloads[SA] { // defer-called afterwards, either directly or via testing.TB.Cleanup(). This is // a workaround for the single-call limitation on [RegisterExtras]. func TestOnlyClearRegisteredExtras() { - testonly.OrPanic(func() { - registeredExtras = nil - }) + registeredExtras.TestOnlyClear() } -var registeredExtras *extraConstructors +var registeredExtras register.AtMostOnce[*extraConstructors] type extraConstructors struct { stateAccountType string @@ -74,10 +70,10 @@ type extraConstructors struct { func (e *StateAccountExtra) clone() *StateAccountExtra { switch r := registeredExtras; { - case r == nil, e == nil: + case !r.Registered(), e == nil: return nil default: - return r.cloneStateAccount(e) + return r.Get().cloneStateAccount(e) } } @@ -146,7 +142,7 @@ func (a *SlimAccount) extra() *StateAccountExtra { func getOrSetNewStateAccountExtra(curr **StateAccountExtra) *StateAccountExtra { if *curr == nil { *curr = &StateAccountExtra{ - t: registeredExtras.newStateAccount(), + t: registeredExtras.Get().newStateAccount(), } } return *curr @@ -154,7 +150,7 @@ func getOrSetNewStateAccountExtra(curr **StateAccountExtra) *StateAccountExtra { func (e *StateAccountExtra) payload() *pseudo.Type { if e.t == nil { - e.t = registeredExtras.newStateAccount() + e.t = registeredExtras.Get().newStateAccount() } return e.t } @@ -196,13 +192,13 @@ var _ interface { // EncodeRLP implements the [rlp.Encoder] interface. func (e *StateAccountExtra) EncodeRLP(w io.Writer) error { switch r := registeredExtras; { - case r == nil: + case !r.Registered(): return nil case e == nil: e = &StateAccountExtra{} fallthrough case e.t == nil: - e.t = r.newStateAccount() + e.t = r.Get().newStateAccount() } return e.t.EncodeRLP(w) } @@ -210,10 +206,10 @@ func (e *StateAccountExtra) EncodeRLP(w io.Writer) error { // DecodeRLP implements the [rlp.Decoder] interface. func (e *StateAccountExtra) DecodeRLP(s *rlp.Stream) error { switch r := registeredExtras; { - case r == nil: + case !r.Registered(): return nil case e.t == nil: - e.t = r.newStateAccount() + e.t = r.Get().newStateAccount() fallthrough default: return s.Decode(e.t) @@ -224,10 +220,10 @@ func (e *StateAccountExtra) DecodeRLP(s *rlp.Stream) error { func (e *StateAccountExtra) Format(s fmt.State, verb rune) { var out string switch r := registeredExtras; { - case r == nil: + case !r.Registered(): out = "" case e == nil, e.t == nil: - out = fmt.Sprintf("[*StateAccountExtra[%s]]", r.stateAccountType) + out = fmt.Sprintf("[*StateAccountExtra[%s]]", r.Get().stateAccountType) default: e.t.Format(s, verb) return diff --git a/libevm/register/register.go b/libevm/register/register.go new file mode 100644 index 00000000000..0cf3333d413 --- /dev/null +++ b/libevm/register/register.go @@ -0,0 +1,68 @@ +// Copyright 2024 the libevm authors. +// +// The libevm additions to go-ethereum are free software: you can redistribute +// them and/or modify them under the terms of the GNU Lesser General Public License +// as published by the Free Software Foundation, either version 3 of the License, +// or (at your option) any later version. +// +// The libevm additions are distributed in the hope that they will be useful, +// but WITHOUT ANY WARRANTY; without even the implied warranty of +// MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the GNU Lesser +// General Public License for more details. +// +// You should have received a copy of the GNU Lesser General Public License +// along with the go-ethereum library. If not, see +// . + +// Package register provides functionality for optional registration of types. +package register + +import ( + "errors" + + "github.com/ava-labs/libevm/libevm/testonly" +) + +// An AtMostOnce allows zero or one registration of a T. +type AtMostOnce[T any] struct { + v *T +} + +// ErrReRegistration is returned on all but the first of calls to +// [AtMostOnce.Register]. +var ErrReRegistration = errors.New("re-registration") + +// Register registers `v` or returns [ErrReRegistration] if already called. +func (o *AtMostOnce[T]) Register(v T) error { + if o.Registered() { + return ErrReRegistration + } + o.v = &v + return nil +} + +// MustRegister is equivalent to [AtMostOnce.Register], panicking on error. +func (o *AtMostOnce[T]) MustRegister(v T) { + if err := o.Register(v); err != nil { + panic(err) + } +} + +// Registered reports whether [AtMostOnce.Register] has been called. +func (o *AtMostOnce[T]) Registered() bool { + return o.v != nil +} + +// Get returns the registered value. It MUST NOT be called before +// [AtMostOnce.Register]. +func (o *AtMostOnce[T]) Get() T { + return *o.v +} + +// TestOnlyClear clears any previously registered value, returning `o` to its +// default state. It panics if called from a non-testing call stack. +func (o *AtMostOnce[T]) TestOnlyClear() { + testonly.OrPanic(func() { + o.v = nil + }) +} diff --git a/params/config.libevm.go b/params/config.libevm.go index 479dd79b272..a16f8ae4a2d 100644 --- a/params/config.libevm.go +++ b/params/config.libevm.go @@ -22,7 +22,7 @@ import ( "reflect" "github.com/ava-labs/libevm/libevm/pseudo" - "github.com/ava-labs/libevm/libevm/testonly" + "github.com/ava-labs/libevm/libevm/register" ) // Extras are arbitrary payloads to be added as extra fields in [ChainConfig] @@ -68,20 +68,17 @@ type Extras[C ChainConfigHooks, R RulesHooks] struct { // alter Ethereum behaviour; if this isn't desired then they can embed // [NOOPHooks] to satisfy either interface. func RegisterExtras[C ChainConfigHooks, R RulesHooks](e Extras[C, R]) ExtraPayloads[C, R] { - if registeredExtras != nil { - panic("re-registration of Extras") - } mustBeStructOrPointerToOne[C]() mustBeStructOrPointerToOne[R]() payloads := e.payloads() - registeredExtras = &extraConstructors{ + registeredExtras.MustRegister(&extraConstructors{ newChainConfig: pseudo.NewConstructor[C]().Zero, newRules: pseudo.NewConstructor[R]().Zero, reuseJSONRoot: e.ReuseJSONRoot, newForRules: e.newForRules, payloads: payloads, - } + }) return payloads } @@ -92,14 +89,12 @@ func RegisterExtras[C ChainConfigHooks, R RulesHooks](e Extras[C, R]) ExtraPaylo // defer-called afterwards, either directly or via testing.TB.Cleanup(). This is // a workaround for the single-call limitation on [RegisterExtras]. func TestOnlyClearRegisteredExtras() { - testonly.OrPanic(func() { - registeredExtras = nil - }) + registeredExtras.TestOnlyClear() } // registeredExtras holds non-generic constructors for the [Extras] types // registered via [RegisterExtras]. -var registeredExtras *extraConstructors +var registeredExtras register.AtMostOnce[*extraConstructors] type extraConstructors struct { newChainConfig, newRules func() *pseudo.Type @@ -115,7 +110,7 @@ type extraConstructors struct { func (e *Extras[C, R]) newForRules(c *ChainConfig, r *Rules, blockNum *big.Int, isMerge bool, timestamp uint64) *pseudo.Type { if e.NewRules == nil { - return registeredExtras.newRules() + return registeredExtras.Get().newRules() } rExtra := e.NewRules(c, r, e.payloads().FromChainConfig(c), blockNum, isMerge, timestamp) return pseudo.From(rExtra).Type @@ -209,8 +204,8 @@ func (e ExtraPayloads[C, R]) hooksFromRules(r *Rules) RulesHooks { // abstract the libevm-specific behaviour outside of original geth code. func (c *ChainConfig) addRulesExtra(r *Rules, blockNum *big.Int, isMerge bool, timestamp uint64) { r.extra = nil - if registeredExtras != nil { - r.extra = registeredExtras.newForRules(c, r, blockNum, isMerge, timestamp) + if registeredExtras.Registered() { + r.extra = registeredExtras.Get().newForRules(c, r, blockNum, isMerge, timestamp) } } @@ -218,7 +213,7 @@ func (c *ChainConfig) addRulesExtra(r *Rules, blockNum *big.Int, isMerge bool, t // already been called. If the payload hasn't been populated (typically via // unmarshalling of JSON), a nil value is constructed and returned. func (c *ChainConfig) extraPayload() *pseudo.Type { - if registeredExtras == nil { + if !registeredExtras.Registered() { // This will only happen if someone constructs an [ExtraPayloads] // directly, without a call to [RegisterExtras]. // @@ -226,19 +221,19 @@ func (c *ChainConfig) extraPayload() *pseudo.Type { panic(fmt.Sprintf("%T.ExtraPayload() called before RegisterExtras()", c)) } if c.extra == nil { - c.extra = registeredExtras.newChainConfig() + c.extra = registeredExtras.Get().newChainConfig() } return c.extra } // extraPayload is equivalent to [ChainConfig.extraPayload]. func (r *Rules) extraPayload() *pseudo.Type { - if registeredExtras == nil { + if !registeredExtras.Registered() { // See ChainConfig.extraPayload() equivalent. panic(fmt.Sprintf("%T.ExtraPayload() called before RegisterExtras()", r)) } if r.extra == nil { - r.extra = registeredExtras.newRules() + r.extra = registeredExtras.Get().newRules() } return r.extra } diff --git a/params/config.libevm_test.go b/params/config.libevm_test.go index 7a665a45686..24ae4ab5e4e 100644 --- a/params/config.libevm_test.go +++ b/params/config.libevm_test.go @@ -24,6 +24,7 @@ import ( "github.com/stretchr/testify/require" "github.com/ava-labs/libevm/libevm/pseudo" + "github.com/ava-labs/libevm/libevm/register" ) type rawJSON struct { @@ -255,18 +256,21 @@ func TestExtrasPanic(t *testing.T) { t, func() { RegisterExtras(Extras[struct{ ChainConfigHooks }, struct{ RulesHooks }]{}) }, - "re-registration", + register.ErrReRegistration.Error(), ) } func assertPanics(t *testing.T, fn func(), wantContains string) { t.Helper() defer func() { + t.Helper() switch r := recover().(type) { case nil: - t.Error("function did not panic as expected") + t.Error("function did not panic when panic expected") case string: assert.Contains(t, r, wantContains) + case error: + assert.Contains(t, r.Error(), wantContains) default: t.Fatalf("BAD TEST SETUP: recover() got unsupported type %T", r) } diff --git a/params/hooks.libevm.go b/params/hooks.libevm.go index c67c7c5d3d9..9faf403a65d 100644 --- a/params/hooks.libevm.go +++ b/params/hooks.libevm.go @@ -69,8 +69,8 @@ type RulesAllowlistHooks interface { // Hooks returns the hooks registered with [RegisterExtras], or [NOOPHooks] if // none were registered. func (c *ChainConfig) Hooks() ChainConfigHooks { - if e := registeredExtras; e != nil { - return e.payloads.hooksFromChainConfig(c) + if e := registeredExtras; e.Registered() { + return e.Get().payloads.hooksFromChainConfig(c) } return NOOPHooks{} } @@ -78,8 +78,8 @@ func (c *ChainConfig) Hooks() ChainConfigHooks { // Hooks returns the hooks registered with [RegisterExtras], or [NOOPHooks] if // none were registered. func (r *Rules) Hooks() RulesHooks { - if e := registeredExtras; e != nil { - return e.payloads.hooksFromRules(r) + if e := registeredExtras; e.Registered() { + return e.Get().payloads.hooksFromRules(r) } return NOOPHooks{} } diff --git a/params/json.libevm.go b/params/json.libevm.go index 9f669008926..c9acefe9497 100644 --- a/params/json.libevm.go +++ b/params/json.libevm.go @@ -42,11 +42,11 @@ type chainConfigWithExportedExtra struct { // UnmarshalJSON implements the [json.Unmarshaler] interface. func (c *ChainConfig) UnmarshalJSON(data []byte) error { switch reg := registeredExtras; { - case reg != nil && !reg.reuseJSONRoot: + case reg.Registered() && !reg.Get().reuseJSONRoot: return c.unmarshalJSONWithExtra(data) - case reg != nil && reg.reuseJSONRoot: // although the latter is redundant, it's clearer - c.extra = reg.newChainConfig() + case reg.Registered() && reg.Get().reuseJSONRoot: // although the latter is redundant, it's clearer + c.extra = reg.Get().newChainConfig() if err := json.Unmarshal(data, c.extra); err != nil { c.extra = nil return err @@ -63,7 +63,7 @@ func (c *ChainConfig) UnmarshalJSON(data []byte) error { func (c *ChainConfig) unmarshalJSONWithExtra(data []byte) error { cc := &chainConfigWithExportedExtra{ chainConfigWithoutMethods: (*chainConfigWithoutMethods)(c), - Extra: registeredExtras.newChainConfig(), + Extra: registeredExtras.Get().newChainConfig(), } if err := json.Unmarshal(data, cc); err != nil { return err @@ -75,10 +75,10 @@ func (c *ChainConfig) unmarshalJSONWithExtra(data []byte) error { // MarshalJSON implements the [json.Marshaler] interface. func (c *ChainConfig) MarshalJSON() ([]byte, error) { switch reg := registeredExtras; { - case reg == nil: + case !reg.Registered(): return json.Marshal((*chainConfigWithoutMethods)(c)) - case !reg.reuseJSONRoot: + case !reg.Get().reuseJSONRoot: return c.marshalJSONWithExtra() default: // reg.reuseJSONRoot == true