| 
 | 1 | +package peer  | 
 | 2 | + | 
 | 3 | +import (  | 
 | 4 | +	"fmt"  | 
 | 5 | +	"maps"  | 
 | 6 | +	"sync"  | 
 | 7 | + | 
 | 8 | +	"github.com/lightningnetwork/lnd/fn"  | 
 | 9 | +	"github.com/lightningnetwork/lnd/lnwire"  | 
 | 10 | +)  | 
 | 11 | + | 
 | 12 | +var (  | 
 | 13 | +	// ErrDuplicateEndpoint is returned when an endpoint is registered with  | 
 | 14 | +	// a name that already exists.  | 
 | 15 | +	ErrDuplicateEndpoint = fmt.Errorf("endpoint already registered")  | 
 | 16 | + | 
 | 17 | +	// ErrUnableToRouteMsg is returned when a message is unable to be  | 
 | 18 | +	// routed to any endpoints.  | 
 | 19 | +	ErrUnableToRouteMsg = fmt.Errorf("unable to route message")  | 
 | 20 | +)  | 
 | 21 | + | 
 | 22 | +// EndPointName is the name of a given endpoint. This MUST be unique across all  | 
 | 23 | +// registered endpoints.  | 
 | 24 | +type EndPointName = string  | 
 | 25 | + | 
 | 26 | +// MsgEndpoint is an interface that represents a message endpoint, or the  | 
 | 27 | +// sub-system that will handle processing an incoming wire message.  | 
 | 28 | +type MsgEndpoint interface {  | 
 | 29 | +	// Name returns the name of this endpoint. This MUST be unique across  | 
 | 30 | +	// all registered endpoints.  | 
 | 31 | +	Name() EndPointName  | 
 | 32 | + | 
 | 33 | +	// CanHandle returns true if the target message can be routed to this  | 
 | 34 | +	// endpoint.  | 
 | 35 | +	CanHandle(msg lnwire.Message) bool  | 
 | 36 | + | 
 | 37 | +	// SendMessage handles the target message, and returns true if the  | 
 | 38 | +	// message was able to be processed.  | 
 | 39 | +	SendMessage(msg lnwire.Message) bool  | 
 | 40 | +}  | 
 | 41 | + | 
 | 42 | +// MsgRouter is an interface that represents a message router, which is generic  | 
 | 43 | +// sub-system capable of routing any incoming wire message to a set of  | 
 | 44 | +// registered endpoints.  | 
 | 45 | +//  | 
 | 46 | +// TODO(roasbeef): move to diff sub-system?  | 
 | 47 | +type MsgRouter interface {  | 
 | 48 | +	// RegisterEndpoint registers a new endpoint with the router. If a  | 
 | 49 | +	// duplicate endpoint exists, an error is returned.  | 
 | 50 | +	RegisterEndpoint(MsgEndpoint) error  | 
 | 51 | + | 
 | 52 | +	// UnregisterEndpoint unregisters the target endpoint from the router.  | 
 | 53 | +	UnregisterEndpoint(EndPointName) error  | 
 | 54 | + | 
 | 55 | +	// RouteMsg attempts to route the target message to a registered  | 
 | 56 | +	// endpoint. If ANY endpoint could handle the message, then nil is  | 
 | 57 | +	// returned. Otherwise, ErrUnableToRouteMsg is returned.  | 
 | 58 | +	RouteMsg(lnwire.Message) error  | 
 | 59 | + | 
 | 60 | +	// Start starts the peer message router.  | 
 | 61 | +	Start()  | 
 | 62 | + | 
 | 63 | +	// Stop stops the peer message router.  | 
 | 64 | +	Stop()  | 
 | 65 | +}  | 
 | 66 | + | 
 | 67 | +// queryMsg is a message sent into the main event loop to query or modify the  | 
 | 68 | +// internal state.  | 
 | 69 | +type queryMsg[Q any, R any] struct {  | 
 | 70 | +	query Q  | 
 | 71 | + | 
 | 72 | +	respChan chan fn.Either[R, error]  | 
 | 73 | +}  | 
 | 74 | + | 
 | 75 | +// sendQuery sends a query to the main event loop, and returns the response.  | 
 | 76 | +func sendQuery[Q any, R any](sendChan chan queryMsg[Q, R], queryArg Q,  | 
 | 77 | +	quit chan struct{}) fn.Either[R, error] {  | 
 | 78 | + | 
 | 79 | +	query := queryMsg[Q, R]{  | 
 | 80 | +		query:    queryArg,  | 
 | 81 | +		respChan: make(chan fn.Either[R, error], 1),  | 
 | 82 | +	}  | 
 | 83 | + | 
 | 84 | +	if !fn.SendOrQuit(sendChan, query, quit) {  | 
 | 85 | +		return fn.NewRight[R](fmt.Errorf("router shutting down"))  | 
 | 86 | +	}  | 
 | 87 | + | 
 | 88 | +	resp, err := fn.RecvResp(query.respChan, nil, quit)  | 
 | 89 | +	if err != nil {  | 
 | 90 | +		return fn.NewRight[R](err)  | 
 | 91 | +	}  | 
 | 92 | + | 
 | 93 | +	return resp  | 
 | 94 | +}  | 
 | 95 | + | 
 | 96 | +// sendQueryErr is a helper function based on sendQuery that can be used when  | 
 | 97 | +// the query only needs an error response.  | 
 | 98 | +func sendQueryErr[Q any](sendChan chan queryMsg[Q, error], queryArg Q,  | 
 | 99 | +	quitChan chan struct{}) error {  | 
 | 100 | + | 
 | 101 | +	var err error  | 
 | 102 | +	resp := sendQuery(sendChan, queryArg, quitChan)  | 
 | 103 | +	resp.WhenRight(func(e error) {  | 
 | 104 | +		err = e  | 
 | 105 | +	})  | 
 | 106 | +	resp.WhenLeft(func(e error) {  | 
 | 107 | +		err = e  | 
 | 108 | +	})  | 
 | 109 | + | 
 | 110 | +	return err  | 
 | 111 | +}  | 
 | 112 | + | 
 | 113 | +// EndpointsMap is a map of all registered endpoints.  | 
 | 114 | +type EndpointsMap map[EndPointName]MsgEndpoint  | 
 | 115 | + | 
 | 116 | +// MultiMsgRouter is a type of message router that is capable of routing new  | 
 | 117 | +// incoming messages, permitting a message to be routed to multiple registered  | 
 | 118 | +// endpoints.  | 
 | 119 | +type MultiMsgRouter struct {  | 
 | 120 | +	startOnce sync.Once  | 
 | 121 | +	stopOnce  sync.Once  | 
 | 122 | + | 
 | 123 | +	// registerChan is the channel that all new endpoints will be sent to.  | 
 | 124 | +	registerChan chan queryMsg[MsgEndpoint, error]  | 
 | 125 | + | 
 | 126 | +	// unregisterChan is the channel that all endpoints that are to be  | 
 | 127 | +	// removed are sent to.  | 
 | 128 | +	unregisterChan chan queryMsg[EndPointName, error]  | 
 | 129 | + | 
 | 130 | +	// msgChan is the channel that all messages will be sent to for  | 
 | 131 | +	// processing.  | 
 | 132 | +	msgChan chan queryMsg[lnwire.Message, error]  | 
 | 133 | + | 
 | 134 | +	// endpointsQueries is a channel that all queries to the endpoints map  | 
 | 135 | +	// will be sent to.  | 
 | 136 | +	endpointQueries chan queryMsg[MsgEndpoint, EndpointsMap]  | 
 | 137 | + | 
 | 138 | +	wg   sync.WaitGroup  | 
 | 139 | +	quit chan struct{}  | 
 | 140 | +}  | 
 | 141 | + | 
 | 142 | +// NewMultiMsgRouter creates a new instance of a peer message router.  | 
 | 143 | +func NewMultiMsgRouter() *MultiMsgRouter {  | 
 | 144 | +	return &MultiMsgRouter{  | 
 | 145 | +		registerChan:    make(chan queryMsg[MsgEndpoint, error]),  | 
 | 146 | +		unregisterChan:  make(chan queryMsg[EndPointName, error]),  | 
 | 147 | +		msgChan:         make(chan queryMsg[lnwire.Message, error]),  | 
 | 148 | +		endpointQueries: make(chan queryMsg[MsgEndpoint, EndpointsMap]),  | 
 | 149 | +		quit:            make(chan struct{}),  | 
 | 150 | +	}  | 
 | 151 | +}  | 
 | 152 | + | 
 | 153 | +// Start starts the peer message router.  | 
 | 154 | +func (p *MultiMsgRouter) Start() {  | 
 | 155 | +	peerLog.Infof("Starting MsgRouter")  | 
 | 156 | + | 
 | 157 | +	p.startOnce.Do(func() {  | 
 | 158 | +		p.wg.Add(1)  | 
 | 159 | +		go p.msgRouter()  | 
 | 160 | +	})  | 
 | 161 | +}  | 
 | 162 | + | 
 | 163 | +// Stop stops the peer message router.  | 
 | 164 | +func (p *MultiMsgRouter) Stop() {  | 
 | 165 | +	peerLog.Infof("Stopping MsgRouter")  | 
 | 166 | + | 
 | 167 | +	p.stopOnce.Do(func() {  | 
 | 168 | +		close(p.quit)  | 
 | 169 | +		p.wg.Wait()  | 
 | 170 | +	})  | 
 | 171 | +}  | 
 | 172 | + | 
 | 173 | +// RegisterEndpoint registers a new endpoint with the router. If a duplicate  | 
 | 174 | +// endpoint exists, an error is returned.  | 
 | 175 | +func (p *MultiMsgRouter) RegisterEndpoint(endpoint MsgEndpoint) error {  | 
 | 176 | +	return sendQueryErr(p.registerChan, endpoint, p.quit)  | 
 | 177 | +}  | 
 | 178 | + | 
 | 179 | +// UnregisterEndpoint unregisters the target endpoint from the router.  | 
 | 180 | +func (p *MultiMsgRouter) UnregisterEndpoint(name EndPointName) error {  | 
 | 181 | +	return sendQueryErr(p.unregisterChan, name, p.quit)  | 
 | 182 | +}  | 
 | 183 | + | 
 | 184 | +// RouteMsg attempts to route the target message to a registered endpoint. If  | 
 | 185 | +// ANY endpoint could handle the message, then nil is returned.  | 
 | 186 | +func (p *MultiMsgRouter) RouteMsg(msg lnwire.Message) error {  | 
 | 187 | +	return sendQueryErr(p.msgChan, msg, p.quit)  | 
 | 188 | +}  | 
 | 189 | + | 
 | 190 | +// Endpoints returns a list of all registered endpoints.  | 
 | 191 | +func (p *MultiMsgRouter) Endpoints() EndpointsMap {  | 
 | 192 | +	resp := sendQuery(p.endpointQueries, nil, p.quit)  | 
 | 193 | + | 
 | 194 | +	var endpoints EndpointsMap  | 
 | 195 | +	resp.WhenLeft(func(e EndpointsMap) {  | 
 | 196 | +		endpoints = e  | 
 | 197 | +	})  | 
 | 198 | + | 
 | 199 | +	return endpoints  | 
 | 200 | +}  | 
 | 201 | + | 
 | 202 | +// msgRouter is the main goroutine that handles all incoming messages.  | 
 | 203 | +func (p *MultiMsgRouter) msgRouter() {  | 
 | 204 | +	defer p.wg.Done()  | 
 | 205 | + | 
 | 206 | +	// endpoints is a map of all registered endpoints.  | 
 | 207 | +	endpoints := make(map[EndPointName]MsgEndpoint)  | 
 | 208 | + | 
 | 209 | +	for {  | 
 | 210 | +		select {  | 
 | 211 | +		// A new endpoint was just sent in, so we'll add it to our set  | 
 | 212 | +		// of registered endpoints.  | 
 | 213 | +		case newEndpointMsg := <-p.registerChan:  | 
 | 214 | +			endpoint := newEndpointMsg.query  | 
 | 215 | + | 
 | 216 | +			peerLog.Infof("MsgRouter: registering new "+  | 
 | 217 | +				"MsgEndpoint(%s)", endpoint.Name())  | 
 | 218 | + | 
 | 219 | +			// If this endpoint already exists, then we'll return  | 
 | 220 | +			// an error as we require unique names.  | 
 | 221 | +			if _, ok := endpoints[endpoint.Name()]; ok {  | 
 | 222 | +				peerLog.Errorf("MsgRouter: rejecting "+  | 
 | 223 | +					"duplicate endpoint: %v",  | 
 | 224 | +					endpoint.Name())  | 
 | 225 | + | 
 | 226 | +				newEndpointMsg.respChan <- fn.NewRight[error](  | 
 | 227 | +					ErrDuplicateEndpoint,  | 
 | 228 | +				)  | 
 | 229 | + | 
 | 230 | +				continue  | 
 | 231 | +			}  | 
 | 232 | + | 
 | 233 | +			endpoints[endpoint.Name()] = endpoint  | 
 | 234 | + | 
 | 235 | +			// TODO(roasbeef): put in method?  | 
 | 236 | +			newEndpointMsg.respChan <- fn.NewRight[error, error](  | 
 | 237 | +				nil,  | 
 | 238 | +			)  | 
 | 239 | + | 
 | 240 | +		// A request to unregister an endpoint was just sent in, so  | 
 | 241 | +		// we'll attempt to remove it.  | 
 | 242 | +		case endpointName := <-p.unregisterChan:  | 
 | 243 | +			delete(endpoints, endpointName.query)  | 
 | 244 | + | 
 | 245 | +			peerLog.Infof("MsgRouter: unregistering "+  | 
 | 246 | +				"MsgEndpoint(%s)", endpointName.query)  | 
 | 247 | + | 
 | 248 | +			endpointName.respChan <- fn.NewRight[error, error](  | 
 | 249 | +				nil,  | 
 | 250 | +			)  | 
 | 251 | + | 
 | 252 | +		// A new message was just sent in. We'll attempt to route it to  | 
 | 253 | +		// all the endpoints that can handle it.  | 
 | 254 | +		case msgQuery := <-p.msgChan:  | 
 | 255 | +			msg := msgQuery.query  | 
 | 256 | + | 
 | 257 | +			// Loop through all the endpoints and send the message  | 
 | 258 | +			// to those that can handle it the message.  | 
 | 259 | +			var couldSend bool  | 
 | 260 | +			for _, endpoint := range endpoints {  | 
 | 261 | +				if endpoint.CanHandle(msg) {  | 
 | 262 | +					peerLog.Tracef("MsgRouter: sending "+  | 
 | 263 | +						"msg %T to endpoint %s", msg,  | 
 | 264 | +						endpoint.Name())  | 
 | 265 | + | 
 | 266 | +					sent := endpoint.SendMessage(msg)  | 
 | 267 | +					couldSend = couldSend || sent  | 
 | 268 | +				}  | 
 | 269 | +			}  | 
 | 270 | + | 
 | 271 | +			var err error  | 
 | 272 | +			if !couldSend {  | 
 | 273 | +				peerLog.Tracef("MsgRouter: unable to route "+  | 
 | 274 | +					"msg %T", msg)  | 
 | 275 | + | 
 | 276 | +				err = ErrUnableToRouteMsg  | 
 | 277 | +			}  | 
 | 278 | + | 
 | 279 | +			msgQuery.respChan <- fn.NewRight[error](err)  | 
 | 280 | + | 
 | 281 | +		// A query for the endpoint state just came in, we'll send back  | 
 | 282 | +		// a copy of our current state.  | 
 | 283 | +		case endpointQuery := <-p.endpointQueries:  | 
 | 284 | +			endpointsCopy := make(EndpointsMap, len(endpoints))  | 
 | 285 | +			maps.Copy(endpointsCopy, endpoints)  | 
 | 286 | + | 
 | 287 | +			//nolint:lll  | 
 | 288 | +			endpointQuery.respChan <- fn.NewLeft[EndpointsMap, error](  | 
 | 289 | +				endpointsCopy,  | 
 | 290 | +			)  | 
 | 291 | + | 
 | 292 | +		case <-p.quit:  | 
 | 293 | +			return  | 
 | 294 | +		}  | 
 | 295 | +	}  | 
 | 296 | +}  | 
 | 297 | + | 
 | 298 | +// A compile time check to ensure MultiMsgRouter implements the MsgRouter  | 
 | 299 | +// interface.  | 
 | 300 | +var _ MsgRouter = (*MultiMsgRouter)(nil)  | 
0 commit comments