diff --git a/protocol/chainsync/chainsync.go b/protocol/chainsync/chainsync.go index cb5538ac..7f375ae5 100644 --- a/protocol/chainsync/chainsync.go +++ b/protocol/chainsync/chainsync.go @@ -198,13 +198,14 @@ type ChainSync struct { // Config is used to configure the ChainSync protocol instance type Config struct { - RollBackwardFunc RollBackwardFunc - RollForwardFunc RollForwardFunc - FindIntersectFunc FindIntersectFunc - RequestNextFunc RequestNextFunc - IntersectTimeout time.Duration - BlockTimeout time.Duration - PipelineLimit int + RollBackwardFunc RollBackwardFunc + RollForwardFunc RollForwardFunc + RollForwardRawFunc RollForwardRawFunc + FindIntersectFunc FindIntersectFunc + RequestNextFunc RequestNextFunc + IntersectTimeout time.Duration + BlockTimeout time.Duration + PipelineLimit int } // Callback context @@ -217,6 +218,7 @@ type CallbackContext struct { // Callback function types type RollBackwardFunc func(CallbackContext, common.Point, Tip) error type RollForwardFunc func(CallbackContext, uint, interface{}, Tip) error +type RollForwardRawFunc func(CallbackContext, uint, []byte, Tip) error type FindIntersectFunc func(CallbackContext, []common.Point) (common.Point, Tip, error) type RequestNextFunc func(CallbackContext) error @@ -262,13 +264,20 @@ func WithRollBackwardFunc( } } -// WithRollForwardFunc specifies the RollForward callback function +// WithRollForwardFunc specifies the RollForward callback function. This will provided a parsed header or block func WithRollForwardFunc(rollForwardFunc RollForwardFunc) ChainSyncOptionFunc { return func(c *Config) { c.RollForwardFunc = rollForwardFunc } } +// WithRollForwardRawFunc specifies the RollForwardRaw callback function. This will provide the raw header or block +func WithRollForwardRawFunc(rollForwardRawFunc RollForwardRawFunc) ChainSyncOptionFunc { + return func(c *Config) { + c.RollForwardRawFunc = rollForwardRawFunc + } +} + // WithFindIntersectFunc specifies the FindIntersect callback function func WithFindIntersectFunc( findIntersectFunc FindIntersectFunc, diff --git a/protocol/chainsync/client.go b/protocol/chainsync/client.go index e48beec2..3f2a344e 100644 --- a/protocol/chainsync/client.go +++ b/protocol/chainsync/client.go @@ -596,7 +596,7 @@ func (c *Client) handleRollForward(msgGeneric protocol.Message) error { } }() if firstBlockChan == nil && - (c.config == nil || c.config.RollForwardFunc == nil) { + (c.config == nil || (c.config.RollForwardFunc == nil && c.config.RollForwardRawFunc == nil)) { return fmt.Errorf( "received chain-sync RollForward message but no callback function is defined", ) @@ -607,26 +607,21 @@ func (c *Client) handleRollForward(msgGeneric protocol.Message) error { c.sendCurrentTip(msg.Tip) var blockHeader ledger.BlockHeader + var blockHeaderBytes []byte var blockType uint blockEra := msg.WrappedHeader.Era switch blockEra { case ledger.BlockHeaderTypeByron: blockType = msg.WrappedHeader.ByronType() - var err error - blockHeader, err = ledger.NewBlockHeaderFromCbor( - blockType, - msg.WrappedHeader.HeaderCbor(), - ) - if err != nil { - if firstBlockChan != nil { - firstBlockChan <- clientPointResult{error: err} - } - return err - } + blockHeaderBytes = msg.WrappedHeader.HeaderCbor() default: // Map block header type to block type blockType = ledger.BlockHeaderToBlockTypeMap[blockEra] + blockHeaderBytes = msg.WrappedHeader.HeaderCbor() + } + if firstBlockChan != nil || c.config.RollForwardFunc != nil { + // Decode header var err error blockHeader, err = ledger.NewBlockHeaderFromCbor( blockType, @@ -650,35 +645,63 @@ func (c *Client) handleRollForward(msgGeneric protocol.Message) error { return nil } // Call the user callback function - callbackErr = c.config.RollForwardFunc( - c.callbackContext, - blockType, - blockHeader, - msg.Tip, - ) + if c.config.RollForwardRawFunc != nil { + callbackErr = c.config.RollForwardRawFunc( + c.callbackContext, + blockType, + blockHeaderBytes, + msg.Tip, + ) + } else { + callbackErr = c.config.RollForwardFunc( + c.callbackContext, + blockType, + blockHeader, + msg.Tip, + ) + } } else { msg := msgGeneric.(*MsgRollForwardNtC) c.sendCurrentTip(msg.Tip) - blk, err := ledger.NewBlockFromCbor(msg.BlockType(), msg.BlockCbor()) - if err != nil { - if firstBlockChan != nil { - firstBlockChan <- clientPointResult{error: err} + var block ledger.Block + + if firstBlockChan != nil || c.config.RollForwardFunc != nil { + var err error + block, err = ledger.NewBlockFromCbor(msg.BlockType(), msg.BlockCbor()) + if err != nil { + if firstBlockChan != nil { + firstBlockChan <- clientPointResult{error: err} + } + return err } - return err } if firstBlockChan != nil { - blockHash, err := hex.DecodeString(blk.Hash()) + blockHash, err := hex.DecodeString(block.Hash()) if err != nil { firstBlockChan <- clientPointResult{error: err} return err } - point := common.NewPoint(blk.SlotNumber(), blockHash) + point := common.NewPoint(block.SlotNumber(), blockHash) firstBlockChan <- clientPointResult{tip: msg.Tip, point: point} return nil } // Call the user callback function - callbackErr = c.config.RollForwardFunc(c.callbackContext, msg.BlockType(), blk, msg.Tip) + if c.config.RollForwardRawFunc != nil { + callbackErr = c.config.RollForwardRawFunc( + c.callbackContext, + msg.BlockType(), + msg.BlockCbor(), + msg.Tip, + ) + } else { + callbackErr = c.config.RollForwardFunc( + c.callbackContext, + msg.BlockType(), + block, + msg.Tip, + ) + } } if callbackErr != nil { if callbackErr == StopSyncProcessError {