Skip to content

Commit c5a0f54

Browse files
authored
Merge pull request #489 from blinklabs-io/fix/muxer-data-race
fix: adding locking around muxer protocol receivers map to prevent data race
2 parents 5c8e061 + 0cc24f2 commit c5a0f54

File tree

1 file changed

+19
-11
lines changed

1 file changed

+19
-11
lines changed

muxer/muxer.go

Lines changed: 19 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -54,17 +54,18 @@ const (
5454

5555
// Muxer wraps a connection to allow running multiple mini-protocols over a single connection
5656
type Muxer struct {
57-
errorChan chan error
58-
conn net.Conn
59-
sendMutex sync.Mutex
60-
startChan chan bool
61-
doneChan chan bool
62-
waitGroup sync.WaitGroup
63-
protocolSenders map[uint16]map[ProtocolRole]chan *Segment
64-
protocolReceivers map[uint16]map[ProtocolRole]chan *Segment
65-
diffusionMode DiffusionMode
66-
onceStart sync.Once
67-
onceStop sync.Once
57+
errorChan chan error
58+
conn net.Conn
59+
sendMutex sync.Mutex
60+
startChan chan bool
61+
doneChan chan bool
62+
waitGroup sync.WaitGroup
63+
protocolSenders map[uint16]map[ProtocolRole]chan *Segment
64+
protocolReceivers map[uint16]map[ProtocolRole]chan *Segment
65+
protocolReceiversMutex sync.Mutex
66+
diffusionMode DiffusionMode
67+
onceStart sync.Once
68+
onceStop sync.Once
6869
}
6970

7071
// New creates a new Muxer object and starts the read loop
@@ -137,12 +138,14 @@ func (m *Muxer) RegisterProtocol(
137138
senderChan := make(chan *Segment, 10)
138139
receiverChan := make(chan *Segment, 10)
139140
// Record channels in protocol sender/receiver maps
141+
m.protocolReceiversMutex.Lock()
140142
if _, ok := m.protocolSenders[protocolId]; !ok {
141143
m.protocolSenders[protocolId] = make(map[ProtocolRole]chan *Segment)
142144
m.protocolReceivers[protocolId] = make(map[ProtocolRole]chan *Segment)
143145
}
144146
m.protocolSenders[protocolId][protocolRole] = senderChan
145147
m.protocolReceivers[protocolId][protocolRole] = receiverChan
148+
m.protocolReceiversMutex.Unlock()
146149
// Start Goroutine to handle outbound messages
147150
m.waitGroup.Add(1)
148151
go func() {
@@ -199,11 +202,13 @@ func (m *Muxer) readLoop() {
199202
defer func() {
200203
m.waitGroup.Done()
201204
// Close receiver channels
205+
m.protocolReceiversMutex.Lock()
202206
for _, protocolRoles := range m.protocolReceivers {
203207
for _, recvChan := range protocolRoles {
204208
close(recvChan)
205209
}
206210
}
211+
m.protocolReceiversMutex.Unlock()
207212
}()
208213
started := false
209214
for {
@@ -251,11 +256,13 @@ func (m *Muxer) readLoop() {
251256
if msg.IsResponse() {
252257
protocolRole = ProtocolRoleInitiator
253258
}
259+
m.protocolReceiversMutex.Lock()
254260
protocolRoles, ok := m.protocolReceivers[msg.GetProtocolId()]
255261
if !ok {
256262
// Try the "unknown protocol" receiver if we didn't find an explicit one
257263
protocolRoles, ok = m.protocolReceivers[ProtocolUnknown]
258264
if !ok {
265+
m.protocolReceiversMutex.Unlock()
259266
m.sendError(
260267
fmt.Errorf(
261268
"received message for unknown protocol ID %d",
@@ -266,6 +273,7 @@ func (m *Muxer) readLoop() {
266273
}
267274
}
268275
recvChan := protocolRoles[protocolRole]
276+
m.protocolReceiversMutex.Unlock()
269277
if recvChan == nil {
270278
m.sendError(
271279
fmt.Errorf(

0 commit comments

Comments
 (0)