11package protofsm
22
33import (
4+ "context"
45 "fmt"
56 "sync"
67 "time"
@@ -135,12 +136,11 @@ type StateMachine[Event any, Env Environment] struct {
135136 // query the internal state machine state.
136137 stateQuery chan stateQuery [Event , Env ]
137138
139+ wg fn.GoroutineManager
140+ quit chan struct {}
141+
138142 startOnce sync.Once
139143 stopOnce sync.Once
140-
141- // TODO(roasbeef): also use that context guard here?
142- quit chan struct {}
143- wg sync.WaitGroup
144144}
145145
146146// ErrorReporter is an interface that's used to report errors that occur during
@@ -194,17 +194,19 @@ func NewStateMachine[Event any, Env Environment](cfg StateMachineCfg[Event, Env]
194194 cfg : cfg ,
195195 events : make (chan Event , 1 ),
196196 stateQuery : make (chan stateQuery [Event , Env ]),
197- quit : make ( chan struct {} ),
197+ wg : * fn . NewGoroutineManager ( context . Background () ),
198198 newStateEvents : fn .NewEventDistributor [State [Event , Env ]](),
199+ quit : make (chan struct {}),
199200 }
200201}
201202
202203// Start starts the state machine. This will spawn a goroutine that will drive
203204// the state machine to completion.
204205func (s * StateMachine [Event , Env ]) Start () {
205206 s .startOnce .Do (func () {
206- s .wg .Add (1 )
207- go s .driveMachine ()
207+ _ = s .wg .Go (func (ctx context.Context ) {
208+ s .driveMachine ()
209+ })
208210 })
209211}
210212
@@ -213,7 +215,7 @@ func (s *StateMachine[Event, Env]) Start() {
213215func (s * StateMachine [Event , Env ]) Stop () {
214216 s .stopOnce .Do (func () {
215217 close (s .quit )
216- s .wg .Wait ()
218+ s .wg .Stop ()
217219 })
218220}
219221
@@ -320,7 +322,7 @@ func (s *StateMachine[Event, Env]) RemoveStateSub(sub StateSubscriber[
320322// executeDaemonEvent executes a daemon event, which is a special type of event
321323// that can be emitted as part of the state transition function of the state
322324// machine. An error is returned if the type of event is unknown.
323- func (s * StateMachine [Event , Env ]) executeDaemonEvent ( //nolint:funlen
325+ func (s * StateMachine [Event , Env ]) executeDaemonEvent (
324326 event DaemonEvent ) error {
325327
326328 switch daemonEvent := event .(type ) {
@@ -342,25 +344,19 @@ func (s *StateMachine[Event, Env]) executeDaemonEvent( //nolint:funlen
342344 err )
343345 }
344346
345- // If a post-send event was specified, then we'll
346- // funnel that back into the main state machine now as
347- // well.
348- daemonEvent .PostSendEvent .WhenSome (func (event Event ) {
349- s .wg .Add (1 )
350- go func () {
351- defer s .wg .Done ()
352-
347+ // If a post-send event was specified, then we'll funnel
348+ // that back into the main state machine now as well.
349+ return fn .MapOptionZ (daemonEvent .PostSendEvent , func (event Event ) error { //nolint:lll
350+ return s .wg .Go (func (ctx context.Context ) {
353351 log .Debugf ("FSM(%v): sending " +
354352 "post-send event: %v" ,
355353 s .cfg .Env .Name (),
356354 lnutils .SpewLogClosure (event ),
357355 )
358356
359357 s .SendEvent (event )
360- }( )
358+ })
361359 })
362-
363- return nil
364360 }
365361
366362 // If this doesn't have a SendWhen predicate, then we can just
@@ -372,10 +368,7 @@ func (s *StateMachine[Event, Env]) executeDaemonEvent( //nolint:funlen
372368 // Otherwise, this has a SendWhen predicate, so we'll need
373369 // launch a goroutine to poll the SendWhen, then send only once
374370 // the predicate is true.
375- s .wg .Add (1 )
376- go func () {
377- defer s .wg .Done ()
378-
371+ return s .wg .Go (func (ctx context.Context ) {
379372 predicateTicker := time .NewTicker (
380373 s .cfg .CustomPollInterval .UnwrapOr (pollInterval ),
381374 )
@@ -408,13 +401,11 @@ func (s *StateMachine[Event, Env]) executeDaemonEvent( //nolint:funlen
408401 return
409402 }
410403
411- case <- s . quit :
404+ case <- ctx . Done () :
412405 return
413406 }
414407 }
415- }()
416-
417- return nil
408+ })
418409
419410 // If this is a broadcast transaction event, then we'll broadcast with
420411 // the label attached.
@@ -445,9 +436,7 @@ func (s *StateMachine[Event, Env]) executeDaemonEvent( //nolint:funlen
445436 return fmt .Errorf ("unable to register spend: %w" , err )
446437 }
447438
448- s .wg .Add (1 )
449- go func () {
450- defer s .wg .Done ()
439+ return s .wg .Go (func (ctx context.Context ) {
451440 for {
452441 select {
453442 case spend , ok := <- spendEvent .Spend :
@@ -466,13 +455,11 @@ func (s *StateMachine[Event, Env]) executeDaemonEvent( //nolint:funlen
466455
467456 return
468457
469- case <- s . quit :
458+ case <- ctx . Done () :
470459 return
471460 }
472461 }
473- }()
474-
475- return nil
462+ })
476463
477464 // The state machine has requested a new event to be sent once a
478465 // specified txid+pkScript pair has confirmed.
@@ -489,9 +476,7 @@ func (s *StateMachine[Event, Env]) executeDaemonEvent( //nolint:funlen
489476 return fmt .Errorf ("unable to register conf: %w" , err )
490477 }
491478
492- s .wg .Add (1 )
493- go func () {
494- defer s .wg .Done ()
479+ return s .wg .Go (func (ctx context.Context ) {
495480 for {
496481 select {
497482 case <- confEvent .Confirmed :
@@ -508,11 +493,11 @@ func (s *StateMachine[Event, Env]) executeDaemonEvent( //nolint:funlen
508493
509494 return
510495
511- case <- s . quit :
496+ case <- ctx . Done () :
512497 return
513498 }
514499 }
515- }( )
500+ })
516501 }
517502
518503 return fmt .Errorf ("unknown daemon event: %T" , event )
@@ -632,8 +617,6 @@ func (s *StateMachine[Event, Env]) applyEvents(currentState State[Event, Env],
632617// incoming events, and then drives the state machine forward until it reaches
633618// a terminal state.
634619func (s * StateMachine [Event , Env ]) driveMachine () {
635- defer s .wg .Done ()
636-
637620 log .Debugf ("FSM(%v): starting state machine" , s .cfg .Env .Name ())
638621
639622 currentState := s .cfg .InitialState
@@ -676,16 +659,11 @@ func (s *StateMachine[Event, Env]) driveMachine() {
676659 // An outside caller is querying our state, so we'll return the
677660 // latest state.
678661 case stateQuery := <- s .stateQuery :
679- if ! fn .SendOrQuit (
680- stateQuery .CurrentState , currentState , s .quit ,
681- ) {
682-
662+ if ! fn .SendOrQuit (stateQuery .CurrentState , currentState , s .quit ) { //nolint:lll
683663 return
684664 }
685665
686- case <- s .quit :
687- // TODO(roasbeef): logs, etc
688- // * something in env?
666+ case <- s .wg .Done ():
689667 return
690668 }
691669 }
0 commit comments