diff --git a/protocol/blockfetch/blockfetch.go b/protocol/blockfetch/blockfetch.go index f1cd247e..0762a8a9 100644 --- a/protocol/blockfetch/blockfetch.go +++ b/protocol/blockfetch/blockfetch.go @@ -88,6 +88,7 @@ type BlockFetch struct { type Config struct { BlockFunc BlockFunc + BatchDoneFunc BatchDoneFunc RequestRangeFunc RequestRangeFunc BatchStartTimeout time.Duration BlockTimeout time.Duration @@ -102,6 +103,7 @@ type CallbackContext struct { // Callback function types type BlockFunc func(CallbackContext, uint, ledger.Block) error +type BatchDoneFunc func(CallbackContext) error type RequestRangeFunc func(CallbackContext, common.Point, common.Point) error func New(protoOptions protocol.ProtocolOptions, cfg *Config) *BlockFetch { @@ -132,6 +134,12 @@ func WithBlockFunc(blockFunc BlockFunc) BlockFetchOptionFunc { } } +func WithBatchDoneFunc(batchDoneFunc BatchDoneFunc) BlockFetchOptionFunc { + return func(c *Config) { + c.BatchDoneFunc = batchDoneFunc + } +} + func WithRequestRangeFunc( requestRangeFunc RequestRangeFunc, ) BlockFetchOptionFunc { diff --git a/protocol/blockfetch/client.go b/protocol/blockfetch/client.go index f581692a..27789e7b 100644 --- a/protocol/blockfetch/client.go +++ b/protocol/blockfetch/client.go @@ -283,6 +283,12 @@ func (c *Client) handleBatchDone() error { "role", "client", "connection_id", c.callbackContext.ConnectionId.String(), ) + // Notify the user if requested + if c.blockUseCallback && c.config.BatchDoneFunc != nil { + if err := c.config.BatchDoneFunc(c.callbackContext); err != nil { + return err + } + } c.busyMutex.Unlock() return nil }