Skip to content

Commit 8e3a8f3

Browse files
committed
protofsm: add ability for state machine to consume wire msgs
In this commit, we add the ability for the state machine to consume wire messages. This'll allow the creation of a new generic message router that takes the place of the current peer `readHandler` in an upcoming commit.
1 parent 2716388 commit 8e3a8f3

File tree

4 files changed

+170
-17
lines changed

4 files changed

+170
-17
lines changed

protofsm/daemon_events.go

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -8,7 +8,7 @@ import (
88
"github.com/lightningnetwork/lnd/lnwire"
99
)
1010

11-
// DaemonEvent is a special event that can be emmitted by a state transition
11+
// DaemonEvent is a special event that can be emitted by a state transition
1212
// function. A state machine can use this to perform side effects, such as
1313
// sending a message to a peer, or broadcasting a transaction.
1414
type DaemonEvent interface {

protofsm/msg_mapper.go

Lines changed: 15 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,15 @@
1+
package protofsm
2+
3+
import (
4+
"github.com/lightningnetwork/lnd/fn"
5+
"github.com/lightningnetwork/lnd/lnwire"
6+
)
7+
8+
// MsgMapper is used to map incoming wire messages into a FSM event. This is
9+
// useful to decouple the translation of an outside or wire message into an
10+
// event type that can be understood by the FSM.
11+
type MsgMapper[Event any] interface {
12+
// MapMsg maps a wire message into a FSM event. If the message is not
13+
// mappable, then an error is returned.
14+
MapMsg(msg lnwire.Message) fn.Option[Event]
15+
}

protofsm/state_machine.go

Lines changed: 73 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -67,7 +67,8 @@ type State[Event any, Env Environment] interface {
6767
// emitted.
6868
ProcessEvent(event Event, env Env) (*StateTransition[Event, Env], error)
6969

70-
// IsTerminal returns true if this state is terminal, and false otherwise.
70+
// IsTerminal returns true if this state is terminal, and false
71+
// otherwise.
7172
IsTerminal() bool
7273

7374
// TODO(roasbeef): also add state serialization?
@@ -165,13 +166,17 @@ type StateMachineCfg[Event any, Env Environment] struct {
165166
// can be used to set up tracking state such as a txid confirmation
166167
// event.
167168
InitEvent fn.Option[DaemonEvent]
169+
170+
// MsgMapper is an optional message mapper that can be used to map
171+
// normal wire messages into FSM events.
172+
MsgMapper fn.Option[MsgMapper[Event]]
168173
}
169174

170175
// NewStateMachine creates a new state machine given a set of daemon adapters,
171176
// an initial state, an environment, and an event to process as if emitted at
172177
// the onset of the state machine. Such an event can be used to set up tracking
173178
// state such as a txid confirmation event.
174-
func NewStateMachine[Event any, Env Environment](cfg StateMachineCfg[Event, Env],
179+
func NewStateMachine[Event any, Env Environment](cfg StateMachineCfg[Event, Env], //nolint:lll
175180
) StateMachine[Event, Env] {
176181

177182
return StateMachine[Event, Env]{
@@ -212,6 +217,43 @@ func (s *StateMachine[Event, Env]) SendEvent(event Event) {
212217
}
213218
}
214219

220+
// CanHandle returns true if the target message can be routed to the state
221+
// machine.
222+
func (s *StateMachine[Event, Env]) CanHandle(msg lnwire.Message) bool {
223+
cfgMapper := s.cfg.MsgMapper
224+
return fn.MapOptionZ(cfgMapper, func(mapper MsgMapper[Event]) bool {
225+
return mapper.MapMsg(msg).IsSome()
226+
})
227+
}
228+
229+
// SendMessage attempts to send a wire message to the state machine. If the
230+
// message can be mapped using the default message mapper, then true is
231+
// returned indicating that the message was processed. Otherwise, false is
232+
// returned.
233+
func (s *StateMachine[Event, Env]) SendMessage(msg lnwire.Message) bool {
234+
// If we have no message mapper, then return false as we can't process
235+
// this message.
236+
if !s.cfg.MsgMapper.IsSome() {
237+
return false
238+
}
239+
240+
// Otherwise, try to map the message using the default message mapper.
241+
// If we can't extract an event, then we'll return false to indicate
242+
// that the message wasn't processed.
243+
var processed bool
244+
s.cfg.MsgMapper.WhenSome(func(mapper MsgMapper[Event]) {
245+
event := mapper.MapMsg(msg)
246+
247+
event.WhenSome(func(event Event) {
248+
s.SendEvent(event)
249+
250+
processed = true
251+
})
252+
})
253+
254+
return processed
255+
}
256+
215257
// CurrentState returns the current state of the state machine.
216258
func (s *StateMachine[Event, Env]) CurrentState() (State[Event, Env], error) {
217259
query := stateQuery[Event, Env]{
@@ -231,7 +273,9 @@ type StateSubscriber[E any, F Environment] *fn.EventReceiver[State[E, F]]
231273

232274
// RegisterStateEvents registers a new event listener that will be notified of
233275
// new state transitions.
234-
func (s *StateMachine[Event, Env]) RegisterStateEvents() StateSubscriber[Event, Env] {
276+
func (s *StateMachine[Event, Env]) RegisterStateEvents() StateSubscriber[
277+
Event, Env] {
278+
235279
subscriber := fn.NewEventReceiver[State[Event, Env]](10)
236280

237281
// TODO(roasbeef): instead give the state and the input event?
@@ -243,16 +287,17 @@ func (s *StateMachine[Event, Env]) RegisterStateEvents() StateSubscriber[Event,
243287

244288
// RemoveStateSub removes the target state subscriber from the set of active
245289
// subscribers.
246-
func (s *StateMachine[Event, Env]) RemoveStateSub(sub StateSubscriber[Event, Env]) {
247-
s.newStateEvents.RemoveSubscriber(sub)
290+
func (s *StateMachine[Event, Env]) RemoveStateSub(sub StateSubscriber[
291+
Event, Env]) {
292+
293+
_ = s.newStateEvents.RemoveSubscriber(sub)
248294
}
249295

250296
// executeDaemonEvent executes a daemon event, which is a special type of event
251297
// that can be emitted as part of the state transition function of the state
252298
// machine. An error is returned if the type of event is unknown.
253299
func (s *StateMachine[Event, Env]) executeDaemonEvent(event DaemonEvent) error {
254300
switch daemonEvent := event.(type) {
255-
256301
// This is a send message event, so we'll send the event, and also mind
257302
// any preconditions as well as post-send events.
258303
case *SendMsgEvent[Event]:
@@ -261,7 +306,8 @@ func (s *StateMachine[Event, Env]) executeDaemonEvent(event DaemonEvent) error {
261306
daemonEvent.TargetPeer, daemonEvent.Msgs,
262307
)
263308
if err != nil {
264-
return fmt.Errorf("unable to send msgs: %w", err)
309+
return fmt.Errorf("unable to send msgs: %w",
310+
err)
265311
}
266312

267313
// If a post-send event was specified, then we'll
@@ -306,7 +352,12 @@ func (s *StateMachine[Event, Env]) executeDaemonEvent(event DaemonEvent) error {
306352
)
307353

308354
if canSend {
309-
sendAndCleanUp()
355+
err := sendAndCleanUp()
356+
if err != nil {
357+
//nolint:lll
358+
log.Errorf("FSM(%v): unable to send message: %v", err)
359+
}
360+
310361
return
311362
}
312363

@@ -335,8 +386,6 @@ func (s *StateMachine[Event, Env]) executeDaemonEvent(event DaemonEvent) error {
335386
daemonEvent.Tx, daemonEvent.Label,
336387
)
337388
if err != nil {
338-
// TODO(roasbeef): hook has channel read event event is
339-
// hit?
340389
return fmt.Errorf("unable to broadcast txn: %w", err)
341390
}
342391

@@ -430,6 +479,8 @@ func (s *StateMachine[Event, Env]) applyEvents(currentState State[Event, Env],
430479
// any new emitted internal events to our event queue. This continues
431480
// until we reach a terminal state, or we run out of internal events to
432481
// process.
482+
//
483+
//nolint:lll
433484
for nextEvent := eventQueue.Dequeue(); nextEvent.IsSome(); nextEvent = eventQueue.Dequeue() {
434485
err := fn.MapOptionZ(nextEvent, func(event Event) error {
435486
// Apply the state transition function of the current
@@ -442,13 +493,17 @@ func (s *StateMachine[Event, Env]) applyEvents(currentState State[Event, Env],
442493
}
443494

444495
newEvents := transition.NewEvents
445-
err = fn.MapOptionZ(newEvents, func(events EmittedEvent[Event]) error {
496+
err = fn.MapOptionZ(newEvents, func(events EmittedEvent[Event]) error { //nolint:lll
446497
// With the event processed, we'll process any
447498
// new daemon events that were emitted as part
448499
// of this new state transition.
500+
//
501+
//nolint:lll
449502
err := fn.MapOptionZ(events.ExternalEvents, func(dEvents DaemonEventSet) error {
450503
for _, dEvent := range dEvents {
451-
err := s.executeDaemonEvent(dEvent)
504+
err := s.executeDaemonEvent(
505+
dEvent,
506+
)
452507
if err != nil {
453508
return err
454509
}
@@ -462,6 +517,8 @@ func (s *StateMachine[Event, Env]) applyEvents(currentState State[Event, Env],
462517

463518
// Next, we'll add any new emitted events to
464519
// our event queue.
520+
//
521+
//nolint:lll
465522
events.InternalEvent.WhenSome(func(inEvent Event) {
466523
eventQueue.Enqueue(inEvent)
467524
})
@@ -543,7 +600,10 @@ func (s *StateMachine[Event, Env]) driveMachine() {
543600
// An outside caller is querying our state, so we'll return the
544601
// latest state.
545602
case stateQuery := <-s.stateQuery:
546-
if !fn.SendOrQuit(stateQuery.CurrentState, currentState, s.quit) {
603+
if !fn.SendOrQuit(
604+
stateQuery.CurrentState, currentState, s.quit,
605+
) {
606+
547607
return
548608
}
549609

protofsm/state_machine_test.go

Lines changed: 81 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -177,13 +177,17 @@ func newDaemonAdapters() *dummyAdapters {
177177
}
178178
}
179179

180-
func (d *dummyAdapters) SendMessages(pub btcec.PublicKey, msgs []lnwire.Message) error {
180+
func (d *dummyAdapters) SendMessages(pub btcec.PublicKey,
181+
msgs []lnwire.Message) error {
182+
181183
args := d.Called(pub, msgs)
182184

183185
return args.Error(0)
184186
}
185187

186-
func (d *dummyAdapters) BroadcastTransaction(tx *wire.MsgTx, label string) error {
188+
func (d *dummyAdapters) BroadcastTransaction(tx *wire.MsgTx,
189+
label string) error {
190+
187191
args := d.Called(tx, label)
188192

189193
return args.Error(0)
@@ -203,6 +207,7 @@ func (d *dummyAdapters) RegisterConfirmationsNtfn(txid *chainhash.Hash,
203207
args := d.Called(txid, pkScript, numConfs)
204208

205209
err := args.Error(0)
210+
206211
return &chainntnfs.ConfirmationEvent{
207212
Confirmed: d.confChan,
208213
}, err
@@ -396,7 +401,9 @@ func TestStateMachineDaemonEvents(t *testing.T) {
396401
// As soon as we send in the daemon event, we expect the
397402
// disable+broadcast events to be processed, as they are unconditional.
398403
adapters.On("DisableChannel", mock.Anything).Return(nil)
399-
adapters.On("BroadcastTransaction", mock.Anything, mock.Anything).Return(nil)
404+
adapters.On(
405+
"BroadcastTransaction", mock.Anything, mock.Anything,
406+
).Return(nil)
400407
adapters.On("SendMessages", *pub2, mock.Anything).Return(nil)
401408

402409
// We'll start off by sending in the daemon event, which'll trigger the
@@ -428,3 +435,74 @@ func TestStateMachineDaemonEvents(t *testing.T) {
428435
adapters.AssertExpectations(t)
429436
env.AssertExpectations(t)
430437
}
438+
439+
type dummyMsgMapper struct {
440+
mock.Mock
441+
}
442+
443+
func (d *dummyMsgMapper) MapMsg(wireMsg lnwire.Message) fn.Option[dummyEvents] {
444+
args := d.Called(wireMsg)
445+
446+
//nolint:forcetypeassert
447+
return args.Get(0).(fn.Option[dummyEvents])
448+
}
449+
450+
// TestStateMachineMsgMapper tests that given a message mapper, we can properly
451+
// send in wire messages get mapped to FSM events.
452+
func TestStateMachineMsgMapper(t *testing.T) {
453+
// First, we'll create our state machine given the env, and our
454+
// starting state.
455+
env := &dummyEnv{}
456+
startingState := &dummyStateStart{}
457+
adapters := newDaemonAdapters()
458+
459+
// We'll also provide a message mapper that only knows how to map a
460+
// single wire message (error).
461+
dummyMapper := &dummyMsgMapper{}
462+
463+
// The only thing we know how to map is the error message, which'll
464+
// terminate the state machine.
465+
wireError := &lnwire.Error{}
466+
initMsg := &lnwire.Init{}
467+
dummyMapper.On("MapMsg", wireError).Return(
468+
fn.Some(dummyEvents(&goToFin{})),
469+
)
470+
dummyMapper.On("MapMsg", initMsg).Return(fn.None[dummyEvents]())
471+
472+
cfg := StateMachineCfg[dummyEvents, *dummyEnv]{
473+
Daemon: adapters,
474+
InitialState: startingState,
475+
Env: env,
476+
MsgMapper: fn.Some[MsgMapper[dummyEvents]](dummyMapper),
477+
}
478+
stateMachine := NewStateMachine(cfg)
479+
stateMachine.Start()
480+
defer stateMachine.Stop()
481+
482+
// As we're triggering internal events, we'll also subscribe to the set
483+
// of new states so we can assert as we go.
484+
stateSub := stateMachine.RegisterStateEvents()
485+
defer stateMachine.RemoveStateSub(stateSub)
486+
487+
// We'll still be going to a terminal state, so we expect that the
488+
// clean up method will be called.
489+
env.On("CleanUp").Return(nil)
490+
491+
// First, we'll verify that the CanHandle method works as expected.
492+
require.True(t, stateMachine.CanHandle(wireError))
493+
require.False(t, stateMachine.CanHandle(&lnwire.Init{}))
494+
495+
// Next, we'll attempt to send the wire message into the state machine.
496+
// We should transition to the final state.
497+
require.True(t, stateMachine.SendMessage(wireError))
498+
499+
// We should transition to the final state.
500+
expectedStates := []State[dummyEvents, *dummyEnv]{
501+
&dummyStateStart{}, &dummyStateFin{},
502+
}
503+
assertStateTransitions(t, stateSub, expectedStates)
504+
505+
dummyMapper.AssertExpectations(t)
506+
adapters.AssertExpectations(t)
507+
env.AssertExpectations(t)
508+
}

0 commit comments

Comments
 (0)