@@ -22,6 +22,7 @@ package muxer
2222import (
2323 "bytes"
2424 "encoding/binary"
25+ "errors"
2526 "fmt"
2627 "io"
2728 "net"
@@ -64,7 +65,6 @@ type Muxer struct {
6465 protocolReceivers map [uint16 ]map [ProtocolRole ]chan * Segment
6566 protocolReceiversMutex sync.Mutex
6667 diffusionMode DiffusionMode
67- onceStart sync.Once
6868 onceStop sync.Once
6969}
7070
@@ -78,8 +78,21 @@ func New(conn net.Conn) *Muxer {
7878 protocolSenders : make (map [uint16 ]map [ProtocolRole ]chan * Segment ),
7979 protocolReceivers : make (map [uint16 ]map [ProtocolRole ]chan * Segment ),
8080 }
81+ // Start read goroutine
8182 m .waitGroup .Add (1 )
8283 go m .readLoop ()
84+ // Start cleanup routine
85+ go func () {
86+ // Wait for done signal
87+ <- m .doneChan
88+ // Close underlying connection
89+ // We must do this to break out of pending Read() calls to shut down cleanly
90+ _ = m .conn .Close ()
91+ // Wait for other goroutines to shutdown
92+ m .waitGroup .Wait ()
93+ // Close ErrorChan to signify to consumer that we're shutting down
94+ close (m .errorChan )
95+ }()
8396 return m
8497}
8598
@@ -89,23 +102,17 @@ func (m *Muxer) ErrorChan() chan error {
89102
90103// Start unblocks the read loop after the initial handshake to allow it to start processing messages
91104func (m * Muxer ) Start () {
92- m .onceStart .Do (func () {
93- m .startChan <- true
94- })
105+ select {
106+ case m .startChan <- true :
107+ default :
108+ }
95109}
96110
97111// Stop shuts down the muxer
98112func (m * Muxer ) Stop () {
99113 m .onceStop .Do (func () {
100114 // Close doneChan to signify that we're shutting down
101115 close (m .doneChan )
102- // Close underlying connection
103- // We must do this to break out of pending Read() calls to shut down cleanly
104- _ = m .conn .Close ()
105- // Wait for other goroutines to shutdown
106- m .waitGroup .Wait ()
107- // Close ErrorChan to signify to consumer that we're shutting down
108- close (m .errorChan )
109116 })
110117}
111118
@@ -220,6 +227,9 @@ func (m *Muxer) readLoop() {
220227 }
221228 header := SegmentHeader {}
222229 if err := binary .Read (m .conn , binary .BigEndian , & header ); err != nil {
230+ if errors .Is (err , io .ErrClosedPipe ) {
231+ err = io .EOF
232+ }
223233 m .sendError (err )
224234 return
225235 }
@@ -230,6 +240,9 @@ func (m *Muxer) readLoop() {
230240 // We use ReadFull because it guarantees to read the expected number of bytes or
231241 // return an error
232242 if _ , err := io .ReadFull (m .conn , msg .Payload ); err != nil {
243+ if errors .Is (err , io .ErrClosedPipe ) {
244+ err = io .EOF
245+ }
233246 m .sendError (err )
234247 return
235248 }
0 commit comments