@@ -25,7 +25,6 @@ import (
2525 "github.com/blinklabs-io/gouroboros/cbor"
2626 "github.com/blinklabs-io/gouroboros/connection"
2727 "github.com/blinklabs-io/gouroboros/muxer"
28- "github.com/blinklabs-io/gouroboros/utils"
2928)
3029
3130// This is completely arbitrary, but the line had to be drawn somewhere
@@ -34,15 +33,16 @@ const maxMessagesPerSegment = 20
3433// Protocol implements the base functionality of an Ouroboros mini-protocol
3534type Protocol struct {
3635 config ProtocolConfig
36+ doneChan chan struct {}
3737 muxerSendChan chan * muxer.Segment
3838 muxerRecvChan chan * muxer.Segment
3939 muxerDoneChan chan bool
4040 sendQueueChan chan Message
41+ recvDoneChan chan struct {}
4142 recvReadyChan chan bool
43+ sendDoneChan chan struct {}
4244 sendReadyChan chan bool
4345 stateTransitionChan chan <- protocolStateTransition
44- doneSignal * utils.DoneSignal
45- waitGroup sync.WaitGroup
4646 onceStart sync.Once
4747}
4848
@@ -105,8 +105,10 @@ type MessageFromCborFunc func(uint, []byte) (Message, error)
105105// New returns a new Protocol object
106106func New (config ProtocolConfig ) * Protocol {
107107 p := & Protocol {
108- config : config ,
109- doneSignal : utils .NewDoneSignal (),
108+ config : config ,
109+ doneChan : make (chan struct {}),
110+ recvDoneChan : make (chan struct {}),
111+ sendDoneChan : make (chan struct {}),
110112 }
111113 return p
112114}
@@ -133,7 +135,11 @@ func (p *Protocol) Start() {
133135 p .stateTransitionChan = stateTransitionChan
134136
135137 // Start our send and receive Goroutines
136- p .waitGroup .Add (2 )
138+ go func () {
139+ <- p .recvDoneChan
140+ <- p .sendDoneChan
141+ close (p .doneChan )
142+ }()
137143
138144 go p .stateLoop (stateTransitionChan )
139145 go p .recvLoop ()
@@ -153,7 +159,7 @@ func (p *Protocol) Role() ProtocolRole {
153159
154160// DoneChan returns the channel used to signal protocol shutdown
155161func (p * Protocol ) DoneChan () <- chan struct {} {
156- return p .doneSignal . GetCh ()
162+ return p .doneChan
157163}
158164
159165// SendMessage appends a message to the send queue
@@ -176,17 +182,16 @@ func (p *Protocol) SendError(err error) {
176182
177183func (p * Protocol ) sendLoop () {
178184 defer func () {
179- p .waitGroup .Done ()
180185 // Close muxer send channel
181186 // We are responsible for closing this channel as the sender, even through it
182187 // was created by the muxer
183188 close (p .muxerSendChan )
184- p . doneSignal . Close ( )
189+ close ( p . sendDoneChan )
185190 }()
186191
187192 for {
188193 select {
189- case <- p .doneSignal . GetCh () :
194+ case <- p .recvDoneChan :
190195 // Break out of send loop if we're shutting down
191196 return
192197 case <- p .sendReadyChan :
@@ -200,7 +205,7 @@ func (p *Protocol) sendLoop() {
200205 for {
201206 // Get next message from send queue
202207 select {
203- case <- p .doneSignal . GetCh () :
208+ case <- p .recvDoneChan :
204209 // Break out of send loop if we're shutting down
205210 return
206211 case msg , ok := <- p .sendQueueChan :
@@ -285,8 +290,7 @@ func (p *Protocol) sendLoop() {
285290
286291func (p * Protocol ) recvLoop () {
287292 defer func () {
288- p .waitGroup .Done ()
289- p .doneSignal .Close ()
293+ close (p .recvDoneChan )
290294 }()
291295
292296 leftoverData := false
@@ -298,7 +302,7 @@ func (p *Protocol) recvLoop() {
298302 if ! leftoverData {
299303 // Wait for segment
300304 select {
301- case <- p .doneSignal . GetCh () :
305+ case <- p .sendDoneChan :
302306 // Break out of receive loop if we're shutting down
303307 return
304308 case <- p .muxerDoneChan :
@@ -314,7 +318,7 @@ func (p *Protocol) recvLoop() {
314318 leftoverData = false
315319 // Wait until ready to receive based on state map
316320 select {
317- case <- p .doneSignal . GetCh () :
321+ case <- p .sendDoneChan :
318322 // Break out of receive loop if we're shutting down
319323 return
320324 case <- p .muxerDoneChan :
@@ -431,9 +435,6 @@ func (p *Protocol) stateLoop(ch <-chan protocolStateTransition) {
431435 return transitionTimer .C
432436 }
433437
434- protocolDoneChan := p .doneSignal .GetCh ()
435- stateDoneChan := make (chan struct {})
436-
437438 setState (p .config .InitialState )
438439
439440 for {
@@ -467,24 +468,11 @@ func (p *Protocol) stateLoop(ch <-chan protocolStateTransition) {
467468 ),
468469 )
469470
470- case <- protocolDoneChan :
471- // Disable this case so it doesn't block
472- protocolDoneChan = nil
473-
474- // Wait for all other goroutines to finish before shutting down the state handler
475- go func () {
476- p .waitGroup .Wait ()
477-
478- close (stateDoneChan )
479- }()
480-
481- case <- stateDoneChan :
482- // All other goroutines have finished, so we can stop the timer and return
471+ case <- p .doneChan :
472+ // Disable any previous state transition timer, as they are no longer needed
483473 if transitionTimer != nil && ! transitionTimer .Stop () {
484474 <- transitionTimer .C
485475 }
486- transitionTimer = nil
487-
488476 return
489477 }
490478 }
0 commit comments