diff --git a/consensus/dummy/consensus.go b/consensus/dummy/consensus.go index 0c570400c5..1512269351 100644 --- a/consensus/dummy/consensus.go +++ b/consensus/dummy/consensus.go @@ -41,10 +41,39 @@ type Mode struct { ModeSkipCoinbase bool } -type DummyEngine struct { - consensusMode Mode - desiredDelayExcess *acp226.DelayExcess -} +type ( + OnFinalizeAndAssembleCallbackType = func( + header *types.Header, + parent *types.Header, + state *state.StateDB, + txs []*types.Transaction, + ) ( + extraData []byte, + blockFeeContribution *big.Int, + extDataGasUsed *big.Int, + err error, + ) + + OnExtraStateChangeType = func( + block *types.Block, + parent *types.Header, + statedb *state.StateDB, + ) ( + blockFeeContribution *big.Int, + extDataGasUsed *big.Int, + err error, + ) + + ConsensusCallbacks struct { + OnFinalizeAndAssemble OnFinalizeAndAssembleCallbackType + OnExtraStateChange OnExtraStateChangeType + } + + DummyEngine struct { + consensusMode Mode + desiredDelayExcess *acp226.DelayExcess + } +) func NewDummyEngine( mode Mode, diff --git a/plugin/evm/block_test.go b/plugin/evm/block_test.go index 0c8bebcf25..6b0a0e2465 100644 --- a/plugin/evm/block_test.go +++ b/plugin/evm/block_test.go @@ -16,6 +16,7 @@ import ( "github.com/ava-labs/subnet-evm/params" "github.com/ava-labs/subnet-evm/params/extras" + "github.com/ava-labs/subnet-evm/plugin/evm/extension" "github.com/ava-labs/subnet-evm/precompile/precompileconfig" ) @@ -26,8 +27,9 @@ func TestHandlePrecompileAccept(t *testing.T) { db := rawdb.NewMemoryDatabase() vm := &VM{ - chaindb: db, - chainConfig: params.TestChainConfig, + chaindb: db, + chainConfig: params.TestChainConfig, + extensionConfig: &extension.Config{}, } precompileAddr := common.Address{0x05} diff --git a/plugin/evm/config/config.go b/plugin/evm/config/config.go index f750abbd6d..19fe6333b3 100644 --- a/plugin/evm/config/config.go +++ b/plugin/evm/config/config.go @@ -9,6 +9,7 @@ import ( "fmt" "time" + "github.com/ava-labs/avalanchego/utils/constants" "github.com/ava-labs/libevm/common" "github.com/ava-labs/libevm/common/hexutil" "github.com/spf13/cast" @@ -129,7 +130,7 @@ type Config struct { MaxOutboundActiveRequests int64 `json:"max-outbound-active-requests"` // Sync settings - StateSyncEnabled bool `json:"state-sync-enabled"` + StateSyncEnabled *bool `json:"state-sync-enabled"` // Pointer distinguishes false (no state sync) and not set (state sync only at genesis). StateSyncSkipResume bool `json:"state-sync-skip-resume"` // Forces state sync to use the highest available summary block StateSyncServerTrieCache int `json:"state-sync-server-trie-cache"` StateSyncIDs string `json:"state-sync-ids"` @@ -236,7 +237,18 @@ func (d Duration) MarshalJSON() ([]byte, error) { } // validate returns an error if this is an invalid config. -func (c *Config) validate(_ uint32) error { +func (c *Config) validate(networkID uint32) error { + // Ensure that non-standard commit interval is not allowed for production networks + if constants.ProductionNetworkIDs.Contains(networkID) { + defaultConfig := NewDefaultConfig() + if c.CommitInterval != defaultConfig.CommitInterval { + return fmt.Errorf("cannot start non-local network with commit interval %d different than %d", c.CommitInterval, defaultConfig.CommitInterval) + } + if c.StateSyncCommitInterval != defaultConfig.StateSyncCommitInterval { + return fmt.Errorf("cannot start non-local network with syncable interval %d different than %d", c.StateSyncCommitInterval, defaultConfig.StateSyncCommitInterval) + } + } + if c.PopulateMissingTries != nil && (c.OfflinePruning || c.Pruning) { return fmt.Errorf("cannot enable populate missing tries while offline pruning (enabled: %t)/pruning (enabled: %t) are enabled", c.OfflinePruning, c.Pruning) } diff --git a/plugin/evm/config/config_test.go b/plugin/evm/config/config_test.go index f27d92ff81..75f852dd5c 100644 --- a/plugin/evm/config/config_test.go +++ b/plugin/evm/config/config_test.go @@ -15,6 +15,12 @@ import ( "github.com/stretchr/testify/require" ) +// newTrue returns a pointer to a bool that is true +func newTrue() *bool { + b := true + return &b +} + func TestUnmarshalConfig(t *testing.T) { tests := []struct { name string @@ -64,7 +70,7 @@ func TestUnmarshalConfig(t *testing.T) { { "state sync enabled", []byte(`{"state-sync-enabled":true}`), - Config{StateSyncEnabled: true}, + Config{StateSyncEnabled: newTrue()}, false, }, { diff --git a/plugin/evm/extension/config.go b/plugin/evm/extension/config.go index 8dbdb7437d..83f07500ac 100644 --- a/plugin/evm/extension/config.go +++ b/plugin/evm/extension/config.go @@ -4,12 +4,30 @@ package extension import ( + "context" "errors" + "github.com/ava-labs/avalanchego/database" + "github.com/ava-labs/avalanchego/database/versiondb" + "github.com/ava-labs/avalanchego/ids" + "github.com/ava-labs/avalanchego/network/p2p" + "github.com/ava-labs/avalanchego/snow/consensus/snowman" + "github.com/ava-labs/avalanchego/snow/engine/snowman/block" "github.com/ava-labs/avalanchego/utils/timer/mockable" + "github.com/ava-labs/libevm/common" + "github.com/ava-labs/libevm/core/types" + "github.com/prometheus/client_golang/prometheus" + "github.com/ava-labs/subnet-evm/consensus/dummy" + "github.com/ava-labs/subnet-evm/core" + "github.com/ava-labs/subnet-evm/params" + "github.com/ava-labs/subnet-evm/params/extras" + "github.com/ava-labs/subnet-evm/plugin/evm/config" "github.com/ava-labs/subnet-evm/plugin/evm/message" "github.com/ava-labs/subnet-evm/plugin/evm/sync" + "github.com/ava-labs/subnet-evm/sync/handlers" + + avalanchecommon "github.com/ava-labs/avalanchego/snow/engine/common" ) var ( @@ -19,15 +37,121 @@ var ( errNilClock = errors.New("nil clock") ) +type ExtensibleVM interface { + // SetExtensionConfig sets the configuration for the VM extension + // Should be called before any other method and only once + SetExtensionConfig(config *Config) error + // NewClient returns a client to send messages with for the given protocol + NewClient(protocol uint64) *p2p.Client + // AddHandler registers a server handler for an application protocol + AddHandler(protocol uint64, handler p2p.Handler) error + // GetExtendedBlock returns the VMBlock for the given ID or an error if the block is not found + GetExtendedBlock(context.Context, ids.ID) (ExtendedBlock, error) + // LastAcceptedExtendedBlock returns the last accepted VM block + LastAcceptedExtendedBlock() ExtendedBlock + // ChainConfig returns the chain config for the VM + ChainConfig() *params.ChainConfig + // P2PValidators returns the validators for the network + P2PValidators() *p2p.Validators + // Blockchain returns the blockchain client + Blockchain() *core.BlockChain + // Config returns the configuration for the VM + Config() config.Config + // MetricRegistry returns the metric registry for the VM + MetricRegistry() *prometheus.Registry + // ReadLastAccepted returns the last accepted block hash and height + ReadLastAccepted() (common.Hash, uint64, error) + // VersionDB returns the versioned database for the VM + VersionDB() *versiondb.Database +} + +// InnerVM is the interface that must be implemented by the VM +// that's being wrapped by the extension +type InnerVM interface { + ExtensibleVM + avalanchecommon.VM + block.ChainVM + block.BuildBlockWithContextChainVM + block.StateSyncableVM +} + +// ExtendedBlock is a block that can be used by the extension +type ExtendedBlock interface { + snowman.Block + GetEthBlock() *types.Block + GetBlockExtension() BlockExtension +} + +type BlockExtender interface { + // NewBlockExtension is called when a new block is created + NewBlockExtension(b ExtendedBlock) (BlockExtension, error) +} + +// BlockExtension allows the VM extension to handle block processing events. +type BlockExtension interface { + // SyntacticVerify verifies the block syntactically + // it can be implemented to extend inner block verification + SyntacticVerify(rules extras.Rules) error + // SemanticVerify verifies the block semantically + // it can be implemented to extend inner block verification + SemanticVerify() error + // CleanupVerified is called when a block has passed SemanticVerify and SynctacticVerify, + // and should be cleaned up due to error or verification runs under non-write mode. This + // does not return an error because the block has already been verified. + CleanupVerified() + // Accept is called when a block is accepted by the block manager. Accept takes a + // database.Batch that contains the changes that were made to the database as a result + // of accepting the block. The changes in the batch should be flushed to the database in this method. + Accept(acceptedBatch database.Batch) error + // Reject is called when a block is rejected by the block manager + Reject() error +} + +// BuilderMempool is a mempool that's used in the block builder +type BuilderMempool interface { + // PendingLen returns the number of pending transactions + // that are waiting to be included in a block + PendingLen() int + // SubscribePendingTxs returns a channel that's signaled when there are pending transactions + SubscribePendingTxs() <-chan struct{} +} + +// LeafRequestConfig is the configuration to handle leaf requests +// in the network and syncer +type LeafRequestConfig struct { + // LeafType is the type of the leaf node + LeafType message.NodeType + // MetricName is the name of the metric to use for the leaf request + MetricName string + // Handler is the handler to use for the leaf request + Handler handlers.LeafRequestHandler +} + // Config is the configuration for the VM extension type Config struct { + // ConsensusCallbacks is the consensus callbacks to use + // for the VM to be used in consensus engine. + // Callback functions can be nil. + ConsensusCallbacks dummy.ConsensusCallbacks // SyncSummaryProvider is the sync summary provider to use // for the VM to be used in syncer. // It's required and should be non-nil SyncSummaryProvider sync.SummaryProvider + // SyncExtender can extend the syncer to handle custom sync logic. + // It's optional and can be nil + SyncExtender sync.Extender // SyncableParser is to parse summary messages from the network. // It's required and should be non-nil SyncableParser message.SyncableParser + // BlockExtender allows the VM extension to create an extension to handle block processing events. + // It's optional and can be nil + BlockExtender BlockExtender + // ExtraSyncLeafHandlerConfig is the extra configuration to handle leaf requests + // in the network and syncer. It's optional and can be nil + ExtraSyncLeafHandlerConfig *LeafRequestConfig + // ExtraMempool is the mempool to be used in the block builder. + // It's optional and can be nil + ExtraMempool BuilderMempool // Clock is the clock to use for time related operations. // It's optional and can be nil Clock *mockable.Clock diff --git a/plugin/evm/message/block_sync_summary_test.go b/plugin/evm/message/block_sync_summary_test.go new file mode 100644 index 0000000000..40aa926d7c --- /dev/null +++ b/plugin/evm/message/block_sync_summary_test.go @@ -0,0 +1,44 @@ +// Copyright (C) 2019-2025, Ava Labs, Inc. All rights reserved. +// See the file LICENSE for licensing terms. + +package message + +import ( + "context" + "encoding/base64" + "testing" + + "github.com/ava-labs/avalanchego/snow/engine/snowman/block" + "github.com/ava-labs/libevm/common" + "github.com/stretchr/testify/require" +) + +func TestMarshalBlockSyncSummary(t *testing.T) { + blockSyncSummary, err := NewBlockSyncSummary(common.Hash{1}, 2, common.Hash{3}) + require.NoError(t, err) + + require.Equal(t, common.Hash{1}, blockSyncSummary.GetBlockHash()) + require.Equal(t, uint64(2), blockSyncSummary.Height()) + require.Equal(t, common.Hash{3}, blockSyncSummary.GetBlockRoot()) + + expectedBase64Bytes := "AAAAAAAAAAAAAgEAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAwAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAA=" + require.Equal(t, expectedBase64Bytes, base64.StdEncoding.EncodeToString(blockSyncSummary.Bytes())) + + parser := NewBlockSyncSummaryParser() + called := false + acceptImplTest := func(Syncable) (block.StateSyncMode, error) { + called = true + return block.StateSyncSkipped, nil + } + s, err := parser.Parse(blockSyncSummary.Bytes(), acceptImplTest) + require.NoError(t, err) + require.Equal(t, blockSyncSummary.GetBlockHash(), s.GetBlockHash()) + require.Equal(t, blockSyncSummary.Height(), s.Height()) + require.Equal(t, blockSyncSummary.GetBlockRoot(), s.GetBlockRoot()) + require.Equal(t, blockSyncSummary.Bytes(), s.Bytes()) + + mode, err := s.Accept(context.TODO()) + require.NoError(t, err) + require.Equal(t, block.StateSyncSkipped, mode) + require.True(t, called) +} diff --git a/plugin/evm/message/codec.go b/plugin/evm/message/codec.go index 59a2632582..ba71af2f32 100644 --- a/plugin/evm/message/codec.go +++ b/plugin/evm/message/codec.go @@ -22,13 +22,11 @@ func init() { c := linearcodec.NewDefault() // Skip registration to keep registeredTypes unchanged after legacy gossip deprecation - c.SkipRegistrations(1) + // Gossip types and sync summary type removed from codec + c.SkipRegistrations(2) errs := wrappers.Errs{} errs.Add( - // Types for state sync frontier consensus - c.RegisterType(BlockSyncSummary{}), - // state sync types c.RegisterType(BlockRequest{}), c.RegisterType(BlockResponse{}), diff --git a/plugin/evm/message/handler.go b/plugin/evm/message/handler.go index a6b0306cd9..42a5319249 100644 --- a/plugin/evm/message/handler.go +++ b/plugin/evm/message/handler.go @@ -15,9 +15,8 @@ var _ RequestHandler = NoopRequestHandler{} // Must have methods in format of handleType(context.Context, ids.NodeID, uint32, request Type) error // so that the Request object of relevant Type can invoke its respective handle method // on this struct. -// Also see GossipHandler for implementation style. type RequestHandler interface { - HandleStateTrieLeafsRequest(ctx context.Context, nodeID ids.NodeID, requestID uint32, leafsRequest LeafsRequest) ([]byte, error) + HandleLeafsRequest(ctx context.Context, nodeID ids.NodeID, requestID uint32, leafsRequest LeafsRequest) ([]byte, error) HandleBlockRequest(ctx context.Context, nodeID ids.NodeID, requestID uint32, request BlockRequest) ([]byte, error) HandleCodeRequest(ctx context.Context, nodeID ids.NodeID, requestID uint32, codeRequest CodeRequest) ([]byte, error) } @@ -33,7 +32,7 @@ type ResponseHandler interface { type NoopRequestHandler struct{} -func (NoopRequestHandler) HandleStateTrieLeafsRequest(context.Context, ids.NodeID, uint32, LeafsRequest) ([]byte, error) { +func (NoopRequestHandler) HandleLeafsRequest(context.Context, ids.NodeID, uint32, LeafsRequest) ([]byte, error) { return nil, nil } diff --git a/plugin/evm/message/leafs_request.go b/plugin/evm/message/leafs_request.go index 2e345949e7..e21e94988a 100644 --- a/plugin/evm/message/leafs_request.go +++ b/plugin/evm/message/leafs_request.go @@ -15,25 +15,37 @@ const MaxCodeHashesPerRequest = 5 var _ Request = LeafsRequest{} +// NodeType outlines the trie that a leaf node belongs to +// handlers.LeafsRequestHandler uses this information to determine +// which trie type to fetch the information from +type NodeType uint8 + +const ( + StateTrieNode = NodeType(1) + StateTrieKeyLength = common.HashLength +) + // LeafsRequest is a request to receive trie leaves at specified Root within Start and End byte range // Limit outlines maximum number of leaves to returns starting at Start +// NodeType outlines which trie to read from state/atomic. type LeafsRequest struct { - Root common.Hash `serialize:"true"` - Account common.Hash `serialize:"true"` - Start []byte `serialize:"true"` - End []byte `serialize:"true"` - Limit uint16 `serialize:"true"` + Root common.Hash `serialize:"true"` + Account common.Hash `serialize:"true"` + Start []byte `serialize:"true"` + End []byte `serialize:"true"` + Limit uint16 `serialize:"true"` + NodeType NodeType `serialize:"true"` } func (l LeafsRequest) String() string { return fmt.Sprintf( - "LeafsRequest(Root=%s, Account=%s, Start=%s, End %s, Limit=%d)", - l.Root, l.Account, common.Bytes2Hex(l.Start), common.Bytes2Hex(l.End), l.Limit, + "LeafsRequest(Root=%s, Account=%s, Start=%s, End=%s, Limit=%d, NodeType=%d)", + l.Root, l.Account, common.Bytes2Hex(l.Start), common.Bytes2Hex(l.End), l.Limit, l.NodeType, ) } func (l LeafsRequest) Handle(ctx context.Context, nodeID ids.NodeID, requestID uint32, handler RequestHandler) ([]byte, error) { - return handler.HandleStateTrieLeafsRequest(ctx, nodeID, requestID, l) + return handler.HandleLeafsRequest(ctx, nodeID, requestID, l) } // LeafsResponse is a response to a LeafsRequest diff --git a/plugin/evm/message/leafs_request_test.go b/plugin/evm/message/leafs_request_test.go index efe5920ecb..123d69302f 100644 --- a/plugin/evm/message/leafs_request_test.go +++ b/plugin/evm/message/leafs_request_test.go @@ -4,13 +4,10 @@ package message import ( - "bytes" - "context" "encoding/base64" "math/rand" "testing" - "github.com/ava-labs/avalanchego/ids" "github.com/ava-labs/libevm/common" "github.com/stretchr/testify/assert" ) @@ -38,7 +35,7 @@ func TestMarshalLeafsRequest(t *testing.T) { Limit: 1024, } - base64LeafsRequest := "AAAAAAAAAAAAAAAAAAAAAABpbSBST09UaW5nIGZvciB5YQAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAIFL9/AchgmVPFj9fD5piHXKVZsdNEAN8TXu7BAfR4sZJAAAAIIGFWthoHQ2G0ekeABZ5OctmlNLEIqzSCKAHKTlIf2mZBAA=" + base64LeafsRequest := "AAAAAAAAAAAAAAAAAAAAAABpbSBST09UaW5nIGZvciB5YQAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAIFL9/AchgmVPFj9fD5piHXKVZsdNEAN8TXu7BAfR4sZJAAAAIIGFWthoHQ2G0ekeABZ5OctmlNLEIqzSCKAHKTlIf2mZBAAA" leafsRequestBytes, err := Codec.Marshal(Version, leafsRequest) assert.NoError(t, err) @@ -105,62 +102,3 @@ func TestMarshalLeafsResponse(t *testing.T) { assert.False(t, l.More) // make sure it is not serialized assert.Equal(t, leafsResponse.ProofVals, l.ProofVals) } - -func TestLeafsRequestValidation(t *testing.T) { - mockRequestHandler := &mockHandler{} - - tests := map[string]struct { - request LeafsRequest - assertResponse func(t *testing.T) - }{ - "node type StateTrieNode": { - request: LeafsRequest{ - Root: common.BytesToHash([]byte("some hash goes here")), - Start: bytes.Repeat([]byte{0x00}, common.HashLength), - End: bytes.Repeat([]byte{0xff}, common.HashLength), - Limit: 10, - }, - assertResponse: func(t *testing.T) { - assert.True(t, mockRequestHandler.handleStateTrieCalled) - assert.False(t, mockRequestHandler.handleBlockRequestCalled) - assert.False(t, mockRequestHandler.handleCodeRequestCalled) - }, - }, - } - for name, test := range tests { - t.Run(name, func(t *testing.T) { - _, _ = test.request.Handle(context.Background(), ids.GenerateTestNodeID(), 1, mockRequestHandler) - test.assertResponse(t) - mockRequestHandler.reset() - }) - } -} - -var _ RequestHandler = (*mockHandler)(nil) - -type mockHandler struct { - handleStateTrieCalled, - handleBlockRequestCalled, - handleCodeRequestCalled bool -} - -func (m *mockHandler) HandleStateTrieLeafsRequest(context.Context, ids.NodeID, uint32, LeafsRequest) ([]byte, error) { - m.handleStateTrieCalled = true - return nil, nil -} - -func (m *mockHandler) HandleBlockRequest(context.Context, ids.NodeID, uint32, BlockRequest) ([]byte, error) { - m.handleBlockRequestCalled = true - return nil, nil -} - -func (m *mockHandler) HandleCodeRequest(context.Context, ids.NodeID, uint32, CodeRequest) ([]byte, error) { - m.handleCodeRequestCalled = true - return nil, nil -} - -func (m *mockHandler) reset() { - m.handleStateTrieCalled = false - m.handleBlockRequestCalled = false - m.handleCodeRequestCalled = false -} diff --git a/plugin/evm/network_handler.go b/plugin/evm/network_handler.go index 4fe51a98df..74952f8375 100644 --- a/plugin/evm/network_handler.go +++ b/plugin/evm/network_handler.go @@ -9,7 +9,7 @@ import ( "github.com/ava-labs/avalanchego/codec" "github.com/ava-labs/avalanchego/ids" "github.com/ava-labs/libevm/ethdb" - "github.com/ava-labs/libevm/metrics" + "github.com/ava-labs/libevm/log" "github.com/ava-labs/libevm/triedb" "github.com/ava-labs/subnet-evm/plugin/evm/message" @@ -20,29 +20,44 @@ import ( var _ message.RequestHandler = (*networkHandler)(nil) +type LeafHandlers map[message.NodeType]syncHandlers.LeafRequestHandler + type networkHandler struct { - leafRequestHandler *syncHandlers.LeafsRequestHandler + leafRequestHandlers LeafHandlers blockRequestHandler *syncHandlers.BlockRequestHandler codeRequestHandler *syncHandlers.CodeRequestHandler } +type LeafRequestTypeConfig struct { + NodeType message.NodeType + NodeKeyLen int + TrieDB *triedb.Database + UseSnapshots bool + MetricName string +} + // newNetworkHandler constructs the handler for serving network requests. func newNetworkHandler( provider syncHandlers.SyncDataProvider, diskDB ethdb.KeyValueReader, - evmTrieDB *triedb.Database, networkCodec codec.Manager, -) message.RequestHandler { - syncStats := syncStats.NewHandlerStats(metrics.Enabled) + leafRequestHandlers LeafHandlers, + syncStats syncStats.HandlerStats, +) *networkHandler { return &networkHandler{ - leafRequestHandler: syncHandlers.NewLeafsRequestHandler(evmTrieDB, nil, networkCodec, syncStats), + leafRequestHandlers: leafRequestHandlers, blockRequestHandler: syncHandlers.NewBlockRequestHandler(provider, networkCodec, syncStats), codeRequestHandler: syncHandlers.NewCodeRequestHandler(diskDB, networkCodec, syncStats), } } -func (n networkHandler) HandleStateTrieLeafsRequest(ctx context.Context, nodeID ids.NodeID, requestID uint32, leafsRequest message.LeafsRequest) ([]byte, error) { - return n.leafRequestHandler.OnLeafsRequest(ctx, nodeID, requestID, leafsRequest) +func (n networkHandler) HandleLeafsRequest(ctx context.Context, nodeID ids.NodeID, requestID uint32, leafsRequest message.LeafsRequest) ([]byte, error) { + handler, ok := n.leafRequestHandlers[leafsRequest.NodeType] + if !ok { + log.Debug("node type is not recognised, dropping request", "nodeID", nodeID, "requestID", requestID, "nodeType", leafsRequest.NodeType) + return nil, nil + } + return handler.OnLeafsRequest(ctx, nodeID, requestID, leafsRequest) } func (n networkHandler) HandleBlockRequest(ctx context.Context, nodeID ids.NodeID, requestID uint32, blockRequest message.BlockRequest) ([]byte, error) { diff --git a/plugin/evm/vm.go b/plugin/evm/vm.go index f922e41668..9fcc95f7a2 100644 --- a/plugin/evm/vm.go +++ b/plugin/evm/vm.go @@ -8,6 +8,7 @@ import ( "encoding/json" "errors" "fmt" + "io" "math/big" "net/http" "os" @@ -31,7 +32,6 @@ import ( "github.com/ava-labs/avalanchego/utils/profiler" "github.com/ava-labs/avalanchego/utils/timer/mockable" "github.com/ava-labs/avalanchego/utils/units" - "github.com/ava-labs/avalanchego/version" "github.com/ava-labs/avalanchego/vms/components/chain" "github.com/ava-labs/avalanchego/vms/evm/acp226" "github.com/ava-labs/firewood-go-ethhash/ffi" @@ -76,6 +76,7 @@ import ( "github.com/ava-labs/subnet-evm/precompile/precompileconfig" "github.com/ava-labs/subnet-evm/rpc" "github.com/ava-labs/subnet-evm/sync/client/stats" + "github.com/ava-labs/subnet-evm/sync/handlers" "github.com/ava-labs/subnet-evm/triedb/hashdb" "github.com/ava-labs/subnet-evm/warp" @@ -88,6 +89,7 @@ import ( subnetevmlog "github.com/ava-labs/subnet-evm/plugin/evm/log" vmsync "github.com/ava-labs/subnet-evm/plugin/evm/sync" statesyncclient "github.com/ava-labs/subnet-evm/sync/client" + handlerstats "github.com/ava-labs/subnet-evm/sync/handlers/stats" avalancheRPC "github.com/gorilla/rpc/v2" ) @@ -151,6 +153,8 @@ var ( errPathStateUnsupported = errors.New("path state scheme is not supported") ) +var originalStderr *os.File + // legacyApiNames maps pre geth v1.10.20 api names to their updated counterparts. // used in attachEthService for backward configuration compatibility. var legacyApiNames = map[string]string{ @@ -184,6 +188,7 @@ type VM struct { config config.Config + chainID *big.Int genesisHash common.Hash chainConfig *params.ChainConfig ethConfig ethconfig.Config @@ -240,6 +245,7 @@ type VM struct { sdkMetrics *prometheus.Registry bootstrapped avalancheUtils.Atomic[bool] + IsPlugin bool stateSyncDone chan struct{} @@ -276,7 +282,21 @@ func (vm *VM) Initialize( appSender commonEng.AppSender, ) error { vm.ctx = chainCtx - vm.stateSyncDone = make(chan struct{}) + + // Initialize extension config if not already set + if vm.extensionConfig == nil { + // Initialize clock if not already set + if vm.clock == nil { + vm.clock = &mockable.Clock{} + } + vm.extensionConfig = &extension.Config{ + SyncSummaryProvider: &message.BlockSyncSummaryProvider{}, + SyncableParser: message.NewBlockSyncSummaryParser(), + } + // Provide a clock to the extension config before validation + vm.extensionConfig.Clock = vm.clock + } + cfg, deprecateMsg, err := config.GetConfig(configBytes, vm.ctx.NetworkID) if err != nil { return fmt.Errorf("failed to get config: %w", err) @@ -297,7 +317,12 @@ func (vm *VM) Initialize( } vm.chainAlias = alias - subnetEVMLogger, err := subnetevmlog.InitLogger(vm.chainAlias, vm.config.LogLevel, vm.config.LogJSONFormat, vm.ctx.Log) + var writer io.Writer = vm.ctx.Log + if vm.IsPlugin { + writer = originalStderr + } + + subnetEVMLogger, err := subnetevmlog.InitLogger(vm.chainAlias, vm.config.LogLevel, vm.config.LogJSONFormat, writer) if err != nil { return fmt.Errorf("%w: %w ", errInitializingLogger, err) } @@ -338,14 +363,22 @@ func (vm *VM) Initialize( return err } + // vm.ChainConfig() should be available for wrapping VMs before vm.initializeChain() + vm.chainConfig = g.Config + vm.chainID = g.Config.ChainID + vm.ethConfig = ethconfig.NewDefaultConfig() vm.ethConfig.Genesis = g - // NetworkID here is different than Avalanche's NetworkID. - // Avalanche's NetworkID represents the Avalanche network is running on - // like Fuji, Mainnet, Local, etc. - // The NetworkId here is kept same as ChainID to be compatible with - // Ethereum tooling. - vm.ethConfig.NetworkId = g.Config.ChainID.Uint64() + vm.ethConfig.NetworkId = vm.chainID.Uint64() + vm.genesisHash = vm.ethConfig.Genesis.ToBlock().Hash() // must create genesis hash before [vm.ReadLastAccepted] + lastAcceptedHash, lastAcceptedHeight, err := vm.ReadLastAccepted() + if err != nil { + return err + } + log.Info("read last accepted", + "hash", lastAcceptedHash, + "height", lastAcceptedHeight, + ) // Set minimum price for mining and default gas price oracle value to the min // gas price to prevent so transactions and blocks all use the correct fees @@ -353,7 +386,6 @@ func (vm *VM) Initialize( vm.ethConfig.RPCEVMTimeout = vm.config.APIMaxDuration.Duration vm.ethConfig.RPCTxFeeCap = vm.config.RPCTxFeeCap - vm.ethConfig.TxPool.Locals = vm.config.PriorityRegossipAddresses vm.ethConfig.TxPool.NoLocals = !vm.config.LocalTxsEnabled vm.ethConfig.TxPool.PriceLimit = vm.config.TxPoolPriceLimit vm.ethConfig.TxPool.PriceBump = vm.config.TxPoolPriceBump @@ -380,7 +412,7 @@ func (vm *VM) Initialize( vm.ethConfig.PopulateMissingTries = vm.config.PopulateMissingTries vm.ethConfig.PopulateMissingTriesParallelism = vm.config.PopulateMissingTriesParallelism vm.ethConfig.AllowMissingTries = vm.config.AllowMissingTries - vm.ethConfig.SnapshotDelayInit = vm.config.StateSyncEnabled + vm.ethConfig.SnapshotDelayInit = vm.stateSyncEnabled(lastAcceptedHeight) vm.ethConfig.SnapshotWait = vm.config.SnapshotWait vm.ethConfig.SnapshotVerify = vm.config.SnapshotVerify vm.ethConfig.HistoricalProofQueryWindow = vm.config.HistoricalProofQueryWindow @@ -438,18 +470,6 @@ func (vm *VM) Initialize( vm.chainConfig = g.Config - // create genesisHash after applying upgradeBytes in case - // upgradeBytes modifies genesis. - vm.genesisHash = vm.ethConfig.Genesis.ToBlock().Hash() // must create genesis hash before [vm.readLastAccepted] - lastAcceptedHash, lastAcceptedHeight, err := vm.readLastAccepted() - if err != nil { - return err - } - log.Info("read last accepted", - "hash", lastAcceptedHash, - "height", lastAcceptedHeight, - ) - vm.networkCodec = message.Codec vm.Network, err = network.NewNetwork(vm.ctx, appSender, vm.networkCodec, vm.config.MaxOutboundActiveRequests, vm.sdkMetrics) if err != nil { @@ -492,7 +512,6 @@ func (vm *VM) Initialize( if err != nil { return err } - if err := vm.initializeChain(lastAcceptedHash, vm.ethConfig); err != nil { return err } @@ -503,11 +522,9 @@ func (vm *VM) Initialize( warpHandler := acp118.NewCachedHandler(meteredCache, vm.warpBackend, vm.ctx.WarpSigner) vm.Network.AddHandler(p2p.SignatureRequestHandlerID, warpHandler) - vm.setAppRequestHandlers() - vm.stateSyncDone = make(chan struct{}) - return vm.initializeStateSyncClient(lastAcceptedHeight) + return vm.initializeStateSync(lastAcceptedHeight) } func parseGenesis(ctx *snow.Context, genesisBytes []byte, upgradeBytes []byte, airdropFile string) (*core.Genesis, error) { @@ -651,10 +668,48 @@ func (vm *VM) initializeChain(lastAcceptedHash common.Hash, ethConfig ethconfig. // initializeStateSyncClient initializes the client for performing state sync. // If state sync is disabled, this function will wipe any ongoing summary from // disk to ensure that we do not continue syncing from an invalid snapshot. -func (vm *VM) initializeStateSyncClient(lastAcceptedHeight uint64) error { +func (vm *VM) initializeStateSync(lastAcceptedHeight uint64) error { + // Create standalone EVM TrieDB (read only) for serving leafs requests. + // We create a standalone TrieDB here, so that it has a standalone cache from the one + // used by the node when processing blocks. + evmTrieDB := triedb.NewDatabase( + vm.chaindb, + &triedb.Config{ + DBOverride: hashdb.Config{ + CleanCacheSize: vm.config.StateSyncServerTrieCache * units.MiB, + }.BackendConstructor, + }, + ) + + // register default leaf request handler for state trie + syncStats := handlerstats.GetOrRegisterHandlerStats(metrics.Enabled) + stateLeafRequestConfig := &extension.LeafRequestConfig{ + LeafType: message.StateTrieNode, + MetricName: "sync_state_trie_leaves", + Handler: handlers.NewLeafsRequestHandler(evmTrieDB, + message.StateTrieKeyLength, + vm.blockChain, vm.networkCodec, + syncStats, + ), + } + + leafHandlers := make(LeafHandlers) + leafHandlers[stateLeafRequestConfig.LeafType] = stateLeafRequestConfig.Handler + + networkHandler := newNetworkHandler( + vm.blockChain, + vm.chaindb, + vm.networkCodec, + leafHandlers, + syncStats, + ) + vm.Network.SetRequestHandler(networkHandler) + + vm.Server = vmsync.NewServer(vm.blockChain, vm.extensionConfig.SyncSummaryProvider, vm.config.StateSyncCommitInterval) + stateSyncEnabled := vm.stateSyncEnabled(lastAcceptedHeight) // parse nodeIDs from state sync IDs in vm config var stateSyncIDs []ids.NodeID - if vm.config.StateSyncEnabled && len(vm.config.StateSyncIDs) > 0 { + if stateSyncEnabled && len(vm.config.StateSyncIDs) > 0 { nodeIDs := strings.Split(vm.config.StateSyncIDs, ",") stateSyncIDs = make([]ids.NodeID, len(nodeIDs)) for i, nodeIDString := range nodeIDs { @@ -666,20 +721,24 @@ func (vm *VM) initializeStateSyncClient(lastAcceptedHeight uint64) error { } } + // Initialize the state sync client + leafMetricsNames := make(map[message.NodeType]string) + leafMetricsNames[stateLeafRequestConfig.LeafType] = stateLeafRequestConfig.MetricName + vm.Client = vmsync.NewClient(&vmsync.ClientConfig{ + StateSyncDone: vm.stateSyncDone, Chain: vm.eth, State: vm.State, - StateSyncDone: vm.stateSyncDone, Client: statesyncclient.NewClient( &statesyncclient.ClientConfig{ NetworkClient: vm.Network, Codec: vm.networkCodec, - Stats: stats.NewClientSyncerStats(), + Stats: stats.NewClientSyncerStats(leafMetricsNames), StateSyncNodeIDs: stateSyncIDs, BlockParser: vm, }, ), - Enabled: vm.config.StateSyncEnabled, + Enabled: stateSyncEnabled, SkipResume: vm.config.StateSyncSkipResume, MinBlocks: vm.config.StateSyncMinBlocks, RequestSize: vm.config.StateSyncRequestSize, @@ -689,11 +748,12 @@ func (vm *VM) initializeStateSyncClient(lastAcceptedHeight uint64) error { MetadataDB: vm.metadataDB, Acceptor: vm, Parser: vm.extensionConfig.SyncableParser, + Extender: nil, }) // If StateSync is disabled, clear any ongoing summary so that we will not attempt to resume // sync using a snapshot that has been modified by the node running normal operations. - if !vm.config.StateSyncEnabled { + if !stateSyncEnabled { return vm.Client.ClearOngoingSummary() } @@ -762,7 +822,6 @@ func (vm *VM) onBootstrapStarted() error { // Ensure snapshots are initialized before bootstrapping (i.e., if state sync is skipped). // Note calling this function has no effect if snapshots are already initialized. vm.blockChain.InitializeSnapshots() - return nil } @@ -895,27 +954,6 @@ func (vm *VM) onNormalOperationsStarted() error { return nil } -// setAppRequestHandlers sets the request handlers for the VM to serve state sync -// requests. -func (vm *VM) setAppRequestHandlers() { - // Create standalone EVM TrieDB (read only) for serving leafs requests. - // We create a standalone TrieDB here, so that it has a standalone cache from the one - // used by the node when processing blocks. - evmTrieDB := triedb.NewDatabase( - vm.chaindb, - &triedb.Config{ - DBOverride: hashdb.Config{ - CleanCacheSize: vm.config.StateSyncServerTrieCache * units.MiB, - }.BackendConstructor, - }, - ) - - networkHandler := newNetworkHandler(vm.blockChain, vm.chaindb, evmTrieDB, vm.networkCodec) - vm.Network.SetRequestHandler(networkHandler) - - vm.Server = vmsync.NewServer(vm.blockChain, vm.extensionConfig.SyncSummaryProvider, vm.config.StateSyncCommitInterval) -} - func (vm *VM) WaitForEvent(ctx context.Context) (commonEng.Message, error) { vm.builderLock.Lock() builder := vm.builder @@ -1271,7 +1309,7 @@ func (vm *VM) startContinuousProfiler() { // last accepted block hash and height by reading directly from [vm.chaindb] instead of relying // on [chain]. // Note: assumes [vm.chaindb] and [vm.genesisHash] have been initialized. -func (vm *VM) readLastAccepted() (common.Hash, uint64, error) { +func (vm *VM) ReadLastAccepted() (common.Hash, uint64, error) { // Attempt to load last accepted block to determine if it is necessary to // initialize state with the genesis block. lastAcceptedBytes, lastAcceptedErr := vm.acceptedBlockDB.Get(lastAcceptedKey) @@ -1339,25 +1377,14 @@ func attachEthService(handler *rpc.Server, apis []rpc.API, names []string) error return nil } -func (vm *VM) Connected(ctx context.Context, nodeID ids.NodeID, version *version.Application) error { - vm.vmLock.Lock() - defer vm.vmLock.Unlock() - - if err := vm.validatorsManager.Connect(nodeID); err != nil { - return fmt.Errorf("uptime manager failed to connect node %s: %w", nodeID, err) - } - return vm.Network.Connected(ctx, nodeID, version) -} - -func (vm *VM) Disconnected(ctx context.Context, nodeID ids.NodeID) error { - vm.vmLock.Lock() - defer vm.vmLock.Unlock() - - if err := vm.validatorsManager.Disconnect(nodeID); err != nil { - return fmt.Errorf("uptime manager failed to disconnect node %s: %w", nodeID, err) +func (vm *VM) stateSyncEnabled(lastAcceptedHeight uint64) bool { + if vm.config.StateSyncEnabled != nil { + // if the config is set, use that + return *vm.config.StateSyncEnabled } - return vm.Network.Disconnected(ctx, nodeID) + // enable state sync by default if the chain is empty. + return lastAcceptedHeight == 0 } func (vm *VM) PutLastAcceptedID(id ids.ID) error { diff --git a/plugin/evm/vm_extensible.go b/plugin/evm/vm_extensible.go new file mode 100644 index 0000000000..2fae0e156e --- /dev/null +++ b/plugin/evm/vm_extensible.go @@ -0,0 +1,92 @@ +// Copyright (C) 2019-2025, Ava Labs, Inc. All rights reserved. +// See the file LICENSE for licensing terms. + +package evm + +import ( + "context" + "errors" + + "github.com/ava-labs/avalanchego/database/versiondb" + "github.com/ava-labs/avalanchego/ids" + "github.com/ava-labs/avalanchego/network/p2p" + "github.com/prometheus/client_golang/prometheus" + + "github.com/ava-labs/subnet-evm/core" + "github.com/ava-labs/subnet-evm/params" + "github.com/ava-labs/subnet-evm/plugin/evm/config" + "github.com/ava-labs/subnet-evm/plugin/evm/extension" + + vmsync "github.com/ava-labs/subnet-evm/plugin/evm/sync" +) + +var _ extension.InnerVM = (*VM)(nil) + +var ( + errVMAlreadyInitialized = errors.New("vm already initialized") + errExtensionConfigAlreadySet = errors.New("extension config already set") +) + +func (vm *VM) SetExtensionConfig(config *extension.Config) error { + if vm.ctx != nil { + return errVMAlreadyInitialized + } + if vm.extensionConfig != nil { + return errExtensionConfigAlreadySet + } + vm.extensionConfig = config + return nil +} + +// All these methods below assumes that VM is already initialized + +func (vm *VM) GetExtendedBlock(ctx context.Context, blkID ids.ID) (extension.ExtendedBlock, error) { + // Since each internal handler used by [vm.State] always returns a block + // with non-nil ethBlock value, GetBlockInternal should never return a + // (*Block) with a nil ethBlock value. + blk, err := vm.GetBlockInternal(ctx, blkID) + if err != nil { + return nil, err + } + + return blk.(*wrappedBlock), nil +} + +func (vm *VM) LastAcceptedExtendedBlock() extension.ExtendedBlock { + lastAcceptedBlock := vm.LastAcceptedBlockInternal() + if lastAcceptedBlock == nil { + return nil + } + return lastAcceptedBlock.(*wrappedBlock) +} + +// ChainConfig returns the chain config for the VM +// Even though this is available through Blockchain().Config(), +// ChainConfig() here will be available before the blockchain is initialized. +func (vm *VM) ChainConfig() *params.ChainConfig { + return vm.chainConfig +} + +func (vm *VM) Blockchain() *core.BlockChain { + return vm.blockChain +} + +func (vm *VM) Config() config.Config { + return vm.config +} + +func (vm *VM) MetricRegistry() *prometheus.Registry { + return vm.sdkMetrics +} + +func (vm *VM) Validators() *p2p.Validators { + return vm.P2PValidators() +} + +func (vm *VM) VersionDB() *versiondb.Database { + return vm.versiondb +} + +func (vm *VM) SyncerClient() vmsync.Client { + return vm.Client +} diff --git a/plugin/evm/vm_test.go b/plugin/evm/vm_test.go index 2331bab992..d8bf871467 100644 --- a/plugin/evm/vm_test.go +++ b/plugin/evm/vm_test.go @@ -59,6 +59,7 @@ import ( "github.com/ava-labs/subnet-evm/plugin/evm/customheader" "github.com/ava-labs/subnet-evm/plugin/evm/customrawdb" "github.com/ava-labs/subnet-evm/plugin/evm/customtypes" + "github.com/ava-labs/subnet-evm/plugin/evm/extension" "github.com/ava-labs/subnet-evm/plugin/evm/vmerrors" "github.com/ava-labs/subnet-evm/precompile/allowlist" "github.com/ava-labs/subnet-evm/precompile/contracts/deployerallowlist" @@ -1682,7 +1683,7 @@ func testEmptyBlock(t *testing.T, scheme string) { } // Create empty block from blkA - ethBlock := blk.(*chain.BlockWrapper).Block.(*wrappedBlock).ethBlock + ethBlock := blk.(*chain.BlockWrapper).Block.(extension.ExtendedBlock).GetEthBlock() emptyEthBlock := types.NewBlock( types.CopyHeader(ethBlock.Header()), @@ -1905,7 +1906,7 @@ func testAcceptReorg(t *testing.T, scheme string) { t.Fatalf("Block failed verification on VM1: %s", err) } - blkBHash := vm1BlkB.(*chain.BlockWrapper).Block.(*wrappedBlock).ethBlock.Hash() + blkBHash := vm1BlkB.(*chain.BlockWrapper).Block.(extension.ExtendedBlock).GetEthBlock().Hash() if b := vm1.blockChain.CurrentBlock(); b.Hash() != blkBHash { t.Fatalf("expected current block to have hash %s but got %s", blkBHash.Hex(), b.Hash().Hex()) } @@ -1914,7 +1915,7 @@ func testAcceptReorg(t *testing.T, scheme string) { t.Fatal(err) } - blkCHash := vm1BlkC.(*chain.BlockWrapper).Block.(*wrappedBlock).ethBlock.Hash() + blkCHash := vm1BlkC.(*chain.BlockWrapper).Block.(extension.ExtendedBlock).GetEthBlock().Hash() if b := vm1.blockChain.CurrentBlock(); b.Hash() != blkCHash { t.Fatalf("expected current block to have hash %s but got %s", blkCHash.Hex(), b.Hash().Hex()) } @@ -1925,7 +1926,7 @@ func testAcceptReorg(t *testing.T, scheme string) { if err := vm1BlkD.Accept(context.Background()); err != nil { t.Fatal(err) } - blkDHash := vm1BlkD.(*chain.BlockWrapper).Block.(*wrappedBlock).ethBlock.Hash() + blkDHash := vm1BlkD.(*chain.BlockWrapper).Block.(extension.ExtendedBlock).GetEthBlock().Hash() if b := vm1.blockChain.CurrentBlock(); b.Hash() != blkDHash { t.Fatalf("expected current block to have hash %s but got %s", blkDHash.Hex(), b.Hash().Hex()) } @@ -2161,7 +2162,7 @@ func testLastAcceptedBlockNumberAllow(t *testing.T, scheme string) { } blkHeight := blk.Height() - blkHash := blk.(*chain.BlockWrapper).Block.(*wrappedBlock).ethBlock.Hash() + blkHash := blk.(*chain.BlockWrapper).Block.(extension.ExtendedBlock).GetEthBlock().Hash() tvm.vm.eth.APIBackend.SetAllowUnfinalizedQueries(true) @@ -2255,7 +2256,7 @@ func testBuildAllowListActivationBlock(t *testing.T, scheme string) { } // Verify that the allow list config activation was handled correctly in the first block. - blkState, err := tvm.vm.blockChain.StateAt(blk.(*chain.BlockWrapper).Block.(*wrappedBlock).ethBlock.Root()) + blkState, err := tvm.vm.blockChain.StateAt(blk.(*chain.BlockWrapper).Block.(extension.ExtendedBlock).GetEthBlock().Root()) if err != nil { t.Fatal(err) } @@ -2371,7 +2372,7 @@ func TestTxAllowListSuccessfulTx(t *testing.T) { require.Equal(t, newHead.Head.Hash(), common.Hash(blk.ID())) // Verify that the constructed block only has the whitelisted tx - block := blk.(*chain.BlockWrapper).Block.(*wrappedBlock).ethBlock + block := blk.(*chain.BlockWrapper).Block.(extension.ExtendedBlock).GetEthBlock() txs := block.Transactions() @@ -2395,7 +2396,7 @@ func TestTxAllowListSuccessfulTx(t *testing.T) { blk = issueAndAccept(t, tvm.vm) newHead = <-newTxPoolHeadChan require.Equal(t, newHead.Head.Hash(), common.Hash(blk.ID())) - block = blk.(*chain.BlockWrapper).Block.(*wrappedBlock).ethBlock + block = blk.(*chain.BlockWrapper).Block.(extension.ExtendedBlock).GetEthBlock() blkState, err := tvm.vm.blockChain.StateAt(block.Root()) require.NoError(t, err) @@ -2421,7 +2422,7 @@ func TestTxAllowListSuccessfulTx(t *testing.T) { require.Equal(t, newHead.Head.Hash(), common.Hash(blk.ID())) // Verify that the constructed block only has the whitelisted tx - block = blk.(*chain.BlockWrapper).Block.(*wrappedBlock).ethBlock + block = blk.(*chain.BlockWrapper).Block.(extension.ExtendedBlock).GetEthBlock() txs = block.Transactions() require.Len(t, txs, 1) @@ -2576,7 +2577,7 @@ func TestTxAllowListDisablePrecompile(t *testing.T) { blk := issueAndAccept(t, tvm.vm) // Verify that the constructed block only has the whitelisted tx - block := blk.(*chain.BlockWrapper).Block.(*wrappedBlock).ethBlock + block := blk.(*chain.BlockWrapper).Block.(extension.ExtendedBlock).GetEthBlock() txs := block.Transactions() if txs.Len() != 1 { t.Fatalf("Expected number of txs to be %d, but found %d", 1, txs.Len()) @@ -2598,7 +2599,7 @@ func TestTxAllowListDisablePrecompile(t *testing.T) { blk = issueAndAccept(t, tvm.vm) // Verify that the constructed block only has the previously rejected tx - block = blk.(*chain.BlockWrapper).Block.(*wrappedBlock).ethBlock + block = blk.(*chain.BlockWrapper).Block.(extension.ExtendedBlock).GetEthBlock() txs = block.Transactions() if txs.Len() != 1 { t.Fatalf("Expected number of txs to be %d, but found %d", 1, txs.Len()) @@ -2704,7 +2705,7 @@ func TestFeeManagerChangeFee(t *testing.T) { t.Fatalf("Expected new block to match") } - block := blk.(*chain.BlockWrapper).Block.(*wrappedBlock).ethBlock + block := blk.(*chain.BlockWrapper).Block.(extension.ExtendedBlock).GetEthBlock() feeConfig, lastChangedAt, err = tvm.vm.blockChain.GetFeeConfigAt(block.Header()) require.NoError(t, err) @@ -2786,16 +2787,16 @@ func testAllowFeeRecipientDisabled(t *testing.T, scheme string) { blk, err := tvm.vm.BuildBlock(context.Background()) require.NoError(t, err) // this won't return an error since miner will set the etherbase to blackhole address - ethBlock := blk.(*chain.BlockWrapper).Block.(*wrappedBlock).ethBlock + ethBlock := blk.(*chain.BlockWrapper).Block.(extension.ExtendedBlock).GetEthBlock() require.Equal(t, constants.BlackholeAddr, ethBlock.Coinbase()) // Create empty block from blk - internalBlk := blk.(*chain.BlockWrapper).Block.(*wrappedBlock) - modifiedHeader := types.CopyHeader(internalBlk.ethBlock.Header()) + internalBlk := blk.(*chain.BlockWrapper).Block.(extension.ExtendedBlock) + modifiedHeader := types.CopyHeader(internalBlk.GetEthBlock().Header()) modifiedHeader.Coinbase = common.HexToAddress("0x0123456789") // set non-blackhole address by force modifiedBlock := types.NewBlock( modifiedHeader, - internalBlk.ethBlock.Transactions(), + internalBlk.GetEthBlock().Transactions(), nil, nil, trie.NewStackTrie(nil), @@ -2860,7 +2861,7 @@ func TestAllowFeeRecipientEnabled(t *testing.T) { if newHead.Head.Hash() != common.Hash(blk.ID()) { t.Fatalf("Expected new block to match") } - ethBlock := blk.(*chain.BlockWrapper).Block.(*wrappedBlock).ethBlock + ethBlock := blk.(*chain.BlockWrapper).Block.(extension.ExtendedBlock).GetEthBlock() require.Equal(t, etherBase, ethBlock.Coinbase()) // Verify that etherBase has received fees blkState, err := tvm.vm.blockChain.StateAt(ethBlock.Root()) @@ -2939,7 +2940,7 @@ func TestRewardManagerPrecompileSetRewardAddress(t *testing.T) { blk := issueAndAccept(t, tvm.vm) newHead := <-newTxPoolHeadChan require.Equal(t, newHead.Head.Hash(), common.Hash(blk.ID())) - ethBlock := blk.(*chain.BlockWrapper).Block.(*wrappedBlock).ethBlock + ethBlock := blk.(*chain.BlockWrapper).Block.(extension.ExtendedBlock).GetEthBlock() require.Equal(t, etherBase, ethBlock.Coinbase()) // reward address is activated at this block so this is fine tx1 := types.NewTransaction(uint64(0), testEthAddrs[0], big.NewInt(2), 21000, big.NewInt(testMinGasPrice*3), nil) @@ -2954,7 +2955,7 @@ func TestRewardManagerPrecompileSetRewardAddress(t *testing.T) { blk = issueAndAccept(t, tvm.vm) newHead = <-newTxPoolHeadChan require.Equal(t, newHead.Head.Hash(), common.Hash(blk.ID())) - ethBlock = blk.(*chain.BlockWrapper).Block.(*wrappedBlock).ethBlock + ethBlock = blk.(*chain.BlockWrapper).Block.(extension.ExtendedBlock).GetEthBlock() require.Equal(t, testAddr, ethBlock.Coinbase()) // reward address was activated at previous block // Verify that etherBase has received fees blkState, err := tvm.vm.blockChain.StateAt(ethBlock.Root()) @@ -2981,7 +2982,7 @@ func TestRewardManagerPrecompileSetRewardAddress(t *testing.T) { blk = issueAndAccept(t, tvm.vm) newHead = <-newTxPoolHeadChan require.Equal(t, newHead.Head.Hash(), common.Hash(blk.ID())) - ethBlock = blk.(*chain.BlockWrapper).Block.(*wrappedBlock).ethBlock + ethBlock = blk.(*chain.BlockWrapper).Block.(extension.ExtendedBlock).GetEthBlock() // Reward manager deactivated at this block, so we expect the parent state // to determine the coinbase for this block before full deactivation in the // next block. @@ -3002,7 +3003,7 @@ func TestRewardManagerPrecompileSetRewardAddress(t *testing.T) { blk = issueAndAccept(t, tvm.vm) newHead = <-newTxPoolHeadChan require.Equal(t, newHead.Head.Hash(), common.Hash(blk.ID())) - ethBlock = blk.(*chain.BlockWrapper).Block.(*wrappedBlock).ethBlock + ethBlock = blk.(*chain.BlockWrapper).Block.(extension.ExtendedBlock).GetEthBlock() // reward manager was disabled at previous block // so this block should revert back to enabling fee recipients require.Equal(t, etherBase, ethBlock.Coinbase()) @@ -3080,7 +3081,7 @@ func TestRewardManagerPrecompileAllowFeeRecipients(t *testing.T) { blk := issueAndAccept(t, tvm.vm) newHead := <-newTxPoolHeadChan require.Equal(t, newHead.Head.Hash(), common.Hash(blk.ID())) - ethBlock := blk.(*chain.BlockWrapper).Block.(*wrappedBlock).ethBlock + ethBlock := blk.(*chain.BlockWrapper).Block.(extension.ExtendedBlock).GetEthBlock() require.Equal(t, constants.BlackholeAddr, ethBlock.Coinbase()) // reward address is activated at this block so this is fine tx1 := types.NewTransaction(uint64(0), testEthAddrs[0], big.NewInt(2), 21000, big.NewInt(testMinGasPrice*3), nil) @@ -3095,7 +3096,7 @@ func TestRewardManagerPrecompileAllowFeeRecipients(t *testing.T) { blk = issueAndAccept(t, tvm.vm) newHead = <-newTxPoolHeadChan require.Equal(t, newHead.Head.Hash(), common.Hash(blk.ID())) - ethBlock = blk.(*chain.BlockWrapper).Block.(*wrappedBlock).ethBlock + ethBlock = blk.(*chain.BlockWrapper).Block.(extension.ExtendedBlock).GetEthBlock() require.Equal(t, etherBase, ethBlock.Coinbase()) // reward address was activated at previous block // Verify that etherBase has received fees blkState, err := tvm.vm.blockChain.StateAt(ethBlock.Root()) @@ -3121,7 +3122,7 @@ func TestRewardManagerPrecompileAllowFeeRecipients(t *testing.T) { blk = issueAndAccept(t, tvm.vm) newHead = <-newTxPoolHeadChan require.Equal(t, newHead.Head.Hash(), common.Hash(blk.ID())) - ethBlock = blk.(*chain.BlockWrapper).Block.(*wrappedBlock).ethBlock + ethBlock = blk.(*chain.BlockWrapper).Block.(extension.ExtendedBlock).GetEthBlock() require.Equal(t, etherBase, ethBlock.Coinbase()) // reward address was activated at previous block require.GreaterOrEqual(t, int64(ethBlock.Time()), disableTime.Unix()) @@ -3138,7 +3139,7 @@ func TestRewardManagerPrecompileAllowFeeRecipients(t *testing.T) { blk = issueAndAccept(t, tvm.vm) newHead = <-newTxPoolHeadChan require.Equal(t, newHead.Head.Hash(), common.Hash(blk.ID())) - ethBlock = blk.(*chain.BlockWrapper).Block.(*wrappedBlock).ethBlock + ethBlock = blk.(*chain.BlockWrapper).Block.(extension.ExtendedBlock).GetEthBlock() require.Equal(t, constants.BlackholeAddr, ethBlock.Coinbase()) // reward address was activated at previous block require.Greater(t, int64(ethBlock.Time()), disableTime.Unix()) @@ -3298,7 +3299,7 @@ func TestParentBeaconRootBlock(t *testing.T) { } // Modify the block to have a parent beacon root - ethBlock := blk.(*chain.BlockWrapper).Block.(*wrappedBlock).ethBlock + ethBlock := blk.(*chain.BlockWrapper).Block.(extension.ExtendedBlock).GetEthBlock() header := types.CopyHeader(ethBlock.Header()) header.ParentBeaconRoot = test.beaconRoot parentBeaconEthBlock := ethBlock.WithSeal(header) @@ -3497,7 +3498,7 @@ func TestFeeManagerRegressionMempoolMinFeeAfterRestart(t *testing.T) { require.Equal(t, newHead.Head.Hash(), common.Hash(blk.ID())) // check that the fee config is updated - block := blk.(*chain.BlockWrapper).Block.(*wrappedBlock).ethBlock + block := blk.(*chain.BlockWrapper).Block.(extension.ExtendedBlock).GetEthBlock() feeConfig, lastChangedAt, err = restartedVM.blockChain.GetFeeConfigAt(block.Header()) require.NoError(t, err) require.EqualValues(t, restartedVM.blockChain.CurrentBlock().Number, lastChangedAt) diff --git a/plugin/evm/vm_warp_test.go b/plugin/evm/vm_warp_test.go index c379b737e9..1e6066cd36 100644 --- a/plugin/evm/vm_warp_test.go +++ b/plugin/evm/vm_warp_test.go @@ -43,6 +43,7 @@ import ( "github.com/ava-labs/subnet-evm/params/extras" "github.com/ava-labs/subnet-evm/params/paramstest" "github.com/ava-labs/subnet-evm/plugin/evm/customheader" + "github.com/ava-labs/subnet-evm/plugin/evm/extension" "github.com/ava-labs/subnet-evm/precompile/contract" "github.com/ava-labs/subnet-evm/utils" "github.com/ava-labs/subnet-evm/warp" @@ -139,7 +140,7 @@ func testSendWarpMessage(t *testing.T, scheme string) { require.NoError(blk.Verify(context.Background())) // Verify that the constructed block contains the expected log with an unsigned warp message in the log data - ethBlock1 := blk.(*chain.BlockWrapper).Block.(*wrappedBlock).ethBlock + ethBlock1 := blk.(*chain.BlockWrapper).Block.(extension.ExtendedBlock).GetEthBlock() require.Len(ethBlock1.Transactions(), 1) receipts := rawdb.ReadReceipts(tvm.vm.chaindb, ethBlock1.Hash(), ethBlock1.NumberU64(), ethBlock1.Time(), tvm.vm.chainConfig) require.Len(receipts, 1) @@ -466,7 +467,7 @@ func testWarpVMTransaction(t *testing.T, scheme string, unsignedMessage *avalanc require.NoError(warpBlock.Accept(context.Background())) tvm.vm.blockChain.DrainAcceptorQueue() - ethBlock := warpBlock.(*chain.BlockWrapper).Block.(*wrappedBlock).ethBlock + ethBlock := warpBlock.(*chain.BlockWrapper).Block.(extension.ExtendedBlock).GetEthBlock() verifiedMessageReceipts := tvm.vm.blockChain.GetReceiptsByHash(ethBlock.Hash()) require.Len(verifiedMessageReceipts, 2) for i, receipt := range verifiedMessageReceipts { @@ -756,7 +757,7 @@ func testReceiveWarpMessage( require.NoError(err) // Require the block was built with a successful predicate result - ethBlock := block2.(*chain.BlockWrapper).Block.(*wrappedBlock).ethBlock + ethBlock := block2.(*chain.BlockWrapper).Block.(extension.ExtendedBlock).GetEthBlock() headerPredicateResultsBytes := customheader.PredicateBytesFromExtra(ethBlock.Extra()) blockResults, err := predicate.ParseBlockResults(headerPredicateResultsBytes) require.NoError(err) diff --git a/plugin/evm/wrapped_block.go b/plugin/evm/wrapped_block.go index ebe2cff6ef..04628561e3 100644 --- a/plugin/evm/wrapped_block.go +++ b/plugin/evm/wrapped_block.go @@ -26,12 +26,14 @@ import ( "github.com/ava-labs/subnet-evm/params/extras" "github.com/ava-labs/subnet-evm/plugin/evm/customheader" "github.com/ava-labs/subnet-evm/plugin/evm/customtypes" + "github.com/ava-labs/subnet-evm/plugin/evm/extension" "github.com/ava-labs/subnet-evm/precompile/precompileconfig" ) var ( _ snowman.Block = (*wrappedBlock)(nil) _ block.WithVerifyContext = (*wrappedBlock)(nil) + _ extension.ExtendedBlock = (*wrappedBlock)(nil) errInvalidParent = errors.New("parent header not found") errMissingParentBlock = errors.New("missing parent block") @@ -54,18 +56,26 @@ var ( // wrappedBlock implements the snowman.wrappedBlock interface type wrappedBlock struct { - id ids.ID - ethBlock *types.Block - vm *VM + id ids.ID + ethBlock *types.Block + extension extension.BlockExtension + vm *VM } // wrapBlock returns a new Block wrapping the ethBlock type and implementing the snowman.Block interface -func wrapBlock(ethBlock *types.Block, vm *VM) (*wrappedBlock, error) { //nolint:unparam // this just makes the function compatible with the future syncs I'll do, it's temporary!! +func wrapBlock(ethBlock *types.Block, vm *VM) (*wrappedBlock, error) { b := &wrappedBlock{ id: ids.ID(ethBlock.Hash()), ethBlock: ethBlock, vm: vm, } + if vm.extensionConfig.BlockExtender != nil { + extension, err := vm.extensionConfig.BlockExtender.NewBlockExtension(b) + if err != nil { + return nil, fmt.Errorf("failed to create block extension: %w", err) + } + b.extension = extension + } return b, nil } @@ -308,6 +318,11 @@ func (b *wrappedBlock) semanticVerify() error { return err } + if b.extension != nil { + if err := b.extension.SemanticVerify(); err != nil { + return err + } + } return nil } @@ -421,6 +436,11 @@ func (b *wrappedBlock) syntacticVerify() error { } } + if b.extension != nil { + if err := b.extension.SyntacticVerify(*rulesExtra); err != nil { + return err + } + } return nil } @@ -468,3 +488,5 @@ func (b *wrappedBlock) Bytes() []byte { func (b *wrappedBlock) String() string { return fmt.Sprintf("EVM block, ID = %s", b.ID()) } func (b *wrappedBlock) GetEthBlock() *types.Block { return b.ethBlock } + +func (b *wrappedBlock) GetBlockExtension() extension.BlockExtension { return b.extension } diff --git a/sync/README.md b/sync/README.md index ad991ada53..ce1530b28f 100644 --- a/sync/README.md +++ b/sync/README.md @@ -42,7 +42,7 @@ When a new node wants to join the network via state sync, it will need a few pie - Number (height) and hash of the latest available syncable block, - Root of the account trie, -The above information is called a _state summary_, and each syncable block corresponds to one such summary (see `message.SyncSummary`). The engine and VM interact as follows to find a syncable state summary: +The above information is called a _state summary_, and each syncable block corresponds to one such summary (see `message.Summary`). The engine and VM interact as follows to find a syncable state summary: 1. The engine calls `StateSyncEnabled`. The VM returns `true` to initiate state sync, or `false` to start bootstrapping. In `subnet-evm`, this is controlled by the `state-sync-enabled` flag. @@ -60,6 +60,8 @@ The following steps are executed by the VM to sync its state from peers (see `st 1. Update in-memory and on-disk pointers. Steps 3 and 4 involve syncing tries. To sync trie data, the VM will send a series of `LeafRequests` to its peers. Each request specifies: +- Type of trie (`NodeType`): + - `statesync.StateTrieNode` (account trie and storage tries share the same database) - `Root` of the trie to sync, - `Start` and `End` specify a range of keys. diff --git a/sync/client/client_test.go b/sync/client/client_test.go index b357c87272..dc2eee41a9 100644 --- a/sync/client/client_test.go +++ b/sync/client/client_test.go @@ -98,7 +98,7 @@ func TestGetCode(t *testing.T) { Codec: message.Codec, Stats: clientstats.NewNoOpStats(), StateSyncNodeIDs: nil, - BlockParser: mockBlockParser, + BlockParser: newTestBlockParser(), }) for name, test := range tests { @@ -166,7 +166,7 @@ func TestGetBlocks(t *testing.T) { Codec: message.Codec, Stats: clientstats.NewNoOpStats(), StateSyncNodeIDs: nil, - BlockParser: mockBlockParser, + BlockParser: newTestBlockParser(), }) blocksRequestHandler := handlers.NewBlockRequestHandler(buildGetter(blocks), message.Codec, handlerstats.NewNoopHandlerStats()) @@ -420,13 +420,13 @@ func TestGetLeafs(t *testing.T) { largeTrieRoot, largeTrieKeys, _ := statesynctest.GenerateTrie(t, r, trieDB, 100_000, common.HashLength) smallTrieRoot, _, _ := statesynctest.GenerateTrie(t, r, trieDB, leafsLimit, common.HashLength) - handler := handlers.NewLeafsRequestHandler(trieDB, nil, message.Codec, handlerstats.NewNoopHandlerStats()) + handler := handlers.NewLeafsRequestHandler(trieDB, message.StateTrieKeyLength, nil, message.Codec, handlerstats.NewNoopHandlerStats()) client := NewClient(&ClientConfig{ NetworkClient: &mockNetwork{}, Codec: message.Codec, Stats: clientstats.NewNoOpStats(), StateSyncNodeIDs: nil, - BlockParser: mockBlockParser, + BlockParser: newTestBlockParser(), }) tests := map[string]struct { @@ -789,7 +789,7 @@ func TestGetLeafsRetries(t *testing.T) { trieDB := triedb.NewDatabase(rawdb.NewMemoryDatabase(), nil) root, _, _ := statesynctest.GenerateTrie(t, r, trieDB, 100_000, common.HashLength) - handler := handlers.NewLeafsRequestHandler(trieDB, nil, message.Codec, handlerstats.NewNoopHandlerStats()) + handler := handlers.NewLeafsRequestHandler(trieDB, message.StateTrieKeyLength, nil, message.Codec, handlerstats.NewNoopHandlerStats()) mockNetClient := &mockNetwork{} const maxAttempts = 8 @@ -798,7 +798,7 @@ func TestGetLeafsRetries(t *testing.T) { Codec: message.Codec, Stats: clientstats.NewNoOpStats(), StateSyncNodeIDs: nil, - BlockParser: mockBlockParser, + BlockParser: newTestBlockParser(), }) request := message.LeafsRequest{ @@ -859,7 +859,7 @@ func TestStateSyncNodes(t *testing.T) { Codec: message.Codec, Stats: clientstats.NewNoOpStats(), StateSyncNodeIDs: stateSyncNodes, - BlockParser: mockBlockParser, + BlockParser: newTestBlockParser(), }) ctx, cancel := context.WithCancel(context.Background()) defer cancel() diff --git a/sync/client/leaf_syncer.go b/sync/client/leaf_syncer.go index 2c5e3ad491..c430b3d85c 100644 --- a/sync/client/leaf_syncer.go +++ b/sync/client/leaf_syncer.go @@ -96,10 +96,11 @@ func (c *CallbackLeafSyncer) syncTask(ctx context.Context, task LeafSyncTask) er } leafsResponse, err := c.client.GetLeafs(ctx, message.LeafsRequest{ - Root: root, - Account: task.Account(), - Start: start, - Limit: c.requestSize, + Root: root, + Account: task.Account(), + Start: start, + Limit: c.requestSize, + NodeType: message.StateTrieNode, }) if err != nil { return fmt.Errorf("%w: %w", errFailedToFetchLeafs, err) diff --git a/sync/client/stats/stats.go b/sync/client/stats/stats.go index 834dbcb193..6d146f7f2c 100644 --- a/sync/client/stats/stats.go +++ b/sync/client/stats/stats.go @@ -76,17 +76,21 @@ func (m *messageMetric) UpdateRequestLatency(duration time.Duration) { } type clientSyncerStats struct { - stateTrieLeavesMetric, - codeRequestMetric, + leafMetrics map[message.NodeType]MessageMetric + codeRequestMetric MessageMetric blockRequestMetric MessageMetric } // NewClientSyncerStats returns stats for the client syncer -func NewClientSyncerStats() ClientSyncerStats { +func NewClientSyncerStats(leafMetricNames map[message.NodeType]string) *clientSyncerStats { + leafMetrics := make(map[message.NodeType]MessageMetric, len(leafMetricNames)) + for nodeType, name := range leafMetricNames { + leafMetrics[nodeType] = NewMessageMetric(name) + } return &clientSyncerStats{ - stateTrieLeavesMetric: NewMessageMetric("sync_state_trie_leaves"), - codeRequestMetric: NewMessageMetric("sync_code"), - blockRequestMetric: NewMessageMetric("sync_blocks"), + leafMetrics: leafMetrics, + codeRequestMetric: NewMessageMetric("sync_code"), + blockRequestMetric: NewMessageMetric("sync_blocks"), } } @@ -98,7 +102,11 @@ func (c *clientSyncerStats) GetMetric(msgIntf message.Request) (MessageMetric, e case message.CodeRequest: return c.codeRequestMetric, nil case message.LeafsRequest: - return c.stateTrieLeavesMetric, nil + metric, ok := c.leafMetrics[msg.NodeType] + if !ok { + return nil, fmt.Errorf("invalid leafs request for node type: %T", msg.NodeType) + } + return metric, nil default: return nil, fmt.Errorf("attempted to get metric for invalid request with type %T", msg) } @@ -125,12 +133,3 @@ func NewNoOpStats() ClientSyncerStats { func (n noopStats) GetMetric(_ message.Request) (MessageMetric, error) { return n.noop, nil } - -// NewStats returns syncer stats if enabled or a no-op version if disabled. -func NewStats(enabled bool) ClientSyncerStats { - if enabled { - return NewClientSyncerStats() - } else { - return NewNoOpStats() - } -} diff --git a/sync/client/mock_client.go b/sync/client/test_client.go similarity index 77% rename from sync/client/mock_client.go rename to sync/client/test_client.go index a43630a175..cc81e5bad4 100644 --- a/sync/client/mock_client.go +++ b/sync/client/test_client.go @@ -19,37 +19,36 @@ import ( ) var ( - _ Client = &MockClient{} - mockBlockParser EthBlockParser = &testBlockParser{} + _ Client = (*TestClient)(nil) + _ EthBlockParser = (*testBlockParser)(nil) ) -// TODO replace with gomock library -type MockClient struct { +type TestClient struct { codec codec.Manager - leafsHandler *handlers.LeafsRequestHandler + leafsHandler handlers.LeafRequestHandler leavesReceived int32 codesHandler *handlers.CodeRequestHandler codeReceived int32 blocksHandler *handlers.BlockRequestHandler blocksReceived int32 // GetLeafsIntercept is called on every GetLeafs request if set to a non-nil callback. - // The returned response will be returned by MockClient to the caller. + // The returned response will be returned by TestClient to the caller. GetLeafsIntercept func(req message.LeafsRequest, res message.LeafsResponse) (message.LeafsResponse, error) // GetCodesIntercept is called on every GetCode request if set to a non-nil callback. - // The returned response will be returned by MockClient to the caller. + // The returned response will be returned by TestClient to the caller. GetCodeIntercept func(hashes []common.Hash, codeBytes [][]byte) ([][]byte, error) // GetBlocksIntercept is called on every GetBlocks request if set to a non-nil callback. - // The returned response will be returned by MockClient to the caller. + // The returned response will be returned by TestClient to the caller. GetBlocksIntercept func(blockReq message.BlockRequest, blocks types.Blocks) (types.Blocks, error) } -func NewMockClient( +func NewTestClient( codec codec.Manager, - leafHandler *handlers.LeafsRequestHandler, + leafHandler handlers.LeafRequestHandler, codesHandler *handlers.CodeRequestHandler, blocksHandler *handlers.BlockRequestHandler, -) *MockClient { - return &MockClient{ +) *TestClient { + return &TestClient{ codec: codec, leafsHandler: leafHandler, codesHandler: codesHandler, @@ -57,7 +56,7 @@ func NewMockClient( } } -func (ml *MockClient) GetLeafs(ctx context.Context, request message.LeafsRequest) (message.LeafsResponse, error) { +func (ml *TestClient) GetLeafs(ctx context.Context, request message.LeafsRequest) (message.LeafsResponse, error) { response, err := ml.leafsHandler.OnLeafsRequest(ctx, ids.GenerateTestNodeID(), 1, request) if err != nil { return message.LeafsResponse{}, err @@ -71,18 +70,18 @@ func (ml *MockClient) GetLeafs(ctx context.Context, request message.LeafsRequest if ml.GetLeafsIntercept != nil { leafsResponse, err = ml.GetLeafsIntercept(request, leafsResponse) } - // Increment the number of leaves received by the mock client + // Increment the number of leaves received by the test client atomic.AddInt32(&ml.leavesReceived, int32(numLeaves)) return leafsResponse, err } -func (ml *MockClient) LeavesReceived() int32 { +func (ml *TestClient) LeavesReceived() int32 { return atomic.LoadInt32(&ml.leavesReceived) } -func (ml *MockClient) GetCode(ctx context.Context, hashes []common.Hash) ([][]byte, error) { +func (ml *TestClient) GetCode(ctx context.Context, hashes []common.Hash) ([][]byte, error) { if ml.codesHandler == nil { - panic("no code handler for mock client") + panic("no code handler for test client") } request := message.CodeRequest{Hashes: hashes} response, err := ml.codesHandler.OnCodeRequest(ctx, ids.GenerateTestNodeID(), 1, request) @@ -104,13 +103,13 @@ func (ml *MockClient) GetCode(ctx context.Context, hashes []common.Hash) ([][]by return code, err } -func (ml *MockClient) CodeReceived() int32 { +func (ml *TestClient) CodeReceived() int32 { return atomic.LoadInt32(&ml.codeReceived) } -func (ml *MockClient) GetBlocks(ctx context.Context, blockHash common.Hash, height uint64, numParents uint16) ([]*types.Block, error) { +func (ml *TestClient) GetBlocks(ctx context.Context, blockHash common.Hash, height uint64, numParents uint16) ([]*types.Block, error) { if ml.blocksHandler == nil { - panic("no blocks handler for mock client") + panic("no blocks handler for test client") } request := message.BlockRequest{ Hash: blockHash, @@ -122,7 +121,7 @@ func (ml *MockClient) GetBlocks(ctx context.Context, blockHash common.Hash, heig return nil, err } - client := &client{blockParser: mockBlockParser} // Hack to avoid duplicate code + client := &client{blockParser: newTestBlockParser()} // Hack to avoid duplicate code blocksRes, numBlocks, err := client.parseBlocks(ml.codec, request, response) if err != nil { return nil, err @@ -135,12 +134,16 @@ func (ml *MockClient) GetBlocks(ctx context.Context, blockHash common.Hash, heig return blocks, err } -func (ml *MockClient) BlocksReceived() int32 { +func (ml *TestClient) BlocksReceived() int32 { return atomic.LoadInt32(&ml.blocksReceived) } type testBlockParser struct{} +func newTestBlockParser() *testBlockParser { + return &testBlockParser{} +} + func (*testBlockParser) ParseEthBlock(b []byte) (*types.Block, error) { block := new(types.Block) if err := rlp.DecodeBytes(b, block); err != nil { diff --git a/sync/handlers/leafs_request.go b/sync/handlers/leafs_request.go index 73595e227d..ec340c44fc 100644 --- a/sync/handlers/leafs_request.go +++ b/sync/handlers/leafs_request.go @@ -26,6 +26,8 @@ import ( "github.com/ava-labs/subnet-evm/utils" ) +var _ LeafRequestHandler = (*leafsRequestHandler)(nil) + const ( // Maximum number of leaves to return in a message.LeafsResponse // This parameter overrides any other Limit specified @@ -40,22 +42,28 @@ const ( keyLength = common.HashLength // length of the keys of the trie to sync ) -// LeafsRequestHandler is a peer.RequestHandler for types.LeafsRequest +type LeafRequestHandler interface { + OnLeafsRequest(ctx context.Context, nodeID ids.NodeID, requestID uint32, leafsRequest message.LeafsRequest) ([]byte, error) +} + +// leafsRequestHandler is a peer.RequestHandler for types.LeafsRequest // serving requested trie data -type LeafsRequestHandler struct { +type leafsRequestHandler struct { trieDB *triedb.Database snapshotProvider SnapshotProvider codec codec.Manager stats stats.LeafsRequestHandlerStats pool sync.Pool + trieKeyLength int } -func NewLeafsRequestHandler(trieDB *triedb.Database, snapshotProvider SnapshotProvider, codec codec.Manager, syncerStats stats.LeafsRequestHandlerStats) *LeafsRequestHandler { - return &LeafsRequestHandler{ +func NewLeafsRequestHandler(trieDB *triedb.Database, trieKeyLength int, snapshotProvider SnapshotProvider, codec codec.Manager, syncerStats stats.LeafsRequestHandlerStats) *leafsRequestHandler { + return &leafsRequestHandler{ trieDB: trieDB, snapshotProvider: snapshotProvider, codec: codec, stats: syncerStats, + trieKeyLength: trieKeyLength, pool: sync.Pool{ New: func() interface{} { return make([][]byte, 0, maxLeavesLimit) }, }, @@ -70,9 +78,9 @@ func NewLeafsRequestHandler(trieDB *triedb.Database, snapshotProvider SnapshotPr // Specified Limit in message.LeafsRequest is overridden to maxLeavesLimit if it is greater than maxLeavesLimit // Expects returned errors to be treated as FATAL // Never returns errors -// Returns nothing if the requested trie root is not found +// Returns nothing if NodeType is invalid or requested trie root is not found // Assumes ctx is active -func (lrh *LeafsRequestHandler) OnLeafsRequest(ctx context.Context, nodeID ids.NodeID, requestID uint32, leafsRequest message.LeafsRequest) ([]byte, error) { +func (lrh *leafsRequestHandler) OnLeafsRequest(ctx context.Context, nodeID ids.NodeID, requestID uint32, leafsRequest message.LeafsRequest) ([]byte, error) { startTime := time.Now() lrh.stats.IncLeafsRequest() @@ -84,13 +92,12 @@ func (lrh *LeafsRequestHandler) OnLeafsRequest(ctx context.Context, nodeID ids.N lrh.stats.IncInvalidLeafsRequest() return nil, nil } - if len(leafsRequest.Start) != 0 && len(leafsRequest.Start) != keyLength || - len(leafsRequest.End) != 0 && len(leafsRequest.End) != keyLength { - log.Debug("invalid length for leafs request range, dropping request", "startLen", len(leafsRequest.Start), "endLen", len(leafsRequest.End), "expected", keyLength) + if (len(leafsRequest.Start) != 0 && len(leafsRequest.Start) != lrh.trieKeyLength) || + (len(leafsRequest.End) != 0 && len(leafsRequest.End) != lrh.trieKeyLength) { + log.Debug("invalid length for leafs request range, dropping request", "startLen", len(leafsRequest.Start), "endLen", len(leafsRequest.End), "expected", lrh.trieKeyLength) lrh.stats.IncInvalidLeafsRequest() return nil, nil } - // TODO: We should know the state root that accounts correspond to, // as this information will be necessary to access storage tries when // the trie is path based. @@ -106,7 +113,6 @@ func (lrh *LeafsRequestHandler) OnLeafsRequest(ctx context.Context, nodeID ids.N if limit > maxLeavesLimit { limit = maxLeavesLimit } - var leafsResponse message.LeafsResponse // pool response's key/val allocations leafsResponse.Keys = lrh.pool.Get().([][]byte) @@ -121,12 +127,11 @@ func (lrh *LeafsRequestHandler) OnLeafsRequest(ctx context.Context, nodeID ids.N lrh.pool.Put(leafsResponse.Keys[:0]) lrh.pool.Put(leafsResponse.Vals[:0]) }() - responseBuilder := &responseBuilder{ request: &leafsRequest, response: &leafsResponse, t: t, - keyLength: keyLength, + keyLength: lrh.trieKeyLength, limit: limit, stats: lrh.stats, } @@ -135,7 +140,6 @@ func (lrh *LeafsRequestHandler) OnLeafsRequest(ctx context.Context, nodeID ids.N responseBuilder.snap = lrh.snapshotProvider.Snapshots() } err = responseBuilder.handleRequest(ctx) - // ensure metrics are captured properly on all return paths defer func() { lrh.stats.UpdateLeafsRequestProcessingTime(time.Since(startTime)) @@ -152,13 +156,11 @@ func (lrh *LeafsRequestHandler) OnLeafsRequest(ctx context.Context, nodeID ids.N log.Debug("context err set before any leafs were iterated", "nodeID", nodeID, "requestID", requestID, "request", leafsRequest, "ctxErr", ctx.Err()) return nil, nil } - responseBytes, err := lrh.codec.Marshal(message.Version, leafsResponse) if err != nil { log.Debug("failed to marshal LeafsResponse, dropping request", "nodeID", nodeID, "requestID", requestID, "request", leafsRequest, "err", err) return nil, nil } - log.Debug("handled leafsRequest", "time", time.Since(startTime), "leafs", len(leafsResponse.Keys), "proofLen", len(leafsResponse.ProofVals)) return responseBytes, nil } diff --git a/sync/handlers/leafs_request_test.go b/sync/handlers/leafs_request_test.go index 0bfe48cbae..ced9c49da4 100644 --- a/sync/handlers/leafs_request_test.go +++ b/sync/handlers/leafs_request_test.go @@ -76,7 +76,7 @@ func TestLeafsRequestHandler_OnLeafsRequest(t *testing.T) { } } snapshotProvider := &TestSnapshotProvider{} - leafsHandler := NewLeafsRequestHandler(trieDB, snapshotProvider, message.Codec, mockHandlerStats) + leafsHandler := NewLeafsRequestHandler(trieDB, message.StateTrieKeyLength, snapshotProvider, message.Codec, mockHandlerStats) snapConfig := snapshot.Config{ CacheSize: 64, AsyncBuild: false, diff --git a/sync/handlers/stats/stats.go b/sync/handlers/stats/stats.go index d15edd160c..e8aa14841b 100644 --- a/sync/handlers/stats/stats.go +++ b/sync/handlers/stats/stats.go @@ -166,7 +166,10 @@ func (h *handlerStats) IncSnapshotReadSuccess() { h.snapshotReadSuccess.Inc(1 func (h *handlerStats) IncSnapshotSegmentValid() { h.snapshotSegmentValid.Inc(1) } func (h *handlerStats) IncSnapshotSegmentInvalid() { h.snapshotSegmentInvalid.Inc(1) } -func NewHandlerStats(enabled bool) HandlerStats { +// GetOrRegisterHandlerStats returns a [HandlerStats] to track state sync handler metrics. +// If `enabled` is false, a no-op implementation is returned. +// if `enabled` is true, calling this multiple times will return the same registered metrics. +func GetOrRegisterHandlerStats(enabled bool) HandlerStats { if !enabled { return NewNoopHandlerStats() } diff --git a/sync/statesync/code_syncer_test.go b/sync/statesync/code_syncer_test.go index 972c056095..966bd4248a 100644 --- a/sync/statesync/code_syncer_test.go +++ b/sync/statesync/code_syncer_test.go @@ -44,7 +44,7 @@ func testCodeSyncer(t *testing.T, test codeSyncerTest) { // Set up mockClient codeRequestHandler := handlers.NewCodeRequestHandler(serverDB, message.Codec, handlerstats.NewNoopHandlerStats()) - mockClient := statesyncclient.NewMockClient(message.Codec, nil, codeRequestHandler, nil) + mockClient := statesyncclient.NewTestClient(message.Codec, nil, codeRequestHandler, nil) mockClient.GetCodeIntercept = test.getCodeIntercept clientDB := rawdb.NewMemoryDatabase() diff --git a/sync/statesync/statesynctest/test_sync.go b/sync/statesync/statesynctest/test_sync.go index e380ca9d60..403670185d 100644 --- a/sync/statesync/statesynctest/test_sync.go +++ b/sync/statesync/statesynctest/test_sync.go @@ -89,26 +89,6 @@ func AssertDBConsistency(t testing.TB, root common.Hash, clientDB ethdb.Database assert.Equal(t, trieAccountLeaves, numSnapshotAccounts) } -func FillAccountsWithStorage(t *testing.T, r *rand.Rand, serverDB ethdb.Database, serverTrieDB *triedb.Database, root common.Hash, numAccounts int) common.Hash { - newRoot, _ := FillAccounts(t, r, serverTrieDB, root, numAccounts, func(t *testing.T, _ int, account types.StateAccount) types.StateAccount { - codeBytes := make([]byte, 256) - _, err := r.Read(codeBytes) - if err != nil { - t.Fatalf("error reading random code bytes: %v", err) - } - - codeHash := crypto.Keccak256Hash(codeBytes) - rawdb.WriteCode(serverDB, codeHash, codeBytes) - account.CodeHash = codeHash[:] - - // now create state trie - numKeys := 16 - account.Root, _, _ = GenerateTrie(t, r, serverTrieDB, numKeys, common.HashLength) - return account - }) - return newRoot -} - // FillAccountsWithOverlappingStorage adds [numAccounts] randomly generated accounts to the secure trie at [root] // and commits it to [trieDB]. For each 3 accounts created: // - One does not have a storage trie, diff --git a/sync/statesync/sync_test.go b/sync/statesync/sync_test.go index 63f1ed21a0..bcaf953743 100644 --- a/sync/statesync/sync_test.go +++ b/sync/statesync/sync_test.go @@ -21,9 +21,11 @@ import ( "github.com/ava-labs/libevm/rlp" "github.com/ava-labs/libevm/trie" "github.com/ava-labs/libevm/triedb" + "github.com/stretchr/testify/assert" "github.com/stretchr/testify/require" "github.com/ava-labs/subnet-evm/core/state/snapshot" + "github.com/ava-labs/subnet-evm/plugin/evm/customrawdb" "github.com/ava-labs/subnet-evm/plugin/evm/message" "github.com/ava-labs/subnet-evm/sync/handlers" "github.com/ava-labs/subnet-evm/sync/statesync/statesynctest" @@ -52,9 +54,9 @@ func testSync(t *testing.T, test syncTest) { } r := rand.New(rand.NewSource(1)) clientDB, serverDB, serverTrieDB, root := test.prepareForTest(t, r) - leafsRequestHandler := handlers.NewLeafsRequestHandler(serverTrieDB, nil, message.Codec, handlerstats.NewNoopHandlerStats()) + leafsRequestHandler := handlers.NewLeafsRequestHandler(serverTrieDB, message.StateTrieKeyLength, nil, message.Codec, handlerstats.NewNoopHandlerStats()) codeRequestHandler := handlers.NewCodeRequestHandler(serverDB, message.Codec, handlerstats.NewNoopHandlerStats()) - mockClient := statesyncclient.NewMockClient(message.Codec, leafsRequestHandler, codeRequestHandler, nil) + mockClient := statesyncclient.NewTestClient(message.Codec, leafsRequestHandler, codeRequestHandler, nil) // Set intercept functions for the mock client mockClient.GetLeafsIntercept = test.GetLeafsIntercept mockClient.GetCodeIntercept = test.GetCodeIntercept @@ -69,14 +71,17 @@ func testSync(t *testing.T, test syncTest) { RequestSize: 1024, }) require.NoError(t, err, "failed to create state syncer") - // begin sync - s.Start(ctx) + + require.NoError(t, s.Start(ctx), "failed to start state syncer") + waitFor(t, context.Background(), s.Wait, test.expectedError, testSyncTimeout) + + // Only assert database consistency if the sync was expected to succeed. if test.expectedError != nil { return } - statesynctest.AssertDBConsistency(t, root, clientDB, serverTrieDB, triedb.NewDatabase(clientDB, nil)) + assertDBConsistency(t, root, clientDB, serverTrieDB, triedb.NewDatabase(clientDB, nil)) } // testSyncResumes tests a series of syncTests work as expected, invoking a callback function after each @@ -144,7 +149,7 @@ func TestSimpleSyncCases(t *testing.T) { prepareForTest: func(t *testing.T, r *rand.Rand) (ethdb.Database, ethdb.Database, *triedb.Database, common.Hash) { serverDB := rawdb.NewMemoryDatabase() serverTrieDB := triedb.NewDatabase(serverDB, nil) - root := statesynctest.FillAccountsWithStorage(t, r, serverDB, serverTrieDB, common.Hash{}, numAccounts) + root := fillAccountsWithStorage(t, r, serverDB, serverTrieDB, common.Hash{}, numAccounts) return rawdb.NewMemoryDatabase(), serverDB, serverTrieDB, root }, }, @@ -186,7 +191,7 @@ func TestSimpleSyncCases(t *testing.T) { prepareForTest: func(t *testing.T, r *rand.Rand) (ethdb.Database, ethdb.Database, *triedb.Database, common.Hash) { serverDB := rawdb.NewMemoryDatabase() serverTrieDB := triedb.NewDatabase(serverDB, nil) - root := statesynctest.FillAccountsWithStorage(t, r, serverDB, serverTrieDB, common.Hash{}, numAccountsSmall) + root := fillAccountsWithStorage(t, r, serverDB, serverTrieDB, common.Hash{}, numAccountsSmall) return rawdb.NewMemoryDatabase(), serverDB, serverTrieDB, root }, GetCodeIntercept: func(_ []common.Hash, _ [][]byte) ([][]byte, error) { @@ -208,7 +213,7 @@ func TestCancelSync(t *testing.T) { serverDB := rawdb.NewMemoryDatabase() serverTrieDB := triedb.NewDatabase(serverDB, nil) // Create trie with 2000 accounts (more than one leaf request) - root := statesynctest.FillAccountsWithStorage(t, r, serverDB, serverTrieDB, common.Hash{}, 2000) + root := fillAccountsWithStorage(t, r, serverDB, serverTrieDB, common.Hash{}, 2000) ctx, cancel := context.WithCancel(context.Background()) defer cancel() testSync(t, syncTest{ @@ -516,20 +521,106 @@ func testSyncerSyncsToNewRoot(t *testing.T, deleteBetweenSyncs func(*testing.T, }) } +// assertDBConsistency checks [serverTrieDB] and [clientTrieDB] have the same EVM state trie at [root], +// and that [clientTrieDB.DiskDB] has corresponding account & snapshot values. +// Also verifies any code referenced by the EVM state is present in [clientTrieDB] and the hash is correct. +func assertDBConsistency(t testing.TB, root common.Hash, clientDB ethdb.Database, serverTrieDB, clientTrieDB *triedb.Database) { + numSnapshotAccounts := 0 + accountIt := customrawdb.IterateAccountSnapshots(clientDB) + defer accountIt.Release() + for accountIt.Next() { + if !bytes.HasPrefix(accountIt.Key(), rawdb.SnapshotAccountPrefix) || len(accountIt.Key()) != len(rawdb.SnapshotAccountPrefix)+common.HashLength { + continue + } + numSnapshotAccounts++ + } + if err := accountIt.Error(); err != nil { + t.Fatal(err) + } + trieAccountLeaves := 0 + + statesynctest.AssertTrieConsistency(t, root, serverTrieDB, clientTrieDB, func(key, val []byte) error { + trieAccountLeaves++ + accHash := common.BytesToHash(key) + var acc types.StateAccount + if err := rlp.DecodeBytes(val, &acc); err != nil { + return err + } + // check snapshot consistency + snapshotVal := rawdb.ReadAccountSnapshot(clientDB, accHash) + expectedSnapshotVal := types.SlimAccountRLP(acc) + assert.Equal(t, expectedSnapshotVal, snapshotVal) + + // check code consistency + if !bytes.Equal(acc.CodeHash, types.EmptyCodeHash[:]) { + codeHash := common.BytesToHash(acc.CodeHash) + code := rawdb.ReadCode(clientDB, codeHash) + actualHash := crypto.Keccak256Hash(code) + assert.NotZero(t, len(code)) + assert.Equal(t, codeHash, actualHash) + } + if acc.Root == types.EmptyRootHash { + return nil + } + + storageIt := rawdb.IterateStorageSnapshots(clientDB, accHash) + defer storageIt.Release() + + snapshotStorageKeysCount := 0 + for storageIt.Next() { + snapshotStorageKeysCount++ + } + + storageTrieLeavesCount := 0 + + // check storage trie and storage snapshot consistency + statesynctest.AssertTrieConsistency(t, acc.Root, serverTrieDB, clientTrieDB, func(key, val []byte) error { + storageTrieLeavesCount++ + snapshotVal := rawdb.ReadStorageSnapshot(clientDB, accHash, common.BytesToHash(key)) + assert.Equal(t, val, snapshotVal) + return nil + }) + + assert.Equal(t, storageTrieLeavesCount, snapshotStorageKeysCount) + return nil + }) + + // Check that the number of accounts in the snapshot matches the number of leaves in the accounts trie + assert.Equal(t, trieAccountLeaves, numSnapshotAccounts) +} + +func fillAccountsWithStorage(t *testing.T, r *rand.Rand, serverDB ethdb.Database, serverTrieDB *triedb.Database, root common.Hash, numAccounts int) common.Hash { //nolint:unparam + newRoot, _ := statesynctest.FillAccounts(t, r, serverTrieDB, root, numAccounts, func(_ *testing.T, _ int, account types.StateAccount) types.StateAccount { + codeBytes := make([]byte, 256) + _, err := r.Read(codeBytes) + require.NoError(t, err, "error reading random code bytes") + + codeHash := crypto.Keccak256Hash(codeBytes) + rawdb.WriteCode(serverDB, codeHash, codeBytes) + account.CodeHash = codeHash[:] + + // now create state trie + numKeys := 16 + account.Root, _, _ = statesynctest.GenerateTrie(t, r, serverTrieDB, numKeys, common.HashLength) + return account + }) + return newRoot +} + func TestDifferentWaitContext(t *testing.T) { r := rand.New(rand.NewSource(1)) serverDB := rawdb.NewMemoryDatabase() serverTrieDB := triedb.NewDatabase(serverDB, nil) // Create trie with many accounts to ensure sync takes time - root := statesynctest.FillAccountsWithStorage(t, r, serverDB, serverTrieDB, common.Hash{}, 2000) + root := fillAccountsWithStorage(t, r, serverDB, serverTrieDB, common.Hash{}, 2000) clientDB := rawdb.NewMemoryDatabase() // Track requests to show sync continues after Wait returns var requestCount int64 - leafsRequestHandler := handlers.NewLeafsRequestHandler(serverTrieDB, nil, message.Codec, handlerstats.NewNoopHandlerStats()) + leafsRequestHandler := handlers.NewLeafsRequestHandler(serverTrieDB, message.StateTrieKeyLength, nil, message.Codec, handlerstats.NewNoopHandlerStats()) codeRequestHandler := handlers.NewCodeRequestHandler(serverDB, message.Codec, handlerstats.NewNoopHandlerStats()) - mockClient := statesyncclient.NewMockClient(message.Codec, leafsRequestHandler, codeRequestHandler, nil) + mockClient := statesyncclient.NewTestClient(message.Codec, leafsRequestHandler, codeRequestHandler, nil) // Intercept to track ongoing requests and add delay mockClient.GetLeafsIntercept = func(_ message.LeafsRequest, resp message.LeafsResponse) (message.LeafsResponse, error) {