@@ -32,6 +32,9 @@ import (
32
32
// This is completely arbitrary, but the line had to be drawn somewhere
33
33
const maxMessagesPerSegment = 20
34
34
35
+ // DefaultRecvQueueSize is the default capacity for the recv queue channel
36
+ const DefaultRecvQueueSize = 50
37
+
35
38
// Protocol implements the base functionality of an Ouroboros mini-protocol
36
39
type Protocol struct {
37
40
config ProtocolConfig
@@ -41,6 +44,7 @@ type Protocol struct {
41
44
muxerDoneChan chan bool
42
45
sendQueueChan chan Message
43
46
recvDoneChan chan struct {}
47
+ recvQueueChan chan Message
44
48
recvReadyChan chan bool
45
49
sendDoneChan chan struct {}
46
50
sendReadyChan chan bool
@@ -63,6 +67,7 @@ type ProtocolConfig struct {
63
67
StateMap StateMap
64
68
StateContext interface {}
65
69
InitialState State
70
+ RecvQueueSize int
66
71
}
67
72
68
73
// ProtocolMode is an enum of the protocol modes
@@ -109,6 +114,9 @@ type MessageFromCborFunc func(uint, []byte) (Message, error)
109
114
110
115
// New returns a new Protocol object
111
116
func New (config ProtocolConfig ) * Protocol {
117
+ if config .RecvQueueSize == 0 {
118
+ config .RecvQueueSize = DefaultRecvQueueSize
119
+ }
112
120
p := & Protocol {
113
121
config : config ,
114
122
doneChan : make (chan struct {}),
@@ -137,6 +145,7 @@ func (p *Protocol) Start() {
137
145
138
146
// Create channels
139
147
p .sendQueueChan = make (chan Message , 50 )
148
+ p .recvQueueChan = make (chan Message , p .config .RecvQueueSize )
140
149
p .recvReadyChan = make (chan bool , 1 )
141
150
p .sendReadyChan = make (chan bool , 1 )
142
151
@@ -151,6 +160,7 @@ func (p *Protocol) Start() {
151
160
}()
152
161
153
162
go p .stateLoop (stateTransitionChan )
163
+ go p .readLoop ()
154
164
go p .recvLoop ()
155
165
go p .sendLoop ()
156
166
})
@@ -320,16 +330,11 @@ func (p *Protocol) sendLoop() {
320
330
}
321
331
}
322
332
323
- func (p * Protocol ) recvLoop () {
324
- defer func () {
325
- close (p .recvDoneChan )
326
- }()
327
-
333
+ func (p * Protocol ) readLoop () {
328
334
leftoverData := false
329
- recvBuffer := bytes .NewBuffer (nil )
335
+ readBuffer := bytes .NewBuffer (nil )
330
336
331
337
for {
332
- var err error
333
338
// Don't grab the next segment from the muxer if we still have data in the buffer
334
339
if ! leftoverData {
335
340
// Wait for segment
@@ -344,29 +349,19 @@ func (p *Protocol) recvLoop() {
344
349
return
345
350
}
346
351
// Add segment payload to buffer
347
- recvBuffer .Write (segment .Payload )
352
+ readBuffer .Write (segment .Payload )
348
353
}
349
354
}
350
355
leftoverData = false
351
- // Wait until ready to receive based on state map
352
- select {
353
- case <- p .sendDoneChan :
354
- // Break out of receive loop if we're shutting down
355
- return
356
- case <- p .muxerDoneChan :
357
- return
358
- case <- p .recvReadyChan :
359
- }
360
356
// Decode message into generic list until we can determine what type of message it is.
361
357
// This also lets us determine how many bytes the message is. We use RawMessage here to
362
358
// avoid parsing things that we may not be able to parse
363
359
tmpMsg := []cbor.RawMessage {}
364
- numBytesRead , err := cbor .Decode (recvBuffer .Bytes (), & tmpMsg )
360
+ numBytesRead , err := cbor .Decode (readBuffer .Bytes (), & tmpMsg )
365
361
if err != nil {
366
- if errors .Is (err , io .ErrUnexpectedEOF ) && recvBuffer .Len () > 0 {
362
+ if errors .Is (err , io .ErrUnexpectedEOF ) && readBuffer .Len () > 0 {
367
363
// This is probably a multi-part message, so we wait until we get more of the message
368
364
// before trying to process it
369
- p .recvReadyChan <- true
370
365
continue
371
366
}
372
367
p .SendError (fmt .Errorf ("%s: decode error: %w" , p .config .Name , err ))
@@ -378,7 +373,7 @@ func (p *Protocol) recvLoop() {
378
373
p .SendError (fmt .Errorf ("%s: decode error: %w" , p .config .Name , err ))
379
374
}
380
375
// Create Message object from CBOR
381
- msgData := recvBuffer .Bytes ()[:numBytesRead ]
376
+ msgData := readBuffer .Bytes ()[:numBytesRead ]
382
377
msg , err := p .config .MessageFromCborFunc (msgType , msgData )
383
378
if err != nil {
384
379
p .SendError (err )
@@ -394,19 +389,58 @@ func (p *Protocol) recvLoop() {
394
389
)
395
390
return
396
391
}
397
- // Handle message
398
- if err := p .handleMessage (msg ); err != nil {
399
- p .SendError (err )
392
+ // Add message to receive queue
393
+ select {
394
+ case p .recvQueueChan <- msg :
395
+ default :
396
+ p .SendError (
397
+ fmt .Errorf (
398
+ "%s: received message queue limit exceeded" ,
399
+ p .config .Name ,
400
+ ),
401
+ )
400
402
return
401
403
}
402
- if numBytesRead < recvBuffer .Len () {
404
+ if numBytesRead < readBuffer .Len () {
403
405
// There is another message in the same muxer segment, so we reset the buffer with just
404
406
// the remaining data
405
- recvBuffer = bytes .NewBuffer (recvBuffer .Bytes ()[numBytesRead :])
407
+ readBuffer = bytes .NewBuffer (readBuffer .Bytes ()[numBytesRead :])
406
408
leftoverData = true
407
409
} else {
408
410
// Empty out our buffer since we successfully processed the message
409
- recvBuffer .Reset ()
411
+ readBuffer .Reset ()
412
+ }
413
+ }
414
+ }
415
+
416
+ func (p * Protocol ) recvLoop () {
417
+ defer func () {
418
+ close (p .recvDoneChan )
419
+ }()
420
+
421
+ for {
422
+ // Wait until ready to receive based on state map
423
+ select {
424
+ case <- p .sendDoneChan :
425
+ // Break out of receive loop if we're shutting down
426
+ return
427
+ case <- p .muxerDoneChan :
428
+ return
429
+ case <- p .recvReadyChan :
430
+ }
431
+ // Read next message from queue
432
+ select {
433
+ case <- p .sendDoneChan :
434
+ // Break out of receive loop if we're shutting down
435
+ return
436
+ case <- p .muxerDoneChan :
437
+ return
438
+ case msg := <- p .recvQueueChan :
439
+ // Handle message
440
+ if err := p .handleMessage (msg ); err != nil {
441
+ p .SendError (err )
442
+ return
443
+ }
410
444
}
411
445
}
412
446
}
0 commit comments