Skip to content
Merged
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
88 changes: 61 additions & 27 deletions protocol/protocol.go
Original file line number Diff line number Diff line change
Expand Up @@ -32,6 +32,9 @@ import (
// This is completely arbitrary, but the line had to be drawn somewhere
const maxMessagesPerSegment = 20

// DefaultRecvQueueSize is the default capacity for the recv queue channel
const DefaultRecvQueueSize = 50

// Protocol implements the base functionality of an Ouroboros mini-protocol
type Protocol struct {
config ProtocolConfig
Expand All @@ -41,6 +44,7 @@ type Protocol struct {
muxerDoneChan chan bool
sendQueueChan chan Message
recvDoneChan chan struct{}
recvQueueChan chan Message
recvReadyChan chan bool
sendDoneChan chan struct{}
sendReadyChan chan bool
Expand All @@ -63,6 +67,7 @@ type ProtocolConfig struct {
StateMap StateMap
StateContext interface{}
InitialState State
RecvQueueSize int
}

// ProtocolMode is an enum of the protocol modes
Expand Down Expand Up @@ -109,6 +114,9 @@ type MessageFromCborFunc func(uint, []byte) (Message, error)

// New returns a new Protocol object
func New(config ProtocolConfig) *Protocol {
if config.RecvQueueSize == 0 {
config.RecvQueueSize = DefaultRecvQueueSize
}
p := &Protocol{
config: config,
doneChan: make(chan struct{}),
Expand Down Expand Up @@ -137,6 +145,7 @@ func (p *Protocol) Start() {

// Create channels
p.sendQueueChan = make(chan Message, 50)
p.recvQueueChan = make(chan Message, p.config.RecvQueueSize)
p.recvReadyChan = make(chan bool, 1)
p.sendReadyChan = make(chan bool, 1)

Expand All @@ -151,6 +160,7 @@ func (p *Protocol) Start() {
}()

go p.stateLoop(stateTransitionChan)
go p.readLoop()
go p.recvLoop()
go p.sendLoop()
})
Expand Down Expand Up @@ -320,16 +330,11 @@ func (p *Protocol) sendLoop() {
}
}

func (p *Protocol) recvLoop() {
defer func() {
close(p.recvDoneChan)
}()

func (p *Protocol) readLoop() {
leftoverData := false
recvBuffer := bytes.NewBuffer(nil)
readBuffer := bytes.NewBuffer(nil)

for {
var err error
// Don't grab the next segment from the muxer if we still have data in the buffer
if !leftoverData {
// Wait for segment
Expand All @@ -344,29 +349,19 @@ func (p *Protocol) recvLoop() {
return
}
// Add segment payload to buffer
recvBuffer.Write(segment.Payload)
readBuffer.Write(segment.Payload)
}
}
leftoverData = false
// Wait until ready to receive based on state map
select {
case <-p.sendDoneChan:
// Break out of receive loop if we're shutting down
return
case <-p.muxerDoneChan:
return
case <-p.recvReadyChan:
}
// Decode message into generic list until we can determine what type of message it is.
// This also lets us determine how many bytes the message is. We use RawMessage here to
// avoid parsing things that we may not be able to parse
tmpMsg := []cbor.RawMessage{}
numBytesRead, err := cbor.Decode(recvBuffer.Bytes(), &tmpMsg)
numBytesRead, err := cbor.Decode(readBuffer.Bytes(), &tmpMsg)
if err != nil {
if errors.Is(err, io.ErrUnexpectedEOF) && recvBuffer.Len() > 0 {
if errors.Is(err, io.ErrUnexpectedEOF) && readBuffer.Len() > 0 {
// This is probably a multi-part message, so we wait until we get more of the message
// before trying to process it
p.recvReadyChan <- true
continue
}
p.SendError(fmt.Errorf("%s: decode error: %w", p.config.Name, err))
Expand All @@ -378,7 +373,7 @@ func (p *Protocol) recvLoop() {
p.SendError(fmt.Errorf("%s: decode error: %w", p.config.Name, err))
}
// Create Message object from CBOR
msgData := recvBuffer.Bytes()[:numBytesRead]
msgData := readBuffer.Bytes()[:numBytesRead]
msg, err := p.config.MessageFromCborFunc(msgType, msgData)
if err != nil {
p.SendError(err)
Expand All @@ -394,19 +389,58 @@ func (p *Protocol) recvLoop() {
)
return
}
// Handle message
if err := p.handleMessage(msg); err != nil {
p.SendError(err)
// Add message to receive queue
select {
case p.recvQueueChan <- msg:
default:
p.SendError(
fmt.Errorf(
"%s: received message queue limit exceeded",
p.config.Name,
),
)
return
}
if numBytesRead < recvBuffer.Len() {
if numBytesRead < readBuffer.Len() {
// There is another message in the same muxer segment, so we reset the buffer with just
// the remaining data
recvBuffer = bytes.NewBuffer(recvBuffer.Bytes()[numBytesRead:])
readBuffer = bytes.NewBuffer(readBuffer.Bytes()[numBytesRead:])
leftoverData = true
} else {
// Empty out our buffer since we successfully processed the message
recvBuffer.Reset()
readBuffer.Reset()
}
}
}

func (p *Protocol) recvLoop() {
defer func() {
close(p.recvDoneChan)
}()

for {
// Wait until ready to receive based on state map
select {
case <-p.sendDoneChan:
// Break out of receive loop if we're shutting down
return
case <-p.muxerDoneChan:
return
case <-p.recvReadyChan:
}
// Read next message from queue
select {
case <-p.sendDoneChan:
// Break out of receive loop if we're shutting down
return
case <-p.muxerDoneChan:
return
case msg := <-p.recvQueueChan:
// Handle message
if err := p.handleMessage(msg); err != nil {
p.SendError(err)
return
}
}
}
}
Expand Down
Loading