Skip to content

Commit 2fa8b17

Browse files
authored
Merge pull request #132 from cloudstruct/feat/chainsync-friendly-interface
feat: make the chainsync interface friendlier
2 parents dfb7ac6 + 9359bca commit 2fa8b17

File tree

6 files changed

+126
-111
lines changed

6 files changed

+126
-111
lines changed

cmd/go-ouroboros-network/chainsync.go

Lines changed: 19 additions & 61 deletions
Original file line numberDiff line numberDiff line change
@@ -13,11 +13,10 @@ import (
1313
)
1414

1515
type chainSyncState struct {
16-
oConn *ouroboros.Ouroboros
17-
nodeToNode bool
18-
readyForNextBlockChan chan bool
19-
byronEpochBaseSlot uint64
20-
byronEpochSlot uint64
16+
oConn *ouroboros.Ouroboros
17+
nodeToNode bool
18+
byronEpochBaseSlot uint64
19+
byronEpochSlot uint64
2120
}
2221

2322
var syncState chainSyncState
@@ -78,12 +77,8 @@ var eraIntersect = map[int]map[string][]interface{}{
7877

7978
func buildChainSyncConfig() chainsync.Config {
8079
return chainsync.Config{
81-
AwaitReplyFunc: chainSyncAwaitReplyHandler,
82-
RollBackwardFunc: chainSyncRollBackwardHandler,
83-
RollForwardFunc: chainSyncRollForwardHandler,
84-
IntersectFoundFunc: chainSyncIntersectFoundHandler,
85-
IntersectNotFoundFunc: chainSyncIntersectNotFoundHandler,
86-
DoneFunc: chainSyncDoneHandler,
80+
RollBackwardFunc: chainSyncRollBackwardHandler,
81+
RollForwardFunc: chainSyncRollForwardHandler,
8782
}
8883
}
8984

@@ -132,56 +127,36 @@ func testChainSync(f *globalFlags) {
132127
os.Exit(1)
133128
}
134129
o.ChainSync.Client.Start()
135-
o.BlockFetch.Client.Start()
130+
if f.ntnProto {
131+
o.BlockFetch.Client.Start()
132+
}
136133

137134
syncState.oConn = o
138-
syncState.readyForNextBlockChan = make(chan bool)
139135
syncState.nodeToNode = f.ntnProto
140136
var point chainsync.Point
141137
if len(eraIntersect[f.networkMagic][chainSyncFlags.startEra]) > 0 {
142138
// Slot
143-
point.Slot = uint64(eraIntersect[f.networkMagic][chainSyncFlags.startEra][0].(int))
139+
slot := uint64(eraIntersect[f.networkMagic][chainSyncFlags.startEra][0].(int))
144140
// Block hash
145141
hash, _ := hex.DecodeString(eraIntersect[f.networkMagic][chainSyncFlags.startEra][1].(string))
146-
point.Hash = hash
142+
point = chainsync.NewPoint(slot, hash)
143+
} else {
144+
point = chainsync.NewPointOrigin()
147145
}
148-
if err := o.ChainSync.Client.FindIntersect([]chainsync.Point{point}); err != nil {
149-
fmt.Printf("ERROR: FindIntersect: %s\n", err)
146+
if err := o.ChainSync.Client.Sync([]chainsync.Point{point}); err != nil {
147+
fmt.Printf("ERROR: failed to start chain-sync: %s\n", err)
150148
os.Exit(1)
151149
}
152-
// Wait until ready for next block
153-
<-syncState.readyForNextBlockChan
154-
// Pipeline the initial block requests to speed things up a bit
155-
// Using a value higher than 10 seems to cause problems with NtN
156-
for i := 0; i < 10; i++ {
157-
err := o.ChainSync.Client.RequestNext()
158-
if err != nil {
159-
fmt.Printf("ERROR: RequestNext: %s\n", err)
160-
os.Exit(1)
161-
}
162-
}
163-
for {
164-
err := o.ChainSync.Client.RequestNext()
165-
if err != nil {
166-
fmt.Printf("ERROR: RequestNext: %s\n", err)
167-
os.Exit(1)
168-
}
169-
// Wait until ready for next block
170-
<-syncState.readyForNextBlockChan
171-
}
172-
}
173-
174-
func chainSyncAwaitReplyHandler() error {
175-
return nil
150+
// Wait forever...the rest of the sync operations are async
151+
select {}
176152
}
177153

178-
func chainSyncRollBackwardHandler(point interface{}, tip interface{}) error {
154+
func chainSyncRollBackwardHandler(point chainsync.Point, tip chainsync.Tip) error {
179155
fmt.Printf("roll backward: point = %#v, tip = %#v\n", point, tip)
180-
syncState.readyForNextBlockChan <- true
181156
return nil
182157
}
183158

184-
func chainSyncRollForwardHandler(blockType uint, blockData interface{}) error {
159+
func chainSyncRollForwardHandler(blockType uint, blockData interface{}, tip chainsync.Tip) error {
185160
if syncState.nodeToNode {
186161
var blockSlot uint64
187162
var blockHash []byte
@@ -241,23 +216,6 @@ func chainSyncRollForwardHandler(blockType uint, blockData interface{}) error {
241216
fmt.Printf("%s\n", utils.DumpCborStructure(blockData, ""))
242217
}
243218
}
244-
syncState.readyForNextBlockChan <- true
245-
return nil
246-
}
247-
248-
func chainSyncIntersectFoundHandler(point interface{}, tip interface{}) error {
249-
fmt.Printf("found intersect: point = %#v, tip = %#v\n", point, tip)
250-
syncState.readyForNextBlockChan <- true
251-
return nil
252-
}
253-
254-
func chainSyncIntersectNotFoundHandler(tip interface{}) error {
255-
fmt.Printf("ERROR: failed to find intersection\n")
256-
os.Exit(1)
257-
return nil
258-
}
259-
260-
func chainSyncDoneHandler() error {
261219
return nil
262220
}
263221

protocol/chainsync/chainsync.go

Lines changed: 4 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -90,21 +90,13 @@ type ChainSync struct {
9090
}
9191

9292
type Config struct {
93-
AwaitReplyFunc AwaitReplyFunc
94-
RollBackwardFunc RollBackwardFunc
95-
RollForwardFunc RollForwardFunc
96-
IntersectFoundFunc IntersectFoundFunc
97-
IntersectNotFoundFunc IntersectNotFoundFunc
98-
DoneFunc DoneFunc
93+
RollBackwardFunc RollBackwardFunc
94+
RollForwardFunc RollForwardFunc
9995
}
10096

10197
// Callback function types
102-
type AwaitReplyFunc func() error
103-
type RollBackwardFunc func(interface{}, interface{}) error
104-
type RollForwardFunc func(uint, interface{}) error
105-
type IntersectFoundFunc func(interface{}, interface{}) error
106-
type IntersectNotFoundFunc func(interface{}) error
107-
type DoneFunc func() error
98+
type RollBackwardFunc func(Point, Tip) error
99+
type RollForwardFunc func(uint, interface{}, Tip) error
108100

109101
func New(protoOptions protocol.ProtocolOptions, cfg *Config) *ChainSync {
110102
c := &ChainSync{

protocol/chainsync/client.go

Lines changed: 64 additions & 37 deletions
Original file line numberDiff line numberDiff line change
@@ -4,11 +4,15 @@ import (
44
"fmt"
55
"github.com/cloudstruct/go-cardano-ledger"
66
"github.com/cloudstruct/go-ouroboros-network/protocol"
7+
"sync"
78
)
89

910
type Client struct {
1011
*protocol.Protocol
11-
config *Config
12+
config *Config
13+
busyMutex sync.Mutex
14+
intersectResultChan chan error
15+
readyForNextBlockChan chan bool
1216
}
1317

1418
func NewClient(protoOptions protocol.ProtocolOptions, cfg *Config) *Client {
@@ -21,7 +25,9 @@ func NewClient(protoOptions protocol.ProtocolOptions, cfg *Config) *Client {
2125
msgFromCborFunc = NewMsgFromCborNtN
2226
}
2327
c := &Client{
24-
config: cfg,
28+
config: cfg,
29+
intersectResultChan: make(chan error),
30+
readyForNextBlockChan: make(chan bool),
2531
}
2632
protoConfig := protocol.ProtocolConfig{
2733
Name: PROTOCOL_NAME,
@@ -52,36 +58,73 @@ func (c *Client) messageHandler(msg protocol.Message, isResponse bool) error {
5258
err = c.handleIntersectFound(msg)
5359
case MESSAGE_TYPE_INTERSECT_NOT_FOUND:
5460
err = c.handleIntersectNotFound(msg)
55-
case MESSAGE_TYPE_DONE:
56-
err = c.handleDone()
5761
default:
5862
err = fmt.Errorf("%s: received unexpected message type %d", PROTOCOL_NAME, msg.Type())
5963
}
6064
return err
6165
}
6266

63-
func (c *Client) RequestNext() error {
64-
msg := NewMsgRequestNext()
65-
return c.SendMessage(msg)
67+
func (c *Client) Stop() error {
68+
c.busyMutex.Lock()
69+
defer c.busyMutex.Unlock()
70+
msg := NewMsgDone()
71+
if err := c.SendMessage(msg); err != nil {
72+
return err
73+
}
74+
return nil
6675
}
6776

68-
func (c *Client) FindIntersect(points []Point) error {
69-
msg := NewMsgFindIntersect(points)
70-
return c.SendMessage(msg)
77+
func (c *Client) Sync(intersectPoints []Point) error {
78+
c.busyMutex.Lock()
79+
defer c.busyMutex.Unlock()
80+
msg := NewMsgFindIntersect(intersectPoints)
81+
if err := c.SendMessage(msg); err != nil {
82+
return err
83+
}
84+
if err := <-c.intersectResultChan; err != nil {
85+
return err
86+
}
87+
// Pipeline the initial block requests to speed things up a bit
88+
// Using a value higher than 10 seems to cause problems with NtN
89+
for i := 0; i < 10; i++ {
90+
msg := NewMsgRequestNext()
91+
if err := c.SendMessage(msg); err != nil {
92+
return err
93+
}
94+
}
95+
go c.syncLoop()
96+
return nil
7197
}
7298

73-
func (c *Client) handleAwaitReply() error {
74-
if c.config.AwaitReplyFunc == nil {
75-
return fmt.Errorf("received chain-sync AwaitReply message but no callback function is defined")
99+
func (c *Client) syncLoop() {
100+
for {
101+
// Wait for a block to be received
102+
<-c.readyForNextBlockChan
103+
c.busyMutex.Lock()
104+
// Request the next block
105+
// In practice we already have multiple block requests pipelined
106+
// and this just adds another one to the pile
107+
msg := NewMsgRequestNext()
108+
if err := c.SendMessage(msg); err != nil {
109+
c.SendError(err)
110+
return
111+
}
112+
c.busyMutex.Unlock()
76113
}
77-
// Call the user callback function
78-
return c.config.AwaitReplyFunc()
114+
}
115+
116+
func (c *Client) handleAwaitReply() error {
117+
return nil
79118
}
80119

81120
func (c *Client) handleRollForward(msgGeneric protocol.Message) error {
82121
if c.config.RollForwardFunc == nil {
83122
return fmt.Errorf("received chain-sync RollForward message but no callback function is defined")
84123
}
124+
// Signal that we're ready for the next block after we finish handling this one
125+
defer func() {
126+
c.readyForNextBlockChan <- true
127+
}()
85128
if c.Mode() == protocol.ProtocolModeNodeToNode {
86129
msg := msgGeneric.(*MsgRollForwardNtN)
87130
var blockHeader interface{}
@@ -112,15 +155,15 @@ func (c *Client) handleRollForward(msgGeneric protocol.Message) error {
112155
}
113156
}
114157
// Call the user callback function
115-
return c.config.RollForwardFunc(blockType, blockHeader)
158+
return c.config.RollForwardFunc(blockType, blockHeader, msg.Tip)
116159
} else {
117160
msg := msgGeneric.(*MsgRollForwardNtC)
118161
blk, err := ledger.NewBlockFromCbor(msg.BlockType(), msg.BlockCbor())
119162
if err != nil {
120163
return err
121164
}
122165
// Call the user callback function
123-
return c.config.RollForwardFunc(msg.BlockType(), blk)
166+
return c.config.RollForwardFunc(msg.BlockType(), blk, msg.Tip)
124167
}
125168
}
126169

@@ -134,27 +177,11 @@ func (c *Client) handleRollBackward(msgGeneric protocol.Message) error {
134177
}
135178

136179
func (c *Client) handleIntersectFound(msgGeneric protocol.Message) error {
137-
if c.config.IntersectFoundFunc == nil {
138-
return fmt.Errorf("received chain-sync IntersectFound message but no callback function is defined")
139-
}
140-
msg := msgGeneric.(*MsgIntersectFound)
141-
// Call the user callback function
142-
return c.config.IntersectFoundFunc(msg.Point, msg.Tip)
180+
c.intersectResultChan <- nil
181+
return nil
143182
}
144183

145184
func (c *Client) handleIntersectNotFound(msgGeneric protocol.Message) error {
146-
if c.config.IntersectNotFoundFunc == nil {
147-
return fmt.Errorf("received chain-sync IntersectNotFound message but no callback function is defined")
148-
}
149-
msg := msgGeneric.(*MsgIntersectNotFound)
150-
// Call the user callback function
151-
return c.config.IntersectNotFoundFunc(msg.Tip)
152-
}
153-
154-
func (c *Client) handleDone() error {
155-
if c.config.DoneFunc == nil {
156-
return fmt.Errorf("received chain-sync Done message but no callback function is defined")
157-
}
158-
// Call the user callback function
159-
return c.config.DoneFunc()
185+
c.intersectResultChan <- IntersectNotFoundError{}
186+
return nil
160187
}

protocol/chainsync/error.go

Lines changed: 8 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,8 @@
1+
package chainsync
2+
3+
type IntersectNotFoundError struct {
4+
}
5+
6+
func (e IntersectNotFoundError) Error() string {
7+
return "chain intersection not found"
8+
}

protocol/chainsync/messages.go

Lines changed: 11 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -228,6 +228,17 @@ type Point struct {
228228
Hash []byte
229229
}
230230

231+
func NewPoint(slot uint64, blockHash []byte) Point {
232+
return Point{
233+
Slot: slot,
234+
Hash: blockHash,
235+
}
236+
}
237+
238+
func NewPointOrigin() Point {
239+
return Point{}
240+
}
241+
231242
// A "point" can sometimes be empty, but the CBOR library gets grumpy about this
232243
// when doing automatic decoding from an array, so we have to handle this case specially
233244
func (p *Point) UnmarshalCBOR(data []byte) error {

protocol/chainsync/server.go

Lines changed: 20 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -40,10 +40,29 @@ func NewServer(protoOptions protocol.ProtocolOptions, cfg *Config) *Server {
4040

4141
func (s *Server) messageHandler(msg protocol.Message, isResponse bool) error {
4242
var err error
43-
// TODO: add cases for messages from client
4443
switch msg.Type() {
44+
case MESSAGE_TYPE_REQUEST_NEXT:
45+
err = s.handleRequestNext(msg)
46+
case MESSAGE_TYPE_FIND_INTERSECT:
47+
err = s.handleFindIntersect(msg)
48+
case MESSAGE_TYPE_DONE:
49+
err = s.handleDone()
4550
default:
4651
err = fmt.Errorf("%s: received unexpected message type %d", PROTOCOL_NAME, msg.Type())
4752
}
4853
return err
4954
}
55+
56+
func (s *Server) handleRequestNext(msg protocol.Message) error {
57+
// TODO
58+
return nil
59+
}
60+
61+
func (s *Server) handleFindIntersect(msg protocol.Message) error {
62+
// TODO
63+
return nil
64+
}
65+
66+
func (s *Server) handleDone() error {
67+
return nil
68+
}

0 commit comments

Comments
 (0)