Skip to content

Commit 91d9678

Browse files
committed
feat: method to stop chainsync process
Fixes #236
1 parent 0700b1a commit 91d9678

File tree

2 files changed

+32
-9
lines changed

2 files changed

+32
-9
lines changed

protocol/chainsync/client.go

Lines changed: 24 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -2,10 +2,11 @@ package chainsync
22

33
import (
44
"fmt"
5+
"sync"
6+
57
"github.com/blinklabs-io/gouroboros/ledger"
68
"github.com/blinklabs-io/gouroboros/protocol"
79
"github.com/blinklabs-io/gouroboros/protocol/common"
8-
"sync"
910
)
1011

1112
// Client implements the ChainSync client
@@ -146,9 +147,12 @@ func (c *Client) Sync(intersectPoints []common.Point) error {
146147
func (c *Client) syncLoop() {
147148
for {
148149
// Wait for a block to be received
149-
if _, ok := <-c.readyForNextBlockChan; !ok {
150+
if ready, ok := <-c.readyForNextBlockChan; !ok {
150151
// Channel is closed, which means we're shutting down
151152
return
153+
} else if !ready {
154+
// Sync was cancelled
155+
return
152156
}
153157
c.busyMutex.Lock()
154158
// Request the next block
@@ -171,10 +175,7 @@ func (c *Client) handleRollForward(msgGeneric protocol.Message) error {
171175
if c.config.RollForwardFunc == nil {
172176
return fmt.Errorf("received chain-sync RollForward message but no callback function is defined")
173177
}
174-
// Signal that we're ready for the next block after we finish handling this one
175-
defer func() {
176-
c.readyForNextBlockChan <- true
177-
}()
178+
var callbackErr error
178179
if c.Mode() == protocol.ProtocolModeNodeToNode {
179180
msg := msgGeneric.(*MsgRollForwardNtN)
180181
var blockHeader interface{}
@@ -205,16 +206,23 @@ func (c *Client) handleRollForward(msgGeneric protocol.Message) error {
205206
}
206207
}
207208
// Call the user callback function
208-
return c.config.RollForwardFunc(blockType, blockHeader, msg.Tip)
209+
callbackErr = c.config.RollForwardFunc(blockType, blockHeader, msg.Tip)
209210
} else {
210211
msg := msgGeneric.(*MsgRollForwardNtC)
211212
blk, err := ledger.NewBlockFromCbor(msg.BlockType(), msg.BlockCbor())
212213
if err != nil {
213214
return err
214215
}
215216
// Call the user callback function
216-
return c.config.RollForwardFunc(msg.BlockType(), blk, msg.Tip)
217+
callbackErr = c.config.RollForwardFunc(msg.BlockType(), blk, msg.Tip)
217218
}
219+
if callbackErr == StopSyncProcessError {
220+
// Signal that we're cancelling the sync
221+
c.readyForNextBlockChan <- false
222+
}
223+
// Signal that we're ready for the next block
224+
c.readyForNextBlockChan <- true
225+
return nil
218226
}
219227

220228
func (c *Client) handleRollBackward(msgGeneric protocol.Message) error {
@@ -223,7 +231,14 @@ func (c *Client) handleRollBackward(msgGeneric protocol.Message) error {
223231
}
224232
msg := msgGeneric.(*MsgRollBackward)
225233
// Call the user callback function
226-
return c.config.RollBackwardFunc(msg.Point, msg.Tip)
234+
callbackErr := c.config.RollBackwardFunc(msg.Point, msg.Tip)
235+
if callbackErr == StopSyncProcessError {
236+
// Signal that we're cancelling the sync
237+
c.readyForNextBlockChan <- false
238+
}
239+
// Signal that we're ready for the next block
240+
c.readyForNextBlockChan <- true
241+
return nil
227242
}
228243

229244
func (c *Client) handleIntersectFound(msgGeneric protocol.Message) error {

protocol/chainsync/error.go

Lines changed: 8 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,9 +1,17 @@
11
package chainsync
22

3+
import (
4+
"fmt"
5+
)
6+
37
// IntersectNotFoundError represents a failure to find a chain intersection
48
type IntersectNotFoundError struct {
59
}
610

711
func (e IntersectNotFoundError) Error() string {
812
return "chain intersection not found"
913
}
14+
15+
// StopChainSync is used as a special return value from a RollForward or RollBackward handler function
16+
// to signify that the sync process should be stopped
17+
var StopSyncProcessError = fmt.Errorf("stop sync process")

0 commit comments

Comments
 (0)