Skip to content

Commit 11df25c

Browse files
authored
Merge pull request #16 from cloudstruct/feature/protocol-refactor
Refactor mini-protocols
2 parents c63495e + 78dee7a commit 11df25c

File tree

12 files changed

+487
-421
lines changed

12 files changed

+487
-421
lines changed

cmd/go-ouroboros-network/chainsync.go

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -173,7 +173,7 @@ func chainSyncIntersectFoundHandler(point interface{}, tip interface{}) error {
173173
return nil
174174
}
175175

176-
func chainSyncIntersectNotFoundHandler() error {
176+
func chainSyncIntersectNotFoundHandler(tip interface{}) error {
177177
fmt.Printf("ERROR: failed to find intersection\n")
178178
os.Exit(1)
179179
return nil

muxer/message.go

Lines changed: 0 additions & 51 deletions
This file was deleted.

muxer/muxer.go

Lines changed: 27 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -5,30 +5,37 @@ import (
55
"encoding/binary"
66
"fmt"
77
"io"
8+
"sync"
9+
)
10+
11+
const (
12+
// Magic number chosen to represent unknown protocols
13+
PROTOCOL_UNKNOWN uint16 = 0xabcd
814
)
915

1016
type Muxer struct {
1117
conn io.ReadWriteCloser
18+
sendMutex sync.Mutex
1219
ErrorChan chan error
13-
protocolSenders map[uint16]chan *Message
14-
protocolReceivers map[uint16]chan *Message
20+
protocolSenders map[uint16]chan *Segment
21+
protocolReceivers map[uint16]chan *Segment
1522
}
1623

1724
func New(conn io.ReadWriteCloser) *Muxer {
1825
m := &Muxer{
1926
conn: conn,
2027
ErrorChan: make(chan error, 10),
21-
protocolSenders: make(map[uint16]chan *Message),
22-
protocolReceivers: make(map[uint16]chan *Message),
28+
protocolSenders: make(map[uint16]chan *Segment),
29+
protocolReceivers: make(map[uint16]chan *Segment),
2330
}
2431
go m.readLoop()
2532
return m
2633
}
2734

28-
func (m *Muxer) RegisterProtocol(protocolId uint16) (chan *Message, chan *Message) {
35+
func (m *Muxer) RegisterProtocol(protocolId uint16) (chan *Segment, chan *Segment) {
2936
// Generate channels
30-
senderChan := make(chan *Message, 10)
31-
receiverChan := make(chan *Message, 10)
37+
senderChan := make(chan *Segment, 10)
38+
receiverChan := make(chan *Segment, 10)
3239
// Record channels in protocol sender/receiver maps
3340
m.protocolSenders[protocolId] = senderChan
3441
m.protocolReceivers[protocolId] = receiverChan
@@ -44,9 +51,12 @@ func (m *Muxer) RegisterProtocol(protocolId uint16) (chan *Message, chan *Messag
4451
return senderChan, receiverChan
4552
}
4653

47-
func (m *Muxer) Send(msg *Message) error {
54+
func (m *Muxer) Send(msg *Segment) error {
55+
// We use a mutex to make sure only one protocol can send at a time
56+
m.sendMutex.Lock()
57+
defer m.sendMutex.Unlock()
4858
buf := &bytes.Buffer{}
49-
err := binary.Write(buf, binary.BigEndian, msg.MessageHeader)
59+
err := binary.Write(buf, binary.BigEndian, msg.SegmentHeader)
5060
if err != nil {
5161
return err
5262
}
@@ -60,12 +70,12 @@ func (m *Muxer) Send(msg *Message) error {
6070

6171
func (m *Muxer) readLoop() {
6272
for {
63-
header := MessageHeader{}
73+
header := SegmentHeader{}
6474
if err := binary.Read(m.conn, binary.BigEndian, &header); err != nil {
6575
m.ErrorChan <- err
6676
}
67-
msg := &Message{
68-
MessageHeader: header,
77+
msg := &Segment{
78+
SegmentHeader: header,
6979
Payload: make([]byte, header.PayloadLength),
7080
}
7181
// We use ReadFull because it guarantees to read the expected number of bytes or
@@ -76,7 +86,11 @@ func (m *Muxer) readLoop() {
7686
// Send message payload to proper receiver
7787
recvChan := m.protocolReceivers[msg.GetProtocolId()]
7888
if recvChan == nil {
79-
m.ErrorChan <- fmt.Errorf("received message for unknown protocol ID %d", msg.GetProtocolId())
89+
// Try the "unknown protocol" receiver if we didn't find an explicit one
90+
recvChan = m.protocolReceivers[PROTOCOL_UNKNOWN]
91+
if recvChan == nil {
92+
m.ErrorChan <- fmt.Errorf("received message for unknown protocol ID %d", msg.GetProtocolId())
93+
}
8094
} else {
8195
m.protocolReceivers[msg.GetProtocolId()] <- msg
8296
}

muxer/segment.go

Lines changed: 51 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,51 @@
1+
package muxer
2+
3+
import (
4+
"time"
5+
)
6+
7+
const (
8+
SEGMENT_PROTOCOL_ID_RESPONSE_FLAG = 0x8000
9+
)
10+
11+
type SegmentHeader struct {
12+
Timestamp uint32
13+
ProtocolId uint16
14+
PayloadLength uint16
15+
}
16+
17+
type Segment struct {
18+
SegmentHeader
19+
Payload []byte
20+
}
21+
22+
func NewSegment(protocolId uint16, payload []byte, isResponse bool) *Segment {
23+
header := SegmentHeader{
24+
Timestamp: uint32(time.Now().UnixNano() & 0xffffffff),
25+
ProtocolId: protocolId,
26+
}
27+
if isResponse {
28+
header.ProtocolId = header.ProtocolId + SEGMENT_PROTOCOL_ID_RESPONSE_FLAG
29+
}
30+
header.PayloadLength = uint16(len(payload))
31+
segment := &Segment{
32+
SegmentHeader: header,
33+
Payload: payload,
34+
}
35+
return segment
36+
}
37+
38+
func (s *SegmentHeader) IsRequest() bool {
39+
return (s.ProtocolId & SEGMENT_PROTOCOL_ID_RESPONSE_FLAG) == 0
40+
}
41+
42+
func (s *SegmentHeader) IsResponse() bool {
43+
return (s.ProtocolId & SEGMENT_PROTOCOL_ID_RESPONSE_FLAG) > 0
44+
}
45+
46+
func (s *SegmentHeader) GetProtocolId() uint16 {
47+
if s.ProtocolId >= SEGMENT_PROTOCOL_ID_RESPONSE_FLAG {
48+
return s.ProtocolId - SEGMENT_PROTOCOL_ID_RESPONSE_FLAG
49+
}
50+
return s.ProtocolId
51+
}

0 commit comments

Comments
 (0)