@@ -25,6 +25,7 @@ import (
2525
2626 "github.com/blinklabs-io/gouroboros/cbor"
2727 "github.com/blinklabs-io/gouroboros/muxer"
28+ "github.com/blinklabs-io/gouroboros/utils"
2829)
2930
3031// This is completely arbitrary, but the line had to be drawn somewhere
@@ -40,7 +41,7 @@ type Protocol struct {
4041 recvReadyChan chan bool
4142 sendReadyChan chan bool
4243 stateTransitionChan chan <- protocolStateTransition
43- doneChan chan bool
44+ doneSignal * utils. DoneSignal
4445 waitGroup sync.WaitGroup
4546 onceStart sync.Once
4647}
@@ -102,8 +103,8 @@ type MessageFromCborFunc func(uint, []byte) (Message, error)
102103// New returns a new Protocol object
103104func New (config ProtocolConfig ) * Protocol {
104105 p := & Protocol {
105- config : config ,
106- doneChan : make ( chan bool ),
106+ config : config ,
107+ doneSignal : utils . NewDoneSignal ( ),
107108 }
108109 return p
109110}
@@ -149,8 +150,8 @@ func (p *Protocol) Role() ProtocolRole {
149150}
150151
151152// DoneChan returns the channel used to signal protocol shutdown
152- func (p * Protocol ) DoneChan () chan bool {
153- return p .doneChan
153+ func (p * Protocol ) DoneChan () <- chan struct {} {
154+ return p .doneSignal . GetCh ()
154155}
155156
156157// SendMessage appends a message to the send queue
@@ -178,11 +179,12 @@ func (p *Protocol) sendLoop() {
178179 // We are responsible for closing this channel as the sender, even through it
179180 // was created by the muxer
180181 close (p .muxerSendChan )
182+ p .doneSignal .Close ()
181183 }()
182184
183185 for {
184186 select {
185- case <- p .doneChan :
187+ case <- p .doneSignal . GetCh () :
186188 // Break out of send loop if we're shutting down
187189 return
188190 case <- p .sendReadyChan :
@@ -196,7 +198,7 @@ func (p *Protocol) sendLoop() {
196198 for {
197199 // Get next message from send queue
198200 select {
199- case <- p .doneChan :
201+ case <- p .doneSignal . GetCh () :
200202 // Break out of send loop if we're shutting down
201203 return
202204 case msg , ok := <- p .sendQueueChan :
@@ -260,10 +262,7 @@ func (p *Protocol) sendLoop() {
260262 }
261263 // Send current segment
262264 segmentPayload := payloadBuf .Bytes ()[:segmentPayloadLength ]
263- isResponse := false
264- if p .Role () == ProtocolRoleServer {
265- isResponse = true
266- }
265+ isResponse := p .Role () == ProtocolRoleServer
267266 segment := muxer .NewSegment (
268267 p .config .ProtocolId ,
269268 segmentPayload ,
@@ -283,7 +282,11 @@ func (p *Protocol) sendLoop() {
283282}
284283
285284func (p * Protocol ) recvLoop () {
286- defer p .waitGroup .Done ()
285+ defer func () {
286+ p .waitGroup .Done ()
287+ p .doneSignal .Close ()
288+ }()
289+
287290 leftoverData := false
288291 recvBuffer := bytes .NewBuffer (nil )
289292
@@ -293,15 +296,13 @@ func (p *Protocol) recvLoop() {
293296 if ! leftoverData {
294297 // Wait for segment
295298 select {
296- case <- p .doneChan :
299+ case <- p .doneSignal . GetCh () :
297300 // Break out of receive loop if we're shutting down
298301 return
299302 case <- p .muxerDoneChan :
300- close (p .doneChan )
301303 return
302304 case segment , ok := <- p .muxerRecvChan :
303305 if ! ok {
304- close (p .doneChan )
305306 return
306307 }
307308 // Add segment payload to buffer
@@ -311,11 +312,10 @@ func (p *Protocol) recvLoop() {
311312 leftoverData = false
312313 // Wait until ready to receive based on state map
313314 select {
314- case <- p .doneChan :
315+ case <- p .doneSignal . GetCh () :
315316 // Break out of receive loop if we're shutting down
316317 return
317318 case <- p .muxerDoneChan :
318- close (p .doneChan )
319319 return
320320 case <- p .recvReadyChan :
321321 }
@@ -429,7 +429,7 @@ func (p *Protocol) stateLoop(ch <-chan protocolStateTransition) {
429429 return transitionTimer .C
430430 }
431431
432- protocolDoneChan := p .doneChan
432+ protocolDoneChan := p .doneSignal . GetCh ()
433433 stateDoneChan := make (chan struct {})
434434
435435 setState (p .config .InitialState )
0 commit comments