Skip to content

Commit 2316c97

Browse files
committed
Muxer improvements and handshake protocol state machine
1 parent 9c18ec9 commit 2316c97

File tree

7 files changed

+325
-222
lines changed

7 files changed

+325
-222
lines changed

handshake/handshake.go

Lines changed: 0 additions & 100 deletions
This file was deleted.

muxer.go

Lines changed: 0 additions & 107 deletions
This file was deleted.

muxer/message.go

Lines changed: 51 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,51 @@
1+
package muxer
2+
3+
import (
4+
"time"
5+
)
6+
7+
const (
8+
MESSAGE_PROTOCOL_ID_RESPONSE_FLAG = 0x8000
9+
)
10+
11+
type MessageHeader struct {
12+
Timestamp uint32
13+
ProtocolId uint16
14+
PayloadLength uint16
15+
}
16+
17+
type Message struct {
18+
MessageHeader
19+
Payload []byte
20+
}
21+
22+
func NewMessage(protocolId uint16, payload []byte, isResponse bool) *Message {
23+
header := MessageHeader{
24+
Timestamp: uint32(time.Now().UnixNano() & 0xffffffff),
25+
ProtocolId: protocolId,
26+
}
27+
if isResponse {
28+
header.ProtocolId = header.ProtocolId + MESSAGE_PROTOCOL_ID_RESPONSE_FLAG
29+
}
30+
header.PayloadLength = uint16(len(payload))
31+
msg := &Message{
32+
MessageHeader: header,
33+
Payload: payload,
34+
}
35+
return msg
36+
}
37+
38+
func (s *MessageHeader) IsRequest() bool {
39+
return (s.ProtocolId & MESSAGE_PROTOCOL_ID_RESPONSE_FLAG) == 0
40+
}
41+
42+
func (s *MessageHeader) IsResponse() bool {
43+
return (s.ProtocolId & MESSAGE_PROTOCOL_ID_RESPONSE_FLAG) > 0
44+
}
45+
46+
func (s *MessageHeader) GetProtocolId() uint16 {
47+
if s.ProtocolId >= MESSAGE_PROTOCOL_ID_RESPONSE_FLAG {
48+
return s.ProtocolId - MESSAGE_PROTOCOL_ID_RESPONSE_FLAG
49+
}
50+
return s.ProtocolId
51+
}

muxer/muxer.go

Lines changed: 82 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,82 @@
1+
package muxer
2+
3+
import (
4+
"bytes"
5+
"encoding/binary"
6+
"fmt"
7+
"io"
8+
)
9+
10+
type Muxer struct {
11+
conn io.ReadWriteCloser
12+
ErrorChan chan error
13+
protocolSenders map[uint16]chan *Message
14+
protocolReceivers map[uint16]chan *Message
15+
}
16+
17+
func New(conn io.ReadWriteCloser) *Muxer {
18+
m := &Muxer{
19+
conn: conn,
20+
ErrorChan: make(chan error, 10),
21+
protocolSenders: make(map[uint16]chan *Message),
22+
protocolReceivers: make(map[uint16]chan *Message),
23+
}
24+
go m.readLoop()
25+
return m
26+
}
27+
28+
func (m *Muxer) RegisterProtocol(protocolId uint16) (chan *Message, chan *Message) {
29+
// Generate channels
30+
senderChan := make(chan *Message, 10)
31+
receiverChan := make(chan *Message, 10)
32+
// Record channels in protocol sender/receiver maps
33+
m.protocolSenders[protocolId] = senderChan
34+
m.protocolReceivers[protocolId] = receiverChan
35+
// Start Goroutine to handle outbound messages
36+
go func() {
37+
for {
38+
msg := <-senderChan
39+
m.Send(msg)
40+
}
41+
}()
42+
return senderChan, receiverChan
43+
}
44+
45+
func (m *Muxer) Send(msg *Message) error {
46+
buf := &bytes.Buffer{}
47+
err := binary.Write(buf, binary.BigEndian, msg.MessageHeader)
48+
if err != nil {
49+
return err
50+
}
51+
buf.Write(msg.Payload)
52+
_, err = m.conn.Write(buf.Bytes())
53+
if err != nil {
54+
return err
55+
}
56+
return nil
57+
}
58+
59+
func (m *Muxer) readLoop() {
60+
for {
61+
header := MessageHeader{}
62+
if err := binary.Read(m.conn, binary.BigEndian, &header); err != nil {
63+
m.ErrorChan <- err
64+
}
65+
msg := &Message{
66+
MessageHeader: header,
67+
Payload: make([]byte, header.PayloadLength),
68+
}
69+
// We use ReadFull because it guarantees to read the expected number of bytes or
70+
// return an error
71+
if _, err := io.ReadFull(m.conn, msg.Payload); err != nil {
72+
m.ErrorChan <- err
73+
}
74+
// Send message payload to proper receiver
75+
recvChan := m.protocolReceivers[msg.GetProtocolId()]
76+
if recvChan == nil {
77+
m.ErrorChan <- fmt.Errorf("received message for unknown protocol ID %d", msg.GetProtocolId())
78+
} else {
79+
m.protocolReceivers[msg.GetProtocolId()] <- msg
80+
}
81+
}
82+
}

0 commit comments

Comments
 (0)