@@ -21,6 +21,7 @@ type Muxer struct {
2121 conn net.Conn
2222 sendMutex sync.Mutex
2323 startChan chan bool
24+ doneChan chan bool
2425 ErrorChan chan error
2526 protocolSenders map [uint16 ]chan * Segment
2627 protocolReceivers map [uint16 ]chan * Segment
@@ -30,6 +31,7 @@ func New(conn net.Conn) *Muxer {
3031 m := & Muxer {
3132 conn : conn ,
3233 startChan : make (chan bool , 1 ),
34+ doneChan : make (chan bool ),
3335 ErrorChan : make (chan error , 10 ),
3436 protocolSenders : make (map [uint16 ]chan * Segment ),
3537 protocolReceivers : make (map [uint16 ]chan * Segment ),
@@ -42,6 +44,37 @@ func (m *Muxer) Start() {
4244 m .startChan <- true
4345}
4446
47+ func (m * Muxer ) Stop () {
48+ // Immediately return if we're already shutting down
49+ select {
50+ case <- m .doneChan :
51+ return
52+ default :
53+ }
54+ // Close protocol receive channels
55+ // We rely on the individual mini-protocols to close the sender channel
56+ for _ , recvChan := range m .protocolReceivers {
57+ close (recvChan )
58+ }
59+ // Close ErrorChan to signify to consumer that we're shutting down
60+ close (m .ErrorChan )
61+ // Close doneChan to signify that we're shutting down
62+ close (m .doneChan )
63+ }
64+
65+ func (m * Muxer ) sendError (err error ) {
66+ // Immediately return if we're already shutting down
67+ select {
68+ case <- m .doneChan :
69+ return
70+ default :
71+ }
72+ // Send error to consumer
73+ m .ErrorChan <- err
74+ // Stop the muxer on any error
75+ m .Stop ()
76+ }
77+
4578func (m * Muxer ) RegisterProtocol (protocolId uint16 ) (chan * Segment , chan * Segment ) {
4679 // Generate channels
4780 senderChan := make (chan * Segment , 10 )
@@ -52,9 +85,17 @@ func (m *Muxer) RegisterProtocol(protocolId uint16) (chan *Segment, chan *Segmen
5285 // Start Goroutine to handle outbound messages
5386 go func () {
5487 for {
55- msg := <- senderChan
56- if err := m .Send (msg ); err != nil {
57- m .ErrorChan <- err
88+ select {
89+ case _ , ok := <- m .doneChan :
90+ // doneChan has been closed, which means we're shutting down
91+ if ! ok {
92+ return
93+ }
94+ case msg := <- senderChan :
95+ if err := m .Send (msg ); err != nil {
96+ m .sendError (err )
97+ return
98+ }
5899 }
59100 }
60101 }()
@@ -81,9 +122,16 @@ func (m *Muxer) Send(msg *Segment) error {
81122func (m * Muxer ) readLoop () {
82123 started := false
83124 for {
125+ // Break out of read loop if we're shutting down
126+ select {
127+ case <- m .doneChan :
128+ return
129+ default :
130+ }
84131 header := SegmentHeader {}
85132 if err := binary .Read (m .conn , binary .BigEndian , & header ); err != nil {
86- m .ErrorChan <- err
133+ m .sendError (err )
134+ return
87135 }
88136 msg := & Segment {
89137 SegmentHeader : header ,
@@ -92,24 +140,32 @@ func (m *Muxer) readLoop() {
92140 // We use ReadFull because it guarantees to read the expected number of bytes or
93141 // return an error
94142 if _ , err := io .ReadFull (m .conn , msg .Payload ); err != nil {
95- m .ErrorChan <- err
96- }
97- // Wait until the muxer is started to process anything other than handshake messages
98- if ! started && msg .GetProtocolId () != PROTOCOL_HANDSHAKE {
99- <- m .startChan
100- started = true
143+ m .sendError (err )
144+ return
101145 }
102146 // Send message payload to proper receiver
103147 recvChan := m .protocolReceivers [msg .GetProtocolId ()]
104148 if recvChan == nil {
105149 // Try the "unknown protocol" receiver if we didn't find an explicit one
106150 recvChan = m .protocolReceivers [PROTOCOL_UNKNOWN ]
107151 if recvChan == nil {
108- m .ErrorChan <- fmt .Errorf ("received message for unknown protocol ID %d" , msg .GetProtocolId ())
152+ m .sendError (fmt .Errorf ("received message for unknown protocol ID %d" , msg .GetProtocolId ()))
153+ return
109154 }
110155 }
111156 if recvChan != nil {
112157 recvChan <- msg
113158 }
159+ // Wait until the muxer is started to continue
160+ // We don't want to read more than one segment until the handshake is complete
161+ if ! started {
162+ select {
163+ case <- m .doneChan :
164+ // Break out of read loop if we're shutting down
165+ return
166+ case <- m .startChan :
167+ started = true
168+ }
169+ }
114170 }
115171}
0 commit comments