diff --git a/protocol/protocol.go b/protocol/protocol.go index 4f7753b4..75705657 100644 --- a/protocol/protocol.go +++ b/protocol/protocol.go @@ -32,6 +32,9 @@ import ( // This is completely arbitrary, but the line had to be drawn somewhere const maxMessagesPerSegment = 20 +// DefaultRecvQueueSize is the default capacity for the recv queue channel +const DefaultRecvQueueSize = 50 + // Protocol implements the base functionality of an Ouroboros mini-protocol type Protocol struct { config ProtocolConfig @@ -41,6 +44,7 @@ type Protocol struct { muxerDoneChan chan bool sendQueueChan chan Message recvDoneChan chan struct{} + recvQueueChan chan Message recvReadyChan chan bool sendDoneChan chan struct{} sendReadyChan chan bool @@ -63,6 +67,7 @@ type ProtocolConfig struct { StateMap StateMap StateContext interface{} InitialState State + RecvQueueSize int } // ProtocolMode is an enum of the protocol modes @@ -109,6 +114,9 @@ type MessageFromCborFunc func(uint, []byte) (Message, error) // New returns a new Protocol object func New(config ProtocolConfig) *Protocol { + if config.RecvQueueSize == 0 { + config.RecvQueueSize = DefaultRecvQueueSize + } p := &Protocol{ config: config, doneChan: make(chan struct{}), @@ -137,6 +145,7 @@ func (p *Protocol) Start() { // Create channels p.sendQueueChan = make(chan Message, 50) + p.recvQueueChan = make(chan Message, p.config.RecvQueueSize) p.recvReadyChan = make(chan bool, 1) p.sendReadyChan = make(chan bool, 1) @@ -151,6 +160,7 @@ func (p *Protocol) Start() { }() go p.stateLoop(stateTransitionChan) + go p.readLoop() go p.recvLoop() go p.sendLoop() }) @@ -320,16 +330,11 @@ func (p *Protocol) sendLoop() { } } -func (p *Protocol) recvLoop() { - defer func() { - close(p.recvDoneChan) - }() - +func (p *Protocol) readLoop() { leftoverData := false - recvBuffer := bytes.NewBuffer(nil) + readBuffer := bytes.NewBuffer(nil) for { - var err error // Don't grab the next segment from the muxer if we still have data in the buffer if !leftoverData { // Wait for segment @@ -344,29 +349,19 @@ func (p *Protocol) recvLoop() { return } // Add segment payload to buffer - recvBuffer.Write(segment.Payload) + readBuffer.Write(segment.Payload) } } leftoverData = false - // Wait until ready to receive based on state map - select { - case <-p.sendDoneChan: - // Break out of receive loop if we're shutting down - return - case <-p.muxerDoneChan: - return - case <-p.recvReadyChan: - } // Decode message into generic list until we can determine what type of message it is. // This also lets us determine how many bytes the message is. We use RawMessage here to // avoid parsing things that we may not be able to parse tmpMsg := []cbor.RawMessage{} - numBytesRead, err := cbor.Decode(recvBuffer.Bytes(), &tmpMsg) + numBytesRead, err := cbor.Decode(readBuffer.Bytes(), &tmpMsg) if err != nil { - if errors.Is(err, io.ErrUnexpectedEOF) && recvBuffer.Len() > 0 { + if errors.Is(err, io.ErrUnexpectedEOF) && readBuffer.Len() > 0 { // This is probably a multi-part message, so we wait until we get more of the message // before trying to process it - p.recvReadyChan <- true continue } p.SendError(fmt.Errorf("%s: decode error: %w", p.config.Name, err)) @@ -378,7 +373,7 @@ func (p *Protocol) recvLoop() { p.SendError(fmt.Errorf("%s: decode error: %w", p.config.Name, err)) } // Create Message object from CBOR - msgData := recvBuffer.Bytes()[:numBytesRead] + msgData := readBuffer.Bytes()[:numBytesRead] msg, err := p.config.MessageFromCborFunc(msgType, msgData) if err != nil { p.SendError(err) @@ -394,19 +389,58 @@ func (p *Protocol) recvLoop() { ) return } - // Handle message - if err := p.handleMessage(msg); err != nil { - p.SendError(err) + // Add message to receive queue + select { + case p.recvQueueChan <- msg: + default: + p.SendError( + fmt.Errorf( + "%s: received message queue limit exceeded", + p.config.Name, + ), + ) return } - if numBytesRead < recvBuffer.Len() { + if numBytesRead < readBuffer.Len() { // There is another message in the same muxer segment, so we reset the buffer with just // the remaining data - recvBuffer = bytes.NewBuffer(recvBuffer.Bytes()[numBytesRead:]) + readBuffer = bytes.NewBuffer(readBuffer.Bytes()[numBytesRead:]) leftoverData = true } else { // Empty out our buffer since we successfully processed the message - recvBuffer.Reset() + readBuffer.Reset() + } + } +} + +func (p *Protocol) recvLoop() { + defer func() { + close(p.recvDoneChan) + }() + + for { + // Wait until ready to receive based on state map + select { + case <-p.sendDoneChan: + // Break out of receive loop if we're shutting down + return + case <-p.muxerDoneChan: + return + case <-p.recvReadyChan: + } + // Read next message from queue + select { + case <-p.sendDoneChan: + // Break out of receive loop if we're shutting down + return + case <-p.muxerDoneChan: + return + case msg := <-p.recvQueueChan: + // Handle message + if err := p.handleMessage(msg); err != nil { + p.SendError(err) + return + } } } }