Skip to content

Commit 8dc0b05

Browse files
authored
fix: remove use of sync.Once to avoid deadlocks (#523)
Fixes #522
1 parent 9f9589f commit 8dc0b05

File tree

2 files changed

+33
-13
lines changed

2 files changed

+33
-13
lines changed

muxer/muxer.go

Lines changed: 24 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -22,6 +22,7 @@ package muxer
2222
import (
2323
"bytes"
2424
"encoding/binary"
25+
"errors"
2526
"fmt"
2627
"io"
2728
"net"
@@ -64,7 +65,6 @@ type Muxer struct {
6465
protocolReceivers map[uint16]map[ProtocolRole]chan *Segment
6566
protocolReceiversMutex sync.Mutex
6667
diffusionMode DiffusionMode
67-
onceStart sync.Once
6868
onceStop sync.Once
6969
}
7070

@@ -78,8 +78,21 @@ func New(conn net.Conn) *Muxer {
7878
protocolSenders: make(map[uint16]map[ProtocolRole]chan *Segment),
7979
protocolReceivers: make(map[uint16]map[ProtocolRole]chan *Segment),
8080
}
81+
// Start read goroutine
8182
m.waitGroup.Add(1)
8283
go m.readLoop()
84+
// Start cleanup routine
85+
go func() {
86+
// Wait for done signal
87+
<-m.doneChan
88+
// Close underlying connection
89+
// We must do this to break out of pending Read() calls to shut down cleanly
90+
_ = m.conn.Close()
91+
// Wait for other goroutines to shutdown
92+
m.waitGroup.Wait()
93+
// Close ErrorChan to signify to consumer that we're shutting down
94+
close(m.errorChan)
95+
}()
8396
return m
8497
}
8598

@@ -89,23 +102,17 @@ func (m *Muxer) ErrorChan() chan error {
89102

90103
// Start unblocks the read loop after the initial handshake to allow it to start processing messages
91104
func (m *Muxer) Start() {
92-
m.onceStart.Do(func() {
93-
m.startChan <- true
94-
})
105+
select {
106+
case m.startChan <- true:
107+
default:
108+
}
95109
}
96110

97111
// Stop shuts down the muxer
98112
func (m *Muxer) Stop() {
99113
m.onceStop.Do(func() {
100114
// Close doneChan to signify that we're shutting down
101115
close(m.doneChan)
102-
// Close underlying connection
103-
// We must do this to break out of pending Read() calls to shut down cleanly
104-
_ = m.conn.Close()
105-
// Wait for other goroutines to shutdown
106-
m.waitGroup.Wait()
107-
// Close ErrorChan to signify to consumer that we're shutting down
108-
close(m.errorChan)
109116
})
110117
}
111118

@@ -220,6 +227,9 @@ func (m *Muxer) readLoop() {
220227
}
221228
header := SegmentHeader{}
222229
if err := binary.Read(m.conn, binary.BigEndian, &header); err != nil {
230+
if errors.Is(err, io.ErrClosedPipe) {
231+
err = io.EOF
232+
}
223233
m.sendError(err)
224234
return
225235
}
@@ -230,6 +240,9 @@ func (m *Muxer) readLoop() {
230240
// We use ReadFull because it guarantees to read the expected number of bytes or
231241
// return an error
232242
if _, err := io.ReadFull(m.conn, msg.Payload); err != nil {
243+
if errors.Is(err, io.ErrClosedPipe) {
244+
err = io.EOF
245+
}
233246
m.sendError(err)
234247
return
235248
}

protocol/handshake/server_test.go

Lines changed: 9 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -106,8 +106,15 @@ func TestServerHandshakeRefuseVersionMismatch(t *testing.T) {
106106
InputMessageType: handshake.MessageTypeRefuse,
107107
InputMessage: handshake.NewMsgRefuse(
108108
[]any{
109-
handshake.RefuseReasonVersionMismatch,
110-
protocol.GetProtocolVersionsNtC(),
109+
uint64(handshake.RefuseReasonVersionMismatch),
110+
// Convert []uint16 to []any
111+
func(in []uint16) []any {
112+
var ret []any
113+
for _, item := range in {
114+
ret = append(ret, item)
115+
}
116+
return ret
117+
}(protocol.GetProtocolVersionsNtC()),
111118
},
112119
),
113120
},

0 commit comments

Comments
 (0)