Skip to content

Commit 6aadb95

Browse files
authored
refactor: prefetch messages from muxer in protocol (#955)
Fixes #954 Signed-off-by: Aurora Gaffney <[email protected]>
1 parent c9ffdea commit 6aadb95

File tree

1 file changed

+61
-27
lines changed

1 file changed

+61
-27
lines changed

protocol/protocol.go

Lines changed: 61 additions & 27 deletions
Original file line numberDiff line numberDiff line change
@@ -32,6 +32,9 @@ import (
3232
// This is completely arbitrary, but the line had to be drawn somewhere
3333
const maxMessagesPerSegment = 20
3434

35+
// DefaultRecvQueueSize is the default capacity for the recv queue channel
36+
const DefaultRecvQueueSize = 50
37+
3538
// Protocol implements the base functionality of an Ouroboros mini-protocol
3639
type Protocol struct {
3740
config ProtocolConfig
@@ -41,6 +44,7 @@ type Protocol struct {
4144
muxerDoneChan chan bool
4245
sendQueueChan chan Message
4346
recvDoneChan chan struct{}
47+
recvQueueChan chan Message
4448
recvReadyChan chan bool
4549
sendDoneChan chan struct{}
4650
sendReadyChan chan bool
@@ -63,6 +67,7 @@ type ProtocolConfig struct {
6367
StateMap StateMap
6468
StateContext interface{}
6569
InitialState State
70+
RecvQueueSize int
6671
}
6772

6873
// ProtocolMode is an enum of the protocol modes
@@ -109,6 +114,9 @@ type MessageFromCborFunc func(uint, []byte) (Message, error)
109114

110115
// New returns a new Protocol object
111116
func New(config ProtocolConfig) *Protocol {
117+
if config.RecvQueueSize == 0 {
118+
config.RecvQueueSize = DefaultRecvQueueSize
119+
}
112120
p := &Protocol{
113121
config: config,
114122
doneChan: make(chan struct{}),
@@ -137,6 +145,7 @@ func (p *Protocol) Start() {
137145

138146
// Create channels
139147
p.sendQueueChan = make(chan Message, 50)
148+
p.recvQueueChan = make(chan Message, p.config.RecvQueueSize)
140149
p.recvReadyChan = make(chan bool, 1)
141150
p.sendReadyChan = make(chan bool, 1)
142151

@@ -151,6 +160,7 @@ func (p *Protocol) Start() {
151160
}()
152161

153162
go p.stateLoop(stateTransitionChan)
163+
go p.readLoop()
154164
go p.recvLoop()
155165
go p.sendLoop()
156166
})
@@ -320,16 +330,11 @@ func (p *Protocol) sendLoop() {
320330
}
321331
}
322332

323-
func (p *Protocol) recvLoop() {
324-
defer func() {
325-
close(p.recvDoneChan)
326-
}()
327-
333+
func (p *Protocol) readLoop() {
328334
leftoverData := false
329-
recvBuffer := bytes.NewBuffer(nil)
335+
readBuffer := bytes.NewBuffer(nil)
330336

331337
for {
332-
var err error
333338
// Don't grab the next segment from the muxer if we still have data in the buffer
334339
if !leftoverData {
335340
// Wait for segment
@@ -344,29 +349,19 @@ func (p *Protocol) recvLoop() {
344349
return
345350
}
346351
// Add segment payload to buffer
347-
recvBuffer.Write(segment.Payload)
352+
readBuffer.Write(segment.Payload)
348353
}
349354
}
350355
leftoverData = false
351-
// Wait until ready to receive based on state map
352-
select {
353-
case <-p.sendDoneChan:
354-
// Break out of receive loop if we're shutting down
355-
return
356-
case <-p.muxerDoneChan:
357-
return
358-
case <-p.recvReadyChan:
359-
}
360356
// Decode message into generic list until we can determine what type of message it is.
361357
// This also lets us determine how many bytes the message is. We use RawMessage here to
362358
// avoid parsing things that we may not be able to parse
363359
tmpMsg := []cbor.RawMessage{}
364-
numBytesRead, err := cbor.Decode(recvBuffer.Bytes(), &tmpMsg)
360+
numBytesRead, err := cbor.Decode(readBuffer.Bytes(), &tmpMsg)
365361
if err != nil {
366-
if errors.Is(err, io.ErrUnexpectedEOF) && recvBuffer.Len() > 0 {
362+
if errors.Is(err, io.ErrUnexpectedEOF) && readBuffer.Len() > 0 {
367363
// This is probably a multi-part message, so we wait until we get more of the message
368364
// before trying to process it
369-
p.recvReadyChan <- true
370365
continue
371366
}
372367
p.SendError(fmt.Errorf("%s: decode error: %w", p.config.Name, err))
@@ -378,7 +373,7 @@ func (p *Protocol) recvLoop() {
378373
p.SendError(fmt.Errorf("%s: decode error: %w", p.config.Name, err))
379374
}
380375
// Create Message object from CBOR
381-
msgData := recvBuffer.Bytes()[:numBytesRead]
376+
msgData := readBuffer.Bytes()[:numBytesRead]
382377
msg, err := p.config.MessageFromCborFunc(msgType, msgData)
383378
if err != nil {
384379
p.SendError(err)
@@ -394,19 +389,58 @@ func (p *Protocol) recvLoop() {
394389
)
395390
return
396391
}
397-
// Handle message
398-
if err := p.handleMessage(msg); err != nil {
399-
p.SendError(err)
392+
// Add message to receive queue
393+
select {
394+
case p.recvQueueChan <- msg:
395+
default:
396+
p.SendError(
397+
fmt.Errorf(
398+
"%s: received message queue limit exceeded",
399+
p.config.Name,
400+
),
401+
)
400402
return
401403
}
402-
if numBytesRead < recvBuffer.Len() {
404+
if numBytesRead < readBuffer.Len() {
403405
// There is another message in the same muxer segment, so we reset the buffer with just
404406
// the remaining data
405-
recvBuffer = bytes.NewBuffer(recvBuffer.Bytes()[numBytesRead:])
407+
readBuffer = bytes.NewBuffer(readBuffer.Bytes()[numBytesRead:])
406408
leftoverData = true
407409
} else {
408410
// Empty out our buffer since we successfully processed the message
409-
recvBuffer.Reset()
411+
readBuffer.Reset()
412+
}
413+
}
414+
}
415+
416+
func (p *Protocol) recvLoop() {
417+
defer func() {
418+
close(p.recvDoneChan)
419+
}()
420+
421+
for {
422+
// Wait until ready to receive based on state map
423+
select {
424+
case <-p.sendDoneChan:
425+
// Break out of receive loop if we're shutting down
426+
return
427+
case <-p.muxerDoneChan:
428+
return
429+
case <-p.recvReadyChan:
430+
}
431+
// Read next message from queue
432+
select {
433+
case <-p.sendDoneChan:
434+
// Break out of receive loop if we're shutting down
435+
return
436+
case <-p.muxerDoneChan:
437+
return
438+
case msg := <-p.recvQueueChan:
439+
// Handle message
440+
if err := p.handleMessage(msg); err != nil {
441+
p.SendError(err)
442+
return
443+
}
410444
}
411445
}
412446
}

0 commit comments

Comments
 (0)