diff --git a/abci/client/unsync_local_client.go b/abci/client/unsync_local_client.go index cd8bab11050..36e98d457b6 100644 --- a/abci/client/unsync_local_client.go +++ b/abci/client/unsync_local_client.go @@ -44,23 +44,26 @@ func (app *unsyncLocalClient) SetResponseCallback(cb Callback) { } func (app *unsyncLocalClient) CheckTxAsync(ctx context.Context, req *types.CheckTxRequest) (*ReqRes, error) { - res, err := app.Application.CheckTx(ctx, req) - if err != nil { - return nil, err - } - return app.callback( - types.ToCheckTxRequest(req), - types.ToCheckTxResponse(res), - ), nil -} + reqres := NewReqRes(types.ToCheckTxRequest(req)) -func (app *unsyncLocalClient) callback(req *types.Request, res *types.Response) *ReqRes { - if app.Callback != nil { - app.Callback(req, res) - } - rr := newLocalReqRes(req, res) - rr.callbackInvoked = true - return rr + go func() { + res, err := app.Application.CheckTx(ctx, req) + if err != nil { + reqres.Response = types.ToExceptionResponse("") // optimistic recheck failed + } else { + reqres.Response = types.ToCheckTxResponse(res) + } + + if app.Callback != nil { + app.Callback(reqres.Request, reqres.Response) + } + + reqres.Done() + + reqres.InvokeCallback() + }() + + return reqres, nil } // ------------------------------------------------------- diff --git a/config/config.go b/config/config.go index 42d26021e82..cccdfce02f0 100644 --- a/config/config.go +++ b/config/config.go @@ -49,6 +49,7 @@ const ( v2 = "v2" MempoolTypeFlood = "flood" + MempoolTypeProxy = "proxy" MempoolTypeNop = "nop" ) diff --git a/mempool/mempool.go b/mempool/mempool.go index a2099423f90..c2e0d977485 100644 --- a/mempool/mempool.go +++ b/mempool/mempool.go @@ -181,3 +181,21 @@ type Iterator interface { // WaitNextCh returns a channel on which to wait for the next available entry. WaitNextCh() <-chan Entry } + +// MempoolTx defines the interface for a transaction in the mempool +// It provides methods to access transaction properties and sender information +type MempoolTx interface { + Height() int64 + GasWanted() int64 + Tx() types.Tx + IsSender(peerID p2p.ID) bool + AddSender(peerID p2p.ID) bool + Senders() []p2p.ID +} + +// TxBroadcastStream defines the interface for streaming transactions to broadcast. +// It provides a channel that will receive transactions to be broadcasted to peers. +type TxBroadcastStream interface { + // GetTxChannel returns a channel that will receive transactions to broadcast. + GetTxChannel() <-chan MempoolTx +} diff --git a/mempool/mempoolTx.go b/mempool/mempoolTx.go index 7517af1cbe4..2a190a23f65 100644 --- a/mempool/mempoolTx.go +++ b/mempool/mempoolTx.go @@ -10,6 +10,9 @@ import ( ) // mempoolTx is an entry in the mempool. +var _ MempoolTx = (*mempoolTx)(nil) + +// mempoolTx is an entry in the mempool type mempoolTx struct { height int64 // height that this tx had been validated in gasWanted int64 // amount of gas this tx states it will require @@ -19,10 +22,17 @@ type mempoolTx struct { timestamp time.Time // time when entry was created // ids of peers who've sent us this tx (as a map for quick lookups). - // senders: PeerID -> struct{} + // senders: PeerID -> bool senders sync.Map } +// NewMempoolTx creates a new mempoolTx using the builder pattern +func NewMempoolTx(tx types.Tx) MempoolTx { + return NewMempoolTxBuilder(). + WithTx(tx). + Build() +} + func (memTx *mempoolTx) Tx() types.Tx { return memTx.tx } @@ -32,7 +42,7 @@ func (memTx *mempoolTx) Height() int64 { } func (memTx *mempoolTx) GasWanted() int64 { - return memTx.gasWanted + return atomic.LoadInt64(&memTx.gasWanted) } func (memTx *mempoolTx) IsSender(peerID p2p.ID) bool { @@ -40,6 +50,10 @@ func (memTx *mempoolTx) IsSender(peerID p2p.ID) bool { return ok } +func (memTx *mempoolTx) AddSender(peerID p2p.ID) bool { + return memTx.addSender(peerID) +} + // Add the peer ID to the list of senders. Return true iff it exists already in the list. func (memTx *mempoolTx) addSender(peerID p2p.ID) bool { if len(peerID) == 0 { @@ -59,3 +73,54 @@ func (memTx *mempoolTx) Senders() []p2p.ID { }) return senders } + +// MempoolTxBuilder is a builder for creating mempoolTx instances +type MempoolTxBuilder struct { + height int64 + gasWanted int64 + tx types.Tx + senders []p2p.ID +} + +// NewMempoolTxBuilder creates a new builder for mempoolTx +func NewMempoolTxBuilder() *MempoolTxBuilder { + return &MempoolTxBuilder{ + senders: make([]p2p.ID, 0), + } +} + +// WithHeight sets the height for the mempoolTx +func (b *MempoolTxBuilder) WithHeight(height int64) *MempoolTxBuilder { + b.height = height + return b +} + +// WithGasWanted sets the gas wanted for the mempoolTx +func (b *MempoolTxBuilder) WithGasWanted(gasWanted int64) *MempoolTxBuilder { + b.gasWanted = gasWanted + return b +} + +// WithTx sets the transaction for the mempoolTx +func (b *MempoolTxBuilder) WithTx(tx types.Tx) *MempoolTxBuilder { + b.tx = tx + return b +} + +func (b *MempoolTxBuilder) WithSender(sender p2p.ID) *MempoolTxBuilder { + b.senders = append(b.senders, sender) + return b +} + +// Build creates the final mempoolTx instance +func (b *MempoolTxBuilder) Build() MempoolTx { + memTx := &mempoolTx{ + height: b.height, + gasWanted: b.gasWanted, + tx: b.tx, + } + for _, sender := range b.senders { + memTx.senders.Store(sender, struct{}{}) + } + return memTx +} diff --git a/mempool/mempool_interface_reactor.go b/mempool/mempool_interface_reactor.go new file mode 100644 index 00000000000..5e336df9a37 --- /dev/null +++ b/mempool/mempool_interface_reactor.go @@ -0,0 +1,311 @@ +package mempool + +import ( + "context" + "errors" + "sync" + "sync/atomic" + "time" + + "fmt" + + protomem "github.com/cometbft/cometbft/api/cometbft/mempool/v2" + cfg "github.com/cometbft/cometbft/config" + "github.com/cometbft/cometbft/libs/log" + "github.com/cometbft/cometbft/p2p" + "github.com/cometbft/cometbft/types" + "golang.org/x/sync/semaphore" +) + +// Reactor handles mempool tx broadcasting amongst peers. +// It maintains a map from peer ID to counter, to prevent gossiping txs to the +// peers you received it from. +type MempoolInterfaceReactor struct { + p2p.BaseReactor + config *cfg.MempoolConfig + mempool Mempool + + waitSync atomic.Bool + waitSyncCh chan struct{} // for signaling when to start receiving and sending txs + + // Semaphores to keep track of how many connections to peers are active for broadcasting + // transactions. Each semaphore has a capacity that puts an upper bound on the number of + // connections for different groups of peers. + activePersistentPeersSemaphore *semaphore.Weighted + activeNonPersistentPeersSemaphore *semaphore.Weighted + + // Map of peer ID to their broadcast channel + peerBroadcastChannels sync.Map + + // Stream for getting transactions to broadcast + txStream TxBroadcastStream +} + +// NewMempoolInterfaceReactor returns a new MempoolInterfaceReactor with the given config and mempool. +func NewMempoolInterfaceReactor(config *cfg.MempoolConfig, mempool Mempool, txStream TxBroadcastStream, waitSync bool) *MempoolInterfaceReactor { + memR := &MempoolInterfaceReactor{ + config: config, + mempool: mempool, + peerBroadcastChannels: sync.Map{}, + txStream: txStream, + waitSync: atomic.Bool{}, + } + memR.BaseReactor = *p2p.NewBaseReactor("Mempool", memR) + if waitSync { + memR.waitSync.Store(true) + memR.waitSyncCh = make(chan struct{}) + } + memR.activePersistentPeersSemaphore = semaphore.NewWeighted(int64(memR.config.ExperimentalMaxGossipConnectionsToPersistentPeers)) + memR.activeNonPersistentPeersSemaphore = semaphore.NewWeighted(int64(memR.config.ExperimentalMaxGossipConnectionsToNonPersistentPeers)) + + return memR +} + +// InitPeer implements Reactor by creating a state for the peer. +func (memR *MempoolInterfaceReactor) InitPeer(peer p2p.Peer) p2p.Peer { + return peer +} + +// SetLogger sets the Logger on the reactor and the underlying mempool. +func (memR *MempoolInterfaceReactor) SetLogger(l log.Logger) { + memR.Logger = l +} + +// OnStart implements p2p.BaseReactor. +func (memR *MempoolInterfaceReactor) OnStart() error { + if !memR.config.Broadcast { + memR.Logger.Info("Tx broadcasting is disabled") + } else { + go memR.broadcastTxRoutine() + } + return nil +} + +// GetChannels implements Reactor by returning the list of channels for this +// reactor. +func (memR *MempoolInterfaceReactor) GetChannels() []*p2p.ChannelDescriptor { + largestTx := make([]byte, memR.config.MaxTxBytes) + batchMsg := protomem.Message{ + Sum: &protomem.Message_Txs{ + Txs: &protomem.Txs{Txs: [][]byte{largestTx}}, + }, + } + + return []*p2p.ChannelDescriptor{ + { + ID: MempoolChannel, + Priority: 5, + RecvMessageCapacity: batchMsg.Size(), + MessageType: &protomem.Message{}, + }, + } +} + +// AddPeer implements Reactor. +// It starts a broadcast routine ensuring all txs are forwarded to the given peer. +func (memR *MempoolInterfaceReactor) AddPeer(peer p2p.Peer) { + if memR.config.Broadcast { + go func() { + // Always forward transactions to unconditional peers. + if !memR.Switch.IsPeerUnconditional(peer.ID()) { + // Depending on the type of peer, we choose a semaphore to limit the gossiping peers. + var peerSemaphore *semaphore.Weighted + if peer.IsPersistent() && memR.config.ExperimentalMaxGossipConnectionsToPersistentPeers > 0 { + peerSemaphore = memR.activePersistentPeersSemaphore + } else if !peer.IsPersistent() && memR.config.ExperimentalMaxGossipConnectionsToNonPersistentPeers > 0 { + peerSemaphore = memR.activeNonPersistentPeersSemaphore + } + + if peerSemaphore != nil { + for peer.IsRunning() { + // Block on the semaphore until a slot is available to start gossiping with this peer. + // Do not block indefinitely, in case the peer is disconnected before gossiping starts. + ctxTimeout, cancel := context.WithTimeout(context.TODO(), 30*time.Second) + // Block sending transactions to peer until one of the connections become + // available in the semaphore. + err := peerSemaphore.Acquire(ctxTimeout, 1) + cancel() + + if err != nil { + continue + } + + // Release semaphore to allow other peer to start sending transactions. + defer peerSemaphore.Release(1) + break + } + } + } + + // Check if peer is still running after semaphore acquisition + if !peer.IsRunning() { + return + } + + peerChan := make(chan MempoolTx, memR.config.Size) + + // Store the channel atomically + if _, loaded := memR.peerBroadcastChannels.LoadOrStore(peer.ID(), peerChan); loaded { + // If channel already exists, close the new one and return + close(peerChan) + return + } + + // Start the broadcast routine + memR.broadcastTxPeerRoutine(peer, peerChan) + }() + } +} + +// RemovePeer implements Reactor. +func (memR *MempoolInterfaceReactor) RemovePeer(peer p2p.Peer, _ interface{}) { + if ch, exists := memR.peerBroadcastChannels.LoadAndDelete(peer.ID()); exists { + close(ch.(chan MempoolTx)) + } +} + +// Receive implements Reactor. +// It adds any received transactions to the mempool. +func (memR *MempoolInterfaceReactor) Receive(e p2p.Envelope) { + memR.Logger.Debug("Receive", "src", e.Src, "chId", e.ChannelID, "msg", e.Message) + switch msg := e.Message.(type) { + case *protomem.Txs: + protoTxs := msg.GetTxs() + if len(protoTxs) == 0 { + memR.Logger.Error("received empty txs from peer", "src", e.Src) + return + } + + var err error + for _, tx := range protoTxs { + ntx := types.Tx(tx) + _, err = memR.mempool.CheckTx(ntx, e.Src.ID()) + if err != nil { + switch { + case errors.Is(err, ErrTxInCache): + memR.Logger.Debug("Tx already exists in cache", "tx", ntx.String()) + case errors.As(err, &ErrMempoolIsFull{}): + // using debug level to avoid flooding when traffic is high + memR.Logger.Debug(err.Error()) + default: + memR.Logger.Info("Could not check tx", "tx", ntx.String(), "err", err) + } + } + } + default: + memR.Logger.Error("unknown message type", "src", e.Src, "chId", e.ChannelID, "msg", e.Message) + memR.Switch.StopPeerForError(e.Src, fmt.Errorf("mempool cannot handle message of type: %T", e.Message)) + return + } + + // broadcasting happens from go routines per peer +} + +func (memR *MempoolInterfaceReactor) EnableInOutTxs() { + memR.Logger.Info("Enabling inbound and outbound transactions") + if !memR.waitSync.CompareAndSwap(true, false) { + return + } + + // Releases all the blocked broadcastTxRoutine instances. + if memR.config.Broadcast { + close(memR.waitSyncCh) + } +} + +func (memR *MempoolInterfaceReactor) WaitSync() bool { + return memR.waitSync.Load() +} + +// Send new mempool txs to peer. +func (memR *MempoolInterfaceReactor) broadcastTxPeerRoutine(peer p2p.Peer, peerChan chan MempoolTx) { + for { + // In case of both next.NextWaitChan() and peer.Quit() are variable at the same time + if !memR.IsRunning() || !peer.IsRunning() { + return + } + + // Make sure the peer is up to date. + peerState, ok := peer.Get(types.PeerStateKey).(PeerState) + if !ok { + // Peer does not have a state yet. We set it in the consensus reactor, but + // when we add peer in Switch, the order we call reactors#AddPeer is + // different every time due to us using a map. Sometimes other reactors + // will be initialized before the consensus reactor. We should wait a few + // milliseconds and retry. + time.Sleep(PeerCatchupSleepIntervalMS * time.Millisecond) + continue + } + + select { + case memTx, ok := <-peerChan: + if !ok { + return + } + + if peerState.GetHeight() < memTx.Height()-1 { + time.Sleep(PeerCatchupSleepIntervalMS * time.Millisecond) + continue + } + + if !memTx.IsSender(peer.ID()) { + success := peer.Send(p2p.Envelope{ + ChannelID: MempoolChannel, + Message: &protomem.Txs{Txs: [][]byte{memTx.Tx()}}, + }) + if !success { + time.Sleep(PeerCatchupSleepIntervalMS * time.Millisecond) + continue + } + } + case <-peer.Quit(): + return + } + } +} + +// broadcastTxRoutine broadcasts transactions from the mempool to all peers. +func (memR *MempoolInterfaceReactor) broadcastTxRoutine() { + // If the node is catching up, don't start this routine immediately. + if memR.WaitSync() { + select { + case <-memR.waitSyncCh: + // EnableInOutTxs() has set WaitSync() to false. + case <-memR.Quit(): + return + } + } + + // Check if txStream is set + if memR.txStream == nil { + memR.Logger.Error("txStream is not set, broadcasting is disabled") + return + } + + txChan := memR.txStream.GetTxChannel() + + for { + if !memR.IsRunning() { + return + } + + select { + case tx := <-txChan: + if tx != nil { + memR.peerBroadcastChannels.Range(func(key, value interface{}) bool { + peerID := key.(p2p.ID) + ch := value.(chan MempoolTx) + + select { + case ch <- tx: + default: + memR.Logger.Debug("peer broadcast channel is full", "peerID", peerID) + } + return true + }) + } + case <-memR.Quit(): + return + } + } +} diff --git a/mempool/mempool_interface_reactor_test.go b/mempool/mempool_interface_reactor_test.go new file mode 100644 index 00000000000..ec0ec052bcd --- /dev/null +++ b/mempool/mempool_interface_reactor_test.go @@ -0,0 +1,298 @@ +package mempool + +import ( + "fmt" + "os" + "sync" + "testing" + "time" + + "github.com/stretchr/testify/assert" + "github.com/stretchr/testify/require" + + "github.com/cometbft/cometbft/abci/example/kvstore" + abci "github.com/cometbft/cometbft/abci/types" + cfg "github.com/cometbft/cometbft/config" + "github.com/cometbft/cometbft/internal/test" + + "github.com/cometbft/cometbft/libs/log" + "github.com/cometbft/cometbft/p2p" + "github.com/cometbft/cometbft/proxy" + "github.com/cometbft/cometbft/types" +) + +const ( + testNumTxsMempoolInterface = 100 + testTimeoutMempoolInterface = 120 * time.Second +) + +// mockTxBroadcastStream is a mock implementation of TxBroadcastStream for testing. +type mockTxBroadcastStream struct { + txChan chan MempoolTx + logger log.Logger +} + +// newMockTxBroadcastStream creates a new mockTxBroadcastStream. +func newMockTxBroadcastStream(logger log.Logger) *mockTxBroadcastStream { + return &mockTxBroadcastStream{ + txChan: make(chan MempoolTx, testNumTxsMempoolInterface), + logger: logger, + } +} + +// GetTxChannel returns the transaction channel. +func (m *mockTxBroadcastStream) GetTxChannel() <-chan MempoolTx { + m.logger.Debug("mockTxBroadcastStream.GetTxChannel called") + return m.txChan +} + +// sendTx sends a transaction to the channel. This is a helper for tests. +func (m *mockTxBroadcastStream) sendTx(tx MempoolTx) { + m.logger.Debug("mockTxBroadcastStream.sendTx called", "txHash", tx.Tx().Hash(), "height", tx.Height()) + select { + case m.txChan <- tx: + m.logger.Debug("mockTxBroadcastStream.sendTx finished sending tx", "txHash", tx.Tx().Hash()) + case <-time.After(5 * time.Second): + m.logger.Error("mockTxBroadcastStream.sendTx timeout", "txHash", tx.Tx().Hash()) + } +} + +// reactorTestPeerState is a mock peer state for testing. +type reactorTestPeerState struct { + height int64 +} + +func (ps reactorTestPeerState) GetHeight() int64 { + return ps.height +} + +// newMempoolInterfaceWithAppAndConfig is a helper to create a CListMempool for MempoolInterface tests. +func newMempoolInterfaceWithAppAndConfig(cc proxy.ClientCreator) (*CListMempool, func()) { + conf := test.ResetTestRoot("mempool_interface_test") + + appConnMem, _ := cc.NewABCIMempoolClient() + appConnMem.SetResponseCallback(func(r1 *abci.Request, r2 *abci.Response) {}) + appConnMem.SetLogger(log.TestingLogger().With("module", "abci-client")) + if err := appConnMem.Start(); err != nil { + panic(err) + } + + mp := NewCListMempool(conf.Mempool, appConnMem, nil, 0) + mp.SetLogger(log.TestingLogger().With("module", "mempool")) + + return mp, func() { os.RemoveAll(conf.RootDir) } +} + +// makeAndConnectMempoolInterfaceReactors creates and connects N MempoolInterfaceReactors. +func makeAndConnectMempoolInterfaceReactors( + t *testing.T, + conf *cfg.Config, + n int, +) ([]*MempoolInterfaceReactor, []*p2p.Switch, []*mockTxBroadcastStream, []Mempool) { + t.Helper() + reactors := make([]*MempoolInterfaceReactor, n) + txStreams := make([]*mockTxBroadcastStream, n) + mempools := make([]Mempool, n) + logger := log.NewTMLogger(log.NewSyncWriter(os.Stdout)).With("test", "makeAndConnectMempoolInterfaceReactors") + + for i := 0; i < n; i++ { + app := kvstore.NewInMemoryApplication() + + cc := proxy.NewLocalClientCreator(app) + mempool, cleanup := newMempoolInterfaceWithAppAndConfig(cc) + + t.Cleanup(cleanup) + mempool.SetLogger(logger.With("validator", i, "module", "mempool")) + mempools[i] = mempool + + txStreams[i] = newMockTxBroadcastStream(logger.With("validator", i, "module", "txstream")) + + reactors[i] = NewMempoolInterfaceReactor(conf.Mempool, mempools[i], txStreams[i], false) + reactors[i].SetLogger(logger.With("validator", i, "module", "mempool-reactor")) + } + + switches := p2p.MakeConnectedSwitches(conf.P2P, n, func(idx int, s *p2p.Switch) *p2p.Switch { + s.AddReactor("MEMPOOL", reactors[idx]) + s.SetLogger(logger.With("validator", idx, "module", "p2p")) + return s + }, p2p.Connect2Switches) + + var wg sync.WaitGroup + for i := 0; i < n; i++ { + wg.Add(1) + go func(idx int) { + defer wg.Done() + for switches[idx].Peers().Size() < (n - 1) { + time.Sleep(100 * time.Millisecond) + } + }(i) + } + + waitTimeout := time.After(15 * time.Second) + done := make(chan struct{}) + go func() { + wg.Wait() + close(done) + }() + + select { + case <-waitTimeout: + t.Fatalf("Timed out waiting for %d peers to connect (got %d for switch 0, %d for switch 1 if N=2)", n-1, switches[0].Peers().Size(), (func() int { + if n > 1 { + return switches[1].Peers().Size() + } else { + return 0 + } + }())) + case <-done: + } + + return reactors, switches, txStreams, mempools +} + +func addRandomTxsToMempoolAndStream( + t *testing.T, + mempool Mempool, + numTxs int, + senderID p2p.ID, +) (types.Txs, []MempoolTx) { + t.Helper() + txs := make(types.Txs, numTxs) + mempoolTxs := make([]MempoolTx, numTxs) + for i := 0; i < numTxs; i++ { + txKey := fmt.Sprintf("key_mempool_interface_%s_%d_%d", senderID, time.Now().UnixNano(), i) + txValue := fmt.Sprintf("value_%d", i) + tx := types.Tx(fmt.Sprintf("%s=%s", txKey, txValue)) + txs[i] = tx + + reqres, err := mempool.CheckTx(tx, senderID) + res := reqres.Response.GetCheckTx() + + if res.IsErr() { + t.Logf("CheckTx callback failed for tx %X: %s, code: %d, log: %s, info: %s", tx, res.Log, res.Code, res.Log, res.Info) + } + require.NoError(t, err, "mempool.CheckTx returned an error for tx %X. Error: %v", tx, err) + require.EqualValuesf(t, abci.CodeTypeOK, res.Code, "CheckTx callback response code is not OK for tx %X. Got %d", tx, res.Code) + + mempoolTx := NewMempoolTxBuilder(). + WithHeight(1). + WithTx(tx). + WithSender(senderID). + Build() + + mempoolTxs[i] = mempoolTx + } + return txs, mempoolTxs +} + +// checkTxsInOrderOnMempoolInterface checks if the mempool of a given reactor contains the expected transactions. +// For N=2, it also checks the order for the receiving reactor. +func checkTxsInOrderOnMempoolInterface(t *testing.T, expectedTxs types.Txs, reactor *MempoolInterfaceReactor, reactorIndex int, numReactors int) { + t.Helper() + currentMempool := reactor.mempool + // Wait for mempool to have enough transactions + for currentMempool.Size() < len(expectedTxs) { + if !reactor.IsRunning() { + t.Logf("Reactor %d is not running, stopping wait for txs", reactorIndex) + return + } + time.Sleep(100 * time.Millisecond) + } + + reapedTxs := currentMempool.ReapMaxTxs(len(expectedTxs)) + require.Lenf(t, reapedTxs, len(expectedTxs), "Mempool (reactor %d) did not reap the expected number of txs. Expected %d, got %d. Mempool size: %d", reactorIndex, len(expectedTxs), len(reapedTxs), currentMempool.Size()) + + expectedTxsMap := make(map[string]struct{}) + for _, tx := range expectedTxs { + expectedTxsMap[string(tx)] = struct{}{} + } + + reapedTxsMap := make(map[string]struct{}) + for _, tx := range reapedTxs { + reapedTxsMap[string(tx)] = struct{}{} + } + + for _, tx := range expectedTxs { + _, ok := reapedTxsMap[string(tx)] + assert.Truef(t, ok, "Expected tx %X not found in mempool of reactor %d", tx, reactorIndex) + } + + // For N=2, the order of transactions received by the second reactor should match the broadcast order. + // The first reactor (sender) will have them in the order they were added. + if numReactors == 2 && reactorIndex == 1 { // Check order for the receiver in a 2-node setup + for j, tx := range expectedTxs { + assert.Equalf(t, tx, reapedTxs[j], + "txs at index %d on reactor %d don't match: expected %X, got %X", j, reactorIndex, tx, reapedTxs[j]) + } + } else if reactorIndex == 0 { // For the sender, order should always match + for j, tx := range expectedTxs { + assert.Equalf(t, tx, reapedTxs[j], + "txs at index %d on reactor %d (sender) don't match: expected %X, got %X", j, reactorIndex, tx, reapedTxs[j]) + } + } +} + +// waitForTxsOnMempoolsInterface waits for all transactions to appear on all specified reactors' mempools. +// It's modeled after waitForTxsOnReactors from reactor_test.go. +func waitForTxsOnMempoolsInterface(t *testing.T, txs types.Txs, reactors []*MempoolInterfaceReactor) { + t.Helper() + wg := new(sync.WaitGroup) + for i, reactor := range reactors { + wg.Add(1) + go func(r *MempoolInterfaceReactor, reactorIndex int) { + defer wg.Done() + checkTxsInOrderOnMempoolInterface(t, txs, r, reactorIndex, len(reactors)) + }(reactor, i) + } + + done := make(chan struct{}) + go func() { + wg.Wait() + close(done) + }() + + timer := time.After(testTimeoutMempoolInterface) // Use the specific timeout for this test suite + select { + case <-timer: + for i, r := range reactors { + t.Logf("Timeout: Reactor %d mempool size: %d, expected: %d", i, r.mempool.Size(), len(txs)) + } + t.Fatal("Timed out waiting for txs on mempools") + case <-done: + } +} + +func TestMempoolInterfaceReactor_BroadcastTxsMessage(t *testing.T) { + config := cfg.TestConfig() + config.Mempool.Broadcast = true + const N = 2 + + reactors, switches, txStreams, _ := makeAndConnectMempoolInterfaceReactors(t, config, N) // mempools slice is not directly used here + + defer func() { + for _, sw := range switches { + if sw.IsRunning() { + if err := sw.Stop(); err != nil { + t.Logf("Error stopping switch: %v", err) + } + } + } + }() + + for _, r := range reactors { + r.Switch.Peers().ForEach(func(peer p2p.Peer) { + peer.Set(types.PeerStateKey, reactorTestPeerState{height: 1}) + }) + } + + // Use the mempool from the first reactor + addedTxs, mempoolTxsToSend := addRandomTxsToMempoolAndStream(t, reactors[0].mempool, testNumTxsMempoolInterface, p2p.ID("xyz")) + + for _, memTx := range mempoolTxsToSend { + // Send to the txStream associated with the first reactor + txStreams[0].sendTx(memTx) + } + + // Pass the slice of reactors to the wait function + waitForTxsOnMempoolsInterface(t, addedTxs, reactors) +} diff --git a/mempool/proxy_mempool.go b/mempool/proxy_mempool.go new file mode 100644 index 00000000000..e176564f0bd --- /dev/null +++ b/mempool/proxy_mempool.go @@ -0,0 +1,19 @@ +package mempool + +// ProxyMempool is a wrapper around a Mempool and a TxBroadcastStream. +// It allows setting the underlying Mempool and TxBroadcastStream dynamically. +type ProxyMempool struct { + Mempool + TxBroadcastStream +} + +var _ Mempool = (*ProxyMempool)(nil) +var _ TxBroadcastStream = (*ProxyMempool)(nil) + +func (m *ProxyMempool) SetMempool(mp Mempool) { + m.Mempool = mp +} + +func (m *ProxyMempool) SetTxBroadcastStream(stream TxBroadcastStream) { + m.TxBroadcastStream = stream +} diff --git a/node/node.go b/node/node.go index 162a42019ac..036eb19ecb4 100644 --- a/node/node.go +++ b/node/node.go @@ -145,6 +145,18 @@ func StateProvider(stateProvider statesync.StateProvider) Option { } } +func WithCustomMempoolOnProxyMempool(customMempool mempl.Mempool, txBroadcastStream mempl.TxBroadcastStream) Option { + return func(n *Node) { + proxyMempool, ok := n.mempool.(*mempl.ProxyMempool) + if !ok { + panic("mempool is not a proxy mempool") + } + + proxyMempool.SetMempool(customMempool) + proxyMempool.SetTxBroadcastStream(txBroadcastStream) + } +} + // BootstrapState synchronizes the stores with the application after state sync // has been performed offline. It is expected that the block store and state // store are empty at the time the function is called. diff --git a/node/setup.go b/node/setup.go index 45ba5abe155..79e5ac2ef75 100644 --- a/node/setup.go +++ b/node/setup.go @@ -322,6 +322,21 @@ func createMempoolAndMempoolReactor( // Strictly speaking, there's no need to have a `mempl.NopMempoolReactor`, but // adding it leads to a cleaner code. return &mempl.NopMempool{}, mempl.NewNopMempoolReactor() + case cfg.MempoolTypeProxy: + mp := &mempl.ProxyMempool{} + reactor := mempl.NewMempoolInterfaceReactor( + config.Mempool, + mp, + mp, + waitSync, + ) + if config.Consensus.WaitForTxs() { + mp.EnableTxsAvailable() + } + reactor.SetLogger(logger) + + return mp, reactor + default: panic(fmt.Sprintf("unknown mempool type: %q", config.Mempool.Type)) }