Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
8 changes: 8 additions & 0 deletions protocol/blockfetch/blockfetch.go
Original file line number Diff line number Diff line change
Expand Up @@ -88,6 +88,7 @@ type BlockFetch struct {

type Config struct {
BlockFunc BlockFunc
BlockRawFunc BlockRawFunc
BatchDoneFunc BatchDoneFunc
RequestRangeFunc RequestRangeFunc
BatchStartTimeout time.Duration
Expand All @@ -103,6 +104,7 @@ type CallbackContext struct {

// Callback function types
type BlockFunc func(CallbackContext, uint, ledger.Block) error
type BlockRawFunc func(CallbackContext, uint, []byte) error
type BatchDoneFunc func(CallbackContext) error
type RequestRangeFunc func(CallbackContext, common.Point, common.Point) error

Expand Down Expand Up @@ -134,6 +136,12 @@ func WithBlockFunc(blockFunc BlockFunc) BlockFetchOptionFunc {
}
}

func WithBlockRawFunc(blockRawFunc BlockRawFunc) BlockFetchOptionFunc {
return func(c *Config) {
c.BlockRawFunc = blockRawFunc
}
}

func WithBatchDoneFunc(batchDoneFunc BatchDoneFunc) BlockFetchOptionFunc {
return func(c *Config) {
c.BatchDoneFunc = batchDoneFunc
Expand Down
32 changes: 23 additions & 9 deletions protocol/blockfetch/client.go
Original file line number Diff line number Diff line change
Expand Up @@ -254,12 +254,16 @@ func (c *Client) handleBlock(msgGeneric protocol.Message) error {
if _, err := cbor.Decode(msg.WrappedBlock, &wrappedBlock); err != nil {
return fmt.Errorf("%s: decode error: %s", ProtocolName, err)
}
blk, err := ledger.NewBlockFromCbor(
wrappedBlock.Type,
wrappedBlock.RawBlock,
)
if err != nil {
return err
var block ledger.Block
if !c.blockUseCallback || c.config.BlockFunc != nil {
var err error
block, err = ledger.NewBlockFromCbor(
wrappedBlock.Type,
wrappedBlock.RawBlock,
)
if err != nil {
return err
}
}
// Check for shutdown
select {
Expand All @@ -269,11 +273,21 @@ func (c *Client) handleBlock(msgGeneric protocol.Message) error {
}
// We use the callback when requesting ranges and the internal channel for a single block
if c.blockUseCallback {
if err := c.config.BlockFunc(c.callbackContext, wrappedBlock.Type, blk); err != nil {
return err
if c.config.BlockRawFunc != nil {
if err := c.config.BlockRawFunc(c.callbackContext, wrappedBlock.Type, wrappedBlock.RawBlock); err != nil {
return err
}
} else if c.config.BlockFunc != nil {
if err := c.config.BlockFunc(c.callbackContext, wrappedBlock.Type, block); err != nil {
return err
}
} else {
return fmt.Errorf(
"received block-fetch Block message but no callback function is defined",
)
}
} else {
c.blockChan <- blk
c.blockChan <- block
}
return nil
}
Expand Down
Loading