diff --git a/protocol/txsubmission/server.go b/protocol/txsubmission/server.go index 72abd45c..c4bd558e 100644 --- a/protocol/txsubmission/server.go +++ b/protocol/txsubmission/server.go @@ -207,6 +207,12 @@ func (s *Server) handleDone() error { "role", "server", "connection_id", s.callbackContext.ConnectionId.String(), ) + // Call the user callback function + if s.config != nil && s.config.DoneFunc != nil { + if err := s.config.DoneFunc(s.callbackContext); err != nil { + return err + } + } // Restart protocol s.Stop() s.initProtocol() diff --git a/protocol/txsubmission/txsubmission.go b/protocol/txsubmission/txsubmission.go index 9eb0f3b4..2c5b9116 100644 --- a/protocol/txsubmission/txsubmission.go +++ b/protocol/txsubmission/txsubmission.go @@ -122,6 +122,7 @@ type Config struct { RequestTxIdsFunc RequestTxIdsFunc RequestTxsFunc RequestTxsFunc InitFunc InitFunc + DoneFunc DoneFunc IdleTimeout time.Duration } @@ -137,6 +138,7 @@ type ( RequestTxIdsFunc func(CallbackContext, bool, uint16, uint16) ([]TxIdAndSize, error) RequestTxsFunc func(CallbackContext, []TxId) ([]TxBody, error) InitFunc func(CallbackContext) error + DoneFunc func(CallbackContext) error ) // New returns a new TxSubmission object @@ -186,6 +188,13 @@ func WithInitFunc(initFunc InitFunc) TxSubmissionOptionFunc { } } +// WithDoneFunc specifies the Done callback function +func WithDoneFunc(doneFunc DoneFunc) TxSubmissionOptionFunc { + return func(c *Config) { + c.DoneFunc = doneFunc + } +} + // WithIdleTimeout specifies the timeout for waiting for new transactions from the remote node's mempool func WithIdleTimeout(timeout time.Duration) TxSubmissionOptionFunc { return func(c *Config) {