@@ -42,6 +42,16 @@ const (
4242 DiffusionModeInitiatorAndResponder DiffusionMode = 3 // Initiator and responder (full duplex) mode
4343)
4444
45+ // ProtocolRole is an enum of the protocol roles
46+ type ProtocolRole uint
47+
48+ // Protocol roles
49+ const (
50+ ProtocolRoleNone ProtocolRole = 0 // Default (invalid) protocol role
51+ ProtocolRoleInitiator ProtocolRole = 1 // Initiator (client) protocol role
52+ ProtocolRoleResponder ProtocolRole = 2 // Responder (server) protocol role
53+ )
54+
4555// Muxer wraps a connection to allow running multiple mini-protocols over a single connection
4656type Muxer struct {
4757 errorChan chan error
@@ -50,8 +60,8 @@ type Muxer struct {
5060 startChan chan bool
5161 doneChan chan bool
5262 waitGroup sync.WaitGroup
53- protocolSenders map [uint16 ]chan * Segment
54- protocolReceivers map [uint16 ]chan * Segment
63+ protocolSenders map [uint16 ]map [ ProtocolRole ] chan * Segment
64+ protocolReceivers map [uint16 ]map [ ProtocolRole ] chan * Segment
5565 diffusionMode DiffusionMode
5666 onceStart sync.Once
5767 onceStop sync.Once
@@ -64,8 +74,8 @@ func New(conn net.Conn) *Muxer {
6474 startChan : make (chan bool , 1 ),
6575 doneChan : make (chan bool ),
6676 errorChan : make (chan error , 10 ),
67- protocolSenders : make (map [uint16 ]chan * Segment ),
68- protocolReceivers : make (map [uint16 ]chan * Segment ),
77+ protocolSenders : make (map [uint16 ]map [ ProtocolRole ] chan * Segment ),
78+ protocolReceivers : make (map [uint16 ]map [ ProtocolRole ] chan * Segment ),
6979 }
7080 m .waitGroup .Add (1 )
7181 go m .readLoop ()
@@ -95,8 +105,10 @@ func (m *Muxer) Stop() {
95105 m .waitGroup .Wait ()
96106 // Close protocol receive channels
97107 // We rely on the individual mini-protocols to close the sender channel
98- for _ , recvChan := range m .protocolReceivers {
99- close (recvChan )
108+ for _ , protocolRoles := range m .protocolReceivers {
109+ for _ , recvChan := range protocolRoles {
110+ close (recvChan )
111+ }
100112 }
101113 // Close ErrorChan to signify to consumer that we're shutting down
102114 close (m .errorChan )
@@ -124,13 +136,17 @@ func (m *Muxer) sendError(err error) {
124136
125137// RegisterProtocol registers the provided protocol ID with the muxer. It returns a channel for sending,
126138// a channel for receiving, and a channel to know when the muxer is shutting down
127- func (m * Muxer ) RegisterProtocol (protocolId uint16 ) (chan * Segment , chan * Segment , chan bool ) {
139+ func (m * Muxer ) RegisterProtocol (protocolId uint16 , protocolRole ProtocolRole ) (chan * Segment , chan * Segment , chan bool ) {
128140 // Generate channels
129141 senderChan := make (chan * Segment , 10 )
130142 receiverChan := make (chan * Segment , 10 )
131143 // Record channels in protocol sender/receiver maps
132- m .protocolSenders [protocolId ] = senderChan
133- m .protocolReceivers [protocolId ] = receiverChan
144+ if _ , ok := m .protocolSenders [protocolId ]; ! ok {
145+ m .protocolSenders [protocolId ] = make (map [ProtocolRole ]chan * Segment )
146+ m .protocolReceivers [protocolId ] = make (map [ProtocolRole ]chan * Segment )
147+ }
148+ m.protocolSenders [protocolId ][protocolRole ] = senderChan
149+ m.protocolReceivers [protocolId ][protocolRole ] = receiverChan
134150 // Start Goroutine to handle outbound messages
135151 m .waitGroup .Add (1 )
136152 go func () {
@@ -216,15 +232,24 @@ func (m *Muxer) readLoop() {
216232 return
217233 }
218234 // Send message payload to proper receiver
219- recvChan := m .protocolReceivers [msg .GetProtocolId ()]
220- if recvChan == nil {
235+ protocolRole := ProtocolRoleResponder
236+ if msg .IsResponse () {
237+ protocolRole = ProtocolRoleInitiator
238+ }
239+ protocolRoles , ok := m .protocolReceivers [msg .GetProtocolId ()]
240+ if ! ok {
221241 // Try the "unknown protocol" receiver if we didn't find an explicit one
222- recvChan = m .protocolReceivers [ProtocolUnknown ]
223- if recvChan == nil {
242+ protocolRoles , ok = m .protocolReceivers [ProtocolUnknown ]
243+ if ! ok {
224244 m .sendError (fmt .Errorf ("received message for unknown protocol ID %d" , msg .GetProtocolId ()))
225245 return
226246 }
227247 }
248+ recvChan := protocolRoles [protocolRole ]
249+ if recvChan == nil {
250+ m .sendError (fmt .Errorf ("received message for unknown protocol ID %d" , msg .GetProtocolId ()))
251+ return
252+ }
228253 if recvChan != nil {
229254 recvChan <- msg
230255 }
0 commit comments