Skip to content

Commit 75a95e7

Browse files
authored
Merge pull request #36 from cloudstruct/feature/muxer-handshake-race-condition
Fix race condition around handshake and muxer protocol registration
2 parents d838151 + d193a75 commit 75a95e7

File tree

2 files changed

+28
-2
lines changed

2 files changed

+28
-2
lines changed

muxer/muxer.go

Lines changed: 18 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -5,25 +5,31 @@ import (
55
"encoding/binary"
66
"fmt"
77
"io"
8+
"net"
89
"sync"
910
)
1011

1112
const (
1213
// Magic number chosen to represent unknown protocols
1314
PROTOCOL_UNKNOWN uint16 = 0xabcd
15+
16+
// Handshake protocol ID
17+
PROTOCOL_HANDSHAKE = 0
1418
)
1519

1620
type Muxer struct {
17-
conn io.ReadWriteCloser
21+
conn net.Conn
1822
sendMutex sync.Mutex
23+
startChan chan bool
1924
ErrorChan chan error
2025
protocolSenders map[uint16]chan *Segment
2126
protocolReceivers map[uint16]chan *Segment
2227
}
2328

24-
func New(conn io.ReadWriteCloser) *Muxer {
29+
func New(conn net.Conn) *Muxer {
2530
m := &Muxer{
2631
conn: conn,
32+
startChan: make(chan bool, 1),
2733
ErrorChan: make(chan error, 10),
2834
protocolSenders: make(map[uint16]chan *Segment),
2935
protocolReceivers: make(map[uint16]chan *Segment),
@@ -32,6 +38,10 @@ func New(conn io.ReadWriteCloser) *Muxer {
3238
return m
3339
}
3440

41+
func (m *Muxer) Start() {
42+
m.startChan <- true
43+
}
44+
3545
func (m *Muxer) RegisterProtocol(protocolId uint16) (chan *Segment, chan *Segment) {
3646
// Generate channels
3747
senderChan := make(chan *Segment, 10)
@@ -69,6 +79,7 @@ func (m *Muxer) Send(msg *Segment) error {
6979
}
7080

7181
func (m *Muxer) readLoop() {
82+
started := false
7283
for {
7384
header := SegmentHeader{}
7485
if err := binary.Read(m.conn, binary.BigEndian, &header); err != nil {
@@ -83,6 +94,11 @@ func (m *Muxer) readLoop() {
8394
if _, err := io.ReadFull(m.conn, msg.Payload); err != nil {
8495
m.ErrorChan <- err
8596
}
97+
// Wait until the muxer is started to process anything other than handshake messages
98+
if !started && msg.GetProtocolId() != PROTOCOL_HANDSHAKE {
99+
<-m.startChan
100+
started = true
101+
}
86102
// Send message payload to proper receiver
87103
recvChan := m.protocolReceivers[msg.GetProtocolId()]
88104
if recvChan == nil {

ouroboros.go

Lines changed: 10 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -20,6 +20,7 @@ type Ouroboros struct {
2020
muxer *muxer.Muxer
2121
ErrorChan chan error
2222
sendKeepAlives bool
23+
delayMuxerStart bool
2324
// Mini-protocols
2425
Handshake *handshake.Handshake
2526
ChainSync *chainsync.ChainSync
@@ -39,6 +40,7 @@ type OuroborosOptions struct {
3940
Server bool
4041
UseNodeToNodeProtocol bool
4142
SendKeepAlives bool
43+
DelayMuxerStart bool
4244
ChainSyncCallbackConfig *chainsync.ChainSyncCallbackConfig
4345
BlockFetchCallbackConfig *blockfetch.BlockFetchCallbackConfig
4446
KeepAliveCallbackConfig *keepalive.KeepAliveCallbackConfig
@@ -57,6 +59,7 @@ func New(options *OuroborosOptions) (*Ouroboros, error) {
5759
localTxSubmissionCallbackConfig: options.LocalTxSubmissionCallbackConfig,
5860
ErrorChan: options.ErrorChan,
5961
sendKeepAlives: options.SendKeepAlives,
62+
delayMuxerStart: options.DelayMuxerStart,
6063
}
6164
if o.ErrorChan == nil {
6265
o.ErrorChan = make(chan error, 10)
@@ -69,6 +72,10 @@ func New(options *OuroborosOptions) (*Ouroboros, error) {
6972
return o, nil
7073
}
7174

75+
func (o *Ouroboros) Muxer() *muxer.Muxer {
76+
return o.muxer
77+
}
78+
7279
// Convenience function for creating a connection if you didn't provide one when
7380
// calling New()
7481
func (o *Ouroboros) Dial(proto string, address string) error {
@@ -134,5 +141,8 @@ func (o *Ouroboros) setupConnection() error {
134141
o.ChainSync = chainsync.New(protoOptions, o.chainSyncCallbackConfig)
135142
o.LocalTxSubmission = localtxsubmission.New(protoOptions, o.localTxSubmissionCallbackConfig)
136143
}
144+
if !o.delayMuxerStart {
145+
o.muxer.Start()
146+
}
137147
return nil
138148
}

0 commit comments

Comments
 (0)