diff --git a/cmd/access/node_builder/access_node_builder.go b/cmd/access/node_builder/access_node_builder.go index 5ac66c3d726..a879597f3ca 100644 --- a/cmd/access/node_builder/access_node_builder.go +++ b/cmd/access/node_builder/access_node_builder.go @@ -2194,12 +2194,17 @@ func (builder *FlowAccessNodeBuilder) Build() (cmd.Node, error) { return builder.RpcEng, nil }). Component("requester engine", func(node *cmd.NodeConfig) (module.ReadyDoneAware, error) { + fifoStore, err := engine.NewFifoMessageStore(requester.DefaultEntityRequestCacheSize) + if err != nil { + return nil, fmt.Errorf("could not create requester store: %w", err) + } requestEng, err := requester.New( node.Logger.With().Str("entity", "collection").Logger(), node.Metrics.Engine, node.EngineRegistry, node.Me, node.State, + fifoStore, channels.RequestCollections, filter.HasRole[flow.Identity](flow.RoleCollection), func() flow.Entity { return new(flow.Collection) }, diff --git a/cmd/consensus/main.go b/cmd/consensus/main.go index cf22c928e2d..3daf6b686c5 100644 --- a/cmd/consensus/main.go +++ b/cmd/consensus/main.go @@ -30,6 +30,7 @@ import ( "github.com/onflow/flow-go/consensus/hotstuff/verification" "github.com/onflow/flow-go/consensus/hotstuff/votecollector" recovery "github.com/onflow/flow-go/consensus/recovery/protocol" + "github.com/onflow/flow-go/engine" "github.com/onflow/flow-go/engine/common/requester" synceng "github.com/onflow/flow-go/engine/common/synchronization" "github.com/onflow/flow-go/engine/consensus/approvals/tracker" @@ -487,12 +488,17 @@ func main() { return e, err }). Component("matching engine", func(node *cmd.NodeConfig) (module.ReadyDoneAware, error) { + fifoStore, err := engine.NewFifoMessageStore(requester.DefaultEntityRequestCacheSize) + if err != nil { + return nil, fmt.Errorf("could not create requester store: %w", err) + } receiptRequester, err = requester.New( node.Logger.With().Str("entity", "receipt").Logger(), node.Metrics.Engine, node.EngineRegistry, node.Me, node.State, + fifoStore, channels.RequestReceiptsByBlockID, filter.HasRole[flow.Identity](flow.RoleExecution), func() flow.Entity { return new(flow.ExecutionReceipt) }, diff --git a/cmd/execution_builder.go b/cmd/execution_builder.go index e79a8a4aea8..9dc3d2557a0 100644 --- a/cmd/execution_builder.go +++ b/cmd/execution_builder.go @@ -1102,7 +1102,12 @@ func (exeNode *ExecutionNode) LoadIngestionEngine( colFetcher = accessFetcher exeNode.collectionRequester = accessFetcher } else { + fifoStore, err := engine.NewFifoMessageStore(requester.DefaultEntityRequestCacheSize) + if err != nil { + return nil, fmt.Errorf("could not create requester store: %w", err) + } reqEng, err := requester.New(node.Logger.With().Str("entity", "collection").Logger(), node.Metrics.Engine, node.EngineRegistry, node.Me, node.State, + fifoStore, channels.RequestCollections, filter.Any, func() flow.Entity { return new(flow.Collection) }, diff --git a/engine/common/requester/engine.go b/engine/common/requester/engine.go index 1354425b1ed..67171312840 100644 --- a/engine/common/requester/engine.go +++ b/engine/common/requester/engine.go @@ -3,6 +3,7 @@ package requester import ( "fmt" "math" + "sync" "time" "github.com/rs/zerolog" @@ -14,6 +15,8 @@ import ( "github.com/onflow/flow-go/model/flow/filter" "github.com/onflow/flow-go/model/messages" "github.com/onflow/flow-go/module" + "github.com/onflow/flow-go/module/component" + "github.com/onflow/flow-go/module/irrecoverable" "github.com/onflow/flow-go/module/metrics" "github.com/onflow/flow-go/network" "github.com/onflow/flow-go/network/channels" @@ -22,6 +25,10 @@ import ( "github.com/onflow/flow-go/utils/rand" ) +// DefaultEntityRequestCacheSize is the default max message queue size for the provider engine. +// This equates to ~5GB of memory usage with a full queue (10M*500) +const DefaultEntityRequestCacheSize = 500 + // HandleFunc is a function provided to the requester engine to handle an entity // once it has been retrieved from a provider. The function should be non-blocking // and errors should be handled internally within the function. @@ -35,29 +42,46 @@ type CreateFunc func() flow.Entity // on the flow network. It is the `request` part of the request-reply // pattern provided by the pair of generic exchange engines. type Engine struct { - unit *engine.Unit - log zerolog.Logger - cfg Config - metrics module.EngineMetrics - me module.Local - state protocol.State - con network.Conduit - channel channels.Channel - selector flow.IdentityFilter[flow.Identity] - create CreateFunc - handle HandleFunc - - // changing the following state variables must be guarded by unit.Lock() + *component.ComponentManager + mu sync.Mutex + log zerolog.Logger + cfg Config + metrics module.EngineMetrics + me module.Local + state protocol.State + con network.Conduit + channel channels.Channel + requestHandler *engine.MessageHandler + requestQueue engine.MessageStore + selector flow.IdentityFilter[flow.Identity] + create CreateFunc + handle HandleFunc + + // changing the following state variables must be guarded by mu.Lock() items map[flow.Identifier]*Item requests map[uint64]*messages.EntityRequest forcedDispatchOngoing *atomic.Bool // to ensure only trigger dispatching logic once at any time } +var _ component.Component = (*Engine)(nil) +var _ network.MessageProcessor = (*Engine)(nil) + // New creates a new requester engine, operating on the provided network channel, and requesting entities from a node // within the set obtained by applying the provided selector filter. The options allow customization of the parameters // related to the batch and retry logic. -func New(log zerolog.Logger, metrics module.EngineMetrics, net network.EngineRegistry, me module.Local, state protocol.State, - channel channels.Channel, selector flow.IdentityFilter[flow.Identity], create CreateFunc, options ...OptionFunc) (*Engine, error) { +// No errors are expected during normal operations. +func New( + log zerolog.Logger, + metrics module.EngineMetrics, + net network.EngineRegistry, + me module.Local, + state protocol.State, + requestQueue engine.MessageStore, + channel channels.Channel, + selector flow.IdentityFilter[flow.Identity], + create CreateFunc, + options ...OptionFunc, +) (*Engine, error) { // initialize the default config cfg := Config{ @@ -102,14 +126,29 @@ func New(log zerolog.Logger, metrics module.EngineMetrics, net network.EngineReg ) } + handler := engine.NewMessageHandler( + log, + engine.NewNotifier(), + engine.Pattern{ + // Match is called on every new message coming to this engine. + // Provider engine only expects *flow.EntityResponse. + // Other message types are discarded by Match. + Match: func(message *engine.Message) bool { + _, ok := message.Payload.(*flow.EntityResponse) + return ok + }, + Store: requestQueue, + }) + // initialize the propagation engine with its dependencies e := &Engine{ - unit: engine.NewUnit(), log: log.With().Str("engine", "requester").Logger(), cfg: cfg, metrics: metrics, me: me, state: state, + requestHandler: handler, + requestQueue: requestQueue, channel: channel, selector: selector, create: create, @@ -120,12 +159,17 @@ func New(log zerolog.Logger, metrics module.EngineMetrics, net network.EngineReg } // register the engine with the network layer and store the conduit - con, err := net.Register(channels.Channel(channel), e) + con, err := net.Register(channel, e) if err != nil { return nil, fmt.Errorf("could not register engine: %w", err) } e.con = con + e.ComponentManager = component.NewComponentManagerBuilder(). + AddWorker(e.poll). + AddWorker(e.processQueuedRequestsShovellerWorker). + Build() + return e, nil } @@ -138,58 +182,89 @@ func (e *Engine) WithHandle(handle HandleFunc) { e.handle = handle } -// Ready returns a ready channel that is closed once the engine has fully -// started. For consensus engine, this is true once the underlying consensus -// algorithm has started. -func (e *Engine) Ready() <-chan struct{} { - if e.handle == nil { - panic("must initialize requester engine with handler") +// Process processes the given message from the node with the given origin ID in +// a blocking manner. It returns the potential processing error when done. +func (e *Engine) Process(channel channels.Channel, originID flow.Identifier, event interface{}) error { + select { + case <-e.ShutdownSignal(): + e.log.Warn(). + Hex("origin_id", logging.ID(originID)). + Msgf("received message after shutdown") + return nil + default: } - e.unit.Launch(e.poll) - return e.unit.Ready() -} -// Done returns a done channel that is closed once the engine has fully stopped. -// For the consensus engine, we wait for hotstuff to finish. -func (e *Engine) Done() <-chan struct{} { - return e.unit.Done() + e.metrics.MessageReceived(e.channel.String(), metrics.MessageEntityResponse) + err := e.requestHandler.Process(originID, event) + if err != nil { + if engine.IsIncompatibleInputTypeError(err) { + e.log.Warn(). + Hex("origin_id", logging.ID(originID)). + Str("channel", channel.String()). + Str("event", fmt.Sprintf("%+v", event)). + Bool(logging.KeySuspicious, true). + Msg("received unsupported message type") + return nil + } + return fmt.Errorf("unexpected error while processing engine event: %w", err) + } + return nil } -// SubmitLocal submits an message originating on the local node. -func (e *Engine) SubmitLocal(message interface{}) { - e.unit.Launch(func() { - err := e.process(e.me.NodeID(), message) - if err != nil { - engine.LogError(e.log, err) +// processQueuedRequestsShovellerWorker runs as a dedicated worker for [component.ComponentManager]. +// It tracks when there is available work and performs dispatch of incoming messages. +func (e *Engine) processQueuedRequestsShovellerWorker(ctx irrecoverable.SignalerContext, ready component.ReadyFunc) { + ready() + + e.log.Debug().Msg("process entity request shoveller worker started") + + for { + select { + case <-e.requestHandler.GetNotifier(): + // there is at least a single request in the queue, so we try to process it. + e.processAvailableMessages(ctx) + case <-ctx.Done(): + return } - }) + } } -// Submit submits the given message from the node with the given origin ID -// for processing in a non-blocking manner. It returns instantly and logs -// a potential processing error internally when done. -func (e *Engine) Submit(channel channels.Channel, originID flow.Identifier, message interface{}) { - e.unit.Launch(func() { - err := e.Process(channel, originID, message) - if err != nil { - engine.LogError(e.log, err) +// processAvailableMessages is called when there are messages in the queue that are ready to be processed. +// All unexpected errors are reported to the SignalerContext. +func (e *Engine) processAvailableMessages(ctx irrecoverable.SignalerContext) { + for { + select { + case <-ctx.Done(): + return + default: } - }) -} -// ProcessLocal processes an message originating on the local node. -func (e *Engine) ProcessLocal(message interface{}) error { - return e.unit.Do(func() error { - return e.process(e.me.NodeID(), message) - }) -} + msg, ok := e.requestQueue.Get() + if !ok { + // no more requests, return + return + } -// Process processes the given message from the node with the given origin ID in -// a blocking manner. It returns the potential processing error when done. -func (e *Engine) Process(channel channels.Channel, originID flow.Identifier, message interface{}) error { - return e.unit.Do(func() error { - return e.process(originID, message) - }) + res, ok := msg.Payload.(*flow.EntityResponse) + if !ok { + // should never happen, as we only put EntityRequest in the queue, + // if it does happen, it means there is a bug in the queue implementation. + ctx.Throw(fmt.Errorf("invalid message type in entity request queue: %T", msg.Payload)) + } + + err := e.onEntityResponse(msg.OriginID, res) + if err != nil { + if engine.IsInvalidInputError(err) { + e.log.Err(err). + Str("origin_id", msg.OriginID.String()). + Uint64("nonce", res.Nonce). + Bool(logging.KeySuspicious, true). + Msg("invalid response detected") + continue + } + ctx.Throw(err) + } + } } // EntityByID adds an entity to the list of entities to be requested from the @@ -214,9 +289,11 @@ func (e *Engine) Query(key flow.Identifier, selector flow.IdentityFilter[flow.Id e.addEntityRequest(key, selector, false) } +// addEntityRequest adds request in in-memory storage of pending items to be requested. +// Concurrency safe. func (e *Engine) addEntityRequest(entityID flow.Identifier, selector flow.IdentityFilter[flow.Identity], checkIntegrity bool) { - e.unit.Lock() - defer e.unit.Unlock() + e.mu.Lock() + defer e.mu.Unlock() // check if we already have an item for this entity _, duplicate := e.items[entityID] @@ -245,7 +322,7 @@ func (e *Engine) Force() { } // using Launch to ensure the caller won't be blocked - e.unit.Launch(func() { + go func() { // using atomic bool to ensure there is at most one caller would trigger dispatching requests if e.forcedDispatchOngoing.CompareAndSwap(false, true) { count := uint(0) @@ -263,35 +340,38 @@ func (e *Engine) Force() { } e.forcedDispatchOngoing.Store(false) } - }) + }() } -func (e *Engine) poll() { - ticker := time.NewTicker(e.cfg.BatchInterval) +// poll runs as a dedicated worker for [component.ComponentManager]. It performs dispatch of pending requests using a timer. +func (e *Engine) poll(ctx irrecoverable.SignalerContext, ready component.ReadyFunc) { + if e.handle == nil { + ctx.Throw(fmt.Errorf("must initialize requester engine with handler")) + } + + ready() -PollLoop: + ticker := time.NewTicker(e.cfg.BatchInterval) + defer ticker.Stop() for { select { - case <-e.unit.Quit(): - break PollLoop + case <-e.ShutdownSignal(): + return case <-ticker.C: if e.forcedDispatchOngoing.Load() { - return + continue } dispatched, err := e.dispatchRequest() if err != nil { - e.log.Error().Err(err).Msg("could not dispatch requests") - continue PollLoop + ctx.Throw(err) } if dispatched { e.log.Debug().Uint("requests", 1).Msg("regular request dispatch") } } } - - ticker.Stop() } // dispatchRequest dispatches a subset of requests (selection based on internal heuristic). @@ -299,10 +379,10 @@ PollLoop: // if and only if there is something to request. In other words it cannot happen that // `dispatchRequest` sends no request, but there is something to be requested. // The boolean return value indicates whether a request was dispatched at all. +// No errors are expected during normal operations. func (e *Engine) dispatchRequest() (bool, error) { - - e.unit.Lock() - defer e.unit.Unlock() + e.mu.Lock() + defer e.mu.Unlock() e.log.Debug().Int("num_entities", len(e.items)).Msg("selecting entities") @@ -352,7 +432,8 @@ func (e *Engine) dispatchRequest() (bool, error) { if providerID == flow.ZeroID { filteredProviders := providers.Filter(item.ExtraSelector) if len(filteredProviders) == 0 { - return false, fmt.Errorf("no valid providers available for item %s, total providers: %v", entityID.String(), len(providers)) + e.log.Error().Msgf("could not dispatch requests: no valid providers available for item %s, total providers: %v", entityID.String(), len(providers)) + return false, nil } // ramdonly select a provider from the filtered set // to send as many item requests as possible. @@ -417,7 +498,8 @@ func (e *Engine) dispatchRequest() (bool, error) { err = e.con.Unicast(req, providerID) if err != nil { - return true, fmt.Errorf("could not send request for entities %v: %w", logging.IDs(entityIDs), err) + e.log.Error().Err(err).Msgf("could not dispatch requests: could not send request for entities %v", logging.IDs(entityIDs)) + return false, nil } e.requests[req.Nonce] = req @@ -429,9 +511,9 @@ func (e *Engine) dispatchRequest() (bool, error) { go func() { <-time.After(e.cfg.RetryInitial) - e.unit.Lock() - defer e.unit.Unlock() + e.mu.Lock() delete(e.requests, req.Nonce) + e.mu.Unlock() }() if e.log.Debug().Enabled() { @@ -447,27 +529,19 @@ func (e *Engine) dispatchRequest() (bool, error) { return true, nil } -// process processes events for the propagation engine on the consensus node. -func (e *Engine) process(originID flow.Identifier, message interface{}) error { - - e.metrics.MessageReceived(e.channel.String(), metrics.MessageEntityResponse) - defer e.metrics.MessageHandled(e.channel.String(), metrics.MessageEntityResponse) - - switch msg := message.(type) { - case *flow.EntityResponse: - return e.onEntityResponse(originID, msg) - default: - return engine.NewInvalidInputErrorf("invalid message type (%T)", message) - } -} - +// onEntityResponse handles response for request that was originally made by the engine. +// For each successful response this function spawns a dedicated go routine to perform handling of the parsed response. +// Considering the fact we process only responses that we have previously requested it's impossible to force this function to +// spawn arbitrary number of goroutines. +// Expected errors during normal operations: +// - [engine.InvalidInputError] if the provided response is malformed func (e *Engine) onEntityResponse(originID flow.Identifier, res *flow.EntityResponse) error { + defer e.metrics.MessageHandled(e.channel.String(), metrics.MessageEntityResponse) lg := e.log.With().Str("origin_id", originID.String()).Uint64("nonce", res.Nonce).Logger() lg.Debug().Strs("entity_ids", flow.IdentifierList(res.EntityIDs).Strings()).Msg("entity response received") if e.cfg.ValidateStaking { - // check that the response comes from a valid provider providers, err := e.state.Final().Identities(filter.And( e.selector, @@ -489,8 +563,8 @@ func (e *Engine) onEntityResponse(originID flow.Identifier, res *flow.EntityResp Msg("onEntityResponse entries received") } - e.unit.Lock() - defer e.unit.Unlock() + e.mu.Lock() + defer e.mu.Unlock() // build a list of needed entities; if not available, process anyway, // but in that case we can't re-queue missing items @@ -526,7 +600,7 @@ func (e *Engine) onEntityResponse(originID flow.Identifier, res *flow.EntityResp entity := e.create() err := msgpack.Unmarshal(blob, &entity) if err != nil { - return fmt.Errorf("could not decode entity: %w", err) + return engine.NewInvalidInputErrorf("could not decode entity: %s", err.Error()) } if item.checkIntegrity { diff --git a/engine/common/requester/engine_test.go b/engine/common/requester/engine_test.go index e10555e19ba..6fd2b58c52c 100644 --- a/engine/common/requester/engine_test.go +++ b/engine/common/requester/engine_test.go @@ -9,6 +9,7 @@ import ( "github.com/stretchr/testify/assert" "github.com/stretchr/testify/mock" "github.com/stretchr/testify/require" + "github.com/stretchr/testify/suite" "github.com/vmihailenco/msgpack/v4" "go.uber.org/atomic" @@ -18,51 +19,86 @@ import ( "github.com/onflow/flow-go/model/flow" "github.com/onflow/flow-go/model/flow/filter" "github.com/onflow/flow-go/model/messages" + "github.com/onflow/flow-go/module/mempool/queue" "github.com/onflow/flow-go/module/metrics" mocknetwork "github.com/onflow/flow-go/network/mock" protocol "github.com/onflow/flow-go/state/protocol/mock" "github.com/onflow/flow-go/utils/unittest" ) -func TestEntityByID(t *testing.T) { +func TestRequesterEngine(t *testing.T) { + suite.Run(t, new(RequesterEngineSuite)) +} - request := Engine{ - unit: engine.NewUnit(), - items: make(map[flow.Identifier]*Item), - } +// RequesterEngineSuite is a test suite for the requester engine that holds minimal state for testing. +type RequesterEngineSuite struct { + suite.Suite + con *mocknetwork.Conduit + final *protocol.Snapshot + + engine *Engine +} + +func (s *RequesterEngineSuite) SetupTest() { + s.final = protocol.NewSnapshot(s.T()) + + state := protocol.NewState(s.T()) + state.On("Final").Return(s.final).Maybe() + + me := module.NewLocal(s.T()) + localID := unittest.IdentifierFixture() + me.On("NodeID").Return(localID).Maybe() + + s.con = mocknetwork.NewConduit(s.T()) + + network := mocknetwork.NewEngineRegistry(s.T()) + network.On("Register", mock.Anything, mock.Anything).Return(s.con, nil) + requestQueue := queue.NewHeroStore(10, unittest.Logger(), metrics.NewNoopCollector()) + var err error + s.engine, err = New( + zerolog.Nop(), + metrics.NewNoopCollector(), + network, + me, + state, + requestQueue, + "", + filter.Any, + func() flow.Entity { return &flow.Collection{} }, + ) + require.NoError(s.T(), err) +} + +func (s *RequesterEngineSuite) TestEntityByID() { now := time.Now().UTC() entityID := unittest.IdentifierFixture() selector := filter.Any - request.EntityByID(entityID, selector) + s.engine.EntityByID(entityID, selector) - assert.Len(t, request.items, 1) - item, contains := request.items[entityID] - if assert.True(t, contains) { - assert.Equal(t, item.EntityID, entityID) - assert.Equal(t, item.NumAttempts, uint(0)) + assert.Len(s.T(), s.engine.items, 1) + item, contains := s.engine.items[entityID] + if assert.True(s.T(), contains) { + assert.Equal(s.T(), item.EntityID, entityID) + assert.Equal(s.T(), item.NumAttempts, uint(0)) cutoff := item.LastRequested.Add(item.RetryAfter) - assert.True(t, cutoff.Before(now)) // make sure we push out immediately + assert.True(s.T(), cutoff.Before(now)) // make sure we push out immediately } } -func TestDispatchRequestVarious(t *testing.T) { +func (s *RequesterEngineSuite) TestDispatchRequestVarious() { identities := unittest.IdentityListFixture(16) targetID := identities[0].NodeID - final := &protocol.Snapshot{} - final.On("Identities", mock.Anything).Return( + s.final.On("Identities", mock.Anything).Return( func(selector flow.IdentityFilter[flow.Identity]) flow.IdentityList { return identities.Filter(selector) }, nil, ) - state := &protocol.State{} - state.On("Final").Return(final) - cfg := Config{ BatchInterval: 200 * time.Millisecond, BatchThreshold: 999, @@ -112,67 +148,52 @@ func TestDispatchRequestVarious(t *testing.T) { items[triedAnciently.EntityID] = triedAnciently items[triedRecently.EntityID] = triedRecently items[triedTwice.EntityID] = triedTwice + s.engine.cfg = cfg + s.engine.items = items + s.engine.selector = filter.HasNodeID[flow.Identity](targetID) var nonce uint64 - con := &mocknetwork.Conduit{} - con.On("Unicast", mock.Anything, mock.Anything).Run( + s.con.On("Unicast", mock.Anything, mock.Anything).Run( func(args mock.Arguments) { request := args.Get(0).(*messages.EntityRequest) originID := args.Get(1).(flow.Identifier) nonce = request.Nonce - assert.Equal(t, originID, targetID) - assert.ElementsMatch(t, request.EntityIDs, []flow.Identifier{justAdded.EntityID, triedAnciently.EntityID}) + assert.Equal(s.T(), originID, targetID) + assert.ElementsMatch(s.T(), request.EntityIDs, []flow.Identifier{justAdded.EntityID, triedAnciently.EntityID}) }, - ).Return(nil) - - request := Engine{ - unit: engine.NewUnit(), - metrics: metrics.NewNoopCollector(), - cfg: cfg, - state: state, - con: con, - items: items, - requests: make(map[uint64]*messages.EntityRequest), - selector: filter.HasNodeID[flow.Identity](targetID), - } - dispatched, err := request.dispatchRequest() - require.NoError(t, err) - require.True(t, dispatched) + ).Return(nil).Once() - con.AssertExpectations(t) + dispatched, err := s.engine.dispatchRequest() + require.NoError(s.T(), err) + require.True(s.T(), dispatched) - request.unit.Lock() - assert.Contains(t, request.requests, nonce) - request.unit.Unlock() + s.engine.mu.Lock() + assert.Contains(s.T(), s.engine.requests, nonce) + s.engine.mu.Unlock() // TODO: racy/slow test time.Sleep(2 * cfg.RetryInitial) - request.unit.Lock() - assert.NotContains(t, request.requests, nonce) - request.unit.Unlock() + s.engine.mu.Lock() + assert.NotContains(s.T(), s.engine.requests, nonce) + s.engine.mu.Unlock() } -func TestDispatchRequestBatchSize(t *testing.T) { +func (s *RequesterEngineSuite) TestDispatchRequestBatchSize() { batchLimit := uint(16) totalItems := uint(99) identities := unittest.IdentityListFixture(16) - - final := &protocol.Snapshot{} - final.On("Identities", mock.Anything).Return( + s.final.On("Identities", mock.Anything).Return( func(selector flow.IdentityFilter[flow.Identity]) flow.IdentityList { return identities.Filter(selector) }, nil, ) - state := &protocol.State{} - state.On("Final").Return(final) - - cfg := Config{ + s.engine.cfg = Config{ BatchInterval: 24 * time.Hour, BatchThreshold: batchLimit, RetryInitial: 24 * time.Hour, @@ -182,59 +203,41 @@ func TestDispatchRequestBatchSize(t *testing.T) { } // item that has just been added, should be included - items := make(map[flow.Identifier]*Item) for i := uint(0); i < totalItems; i++ { item := &Item{ EntityID: unittest.IdentifierFixture(), NumAttempts: 0, LastRequested: time.Time{}, - RetryAfter: cfg.RetryInitial, + RetryAfter: s.engine.cfg.RetryInitial, ExtraSelector: filter.Any, } - items[item.EntityID] = item + s.engine.items[item.EntityID] = item } - con := &mocknetwork.Conduit{} - con.On("Unicast", mock.Anything, mock.Anything).Run( + s.con.On("Unicast", mock.Anything, mock.Anything).Run( func(args mock.Arguments) { request := args.Get(0).(*messages.EntityRequest) - assert.Len(t, request.EntityIDs, int(batchLimit)) + assert.Len(s.T(), request.EntityIDs, int(batchLimit)) }, - ).Return(nil) - - request := Engine{ - unit: engine.NewUnit(), - metrics: metrics.NewNoopCollector(), - cfg: cfg, - state: state, - con: con, - items: items, - requests: make(map[uint64]*messages.EntityRequest), - selector: filter.Any, - } - dispatched, err := request.dispatchRequest() - require.NoError(t, err) - require.True(t, dispatched) + ).Return(nil).Once() - con.AssertExpectations(t) + dispatched, err := s.engine.dispatchRequest() + require.NoError(s.T(), err) + require.True(s.T(), dispatched) } -func TestOnEntityResponseValid(t *testing.T) { +func (s *RequesterEngineSuite) TestOnEntityResponseValid() { identities := unittest.IdentityListFixture(16) targetID := identities[0].NodeID - final := &protocol.Snapshot{} - final.On("Identities", mock.Anything).Return( + s.final.On("Identities", mock.Anything).Return( func(selector flow.IdentityFilter[flow.Identity]) flow.IdentityList { return identities.Filter(selector) }, nil, ) - state := &protocol.State{} - state.On("Final").Return(final) - nonce := rand.Uint64() wanted1 := unittest.CollectionFixture(1) @@ -277,62 +280,49 @@ func TestOnEntityResponseValid(t *testing.T) { done := make(chan struct{}) called := *atomic.NewUint64(0) - request := Engine{ - unit: engine.NewUnit(), - metrics: metrics.NewNoopCollector(), - state: state, - items: make(map[flow.Identifier]*Item), - requests: make(map[uint64]*messages.EntityRequest), - selector: filter.HasNodeID[flow.Identity](targetID), - create: func() flow.Entity { return &flow.Collection{} }, - handle: func(flow.Identifier, flow.Entity) { - if called.Inc() >= 2 { - close(done) - } - }, - } + s.engine.WithHandle(func(flow.Identifier, flow.Entity) { + if called.Inc() >= 2 { + close(done) + } + }) - request.items[iwanted1.EntityID] = iwanted1 - request.items[iwanted2.EntityID] = iwanted2 - request.items[iunavailable.EntityID] = iunavailable + s.engine.items[iwanted1.EntityID] = iwanted1 + s.engine.items[iwanted2.EntityID] = iwanted2 + s.engine.items[iunavailable.EntityID] = iunavailable - request.requests[req.Nonce] = req + s.engine.requests[req.Nonce] = req - err := request.onEntityResponse(targetID, res) - assert.NoError(t, err) + err := s.engine.onEntityResponse(targetID, res) + assert.NoError(s.T(), err) // check that the request was removed - assert.NotContains(t, request.requests, nonce) + assert.NotContains(s.T(), s.engine.requests, nonce) // check that the provided items were removed - assert.NotContains(t, request.items, wanted1.ID()) - assert.NotContains(t, request.items, wanted2.ID()) + assert.NotContains(s.T(), s.engine.items, wanted1.ID()) + assert.NotContains(s.T(), s.engine.items, wanted2.ID()) // check that the missing item is still there - assert.Contains(t, request.items, unavailable.ID()) + assert.Contains(s.T(), s.engine.items, unavailable.ID()) // make sure we processed two items - unittest.AssertClosesBefore(t, done, time.Second) + unittest.AssertClosesBefore(s.T(), done, time.Second) // check that the missing items timestamp was reset - assert.Equal(t, iunavailable.LastRequested, time.Time{}) + assert.Equal(s.T(), iunavailable.LastRequested, time.Time{}) } -func TestOnEntityIntegrityCheck(t *testing.T) { +func (s *RequesterEngineSuite) TestOnEntityIntegrityCheck() { identities := unittest.IdentityListFixture(16) targetID := identities[0].NodeID - final := &protocol.Snapshot{} - final.On("Identities", mock.Anything).Return( + s.final.On("Identities", mock.Anything).Return( func(selector flow.IdentityFilter[flow.Identity]) flow.IdentityList { return identities.Filter(selector) }, nil, ) - state := &protocol.State{} - state.On("Final").Return(final) - nonce := rand.Uint64() wanted := unittest.CollectionFixture(1) @@ -347,7 +337,7 @@ func TestOnEntityIntegrityCheck(t *testing.T) { checkIntegrity: true, } - assert.NotEqual(t, wanted, wanted2) + assert.NotEqual(s.T(), wanted, wanted2) // prepare payload from different entity bwanted, _ := msgpack.Marshal(wanted2) @@ -364,63 +354,44 @@ func TestOnEntityIntegrityCheck(t *testing.T) { } called := make(chan struct{}) - request := Engine{ - unit: engine.NewUnit(), - metrics: metrics.NewNoopCollector(), - state: state, - items: make(map[flow.Identifier]*Item), - requests: make(map[uint64]*messages.EntityRequest), - selector: filter.HasNodeID[flow.Identity](targetID), - create: func() flow.Entity { return &flow.Collection{} }, - handle: func(flow.Identifier, flow.Entity) { close(called) }, - } + s.engine.WithHandle(func(flow.Identifier, flow.Entity) { close(called) }) - request.items[iwanted.EntityID] = iwanted + s.engine.items[iwanted.EntityID] = iwanted - request.requests[req.Nonce] = req + s.engine.requests[req.Nonce] = req - err := request.onEntityResponse(targetID, res) - assert.NoError(t, err) + err := s.engine.onEntityResponse(targetID, res) + assert.NoError(s.T(), err) // check that the request was removed - assert.NotContains(t, request.requests, nonce) + assert.NotContains(s.T(), s.engine.requests, nonce) // check that the provided item wasn't removed - assert.Contains(t, request.items, wanted.ID()) + assert.Contains(s.T(), s.engine.items, wanted.ID()) iwanted.checkIntegrity = false - request.items[iwanted.EntityID] = iwanted - request.requests[req.Nonce] = req + s.engine.items[iwanted.EntityID] = iwanted + s.engine.requests[req.Nonce] = req - err = request.onEntityResponse(targetID, res) - assert.NoError(t, err) + err = s.engine.onEntityResponse(targetID, res) + assert.NoError(s.T(), err) // make sure we process item without checking integrity - unittest.AssertClosesBefore(t, called, time.Second) + unittest.AssertClosesBefore(s.T(), called, time.Second) } // Verify that the origin should not be checked when ValidateStaking config is set to false -func TestOriginValidation(t *testing.T) { +func (s *RequesterEngineSuite) TestOriginValidation() { identities := unittest.IdentityListFixture(16) targetID := identities[0].NodeID - wrongID := identities[1].NodeID - meID := identities[3].NodeID + wrongID := unittest.IdentifierFixture() - final := &protocol.Snapshot{} - final.On("Identities", mock.Anything).Return( + s.final.On("Identities", mock.Anything).Return( func(selector flow.IdentityFilter[flow.Identity]) flow.IdentityList { return identities.Filter(selector) }, nil, ) - - state := &protocol.State{} - state.On("Final").Return(final) - - me := &module.Local{} - - me.On("NodeID").Return(meID) - nonce := rand.Uint64() wanted := unittest.CollectionFixture(1) @@ -451,38 +422,26 @@ func TestOriginValidation(t *testing.T) { network := &mocknetwork.EngineRegistry{} network.On("Register", mock.Anything, mock.Anything).Return(nil, nil) - e, err := New( - zerolog.Nop(), - metrics.NewNoopCollector(), - network, - me, - state, - "", - filter.HasNodeID[flow.Identity](targetID), - func() flow.Entity { return &flow.Collection{} }, - ) - assert.NoError(t, err) - called := make(chan struct{}) - e.WithHandle(func(origin flow.Identifier, _ flow.Entity) { + s.engine.WithHandle(func(origin flow.Identifier, _ flow.Entity) { // we expect wrong origin to propagate here with validation disabled - assert.Equal(t, wrongID, origin) + assert.Equal(s.T(), wrongID, origin) close(called) }) - e.items[iwanted.EntityID] = iwanted - e.requests[req.Nonce] = req + s.engine.items[iwanted.EntityID] = iwanted + s.engine.requests[req.Nonce] = req - err = e.onEntityResponse(wrongID, res) - assert.Error(t, err) - assert.IsType(t, engine.InvalidInputError{}, err) + err := s.engine.onEntityResponse(wrongID, res) + assert.Error(s.T(), err) + assert.IsType(s.T(), engine.InvalidInputError{}, err) - e.cfg.ValidateStaking = false + s.engine.cfg.ValidateStaking = false - err = e.onEntityResponse(wrongID, res) - assert.NoError(t, err) + err = s.engine.onEntityResponse(wrongID, res) + assert.NoError(s.T(), err) // handler are called async, but this should be extremely quick - unittest.AssertClosesBefore(t, called, time.Second) + unittest.AssertClosesBefore(s.T(), called, time.Second) } diff --git a/engine/execution/ingestion/machine.go b/engine/execution/ingestion/machine.go index 194c12b8fea..3074989a65a 100644 --- a/engine/execution/ingestion/machine.go +++ b/engine/execution/ingestion/machine.go @@ -35,6 +35,7 @@ type Machine struct { type CollectionRequester interface { module.ReadyDoneAware + module.Startable WithHandle(requester.HandleFunc) } @@ -42,7 +43,6 @@ func NewMachine( logger zerolog.Logger, protocolEvents *events.Distributor, collectionRequester CollectionRequester, - collectionFetcher CollectionFetcher, headers storage.Headers, blocks storage.Blocks, diff --git a/engine/fifoqueue.go b/engine/fifoqueue.go index 459e5951a78..ef45ce5a31d 100644 --- a/engine/fifoqueue.go +++ b/engine/fifoqueue.go @@ -9,6 +9,16 @@ type FifoMessageStore struct { *fifoqueue.FifoQueue } +// NewFifoMessageStore creates a FifoMessageStore backed by a fifoqueue.FifoQueue. +// No errors are expected during normal operations. +func NewFifoMessageStore(maxCapacity int) (*FifoMessageStore, error) { + queue, err := fifoqueue.NewFifoQueue(maxCapacity) + if err != nil { + return nil, err + } + return &FifoMessageStore{FifoQueue: queue}, nil +} + func (s *FifoMessageStore) Put(msg *Message) bool { return s.Push(msg) } diff --git a/engine/testutil/mock/nodes.go b/engine/testutil/mock/nodes.go index e41a4242be9..a45f8a6369e 100644 --- a/engine/testutil/mock/nodes.go +++ b/engine/testutil/mock/nodes.go @@ -243,6 +243,7 @@ func (en ExecutionNode) Ready(t *testing.T, ctx context.Context) { en.FollowerCore.Start(irctx) en.FollowerEngine.Start(irctx) en.SyncEngine.Start(irctx) + en.RequestEngine.Start(irctx) <-util.AllReady( en.Ledger, diff --git a/engine/testutil/nodes.go b/engine/testutil/nodes.go index 6dbb9f33f3c..384d04b8087 100644 --- a/engine/testutil/nodes.go +++ b/engine/testutil/nodes.go @@ -459,8 +459,19 @@ func ConsensusNode(t *testing.T, hub *stub.Hub, identity bootstrap.NodeInfo, ide ingestionEngine, err := consensusingest.New(node.Log, node.Metrics, node.Net, node.Me, ingestionCore) require.NoError(t, err) + requestQueue := queue.NewHeroStore(10, unittest.Logger(), metrics.NewNoopCollector()) // request receipts from execution nodes - receiptRequester, err := requester.New(node.Log.With().Str("entity", "receipt").Logger(), node.Metrics, node.Net, node.Me, node.State, channels.RequestReceiptsByBlockID, filter.Any, func() flow.Entity { return new(flow.ExecutionReceipt) }) + receiptRequester, err := requester.New( + node.Log.With().Str("entity", "receipt").Logger(), + node.Metrics, + node.Net, + node.Me, + node.State, + requestQueue, + channels.RequestReceiptsByBlockID, + filter.Any, + func() flow.Entity { return new(flow.ExecutionReceipt) }, + ) require.NoError(t, err) assigner, err := chunks.NewChunkAssigner(flow.DefaultChunkAssignmentAlpha, node.State) @@ -666,8 +677,10 @@ func ExecutionNode(t *testing.T, hub *stub.Hub, identity bootstrap.NodeInfo, ide node.LockManager, ) + requestQueue := queue.NewHeroStore(10, unittest.Logger(), metrics.NewNoopCollector()) requestEngine, err := requester.New( node.Log.With().Str("entity", "collection").Logger(), node.Metrics, node.Net, node.Me, node.State, + requestQueue, channels.RequestCollections, filter.HasRole[flow.Identity](flow.RoleCollection), func() flow.Entity { return new(flow.Collection) },