@@ -5,25 +5,31 @@ import (
55 "encoding/binary"
66 "fmt"
77 "io"
8+ "net"
89 "sync"
910)
1011
1112const (
1213 // Magic number chosen to represent unknown protocols
1314 PROTOCOL_UNKNOWN uint16 = 0xabcd
15+
16+ // Handshake protocol ID
17+ PROTOCOL_HANDSHAKE = 0
1418)
1519
1620type Muxer struct {
17- conn io. ReadWriteCloser
21+ conn net. Conn
1822 sendMutex sync.Mutex
23+ startChan chan bool
1924 ErrorChan chan error
2025 protocolSenders map [uint16 ]chan * Segment
2126 protocolReceivers map [uint16 ]chan * Segment
2227}
2328
24- func New (conn io. ReadWriteCloser ) * Muxer {
29+ func New (conn net. Conn ) * Muxer {
2530 m := & Muxer {
2631 conn : conn ,
32+ startChan : make (chan bool , 1 ),
2733 ErrorChan : make (chan error , 10 ),
2834 protocolSenders : make (map [uint16 ]chan * Segment ),
2935 protocolReceivers : make (map [uint16 ]chan * Segment ),
@@ -32,6 +38,10 @@ func New(conn io.ReadWriteCloser) *Muxer {
3238 return m
3339}
3440
41+ func (m * Muxer ) Start () {
42+ m .startChan <- true
43+ }
44+
3545func (m * Muxer ) RegisterProtocol (protocolId uint16 ) (chan * Segment , chan * Segment ) {
3646 // Generate channels
3747 senderChan := make (chan * Segment , 10 )
@@ -69,6 +79,7 @@ func (m *Muxer) Send(msg *Segment) error {
6979}
7080
7181func (m * Muxer ) readLoop () {
82+ started := false
7283 for {
7384 header := SegmentHeader {}
7485 if err := binary .Read (m .conn , binary .BigEndian , & header ); err != nil {
@@ -83,6 +94,11 @@ func (m *Muxer) readLoop() {
8394 if _ , err := io .ReadFull (m .conn , msg .Payload ); err != nil {
8495 m .ErrorChan <- err
8596 }
97+ // Wait until the muxer is started to process anything other than handshake messages
98+ if ! started && msg .GetProtocolId () != PROTOCOL_HANDSHAKE {
99+ <- m .startChan
100+ started = true
101+ }
86102 // Send message payload to proper receiver
87103 recvChan := m .protocolReceivers [msg .GetProtocolId ()]
88104 if recvChan == nil {
0 commit comments