Skip to content

Commit 7f7e628

Browse files
authored
fix: Goroutine leak in handshake server on version mismatch refusal #535 (#542)
1 parent 1f146ee commit 7f7e628

File tree

3 files changed

+47
-20
lines changed

3 files changed

+47
-20
lines changed

protocol/handshake/server_test.go

Lines changed: 3 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -88,8 +88,9 @@ func TestServerBasicHandshake(t *testing.T) {
8888
}
8989

9090
func TestServerHandshakeRefuseVersionMismatch(t *testing.T) {
91-
// TODO: fix leaking goroutines
92-
//defer goleak.VerifyNone(t)
91+
defer func() {
92+
goleak.VerifyNone(t)
93+
}()
9394
expectedErr := fmt.Errorf("handshake failed: refused due to version mismatch")
9495
mockConn := ouroboros_mock.NewConnection(
9596
ouroboros_mock.ProtocolRoleServer,

protocol/protocol.go

Lines changed: 18 additions & 18 deletions
Original file line numberDiff line numberDiff line change
@@ -25,6 +25,7 @@ import (
2525

2626
"github.com/blinklabs-io/gouroboros/cbor"
2727
"github.com/blinklabs-io/gouroboros/muxer"
28+
"github.com/blinklabs-io/gouroboros/utils"
2829
)
2930

3031
// This is completely arbitrary, but the line had to be drawn somewhere
@@ -40,7 +41,7 @@ type Protocol struct {
4041
recvReadyChan chan bool
4142
sendReadyChan chan bool
4243
stateTransitionChan chan<- protocolStateTransition
43-
doneChan chan bool
44+
doneSignal *utils.DoneSignal
4445
waitGroup sync.WaitGroup
4546
onceStart sync.Once
4647
}
@@ -102,8 +103,8 @@ type MessageFromCborFunc func(uint, []byte) (Message, error)
102103
// New returns a new Protocol object
103104
func New(config ProtocolConfig) *Protocol {
104105
p := &Protocol{
105-
config: config,
106-
doneChan: make(chan bool),
106+
config: config,
107+
doneSignal: utils.NewDoneSignal(),
107108
}
108109
return p
109110
}
@@ -149,8 +150,8 @@ func (p *Protocol) Role() ProtocolRole {
149150
}
150151

151152
// DoneChan returns the channel used to signal protocol shutdown
152-
func (p *Protocol) DoneChan() chan bool {
153-
return p.doneChan
153+
func (p *Protocol) DoneChan() <-chan struct{} {
154+
return p.doneSignal.GetCh()
154155
}
155156

156157
// SendMessage appends a message to the send queue
@@ -178,11 +179,12 @@ func (p *Protocol) sendLoop() {
178179
// We are responsible for closing this channel as the sender, even through it
179180
// was created by the muxer
180181
close(p.muxerSendChan)
182+
p.doneSignal.Close()
181183
}()
182184

183185
for {
184186
select {
185-
case <-p.doneChan:
187+
case <-p.doneSignal.GetCh():
186188
// Break out of send loop if we're shutting down
187189
return
188190
case <-p.sendReadyChan:
@@ -196,7 +198,7 @@ func (p *Protocol) sendLoop() {
196198
for {
197199
// Get next message from send queue
198200
select {
199-
case <-p.doneChan:
201+
case <-p.doneSignal.GetCh():
200202
// Break out of send loop if we're shutting down
201203
return
202204
case msg, ok := <-p.sendQueueChan:
@@ -260,10 +262,7 @@ func (p *Protocol) sendLoop() {
260262
}
261263
// Send current segment
262264
segmentPayload := payloadBuf.Bytes()[:segmentPayloadLength]
263-
isResponse := false
264-
if p.Role() == ProtocolRoleServer {
265-
isResponse = true
266-
}
265+
isResponse := p.Role() == ProtocolRoleServer
267266
segment := muxer.NewSegment(
268267
p.config.ProtocolId,
269268
segmentPayload,
@@ -283,7 +282,11 @@ func (p *Protocol) sendLoop() {
283282
}
284283

285284
func (p *Protocol) recvLoop() {
286-
defer p.waitGroup.Done()
285+
defer func() {
286+
p.waitGroup.Done()
287+
p.doneSignal.Close()
288+
}()
289+
287290
leftoverData := false
288291
recvBuffer := bytes.NewBuffer(nil)
289292

@@ -293,15 +296,13 @@ func (p *Protocol) recvLoop() {
293296
if !leftoverData {
294297
// Wait for segment
295298
select {
296-
case <-p.doneChan:
299+
case <-p.doneSignal.GetCh():
297300
// Break out of receive loop if we're shutting down
298301
return
299302
case <-p.muxerDoneChan:
300-
close(p.doneChan)
301303
return
302304
case segment, ok := <-p.muxerRecvChan:
303305
if !ok {
304-
close(p.doneChan)
305306
return
306307
}
307308
// Add segment payload to buffer
@@ -311,11 +312,10 @@ func (p *Protocol) recvLoop() {
311312
leftoverData = false
312313
// Wait until ready to receive based on state map
313314
select {
314-
case <-p.doneChan:
315+
case <-p.doneSignal.GetCh():
315316
// Break out of receive loop if we're shutting down
316317
return
317318
case <-p.muxerDoneChan:
318-
close(p.doneChan)
319319
return
320320
case <-p.recvReadyChan:
321321
}
@@ -429,7 +429,7 @@ func (p *Protocol) stateLoop(ch <-chan protocolStateTransition) {
429429
return transitionTimer.C
430430
}
431431

432-
protocolDoneChan := p.doneChan
432+
protocolDoneChan := p.doneSignal.GetCh()
433433
stateDoneChan := make(chan struct{})
434434

435435
setState(p.config.InitialState)

utils/utils.go

Lines changed: 26 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -14,3 +14,29 @@
1414

1515
// Package utils provides random utility functions
1616
package utils
17+
18+
import (
19+
"sync"
20+
)
21+
22+
// DoneSignal provides a thread-safe way to close a channel and allows other routines to listen to the channel
23+
type DoneSignal struct {
24+
closeCh chan struct{}
25+
once sync.Once
26+
}
27+
28+
func NewDoneSignal() *DoneSignal {
29+
return &DoneSignal{
30+
closeCh: make(chan struct{}),
31+
}
32+
}
33+
34+
func (cn *DoneSignal) Close() {
35+
cn.once.Do(func() {
36+
close(cn.closeCh)
37+
})
38+
}
39+
40+
func (cn *DoneSignal) GetCh() <-chan struct{} {
41+
return cn.closeCh
42+
}

0 commit comments

Comments
 (0)