Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
25 changes: 17 additions & 8 deletions protocol/chainsync/chainsync.go
Original file line number Diff line number Diff line change
Expand Up @@ -198,13 +198,14 @@ type ChainSync struct {

// Config is used to configure the ChainSync protocol instance
type Config struct {
RollBackwardFunc RollBackwardFunc
RollForwardFunc RollForwardFunc
FindIntersectFunc FindIntersectFunc
RequestNextFunc RequestNextFunc
IntersectTimeout time.Duration
BlockTimeout time.Duration
PipelineLimit int
RollBackwardFunc RollBackwardFunc
RollForwardFunc RollForwardFunc
RollForwardRawFunc RollForwardRawFunc
FindIntersectFunc FindIntersectFunc
RequestNextFunc RequestNextFunc
IntersectTimeout time.Duration
BlockTimeout time.Duration
PipelineLimit int
}

// Callback context
Expand All @@ -217,6 +218,7 @@ type CallbackContext struct {
// Callback function types
type RollBackwardFunc func(CallbackContext, common.Point, Tip) error
type RollForwardFunc func(CallbackContext, uint, interface{}, Tip) error
type RollForwardRawFunc func(CallbackContext, uint, []byte, Tip) error

type FindIntersectFunc func(CallbackContext, []common.Point) (common.Point, Tip, error)
type RequestNextFunc func(CallbackContext) error
Expand Down Expand Up @@ -262,13 +264,20 @@ func WithRollBackwardFunc(
}
}

// WithRollForwardFunc specifies the RollForward callback function
// WithRollForwardFunc specifies the RollForward callback function. This will provided a parsed header or block
func WithRollForwardFunc(rollForwardFunc RollForwardFunc) ChainSyncOptionFunc {
return func(c *Config) {
c.RollForwardFunc = rollForwardFunc
}
}

// WithRollForwardRawFunc specifies the RollForwardRaw callback function. This will provide the raw header or block
func WithRollForwardRawFunc(rollForwardRawFunc RollForwardRawFunc) ChainSyncOptionFunc {
return func(c *Config) {
c.RollForwardRawFunc = rollForwardRawFunc
}
}

// WithFindIntersectFunc specifies the FindIntersect callback function
func WithFindIntersectFunc(
findIntersectFunc FindIntersectFunc,
Expand Down
75 changes: 49 additions & 26 deletions protocol/chainsync/client.go
Original file line number Diff line number Diff line change
Expand Up @@ -596,7 +596,7 @@ func (c *Client) handleRollForward(msgGeneric protocol.Message) error {
}
}()
if firstBlockChan == nil &&
(c.config == nil || c.config.RollForwardFunc == nil) {
(c.config == nil || (c.config.RollForwardFunc == nil && c.config.RollForwardRawFunc == nil)) {
return fmt.Errorf(
"received chain-sync RollForward message but no callback function is defined",
)
Expand All @@ -607,26 +607,21 @@ func (c *Client) handleRollForward(msgGeneric protocol.Message) error {
c.sendCurrentTip(msg.Tip)

var blockHeader ledger.BlockHeader
var blockHeaderBytes []byte
var blockType uint
blockEra := msg.WrappedHeader.Era

switch blockEra {
case ledger.BlockHeaderTypeByron:
blockType = msg.WrappedHeader.ByronType()
var err error
blockHeader, err = ledger.NewBlockHeaderFromCbor(
blockType,
msg.WrappedHeader.HeaderCbor(),
)
if err != nil {
if firstBlockChan != nil {
firstBlockChan <- clientPointResult{error: err}
}
return err
}
blockHeaderBytes = msg.WrappedHeader.HeaderCbor()
default:
// Map block header type to block type
blockType = ledger.BlockHeaderToBlockTypeMap[blockEra]
blockHeaderBytes = msg.WrappedHeader.HeaderCbor()
}
if firstBlockChan != nil || c.config.RollForwardFunc != nil {
// Decode header
var err error
blockHeader, err = ledger.NewBlockHeaderFromCbor(
blockType,
Expand All @@ -650,35 +645,63 @@ func (c *Client) handleRollForward(msgGeneric protocol.Message) error {
return nil
}
// Call the user callback function
callbackErr = c.config.RollForwardFunc(
c.callbackContext,
blockType,
blockHeader,
msg.Tip,
)
if c.config.RollForwardRawFunc != nil {
callbackErr = c.config.RollForwardRawFunc(
c.callbackContext,
blockType,
blockHeaderBytes,
msg.Tip,
)
} else {
callbackErr = c.config.RollForwardFunc(
c.callbackContext,
blockType,
blockHeader,
msg.Tip,
)
}
} else {
msg := msgGeneric.(*MsgRollForwardNtC)
c.sendCurrentTip(msg.Tip)

blk, err := ledger.NewBlockFromCbor(msg.BlockType(), msg.BlockCbor())
if err != nil {
if firstBlockChan != nil {
firstBlockChan <- clientPointResult{error: err}
var block ledger.Block

if firstBlockChan != nil || c.config.RollForwardFunc != nil {
var err error
block, err = ledger.NewBlockFromCbor(msg.BlockType(), msg.BlockCbor())
if err != nil {
if firstBlockChan != nil {
firstBlockChan <- clientPointResult{error: err}
}
return err
}
return err
}
if firstBlockChan != nil {
blockHash, err := hex.DecodeString(blk.Hash())
blockHash, err := hex.DecodeString(block.Hash())
if err != nil {
firstBlockChan <- clientPointResult{error: err}
return err
}
point := common.NewPoint(blk.SlotNumber(), blockHash)
point := common.NewPoint(block.SlotNumber(), blockHash)
firstBlockChan <- clientPointResult{tip: msg.Tip, point: point}
return nil
}
// Call the user callback function
callbackErr = c.config.RollForwardFunc(c.callbackContext, msg.BlockType(), blk, msg.Tip)
if c.config.RollForwardRawFunc != nil {
callbackErr = c.config.RollForwardRawFunc(
c.callbackContext,
msg.BlockType(),
msg.BlockCbor(),
msg.Tip,
)
} else {
callbackErr = c.config.RollForwardFunc(
c.callbackContext,
msg.BlockType(),
block,
msg.Tip,
)
}
}
if callbackErr != nil {
if callbackErr == StopSyncProcessError {
Expand Down
Loading