Skip to content

Commit 2e3c0b2

Browse files
committed
protofsm: use new fn.GoroutineManager to manage goroutines
This fixes an isuse that can occur when we have concurrent calls to `Stop` while the state machine is driving forward.
1 parent 6de0615 commit 2e3c0b2

File tree

1 file changed

+27
-49
lines changed

1 file changed

+27
-49
lines changed

protofsm/state_machine.go

Lines changed: 27 additions & 49 deletions
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,7 @@
11
package protofsm
22

33
import (
4+
"context"
45
"fmt"
56
"sync"
67
"time"
@@ -135,12 +136,11 @@ type StateMachine[Event any, Env Environment] struct {
135136
// query the internal state machine state.
136137
stateQuery chan stateQuery[Event, Env]
137138

139+
wg fn.GoroutineManager
140+
quit chan struct{}
141+
138142
startOnce sync.Once
139143
stopOnce sync.Once
140-
141-
// TODO(roasbeef): also use that context guard here?
142-
quit chan struct{}
143-
wg sync.WaitGroup
144144
}
145145

146146
// ErrorReporter is an interface that's used to report errors that occur during
@@ -194,17 +194,19 @@ func NewStateMachine[Event any, Env Environment](cfg StateMachineCfg[Event, Env]
194194
cfg: cfg,
195195
events: make(chan Event, 1),
196196
stateQuery: make(chan stateQuery[Event, Env]),
197-
quit: make(chan struct{}),
197+
wg: *fn.NewGoroutineManager(context.Background()),
198198
newStateEvents: fn.NewEventDistributor[State[Event, Env]](),
199+
quit: make(chan struct{}),
199200
}
200201
}
201202

202203
// Start starts the state machine. This will spawn a goroutine that will drive
203204
// the state machine to completion.
204205
func (s *StateMachine[Event, Env]) Start() {
205206
s.startOnce.Do(func() {
206-
s.wg.Add(1)
207-
go s.driveMachine()
207+
_ = s.wg.Go(func(ctx context.Context) {
208+
s.driveMachine()
209+
})
208210
})
209211
}
210212

@@ -213,7 +215,7 @@ func (s *StateMachine[Event, Env]) Start() {
213215
func (s *StateMachine[Event, Env]) Stop() {
214216
s.stopOnce.Do(func() {
215217
close(s.quit)
216-
s.wg.Wait()
218+
s.wg.Stop()
217219
})
218220
}
219221

@@ -320,7 +322,7 @@ func (s *StateMachine[Event, Env]) RemoveStateSub(sub StateSubscriber[
320322
// executeDaemonEvent executes a daemon event, which is a special type of event
321323
// that can be emitted as part of the state transition function of the state
322324
// machine. An error is returned if the type of event is unknown.
323-
func (s *StateMachine[Event, Env]) executeDaemonEvent( //nolint:funlen
325+
func (s *StateMachine[Event, Env]) executeDaemonEvent(
324326
event DaemonEvent) error {
325327

326328
switch daemonEvent := event.(type) {
@@ -342,25 +344,19 @@ func (s *StateMachine[Event, Env]) executeDaemonEvent( //nolint:funlen
342344
err)
343345
}
344346

345-
// If a post-send event was specified, then we'll
346-
// funnel that back into the main state machine now as
347-
// well.
348-
daemonEvent.PostSendEvent.WhenSome(func(event Event) {
349-
s.wg.Add(1)
350-
go func() {
351-
defer s.wg.Done()
352-
347+
// If a post-send event was specified, then we'll funnel
348+
// that back into the main state machine now as well.
349+
return fn.MapOptionZ(daemonEvent.PostSendEvent, func(event Event) error { //nolint:lll
350+
return s.wg.Go(func(ctx context.Context) {
353351
log.Debugf("FSM(%v): sending "+
354352
"post-send event: %v",
355353
s.cfg.Env.Name(),
356354
lnutils.SpewLogClosure(event),
357355
)
358356

359357
s.SendEvent(event)
360-
}()
358+
})
361359
})
362-
363-
return nil
364360
}
365361

366362
// If this doesn't have a SendWhen predicate, then we can just
@@ -372,10 +368,7 @@ func (s *StateMachine[Event, Env]) executeDaemonEvent( //nolint:funlen
372368
// Otherwise, this has a SendWhen predicate, so we'll need
373369
// launch a goroutine to poll the SendWhen, then send only once
374370
// the predicate is true.
375-
s.wg.Add(1)
376-
go func() {
377-
defer s.wg.Done()
378-
371+
return s.wg.Go(func(ctx context.Context) {
379372
predicateTicker := time.NewTicker(
380373
s.cfg.CustomPollInterval.UnwrapOr(pollInterval),
381374
)
@@ -408,13 +401,11 @@ func (s *StateMachine[Event, Env]) executeDaemonEvent( //nolint:funlen
408401
return
409402
}
410403

411-
case <-s.quit:
404+
case <-ctx.Done():
412405
return
413406
}
414407
}
415-
}()
416-
417-
return nil
408+
})
418409

419410
// If this is a broadcast transaction event, then we'll broadcast with
420411
// the label attached.
@@ -445,9 +436,7 @@ func (s *StateMachine[Event, Env]) executeDaemonEvent( //nolint:funlen
445436
return fmt.Errorf("unable to register spend: %w", err)
446437
}
447438

448-
s.wg.Add(1)
449-
go func() {
450-
defer s.wg.Done()
439+
return s.wg.Go(func(ctx context.Context) {
451440
for {
452441
select {
453442
case spend, ok := <-spendEvent.Spend:
@@ -466,13 +455,11 @@ func (s *StateMachine[Event, Env]) executeDaemonEvent( //nolint:funlen
466455

467456
return
468457

469-
case <-s.quit:
458+
case <-ctx.Done():
470459
return
471460
}
472461
}
473-
}()
474-
475-
return nil
462+
})
476463

477464
// The state machine has requested a new event to be sent once a
478465
// specified txid+pkScript pair has confirmed.
@@ -489,9 +476,7 @@ func (s *StateMachine[Event, Env]) executeDaemonEvent( //nolint:funlen
489476
return fmt.Errorf("unable to register conf: %w", err)
490477
}
491478

492-
s.wg.Add(1)
493-
go func() {
494-
defer s.wg.Done()
479+
return s.wg.Go(func(ctx context.Context) {
495480
for {
496481
select {
497482
case <-confEvent.Confirmed:
@@ -508,11 +493,11 @@ func (s *StateMachine[Event, Env]) executeDaemonEvent( //nolint:funlen
508493

509494
return
510495

511-
case <-s.quit:
496+
case <-ctx.Done():
512497
return
513498
}
514499
}
515-
}()
500+
})
516501
}
517502

518503
return fmt.Errorf("unknown daemon event: %T", event)
@@ -632,8 +617,6 @@ func (s *StateMachine[Event, Env]) applyEvents(currentState State[Event, Env],
632617
// incoming events, and then drives the state machine forward until it reaches
633618
// a terminal state.
634619
func (s *StateMachine[Event, Env]) driveMachine() {
635-
defer s.wg.Done()
636-
637620
log.Debugf("FSM(%v): starting state machine", s.cfg.Env.Name())
638621

639622
currentState := s.cfg.InitialState
@@ -676,16 +659,11 @@ func (s *StateMachine[Event, Env]) driveMachine() {
676659
// An outside caller is querying our state, so we'll return the
677660
// latest state.
678661
case stateQuery := <-s.stateQuery:
679-
if !fn.SendOrQuit(
680-
stateQuery.CurrentState, currentState, s.quit,
681-
) {
682-
662+
if !fn.SendOrQuit(stateQuery.CurrentState, currentState, s.quit) { //nolint:lll
683663
return
684664
}
685665

686-
case <-s.quit:
687-
// TODO(roasbeef): logs, etc
688-
// * something in env?
666+
case <-s.wg.Done():
689667
return
690668
}
691669
}

0 commit comments

Comments
 (0)