| 
 | 1 | +package msgmux  | 
 | 2 | + | 
 | 3 | +// For some reason golangci-lint has a false positive on the sort order of the  | 
 | 4 | +// imports for the new "maps" package... We need the nolint directive here to  | 
 | 5 | +// ignore that.  | 
 | 6 | +//  | 
 | 7 | +//nolint:gci  | 
 | 8 | +import (  | 
 | 9 | +	"fmt"  | 
 | 10 | +	"maps"  | 
 | 11 | +	"sync"  | 
 | 12 | + | 
 | 13 | +	"github.com/btcsuite/btcd/btcec/v2"  | 
 | 14 | +	"github.com/lightningnetwork/lnd/fn"  | 
 | 15 | +	"github.com/lightningnetwork/lnd/lnwire"  | 
 | 16 | +)  | 
 | 17 | + | 
 | 18 | +var (  | 
 | 19 | +	// ErrDuplicateEndpoint is returned when an endpoint is registered with  | 
 | 20 | +	// a name that already exists.  | 
 | 21 | +	ErrDuplicateEndpoint = fmt.Errorf("endpoint already registered")  | 
 | 22 | + | 
 | 23 | +	// ErrUnableToRouteMsg is returned when a message is unable to be  | 
 | 24 | +	// routed to any endpoints.  | 
 | 25 | +	ErrUnableToRouteMsg = fmt.Errorf("unable to route message")  | 
 | 26 | +)  | 
 | 27 | + | 
 | 28 | +// EndpointName is the name of a given endpoint. This MUST be unique across all  | 
 | 29 | +// registered endpoints.  | 
 | 30 | +type EndpointName = string  | 
 | 31 | + | 
 | 32 | +// PeerMsg is a wire message that includes the public key of the peer that sent  | 
 | 33 | +// it.  | 
 | 34 | +type PeerMsg struct {  | 
 | 35 | +	lnwire.Message  | 
 | 36 | + | 
 | 37 | +	// PeerPub is the public key of the peer that sent this message.  | 
 | 38 | +	PeerPub btcec.PublicKey  | 
 | 39 | +}  | 
 | 40 | + | 
 | 41 | +// Endpoint is an interface that represents a message endpoint, or the  | 
 | 42 | +// sub-system that will handle processing an incoming wire message.  | 
 | 43 | +type Endpoint interface {  | 
 | 44 | +	// Name returns the name of this endpoint. This MUST be unique across  | 
 | 45 | +	// all registered endpoints.  | 
 | 46 | +	Name() EndpointName  | 
 | 47 | + | 
 | 48 | +	// CanHandle returns true if the target message can be routed to this  | 
 | 49 | +	// endpoint.  | 
 | 50 | +	CanHandle(msg PeerMsg) bool  | 
 | 51 | + | 
 | 52 | +	// SendMessage handles the target message, and returns true if the  | 
 | 53 | +	// message was able being processed.  | 
 | 54 | +	SendMessage(msg PeerMsg) bool  | 
 | 55 | +}  | 
 | 56 | + | 
 | 57 | +// MsgRouter is an interface that represents a message router, which is generic  | 
 | 58 | +// sub-system capable of routing any incoming wire message to a set of  | 
 | 59 | +// registered endpoints.  | 
 | 60 | +type Router interface {  | 
 | 61 | +	// RegisterEndpoint registers a new endpoint with the router. If a  | 
 | 62 | +	// duplicate endpoint exists, an error is returned.  | 
 | 63 | +	RegisterEndpoint(Endpoint) error  | 
 | 64 | + | 
 | 65 | +	// UnregisterEndpoint unregisters the target endpoint from the router.  | 
 | 66 | +	UnregisterEndpoint(EndpointName) error  | 
 | 67 | + | 
 | 68 | +	// RouteMsg attempts to route the target message to a registered  | 
 | 69 | +	// endpoint. If ANY endpoint could handle the message, then nil is  | 
 | 70 | +	// returned. Otherwise, ErrUnableToRouteMsg is returned.  | 
 | 71 | +	RouteMsg(PeerMsg) error  | 
 | 72 | + | 
 | 73 | +	// Start starts the peer message router.  | 
 | 74 | +	Start()  | 
 | 75 | + | 
 | 76 | +	// Stop stops the peer message router.  | 
 | 77 | +	Stop()  | 
 | 78 | +}  | 
 | 79 | + | 
 | 80 | +// sendQuery sends a query to the main event loop, and returns the response.  | 
 | 81 | +func sendQuery[Q any, R any](sendChan chan fn.Req[Q, R], queryArg Q,  | 
 | 82 | +	quit chan struct{}) fn.Result[R] {  | 
 | 83 | + | 
 | 84 | +	query, respChan := fn.NewReq[Q, R](queryArg)  | 
 | 85 | + | 
 | 86 | +	if !fn.SendOrQuit(sendChan, query, quit) {  | 
 | 87 | +		return fn.Errf[R]("router shutting down")  | 
 | 88 | +	}  | 
 | 89 | + | 
 | 90 | +	return fn.NewResult(fn.RecvResp(respChan, nil, quit))  | 
 | 91 | +}  | 
 | 92 | + | 
 | 93 | +// sendQueryErr is a helper function based on sendQuery that can be used when  | 
 | 94 | +// the query only needs an error response.  | 
 | 95 | +func sendQueryErr[Q any](sendChan chan fn.Req[Q, error], queryArg Q,  | 
 | 96 | +	quitChan chan struct{}) error {  | 
 | 97 | + | 
 | 98 | +	return fn.ElimEither(  | 
 | 99 | +		fn.Iden, fn.Iden,  | 
 | 100 | +		sendQuery(sendChan, queryArg, quitChan).Either,  | 
 | 101 | +	)  | 
 | 102 | +}  | 
 | 103 | + | 
 | 104 | +// EndpointsMap is a map of all registered endpoints.  | 
 | 105 | +type EndpointsMap map[EndpointName]Endpoint  | 
 | 106 | + | 
 | 107 | +// MultiMsgRouter is a type of message router that is capable of routing new  | 
 | 108 | +// incoming messages, permitting a message to be routed to multiple registered  | 
 | 109 | +// endpoints.  | 
 | 110 | +type MultiMsgRouter struct {  | 
 | 111 | +	startOnce sync.Once  | 
 | 112 | +	stopOnce  sync.Once  | 
 | 113 | + | 
 | 114 | +	// registerChan is the channel that all new endpoints will be sent to.  | 
 | 115 | +	registerChan chan fn.Req[Endpoint, error]  | 
 | 116 | + | 
 | 117 | +	// unregisterChan is the channel that all endpoints that are to be  | 
 | 118 | +	// removed are sent to.  | 
 | 119 | +	unregisterChan chan fn.Req[EndpointName, error]  | 
 | 120 | + | 
 | 121 | +	// msgChan is the channel that all messages will be sent to for  | 
 | 122 | +	// processing.  | 
 | 123 | +	msgChan chan fn.Req[PeerMsg, error]  | 
 | 124 | + | 
 | 125 | +	// endpointsQueries is a channel that all queries to the endpoints map  | 
 | 126 | +	// will be sent to.  | 
 | 127 | +	endpointQueries chan fn.Req[Endpoint, EndpointsMap]  | 
 | 128 | + | 
 | 129 | +	wg   sync.WaitGroup  | 
 | 130 | +	quit chan struct{}  | 
 | 131 | +}  | 
 | 132 | + | 
 | 133 | +// NewMultiMsgRouter creates a new instance of a peer message router.  | 
 | 134 | +func NewMultiMsgRouter() *MultiMsgRouter {  | 
 | 135 | +	return &MultiMsgRouter{  | 
 | 136 | +		registerChan:    make(chan fn.Req[Endpoint, error]),  | 
 | 137 | +		unregisterChan:  make(chan fn.Req[EndpointName, error]),  | 
 | 138 | +		msgChan:         make(chan fn.Req[PeerMsg, error]),  | 
 | 139 | +		endpointQueries: make(chan fn.Req[Endpoint, EndpointsMap]),  | 
 | 140 | +		quit:            make(chan struct{}),  | 
 | 141 | +	}  | 
 | 142 | +}  | 
 | 143 | + | 
 | 144 | +// Start starts the peer message router.  | 
 | 145 | +func (p *MultiMsgRouter) Start() {  | 
 | 146 | +	log.Infof("Starting Router")  | 
 | 147 | + | 
 | 148 | +	p.startOnce.Do(func() {  | 
 | 149 | +		p.wg.Add(1)  | 
 | 150 | +		go p.msgRouter()  | 
 | 151 | +	})  | 
 | 152 | +}  | 
 | 153 | + | 
 | 154 | +// Stop stops the peer message router.  | 
 | 155 | +func (p *MultiMsgRouter) Stop() {  | 
 | 156 | +	log.Infof("Stopping Router")  | 
 | 157 | + | 
 | 158 | +	p.stopOnce.Do(func() {  | 
 | 159 | +		close(p.quit)  | 
 | 160 | +		p.wg.Wait()  | 
 | 161 | +	})  | 
 | 162 | +}  | 
 | 163 | + | 
 | 164 | +// RegisterEndpoint registers a new endpoint with the router. If a duplicate  | 
 | 165 | +// endpoint exists, an error is returned.  | 
 | 166 | +func (p *MultiMsgRouter) RegisterEndpoint(endpoint Endpoint) error {  | 
 | 167 | +	return sendQueryErr(p.registerChan, endpoint, p.quit)  | 
 | 168 | +}  | 
 | 169 | + | 
 | 170 | +// UnregisterEndpoint unregisters the target endpoint from the router.  | 
 | 171 | +func (p *MultiMsgRouter) UnregisterEndpoint(name EndpointName) error {  | 
 | 172 | +	return sendQueryErr(p.unregisterChan, name, p.quit)  | 
 | 173 | +}  | 
 | 174 | + | 
 | 175 | +// RouteMsg attempts to route the target message to a registered endpoint. If  | 
 | 176 | +// ANY endpoint could handle the message, then nil is returned.  | 
 | 177 | +func (p *MultiMsgRouter) RouteMsg(msg PeerMsg) error {  | 
 | 178 | +	return sendQueryErr(p.msgChan, msg, p.quit)  | 
 | 179 | +}  | 
 | 180 | + | 
 | 181 | +// Endpoints returns a list of all registered endpoints.  | 
 | 182 | +func (p *MultiMsgRouter) endpoints() fn.Result[EndpointsMap] {  | 
 | 183 | +	return sendQuery(p.endpointQueries, nil, p.quit)  | 
 | 184 | +}  | 
 | 185 | + | 
 | 186 | +// msgRouter is the main goroutine that handles all incoming messages.  | 
 | 187 | +func (p *MultiMsgRouter) msgRouter() {  | 
 | 188 | +	defer p.wg.Done()  | 
 | 189 | + | 
 | 190 | +	// endpoints is a map of all registered endpoints.  | 
 | 191 | +	endpoints := make(map[EndpointName]Endpoint)  | 
 | 192 | + | 
 | 193 | +	for {  | 
 | 194 | +		select {  | 
 | 195 | +		// A new endpoint was just sent in, so we'll add it to our set  | 
 | 196 | +		// of registered endpoints.  | 
 | 197 | +		case newEndpointMsg := <-p.registerChan:  | 
 | 198 | +			endpoint := newEndpointMsg.Request  | 
 | 199 | + | 
 | 200 | +			log.Infof("MsgRouter: registering new "+  | 
 | 201 | +				"Endpoint(%s)", endpoint.Name())  | 
 | 202 | + | 
 | 203 | +			// If this endpoint already exists, then we'll return  | 
 | 204 | +			// an error as we require unique names.  | 
 | 205 | +			if _, ok := endpoints[endpoint.Name()]; ok {  | 
 | 206 | +				log.Errorf("MsgRouter: rejecting "+  | 
 | 207 | +					"duplicate endpoint: %v",  | 
 | 208 | +					endpoint.Name())  | 
 | 209 | + | 
 | 210 | +				newEndpointMsg.Resolve(ErrDuplicateEndpoint)  | 
 | 211 | + | 
 | 212 | +				continue  | 
 | 213 | +			}  | 
 | 214 | + | 
 | 215 | +			endpoints[endpoint.Name()] = endpoint  | 
 | 216 | + | 
 | 217 | +			newEndpointMsg.Resolve(nil)  | 
 | 218 | + | 
 | 219 | +		// A request to unregister an endpoint was just sent in, so  | 
 | 220 | +		// we'll attempt to remove it.  | 
 | 221 | +		case endpointName := <-p.unregisterChan:  | 
 | 222 | +			delete(endpoints, endpointName.Request)  | 
 | 223 | + | 
 | 224 | +			log.Infof("MsgRouter: unregistering "+  | 
 | 225 | +				"Endpoint(%s)", endpointName.Request)  | 
 | 226 | + | 
 | 227 | +			endpointName.Resolve(nil)  | 
 | 228 | + | 
 | 229 | +		// A new message was just sent in. We'll attempt to route it to  | 
 | 230 | +		// all the endpoints that can handle it.  | 
 | 231 | +		case msgQuery := <-p.msgChan:  | 
 | 232 | +			msg := msgQuery.Request  | 
 | 233 | + | 
 | 234 | +			// Loop through all the endpoints and send the message  | 
 | 235 | +			// to those that can handle it the message.  | 
 | 236 | +			var couldSend bool  | 
 | 237 | +			for _, endpoint := range endpoints {  | 
 | 238 | +				if endpoint.CanHandle(msg) {  | 
 | 239 | +					log.Tracef("MsgRouter: sending "+  | 
 | 240 | +						"msg %T to endpoint %s", msg,  | 
 | 241 | +						endpoint.Name())  | 
 | 242 | + | 
 | 243 | +					sent := endpoint.SendMessage(msg)  | 
 | 244 | +					couldSend = couldSend || sent  | 
 | 245 | +				}  | 
 | 246 | +			}  | 
 | 247 | + | 
 | 248 | +			var err error  | 
 | 249 | +			if !couldSend {  | 
 | 250 | +				log.Tracef("MsgRouter: unable to route "+  | 
 | 251 | +					"msg %T", msg)  | 
 | 252 | + | 
 | 253 | +				err = ErrUnableToRouteMsg  | 
 | 254 | +			}  | 
 | 255 | + | 
 | 256 | +			msgQuery.Resolve(err)  | 
 | 257 | + | 
 | 258 | +		// A query for the endpoint state just came in, we'll send back  | 
 | 259 | +		// a copy of our current state.  | 
 | 260 | +		case endpointQuery := <-p.endpointQueries:  | 
 | 261 | +			endpointsCopy := make(EndpointsMap, len(endpoints))  | 
 | 262 | +			maps.Copy(endpointsCopy, endpoints)  | 
 | 263 | + | 
 | 264 | +			endpointQuery.Resolve(endpointsCopy)  | 
 | 265 | + | 
 | 266 | +		case <-p.quit:  | 
 | 267 | +			return  | 
 | 268 | +		}  | 
 | 269 | +	}  | 
 | 270 | +}  | 
 | 271 | + | 
 | 272 | +// A compile time check to ensure MultiMsgRouter implements the MsgRouter  | 
 | 273 | +// interface.  | 
 | 274 | +var _ Router = (*MultiMsgRouter)(nil)  | 
0 commit comments