@@ -22,6 +22,7 @@ type Protocol struct {
2222 config ProtocolConfig
2323 muxerSendChan chan * muxer.Segment
2424 muxerRecvChan chan * muxer.Segment
25+ muxerDoneChan chan bool
2526 state State
2627 stateMutex sync.Mutex
2728 recvBuffer * bytes.Buffer
@@ -77,21 +78,29 @@ type MessageFromCborFunc func(uint, []byte) (Message, error)
7778
7879func New (config ProtocolConfig ) * Protocol {
7980 p := & Protocol {
80- config : config ,
81+ config : config ,
82+ doneChan : make (chan bool ),
8183 }
8284 return p
8385}
8486
8587func (p * Protocol ) Start () {
8688 // Register protocol with muxer
87- p .muxerSendChan , p .muxerRecvChan = p .config .Muxer .RegisterProtocol (p .config .ProtocolId )
89+ p .muxerSendChan , p .muxerRecvChan , p . muxerDoneChan = p .config .Muxer .RegisterProtocol (p .config .ProtocolId )
8890 // Create buffers and channels
8991 p .recvBuffer = bytes .NewBuffer (nil )
9092 p .sendQueueChan = make (chan Message , 50 )
9193 p .sendStateQueueChan = make (chan Message , 50 )
9294 p .recvReadyChan = make (chan bool , 1 )
9395 p .sendReadyChan = make (chan bool , 1 )
94- p .doneChan = make (chan bool )
96+ // Start goroutine to cleanup when shutting down
97+ go func () {
98+ <- p .doneChan
99+ close (p .sendQueueChan )
100+ close (p .sendStateQueueChan )
101+ close (p .recvReadyChan )
102+ close (p .sendReadyChan )
103+ }()
95104 // Set initial state
96105 p .setState (p .config .InitialState )
97106 // Start our send and receive Goroutines
@@ -107,6 +116,10 @@ func (p *Protocol) Role() ProtocolRole {
107116 return p .config .Role
108117}
109118
119+ func (p * Protocol ) DoneChan () chan bool {
120+ return p .doneChan
121+ }
122+
110123func (p * Protocol ) SendMessage (msg Message ) error {
111124 p .sendQueueChan <- msg
112125 return nil
@@ -122,14 +135,14 @@ func (p *Protocol) sendLoop() {
122135 var err error
123136 for {
124137 select {
125- case <- p .sendReadyChan :
126- // We are ready to send based on state map
127138 case <- p .doneChan :
128139 // We are responsible for closing this channel as the sender, even through it
129140 // was created by the muxer
130141 close (p .muxerSendChan )
131142 // Break out of send loop if we're shutting down
132143 return
144+ case <- p .sendReadyChan :
145+ // We are ready to send based on state map
133146 }
134147 // Lock the state to prevent collisions
135148 p .stateMutex .Lock ()
@@ -155,7 +168,11 @@ func (p *Protocol) sendLoop() {
155168 msgCount := 0
156169 for {
157170 // Get next message from send queue
158- msg := <- p .sendQueueChan
171+ msg , ok := <- p .sendQueueChan
172+ if ! ok {
173+ // We're shutting down
174+ return
175+ }
159176 msgCount = msgCount + 1
160177 // Write the message into the send state queue if we already have a new state
161178 if setNewState {
@@ -234,20 +251,29 @@ func (p *Protocol) recvLoop() {
234251 // Don't grab the next segment from the muxer if we still have data in the buffer
235252 if ! leftoverData {
236253 // Wait for segment
237- segment , ok := <- p .muxerRecvChan
238- // Break out of receive loop if channel is closed
239- if ! ok {
254+ select {
255+ case <- p .muxerDoneChan :
240256 close (p .doneChan )
241257 return
258+ case segment , ok := <- p .muxerRecvChan :
259+ if ! ok {
260+ close (p .doneChan )
261+ return
262+ }
263+ // Add segment payload to buffer
264+ p .recvBuffer .Write (segment .Payload )
265+ // Save whether it's a response
266+ isResponse = segment .IsResponse ()
242267 }
243- // Add segment payload to buffer
244- p .recvBuffer .Write (segment .Payload )
245- // Save whether it's a response
246- isResponse = segment .IsResponse ()
247268 }
248269 leftoverData = false
249270 // Wait until ready to receive based on state map
250- <- p .recvReadyChan
271+ select {
272+ case <- p .muxerDoneChan :
273+ close (p .doneChan )
274+ return
275+ case <- p .recvReadyChan :
276+ }
251277 // Decode message into generic list until we can determine what type of message it is.
252278 // This also lets us determine how many bytes the message is. We use RawMessage here to
253279 // avoid parsing things that we may not be able to parse
@@ -321,6 +347,7 @@ func (p *Protocol) getNewState(msg Message) (State, error) {
321347func (p * Protocol ) setState (state State ) {
322348 // Disable any previous state transition timer
323349 if p .stateTransitionTimer != nil {
350+ // Stop timer and drain channel
324351 if ! p .stateTransitionTimer .Stop () {
325352 <- p .stateTransitionTimer .C
326353 }
0 commit comments