Skip to content

Commit 0506667

Browse files
authored
Merge pull request #17 from cloudstruct/feature/protocol-refactor
Refactor mini-protocols (part 2)
2 parents 11df25c + 3873d14 commit 0506667

File tree

4 files changed

+140
-65
lines changed

4 files changed

+140
-65
lines changed

protocol/blockfetch/blockfetch.go

Lines changed: 32 additions & 20 deletions
Original file line numberDiff line numberDiff line change
@@ -11,11 +11,13 @@ import (
1111
const (
1212
PROTOCOL_NAME = "block-fetch"
1313
PROTOCOL_ID uint16 = 3
14+
)
1415

15-
STATE_IDLE = iota
16-
STATE_BUSY
17-
STATE_STREAMING
18-
STATE_DONE
16+
var (
17+
STATE_IDLE = protocol.NewState(1, "Idle")
18+
STATE_BUSY = protocol.NewState(2, "Busy")
19+
STATE_STREAMING = protocol.NewState(3, "Streaming")
20+
STATE_DONE = protocol.NewState(4, "Done")
1921
)
2022

2123
type BlockFetch struct {
@@ -41,6 +43,7 @@ func New(m *muxer.Muxer, errorChan chan error, callbackConfig *BlockFetchCallbac
4143
callbackConfig: callbackConfig,
4244
}
4345
b.proto = protocol.New(PROTOCOL_NAME, PROTOCOL_ID, m, errorChan, b.messageHandler, NewMsgFromCbor)
46+
// Set initial state
4447
b.proto.SetState(STATE_IDLE)
4548
return b
4649
}
@@ -63,49 +66,55 @@ func (b *BlockFetch) messageHandler(msg protocol.Message) error {
6366
}
6467

6568
func (b *BlockFetch) RequestRange(start []interface{}, end []interface{}) error {
66-
if b.proto.GetState() != STATE_IDLE {
67-
return fmt.Errorf("block-fetch: RequestRange: protocol not in expected state")
69+
if err := b.proto.LockState([]protocol.State{STATE_IDLE}); err != nil {
70+
return fmt.Errorf("%s: RequestRange: protocol not in expected state", PROTOCOL_NAME)
6871
}
6972
msg := newMsgRequestRange(start, end)
70-
b.proto.SetState(STATE_BUSY)
73+
// Unlock and change state when we're done
74+
defer b.proto.UnlockState(STATE_BUSY)
7175
// Send request
7276
return b.proto.SendMessage(msg, false)
7377
}
7478

7579
func (b *BlockFetch) ClientDone() error {
76-
if b.proto.GetState() != STATE_IDLE {
77-
return fmt.Errorf("block-fetch: ClientDone: protocol not in expected state")
80+
if err := b.proto.LockState([]protocol.State{STATE_IDLE}); err != nil {
81+
return fmt.Errorf("%s: ClientDone: protocol not in expected state", PROTOCOL_NAME)
7882
}
7983
msg := newMsgClientDone()
80-
b.proto.SetState(STATE_BUSY)
84+
// Unlock and change state when we're done
85+
defer b.proto.UnlockState(STATE_BUSY)
8186
// Send request
8287
return b.proto.SendMessage(msg, false)
8388
}
8489

8590
func (b *BlockFetch) handleStartBatch() error {
86-
if b.proto.GetState() != STATE_BUSY {
91+
if err := b.proto.LockState([]protocol.State{STATE_BUSY}); err != nil {
8792
return fmt.Errorf("received block-fetch StartBatch message when protocol not in expected state")
8893
}
8994
if b.callbackConfig.StartBatchFunc == nil {
9095
return fmt.Errorf("received block-fetch StartBatch message but no callback function is defined")
9196
}
92-
b.proto.SetState(STATE_STREAMING)
97+
// Unlock and change state when we're done
98+
defer b.proto.UnlockState(STATE_STREAMING)
99+
// Call the user callback function
93100
return b.callbackConfig.StartBatchFunc()
94101
}
95102

96103
func (b *BlockFetch) handleNoBlocks() error {
97-
if b.proto.GetState() != STATE_BUSY {
104+
if err := b.proto.LockState([]protocol.State{STATE_BUSY}); err != nil {
98105
return fmt.Errorf("received block-fetch NoBlocks message when protocol not in expected state")
99106
}
100107
if b.callbackConfig.NoBlocksFunc == nil {
101108
return fmt.Errorf("received block-fetch NoBlocks message but no callback function is defined")
102109
}
103-
b.proto.SetState(STATE_IDLE)
110+
// Unlock and change state when we're done
111+
defer b.proto.UnlockState(STATE_IDLE)
112+
// Call the user callback function
104113
return b.callbackConfig.NoBlocksFunc()
105114
}
106115

107116
func (b *BlockFetch) handleBlock(msgGeneric protocol.Message) error {
108-
if b.proto.GetState() != STATE_STREAMING {
117+
if err := b.proto.LockState([]protocol.State{STATE_STREAMING}); err != nil {
109118
return fmt.Errorf("received block-fetch Block message when protocol not in expected state")
110119
}
111120
if b.callbackConfig.BlockFunc == nil {
@@ -115,24 +124,27 @@ func (b *BlockFetch) handleBlock(msgGeneric protocol.Message) error {
115124
// Decode only enough to get the block type value
116125
var wrapBlock wrappedBlock
117126
if _, err := utils.CborDecode(msg.WrappedBlock, &wrapBlock); err != nil {
118-
return fmt.Errorf("block-fetch: decode error: %s", err)
127+
return fmt.Errorf("%s: decode error: %s", PROTOCOL_NAME, err)
119128
}
120129
blk, err := block.NewBlockFromCbor(wrapBlock.Type, wrapBlock.RawBlock)
121130
if err != nil {
122131
return err
123132
}
124-
// We don't actually need this since it's the state we started in, but it's good to be explicit
125-
b.proto.SetState(STATE_STREAMING)
133+
// Unlock and change state when we're done
134+
defer b.proto.UnlockState(STATE_STREAMING)
135+
// Call the user callback function
126136
return b.callbackConfig.BlockFunc(wrapBlock.Type, blk)
127137
}
128138

129139
func (b *BlockFetch) handleBatchDone() error {
130-
if b.proto.GetState() != STATE_STREAMING {
140+
if err := b.proto.LockState([]protocol.State{STATE_STREAMING}); err != nil {
131141
return fmt.Errorf("received block-fetch BatchDone message when protocol not in expected state")
132142
}
133143
if b.callbackConfig.BatchDoneFunc == nil {
134144
return fmt.Errorf("received block-fetch BatchDone message but no callback function is defined")
135145
}
136-
b.proto.SetState(STATE_IDLE)
146+
// Unlock and change state when we're done
147+
defer b.proto.UnlockState(STATE_IDLE)
148+
// Call the user callback function
137149
return b.callbackConfig.BatchDoneFunc()
138150
}

protocol/chainsync/chainsync.go

Lines changed: 46 additions & 27 deletions
Original file line numberDiff line numberDiff line change
@@ -12,12 +12,14 @@ const (
1212
PROTOCOL_NAME = "chain-sync"
1313
PROTOCOL_ID_NTN uint16 = 2
1414
PROTOCOL_ID_NTC uint16 = 5
15+
)
1516

16-
STATE_IDLE = iota
17-
STATE_CAN_AWAIT
18-
STATE_MUST_REPLY
19-
STATE_INTERSECT
20-
STATE_DONE
17+
var (
18+
STATE_IDLE = protocol.NewState(1, "Idle")
19+
STATE_CAN_AWAIT = protocol.NewState(2, "CanAwait")
20+
STATE_MUST_REPLY = protocol.NewState(3, "MustReply")
21+
STATE_INTERSECT = protocol.NewState(4, "Intersect")
22+
STATE_DONE = protocol.NewState(5, "Done")
2123
)
2224

2325
type ChainSync struct {
@@ -55,6 +57,7 @@ func New(m *muxer.Muxer, errorChan chan error, nodeToNode bool, callbackConfig *
5557
callbackConfig: callbackConfig,
5658
}
5759
c.proto = protocol.New(PROTOCOL_NAME, protocolId, m, errorChan, c.messageHandler, c.NewMsgFromCbor)
60+
// Set initial state
5861
c.proto.SetState(STATE_IDLE)
5962
return c
6063
}
@@ -81,39 +84,43 @@ func (c *ChainSync) messageHandler(msg protocol.Message) error {
8184
}
8285

8386
func (c *ChainSync) RequestNext() error {
84-
if c.proto.GetState() != STATE_IDLE {
85-
return fmt.Errorf("chain-sync: RequestNext: protocol not in expected state")
87+
if err := c.proto.LockState([]protocol.State{STATE_IDLE}); err != nil {
88+
return fmt.Errorf("%s: RequestNext: protocol not in expected state", PROTOCOL_NAME)
8689
}
8790
// Create our request
8891
msg := newMsgRequestNext()
89-
c.proto.SetState(STATE_CAN_AWAIT)
92+
// Unlock and change state when we're done
93+
defer c.proto.UnlockState(STATE_CAN_AWAIT)
9094
// Send request
9195
return c.proto.SendMessage(msg, false)
9296
}
9397

9498
func (c *ChainSync) FindIntersect(points []interface{}) error {
95-
if c.proto.GetState() != STATE_IDLE {
96-
return fmt.Errorf("chain-sync: FindIntersect: protocol not in expected state")
99+
if err := c.proto.LockState([]protocol.State{STATE_IDLE}); err != nil {
100+
return fmt.Errorf("%s: FindIntersect: protocol not in expected state", PROTOCOL_NAME)
97101
}
98102
msg := newMsgFindIntersect(points)
99-
c.proto.SetState(STATE_INTERSECT)
103+
// Unlock and change state when we're done
104+
defer c.proto.UnlockState(STATE_INTERSECT)
100105
// Send request
101106
return c.proto.SendMessage(msg, false)
102107
}
103108

104109
func (c *ChainSync) handleAwaitReply() error {
105-
if c.proto.GetState() != STATE_CAN_AWAIT {
110+
if err := c.proto.LockState([]protocol.State{STATE_CAN_AWAIT}); err != nil {
106111
return fmt.Errorf("received chain-sync AwaitReply message when protocol not in expected state")
107112
}
108113
if c.callbackConfig.AwaitReplyFunc == nil {
109114
return fmt.Errorf("received chain-sync AwaitReply message but no callback function is defined")
110115
}
111-
c.proto.SetState(STATE_MUST_REPLY)
116+
// Unlock and change state when we're done
117+
defer c.proto.UnlockState(STATE_MUST_REPLY)
118+
// Call the user callback function
112119
return c.callbackConfig.AwaitReplyFunc()
113120
}
114121

115122
func (c *ChainSync) handleRollForward(msgGeneric protocol.Message) error {
116-
if c.proto.GetState() != STATE_CAN_AWAIT && c.proto.GetState() != STATE_MUST_REPLY {
123+
if err := c.proto.LockState([]protocol.State{STATE_CAN_AWAIT, STATE_MUST_REPLY}); err != nil {
117124
return fmt.Errorf("received chain-sync RollForward message when protocol not in expected state")
118125
}
119126
if c.callbackConfig.RollForwardFunc == nil {
@@ -128,7 +135,7 @@ func (c *ChainSync) handleRollForward(msgGeneric protocol.Message) error {
128135
case block.BLOCK_HEADER_TYPE_BYRON:
129136
var wrapHeaderByron wrappedHeaderByron
130137
if _, err := utils.CborDecode(msg.WrappedHeader.RawData, &wrapHeaderByron); err != nil {
131-
return fmt.Errorf("chain-sync: decode error: %s", err)
138+
return fmt.Errorf("%s: decode error: %s", PROTOCOL_NAME, err)
132139
}
133140
blockType = wrapHeaderByron.Unknown.Type
134141
var err error
@@ -148,75 +155,87 @@ func (c *ChainSync) handleRollForward(msgGeneric protocol.Message) error {
148155
// We decode into a byte array to implicitly unwrap the CBOR tag object
149156
var payload []byte
150157
if _, err := utils.CborDecode(msg.WrappedHeader.RawData, &payload); err != nil {
151-
return fmt.Errorf("failed fallback decode: %s", err)
158+
return fmt.Errorf("%s: decode error: %s", PROTOCOL_NAME, err)
152159
}
153160
var err error
154161
blockHeader, err = block.NewBlockHeaderFromCbor(blockType, payload)
155162
if err != nil {
156163
return err
157164
}
158165
}
159-
c.proto.SetState(STATE_IDLE)
166+
// Unlock and change state when we're done
167+
defer c.proto.UnlockState(STATE_IDLE)
168+
// Call the user callback function
160169
return c.callbackConfig.RollForwardFunc(blockType, blockHeader)
161170
} else {
162171
msg := msgGeneric.(*msgRollForwardNtC)
163172
// Decode only enough to get the block type value
164173
var wrapBlock wrappedBlock
165174
if _, err := utils.CborDecode(msg.WrappedData, &wrapBlock); err != nil {
166-
return fmt.Errorf("chain-sync: decode error: %s", err)
175+
return fmt.Errorf("%s: decode error: %s", PROTOCOL_NAME, err)
167176
}
168177
blk, err := block.NewBlockFromCbor(wrapBlock.Type, wrapBlock.RawBlock)
169178
if err != nil {
170179
return err
171180
}
172-
c.proto.SetState(STATE_IDLE)
181+
// Unlock and change state when we're done
182+
defer c.proto.UnlockState(STATE_IDLE)
183+
// Call the user callback function
173184
return c.callbackConfig.RollForwardFunc(wrapBlock.Type, blk)
174185
}
175186
}
176187

177188
func (c *ChainSync) handleRollBackward(msgGeneric protocol.Message) error {
178-
if c.proto.GetState() != STATE_CAN_AWAIT && c.proto.GetState() != STATE_MUST_REPLY {
189+
if err := c.proto.LockState([]protocol.State{STATE_CAN_AWAIT, STATE_MUST_REPLY}); err != nil {
179190
return fmt.Errorf("received chain-sync RollBackward message when protocol not in expected state")
180191
}
181192
if c.callbackConfig.RollBackwardFunc == nil {
182193
return fmt.Errorf("received chain-sync RollBackward message but no callback function is defined")
183194
}
184195
msg := msgGeneric.(*msgRollBackward)
185-
c.proto.SetState(STATE_IDLE)
196+
// Unlock and change state when we're done
197+
defer c.proto.UnlockState(STATE_IDLE)
198+
// Call the user callback function
186199
return c.callbackConfig.RollBackwardFunc(msg.Point, msg.Tip)
187200
}
188201

189202
func (c *ChainSync) handleIntersectFound(msgGeneric protocol.Message) error {
190-
if c.proto.GetState() != STATE_INTERSECT {
203+
if err := c.proto.LockState([]protocol.State{STATE_INTERSECT}); err != nil {
191204
return fmt.Errorf("received chain-sync IntersectFound message when protocol not in expected state")
192205
}
193206
if c.callbackConfig.IntersectFoundFunc == nil {
194207
return fmt.Errorf("received chain-sync IntersectFound message but no callback function is defined")
195208
}
196209
msg := msgGeneric.(*msgIntersectFound)
197-
c.proto.SetState(STATE_IDLE)
210+
// Unlock and change state when we're done
211+
defer c.proto.UnlockState(STATE_IDLE)
212+
// Call the user callback function
198213
return c.callbackConfig.IntersectFoundFunc(msg.Point, msg.Tip)
199214
}
200215

201216
func (c *ChainSync) handleIntersectNotFound(msgGeneric protocol.Message) error {
202-
if c.proto.GetState() != STATE_INTERSECT {
217+
if err := c.proto.LockState([]protocol.State{STATE_INTERSECT}); err != nil {
203218
return fmt.Errorf("received chain-sync IntersectNotFound message when protocol not in expected state")
204219
}
205220
if c.callbackConfig.IntersectNotFoundFunc == nil {
206221
return fmt.Errorf("received chain-sync IntersectNotFound message but no callback function is defined")
207222
}
208223
msg := msgGeneric.(*msgIntersectNotFound)
209-
c.proto.SetState(STATE_IDLE)
224+
// Unlock and change state when we're done
225+
defer c.proto.UnlockState(STATE_IDLE)
226+
// Call the user callback function
210227
return c.callbackConfig.IntersectNotFoundFunc(msg.Tip)
211228
}
212229

213230
func (c *ChainSync) handleDone() error {
214-
if c.proto.GetState() != STATE_IDLE {
231+
if err := c.proto.LockState([]protocol.State{STATE_IDLE}); err != nil {
215232
return fmt.Errorf("received chain-sync Done message when protocol not in expected state")
216233
}
217234
if c.callbackConfig.DoneFunc == nil {
218235
return fmt.Errorf("received chain-sync Done message but no callback function is defined")
219236
}
220-
c.proto.SetState(STATE_DONE)
237+
// Unlock and change state when we're done
238+
defer c.proto.UnlockState(STATE_DONE)
239+
// Call the user callback function
221240
return c.callbackConfig.DoneFunc()
222241
}

0 commit comments

Comments
 (0)