@@ -5,30 +5,37 @@ import (
55 "encoding/binary"
66 "fmt"
77 "io"
8+ "sync"
9+ )
10+
11+ const (
12+ // Magic number chosen to represent unknown protocols
13+ PROTOCOL_UNKNOWN uint16 = 0xabcd
814)
915
1016type Muxer struct {
1117 conn io.ReadWriteCloser
18+ sendMutex sync.Mutex
1219 ErrorChan chan error
13- protocolSenders map [uint16 ]chan * Message
14- protocolReceivers map [uint16 ]chan * Message
20+ protocolSenders map [uint16 ]chan * Segment
21+ protocolReceivers map [uint16 ]chan * Segment
1522}
1623
1724func New (conn io.ReadWriteCloser ) * Muxer {
1825 m := & Muxer {
1926 conn : conn ,
2027 ErrorChan : make (chan error , 10 ),
21- protocolSenders : make (map [uint16 ]chan * Message ),
22- protocolReceivers : make (map [uint16 ]chan * Message ),
28+ protocolSenders : make (map [uint16 ]chan * Segment ),
29+ protocolReceivers : make (map [uint16 ]chan * Segment ),
2330 }
2431 go m .readLoop ()
2532 return m
2633}
2734
28- func (m * Muxer ) RegisterProtocol (protocolId uint16 ) (chan * Message , chan * Message ) {
35+ func (m * Muxer ) RegisterProtocol (protocolId uint16 ) (chan * Segment , chan * Segment ) {
2936 // Generate channels
30- senderChan := make (chan * Message , 10 )
31- receiverChan := make (chan * Message , 10 )
37+ senderChan := make (chan * Segment , 10 )
38+ receiverChan := make (chan * Segment , 10 )
3239 // Record channels in protocol sender/receiver maps
3340 m .protocolSenders [protocolId ] = senderChan
3441 m .protocolReceivers [protocolId ] = receiverChan
@@ -44,9 +51,12 @@ func (m *Muxer) RegisterProtocol(protocolId uint16) (chan *Message, chan *Messag
4451 return senderChan , receiverChan
4552}
4653
47- func (m * Muxer ) Send (msg * Message ) error {
54+ func (m * Muxer ) Send (msg * Segment ) error {
55+ // We use a mutex to make sure only one protocol can send at a time
56+ m .sendMutex .Lock ()
57+ defer m .sendMutex .Unlock ()
4858 buf := & bytes.Buffer {}
49- err := binary .Write (buf , binary .BigEndian , msg .MessageHeader )
59+ err := binary .Write (buf , binary .BigEndian , msg .SegmentHeader )
5060 if err != nil {
5161 return err
5262 }
@@ -60,12 +70,12 @@ func (m *Muxer) Send(msg *Message) error {
6070
6171func (m * Muxer ) readLoop () {
6272 for {
63- header := MessageHeader {}
73+ header := SegmentHeader {}
6474 if err := binary .Read (m .conn , binary .BigEndian , & header ); err != nil {
6575 m .ErrorChan <- err
6676 }
67- msg := & Message {
68- MessageHeader : header ,
77+ msg := & Segment {
78+ SegmentHeader : header ,
6979 Payload : make ([]byte , header .PayloadLength ),
7080 }
7181 // We use ReadFull because it guarantees to read the expected number of bytes or
@@ -76,7 +86,11 @@ func (m *Muxer) readLoop() {
7686 // Send message payload to proper receiver
7787 recvChan := m .protocolReceivers [msg .GetProtocolId ()]
7888 if recvChan == nil {
79- m .ErrorChan <- fmt .Errorf ("received message for unknown protocol ID %d" , msg .GetProtocolId ())
89+ // Try the "unknown protocol" receiver if we didn't find an explicit one
90+ recvChan = m .protocolReceivers [PROTOCOL_UNKNOWN ]
91+ if recvChan == nil {
92+ m .ErrorChan <- fmt .Errorf ("received message for unknown protocol ID %d" , msg .GetProtocolId ())
93+ }
8094 } else {
8195 m .protocolReceivers [msg .GetProtocolId ()] <- msg
8296 }
0 commit comments