Skip to content

Commit f6b4e39

Browse files
authored
Merge pull request #479 from blinklabs-io/fix/protocol-shutdown-deadlocks
fix: protocol shutdown deadlocks
2 parents acb24fa + 50c3fce commit f6b4e39

File tree

1 file changed

+88
-66
lines changed

1 file changed

+88
-66
lines changed

protocol/protocol.go

Lines changed: 88 additions & 66 deletions
Original file line numberDiff line numberDiff line change
@@ -138,10 +138,7 @@ func (p *Protocol) Start() {
138138
close(p.sendReadyChan)
139139
// Cancel any timer
140140
if p.stateTransitionTimer != nil {
141-
// Stop timer and drain channel
142-
if !p.stateTransitionTimer.Stop() {
143-
<-p.stateTransitionTimer.C
144-
}
141+
p.stateTransitionTimer.Stop()
145142
p.stateTransitionTimer = nil
146143
}
147144
}()
@@ -201,54 +198,15 @@ func (p *Protocol) sendLoop() {
201198
// Check for queued state changes from previous pipelined sends
202199
setNewState = false
203200
if len(p.sendStateQueueChan) > 0 {
204-
msg := <-p.sendStateQueueChan
205-
newState, err = p.getNewState(msg)
206-
if err != nil {
207-
p.SendError(
208-
fmt.Errorf(
209-
"%s: error sending message: %s",
210-
p.config.Name,
211-
err,
212-
),
213-
)
214-
return
215-
}
216-
setNewState = true
217-
// If there are no queued messages, set the new state now
218-
if len(p.sendQueueChan) == 0 {
219-
p.setState(newState)
220-
p.stateMutex.Unlock()
221-
continue
222-
}
223-
}
224-
// Read queued messages and write into buffer
225-
payloadBuf := bytes.NewBuffer(nil)
226-
msgCount := 0
227-
for {
228-
// Get next message from send queue
229-
msg, ok := <-p.sendQueueChan
230-
if !ok {
231-
// We're shutting down
201+
select {
202+
case <-p.doneChan:
203+
// Break out of send loop if we're shutting down
232204
return
233-
}
234-
msgCount = msgCount + 1
235-
// Write the message into the send state queue if we already have a new state
236-
if setNewState {
237-
p.sendStateQueueChan <- msg
238-
}
239-
// Get raw CBOR from message
240-
data := msg.Cbor()
241-
// If message has no raw CBOR, encode the message
242-
if data == nil {
243-
var err error
244-
data, err = cbor.Encode(msg)
245-
if err != nil {
246-
p.SendError(err)
205+
case msg, ok := <-p.sendStateQueueChan:
206+
if !ok {
207+
// We're shutting down
247208
return
248209
}
249-
}
250-
payloadBuf.Write(data)
251-
if !setNewState {
252210
newState, err = p.getNewState(msg)
253211
if err != nil {
254212
p.SendError(
@@ -261,21 +219,82 @@ func (p *Protocol) sendLoop() {
261219
return
262220
}
263221
setNewState = true
222+
// If there are no queued messages, set the new state now
223+
if len(p.sendQueueChan) == 0 {
224+
p.setState(newState)
225+
p.stateMutex.Unlock()
226+
continue
227+
}
264228
}
265-
// We don't want more than maxMessagesPerSegment messages in a segment
266-
if msgCount >= maxMessagesPerSegment {
267-
break
268-
}
269-
// We don't want to add more messages once we spill over into a second segment
270-
if payloadBuf.Len() > muxer.SegmentMaxPayloadLength {
271-
break
272-
}
273-
// Check if there are any more queued messages
274-
if len(p.sendQueueChan) == 0 {
275-
break
229+
}
230+
// Read queued messages and write into buffer
231+
payloadBuf := bytes.NewBuffer(nil)
232+
msgCount := 0
233+
breakLoop := false
234+
for {
235+
// Get next message from send queue
236+
select {
237+
case <-p.doneChan:
238+
// Break out of send loop if we're shutting down
239+
return
240+
case msg, ok := <-p.sendQueueChan:
241+
if !ok {
242+
// We're shutting down
243+
return
244+
}
245+
msgCount = msgCount + 1
246+
// Write the message into the send state queue if we already have a new state
247+
if setNewState {
248+
p.sendStateQueueChan <- msg
249+
}
250+
// Get raw CBOR from message
251+
data := msg.Cbor()
252+
// If message has no raw CBOR, encode the message
253+
if data == nil {
254+
var err error
255+
data, err = cbor.Encode(msg)
256+
if err != nil {
257+
p.SendError(err)
258+
return
259+
}
260+
}
261+
payloadBuf.Write(data)
262+
if !setNewState {
263+
newState, err = p.getNewState(msg)
264+
if err != nil {
265+
p.SendError(
266+
fmt.Errorf(
267+
"%s: error sending message: %s",
268+
p.config.Name,
269+
err,
270+
),
271+
)
272+
return
273+
}
274+
setNewState = true
275+
}
276+
// We don't want more than maxMessagesPerSegment messages in a segment
277+
if msgCount >= maxMessagesPerSegment {
278+
breakLoop = true
279+
break
280+
}
281+
// We don't want to add more messages once we spill over into a second segment
282+
if payloadBuf.Len() > muxer.SegmentMaxPayloadLength {
283+
breakLoop = true
284+
break
285+
}
286+
// Check if there are any more queued messages
287+
if len(p.sendQueueChan) == 0 {
288+
breakLoop = true
289+
break
290+
}
291+
// We don't want to block on writes to the send state queue
292+
if len(p.sendStateQueueChan) == cap(p.sendStateQueueChan) {
293+
breakLoop = true
294+
break
295+
}
276296
}
277-
// We don't want to block on writes to the send state queue
278-
if len(p.sendStateQueueChan) == cap(p.sendStateQueueChan) {
297+
if breakLoop {
279298
break
280299
}
281300
}
@@ -322,6 +341,9 @@ func (p *Protocol) recvLoop() {
322341
if !leftoverData {
323342
// Wait for segment
324343
select {
344+
case <-p.doneChan:
345+
// Break out of receive loop if we're shutting down
346+
return
325347
case <-p.muxerDoneChan:
326348
close(p.doneChan)
327349
return
@@ -337,6 +359,9 @@ func (p *Protocol) recvLoop() {
337359
leftoverData = false
338360
// Wait until ready to receive based on state map
339361
select {
362+
case <-p.doneChan:
363+
// Break out of receive loop if we're shutting down
364+
return
340365
case <-p.muxerDoneChan:
341366
close(p.doneChan)
342367
return
@@ -425,10 +450,7 @@ func (p *Protocol) getNewState(msg Message) (State, error) {
425450
func (p *Protocol) setState(state State) {
426451
// Disable any previous state transition timer
427452
if p.stateTransitionTimer != nil {
428-
// Stop timer and drain channel
429-
if !p.stateTransitionTimer.Stop() {
430-
<-p.stateTransitionTimer.C
431-
}
453+
p.stateTransitionTimer.Stop()
432454
p.stateTransitionTimer = nil
433455
}
434456
// Set the new state

0 commit comments

Comments
 (0)