Skip to content

Commit bcab703

Browse files
authored
Merge pull request #124 from cloudstruct/feature/chainsync-client-server-split
feat: split chainsync protocol into client and server
2 parents 195dff6 + 15aa238 commit bcab703

File tree

4 files changed

+218
-155
lines changed

4 files changed

+218
-155
lines changed

cmd/go-ouroboros-network/chainsync.go

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -131,7 +131,7 @@ func testChainSync(f *globalFlags) {
131131
fmt.Printf("ERROR: %s\n", err)
132132
os.Exit(1)
133133
}
134-
o.ChainSync.Start()
134+
o.ChainSync.Client.Start()
135135
o.BlockFetch.Client.Start()
136136

137137
syncState.oConn = o
@@ -145,7 +145,7 @@ func testChainSync(f *globalFlags) {
145145
hash, _ := hex.DecodeString(eraIntersect[f.networkMagic][chainSyncFlags.startEra][1].(string))
146146
point.Hash = hash
147147
}
148-
if err := o.ChainSync.FindIntersect([]chainsync.Point{point}); err != nil {
148+
if err := o.ChainSync.Client.FindIntersect([]chainsync.Point{point}); err != nil {
149149
fmt.Printf("ERROR: FindIntersect: %s\n", err)
150150
os.Exit(1)
151151
}
@@ -154,14 +154,14 @@ func testChainSync(f *globalFlags) {
154154
// Pipeline the initial block requests to speed things up a bit
155155
// Using a value higher than 10 seems to cause problems with NtN
156156
for i := 0; i < 10; i++ {
157-
err := o.ChainSync.RequestNext()
157+
err := o.ChainSync.Client.RequestNext()
158158
if err != nil {
159159
fmt.Printf("ERROR: RequestNext: %s\n", err)
160160
os.Exit(1)
161161
}
162162
}
163163
for {
164-
err := o.ChainSync.RequestNext()
164+
err := o.ChainSync.Client.RequestNext()
165165
if err != nil {
166166
fmt.Printf("ERROR: RequestNext: %s\n", err)
167167
os.Exit(1)

protocol/chainsync/chainsync.go

Lines changed: 5 additions & 151 deletions
Original file line numberDiff line numberDiff line change
@@ -1,8 +1,6 @@
11
package chainsync
22

33
import (
4-
"fmt"
5-
"github.com/cloudstruct/go-cardano-ledger"
64
"github.com/cloudstruct/go-ouroboros-network/protocol"
75
)
86

@@ -87,8 +85,8 @@ var StateMap = protocol.StateMap{
8785
}
8886

8987
type ChainSync struct {
90-
*protocol.Protocol
91-
config *Config
88+
Client *Client
89+
Server *Server
9290
}
9391

9492
type Config struct {
@@ -108,154 +106,10 @@ type IntersectFoundFunc func(interface{}, interface{}) error
108106
type IntersectNotFoundFunc func(interface{}) error
109107
type DoneFunc func() error
110108

111-
func New(options protocol.ProtocolOptions, cfg *Config) *ChainSync {
112-
// Use node-to-client protocol ID
113-
protocolId := PROTOCOL_ID_NTC
114-
msgFromCborFunc := NewMsgFromCborNtC
115-
if options.Mode == protocol.ProtocolModeNodeToNode {
116-
// Use node-to-node protocol ID
117-
protocolId = PROTOCOL_ID_NTN
118-
msgFromCborFunc = NewMsgFromCborNtN
119-
}
109+
func New(protoOptions protocol.ProtocolOptions, cfg *Config) *ChainSync {
120110
c := &ChainSync{
121-
config: cfg,
122-
}
123-
protoConfig := protocol.ProtocolConfig{
124-
Name: PROTOCOL_NAME,
125-
ProtocolId: protocolId,
126-
Muxer: options.Muxer,
127-
ErrorChan: options.ErrorChan,
128-
Mode: options.Mode,
129-
Role: options.Role,
130-
MessageHandlerFunc: c.messageHandler,
131-
MessageFromCborFunc: msgFromCborFunc,
132-
StateMap: StateMap,
133-
InitialState: STATE_IDLE,
111+
Client: NewClient(protoOptions, cfg),
112+
Server: NewServer(protoOptions, cfg),
134113
}
135-
c.Protocol = protocol.New(protoConfig)
136114
return c
137115
}
138-
139-
func (c *ChainSync) Start() {
140-
c.Protocol.Start()
141-
}
142-
143-
func (c *ChainSync) messageHandler(msg protocol.Message, isResponse bool) error {
144-
var err error
145-
switch msg.Type() {
146-
case MESSAGE_TYPE_AWAIT_REPLY:
147-
err = c.handleAwaitReply()
148-
case MESSAGE_TYPE_ROLL_FORWARD:
149-
err = c.handleRollForward(msg)
150-
case MESSAGE_TYPE_ROLL_BACKWARD:
151-
err = c.handleRollBackward(msg)
152-
case MESSAGE_TYPE_INTERSECT_FOUND:
153-
err = c.handleIntersectFound(msg)
154-
case MESSAGE_TYPE_INTERSECT_NOT_FOUND:
155-
err = c.handleIntersectNotFound(msg)
156-
case MESSAGE_TYPE_DONE:
157-
err = c.handleDone()
158-
default:
159-
err = fmt.Errorf("%s: received unexpected message type %d", PROTOCOL_NAME, msg.Type())
160-
}
161-
return err
162-
}
163-
164-
func (c *ChainSync) RequestNext() error {
165-
msg := NewMsgRequestNext()
166-
return c.SendMessage(msg)
167-
}
168-
169-
func (c *ChainSync) FindIntersect(points []Point) error {
170-
msg := NewMsgFindIntersect(points)
171-
return c.SendMessage(msg)
172-
}
173-
174-
func (c *ChainSync) handleAwaitReply() error {
175-
if c.config.AwaitReplyFunc == nil {
176-
return fmt.Errorf("received chain-sync AwaitReply message but no callback function is defined")
177-
}
178-
// Call the user callback function
179-
return c.config.AwaitReplyFunc()
180-
}
181-
182-
func (c *ChainSync) handleRollForward(msgGeneric protocol.Message) error {
183-
if c.config.RollForwardFunc == nil {
184-
return fmt.Errorf("received chain-sync RollForward message but no callback function is defined")
185-
}
186-
if c.Mode() == protocol.ProtocolModeNodeToNode {
187-
msg := msgGeneric.(*MsgRollForwardNtN)
188-
var blockHeader interface{}
189-
var blockType uint
190-
blockEra := msg.WrappedHeader.Era
191-
switch blockEra {
192-
case ledger.BLOCK_HEADER_TYPE_BYRON:
193-
blockType = msg.WrappedHeader.ByronType()
194-
var err error
195-
blockHeader, err = ledger.NewBlockHeaderFromCbor(blockType, msg.WrappedHeader.HeaderCbor())
196-
if err != nil {
197-
return err
198-
}
199-
default:
200-
// Map block header types to block types
201-
blockTypeMap := map[uint]uint{
202-
ledger.BLOCK_HEADER_TYPE_SHELLEY: ledger.BLOCK_TYPE_SHELLEY,
203-
ledger.BLOCK_HEADER_TYPE_ALLEGRA: ledger.BLOCK_TYPE_ALLEGRA,
204-
ledger.BLOCK_HEADER_TYPE_MARY: ledger.BLOCK_TYPE_MARY,
205-
ledger.BLOCK_HEADER_TYPE_ALONZO: ledger.BLOCK_TYPE_ALONZO,
206-
ledger.BLOCK_HEADER_TYPE_BABBAGE: ledger.BLOCK_TYPE_BABBAGE,
207-
}
208-
blockType = blockTypeMap[blockEra]
209-
var err error
210-
blockHeader, err = ledger.NewBlockHeaderFromCbor(blockType, msg.WrappedHeader.HeaderCbor())
211-
if err != nil {
212-
return err
213-
}
214-
}
215-
// Call the user callback function
216-
return c.config.RollForwardFunc(blockType, blockHeader)
217-
} else {
218-
msg := msgGeneric.(*MsgRollForwardNtC)
219-
blk, err := ledger.NewBlockFromCbor(msg.BlockType(), msg.BlockCbor())
220-
if err != nil {
221-
return err
222-
}
223-
// Call the user callback function
224-
return c.config.RollForwardFunc(msg.BlockType(), blk)
225-
}
226-
}
227-
228-
func (c *ChainSync) handleRollBackward(msgGeneric protocol.Message) error {
229-
if c.config.RollBackwardFunc == nil {
230-
return fmt.Errorf("received chain-sync RollBackward message but no callback function is defined")
231-
}
232-
msg := msgGeneric.(*MsgRollBackward)
233-
// Call the user callback function
234-
return c.config.RollBackwardFunc(msg.Point, msg.Tip)
235-
}
236-
237-
func (c *ChainSync) handleIntersectFound(msgGeneric protocol.Message) error {
238-
if c.config.IntersectFoundFunc == nil {
239-
return fmt.Errorf("received chain-sync IntersectFound message but no callback function is defined")
240-
}
241-
msg := msgGeneric.(*MsgIntersectFound)
242-
// Call the user callback function
243-
return c.config.IntersectFoundFunc(msg.Point, msg.Tip)
244-
}
245-
246-
func (c *ChainSync) handleIntersectNotFound(msgGeneric protocol.Message) error {
247-
if c.config.IntersectNotFoundFunc == nil {
248-
return fmt.Errorf("received chain-sync IntersectNotFound message but no callback function is defined")
249-
}
250-
msg := msgGeneric.(*MsgIntersectNotFound)
251-
// Call the user callback function
252-
return c.config.IntersectNotFoundFunc(msg.Tip)
253-
}
254-
255-
func (c *ChainSync) handleDone() error {
256-
if c.config.DoneFunc == nil {
257-
return fmt.Errorf("received chain-sync Done message but no callback function is defined")
258-
}
259-
// Call the user callback function
260-
return c.config.DoneFunc()
261-
}

protocol/chainsync/client.go

Lines changed: 160 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,160 @@
1+
package chainsync
2+
3+
import (
4+
"fmt"
5+
"github.com/cloudstruct/go-cardano-ledger"
6+
"github.com/cloudstruct/go-ouroboros-network/protocol"
7+
)
8+
9+
type Client struct {
10+
*protocol.Protocol
11+
config *Config
12+
}
13+
14+
func NewClient(protoOptions protocol.ProtocolOptions, cfg *Config) *Client {
15+
// Use node-to-client protocol ID
16+
protocolId := PROTOCOL_ID_NTC
17+
msgFromCborFunc := NewMsgFromCborNtC
18+
if protoOptions.Mode == protocol.ProtocolModeNodeToNode {
19+
// Use node-to-node protocol ID
20+
protocolId = PROTOCOL_ID_NTN
21+
msgFromCborFunc = NewMsgFromCborNtN
22+
}
23+
c := &Client{
24+
config: cfg,
25+
}
26+
protoConfig := protocol.ProtocolConfig{
27+
Name: PROTOCOL_NAME,
28+
ProtocolId: protocolId,
29+
Muxer: protoOptions.Muxer,
30+
ErrorChan: protoOptions.ErrorChan,
31+
Mode: protoOptions.Mode,
32+
Role: protocol.ProtocolRoleClient,
33+
MessageHandlerFunc: c.messageHandler,
34+
MessageFromCborFunc: msgFromCborFunc,
35+
StateMap: StateMap,
36+
InitialState: STATE_IDLE,
37+
}
38+
c.Protocol = protocol.New(protoConfig)
39+
return c
40+
}
41+
42+
func (c *Client) messageHandler(msg protocol.Message, isResponse bool) error {
43+
var err error
44+
switch msg.Type() {
45+
case MESSAGE_TYPE_AWAIT_REPLY:
46+
err = c.handleAwaitReply()
47+
case MESSAGE_TYPE_ROLL_FORWARD:
48+
err = c.handleRollForward(msg)
49+
case MESSAGE_TYPE_ROLL_BACKWARD:
50+
err = c.handleRollBackward(msg)
51+
case MESSAGE_TYPE_INTERSECT_FOUND:
52+
err = c.handleIntersectFound(msg)
53+
case MESSAGE_TYPE_INTERSECT_NOT_FOUND:
54+
err = c.handleIntersectNotFound(msg)
55+
case MESSAGE_TYPE_DONE:
56+
err = c.handleDone()
57+
default:
58+
err = fmt.Errorf("%s: received unexpected message type %d", PROTOCOL_NAME, msg.Type())
59+
}
60+
return err
61+
}
62+
63+
func (c *Client) RequestNext() error {
64+
msg := NewMsgRequestNext()
65+
return c.SendMessage(msg)
66+
}
67+
68+
func (c *Client) FindIntersect(points []Point) error {
69+
msg := NewMsgFindIntersect(points)
70+
return c.SendMessage(msg)
71+
}
72+
73+
func (c *Client) handleAwaitReply() error {
74+
if c.config.AwaitReplyFunc == nil {
75+
return fmt.Errorf("received chain-sync AwaitReply message but no callback function is defined")
76+
}
77+
// Call the user callback function
78+
return c.config.AwaitReplyFunc()
79+
}
80+
81+
func (c *Client) handleRollForward(msgGeneric protocol.Message) error {
82+
if c.config.RollForwardFunc == nil {
83+
return fmt.Errorf("received chain-sync RollForward message but no callback function is defined")
84+
}
85+
if c.Mode() == protocol.ProtocolModeNodeToNode {
86+
msg := msgGeneric.(*MsgRollForwardNtN)
87+
var blockHeader interface{}
88+
var blockType uint
89+
blockEra := msg.WrappedHeader.Era
90+
switch blockEra {
91+
case ledger.BLOCK_HEADER_TYPE_BYRON:
92+
blockType = msg.WrappedHeader.ByronType()
93+
var err error
94+
blockHeader, err = ledger.NewBlockHeaderFromCbor(blockType, msg.WrappedHeader.HeaderCbor())
95+
if err != nil {
96+
return err
97+
}
98+
default:
99+
// Map block header types to block types
100+
blockTypeMap := map[uint]uint{
101+
ledger.BLOCK_HEADER_TYPE_SHELLEY: ledger.BLOCK_TYPE_SHELLEY,
102+
ledger.BLOCK_HEADER_TYPE_ALLEGRA: ledger.BLOCK_TYPE_ALLEGRA,
103+
ledger.BLOCK_HEADER_TYPE_MARY: ledger.BLOCK_TYPE_MARY,
104+
ledger.BLOCK_HEADER_TYPE_ALONZO: ledger.BLOCK_TYPE_ALONZO,
105+
ledger.BLOCK_HEADER_TYPE_BABBAGE: ledger.BLOCK_TYPE_BABBAGE,
106+
}
107+
blockType = blockTypeMap[blockEra]
108+
var err error
109+
blockHeader, err = ledger.NewBlockHeaderFromCbor(blockType, msg.WrappedHeader.HeaderCbor())
110+
if err != nil {
111+
return err
112+
}
113+
}
114+
// Call the user callback function
115+
return c.config.RollForwardFunc(blockType, blockHeader)
116+
} else {
117+
msg := msgGeneric.(*MsgRollForwardNtC)
118+
blk, err := ledger.NewBlockFromCbor(msg.BlockType(), msg.BlockCbor())
119+
if err != nil {
120+
return err
121+
}
122+
// Call the user callback function
123+
return c.config.RollForwardFunc(msg.BlockType(), blk)
124+
}
125+
}
126+
127+
func (c *Client) handleRollBackward(msgGeneric protocol.Message) error {
128+
if c.config.RollBackwardFunc == nil {
129+
return fmt.Errorf("received chain-sync RollBackward message but no callback function is defined")
130+
}
131+
msg := msgGeneric.(*MsgRollBackward)
132+
// Call the user callback function
133+
return c.config.RollBackwardFunc(msg.Point, msg.Tip)
134+
}
135+
136+
func (c *Client) handleIntersectFound(msgGeneric protocol.Message) error {
137+
if c.config.IntersectFoundFunc == nil {
138+
return fmt.Errorf("received chain-sync IntersectFound message but no callback function is defined")
139+
}
140+
msg := msgGeneric.(*MsgIntersectFound)
141+
// Call the user callback function
142+
return c.config.IntersectFoundFunc(msg.Point, msg.Tip)
143+
}
144+
145+
func (c *Client) handleIntersectNotFound(msgGeneric protocol.Message) error {
146+
if c.config.IntersectNotFoundFunc == nil {
147+
return fmt.Errorf("received chain-sync IntersectNotFound message but no callback function is defined")
148+
}
149+
msg := msgGeneric.(*MsgIntersectNotFound)
150+
// Call the user callback function
151+
return c.config.IntersectNotFoundFunc(msg.Tip)
152+
}
153+
154+
func (c *Client) handleDone() error {
155+
if c.config.DoneFunc == nil {
156+
return fmt.Errorf("received chain-sync Done message but no callback function is defined")
157+
}
158+
// Call the user callback function
159+
return c.config.DoneFunc()
160+
}

0 commit comments

Comments
 (0)