From 87bef069e3e3c4a16deb937ae26c5cdff9b380ab Mon Sep 17 00:00:00 2001 From: Elle Mouton Date: Sun, 20 Apr 2025 10:55:34 +0200 Subject: [PATCH 1/2] session: add session ID to grpc metadata via context Add grpc interceptors that inject an LNC session's ID into the context as gRPC metadata. By injecting it as such, it will be transported over the wire in any outgoing gRPC calls. This lets us be sure that any session call sent to the RPCMiddleware interceptor in LND will continue to be grouped along with the appropriate session ID. This gives LND a way to send the metadata we include back to LiT meaning that we will later on be able to extract the session ID again. --- session/context.go | 57 ++++++++++++++++++++++++++++++++++ session/server.go | 5 +-- session_rpcserver.go | 74 ++++++++++++++++++++++++++++++++++++++++++-- 3 files changed, 132 insertions(+), 4 deletions(-) create mode 100644 session/context.go diff --git a/session/context.go b/session/context.go new file mode 100644 index 000000000..bb3b89526 --- /dev/null +++ b/session/context.go @@ -0,0 +1,57 @@ +package session + +import ( + "encoding/hex" + "fmt" + + "github.com/lightningnetwork/lnd/fn" + "google.golang.org/grpc/metadata" +) + +// contextKey is a struct that is used as a key for storing session IDs +// in a context. Using this unexported type prevents collisions with other +// context keys that may be used in the same context. However, this only +// applies if the context is passed around in the same binary and not if the +// value is converted to grpc metadata and sent over the wire. In that case, +// we need to use a string key to avoid collisions with other metadata keys. +type contextKey struct { + name string +} + +// sessionIDCtxKey is the context key used to store the session ID in +// a context. The key is a string to avoid collisions with other context values +// that may also be included in grpc metadata which is why we add the 'lit' +// prefix. +var sessionIDCtxKey = contextKey{"lit_session_id"} + +// FromGRPCMetadata extracts the session ID from the given gRPC metadata kv +// pairs if one is found. +func FromGRPCMetadata(md metadata.MD) (fn.Option[ID], error) { + val := md.Get(sessionIDCtxKey.name) + if len(val) == 0 { + return fn.None[ID](), nil + } + + if len(val) != 1 { + return fn.None[ID](), fmt.Errorf("more than one session ID "+ + "found in gRPC metadata: %v", val) + } + + b, err := hex.DecodeString(val[0]) + if err != nil { + return fn.None[ID](), err + } + + sessID, err := IDFromBytes(b) + if err != nil { + return fn.None[ID](), err + } + + return fn.Some(sessID), nil +} + +// AddToGRPCMetadata adds the session ID to the given gRPC metadata kv pairs. +// The session ID is encoded as a hex string. +func AddToGRPCMetadata(md metadata.MD, id ID) { + md.Set(sessionIDCtxKey.name, hex.EncodeToString(id[:])) +} diff --git a/session/server.go b/session/server.go index 75c0e3edf..8dd75b44f 100644 --- a/session/server.go +++ b/session/server.go @@ -18,7 +18,8 @@ import ( type sessionID [33]byte -type GRPCServerCreator func(opts ...grpc.ServerOption) *grpc.Server +type GRPCServerCreator func(sessionID ID, + opts ...grpc.ServerOption) *grpc.Server type mailboxSession struct { server *grpc.Server @@ -70,7 +71,7 @@ func (m *mailboxSession) start(session *Session, } noiseConn := mailbox.NewNoiseGrpcConn(keys) - m.server = serverCreator(grpc.Creds(noiseConn)) + m.server = serverCreator(session.ID, grpc.Creds(noiseConn)) m.wg.Add(1) go m.run(mailboxServer) diff --git a/session_rpcserver.go b/session_rpcserver.go index 59ebfdb29..9baf4aa5c 100644 --- a/session_rpcserver.go +++ b/session_rpcserver.go @@ -26,6 +26,7 @@ import ( "github.com/lightningnetwork/lnd/fn" "github.com/lightningnetwork/lnd/macaroons" "google.golang.org/grpc" + "google.golang.org/grpc/metadata" "gopkg.in/macaroon-bakery.v2/bakery" "gopkg.in/macaroon-bakery.v2/bakery/checkers" "gopkg.in/macaroon.v2" @@ -77,10 +78,23 @@ func newSessionRPCServer(cfg *sessionRpcServerConfig) (*sessionRpcServer, // actual mailbox server that spins up the Terminal Connect server // interface. server := session.NewServer( - func(opts ...grpc.ServerOption) *grpc.Server { - allOpts := append(cfg.grpcOptions, opts...) + func(id session.ID, opts ...grpc.ServerOption) *grpc.Server { + // Add the session ID injector interceptors first so + // that the session ID is available in the context of + // all interceptors that come after. + allOpts := []grpc.ServerOption{ + addSessionIDToStreamCtx(id), + addSessionIDToUnaryCtx(id), + } + + allOpts = append(allOpts, cfg.grpcOptions...) + allOpts = append(allOpts, opts...) + + // Construct the gRPC server with the options. grpcServer := grpc.NewServer(allOpts...) + // Register various grpc servers with the LNC session + // server. cfg.registerGrpcServers(grpcServer) return grpcServer @@ -94,6 +108,62 @@ func newSessionRPCServer(cfg *sessionRpcServerConfig) (*sessionRpcServer, }, nil } +// wrappedServerStream is a wrapper around the grpc.ServerStream that allows us +// to set a custom context. This is needed since the stream handler function +// doesn't take a context as an argument, but rather has a Context method on the +// handler itself. So we use this custom wrapper to override this method. +type wrappedServerStream struct { + grpc.ServerStream + ctx context.Context +} + +// Context returns the context of the stream. +// +// NOTE: This implements the grpc.ServerStream Context method. +func (w *wrappedServerStream) Context() context.Context { + return w.ctx +} + +// addSessionIDToStreamCtx is a gRPC stream interceptor that adds the given +// session ID to the context of the stream. This allows us to access the +// session ID later on for any gRPC calls made through this stream. +func addSessionIDToStreamCtx(id session.ID) grpc.ServerOption { + return grpc.StreamInterceptor(func(srv any, ss grpc.ServerStream, + info *grpc.StreamServerInfo, + handler grpc.StreamHandler) error { + + md, _ := metadata.FromIncomingContext(ss.Context()) + mdCopy := md.Copy() + session.AddToGRPCMetadata(mdCopy, id) + + // Wrap the original stream with our custom context. + wrapped := &wrappedServerStream{ + ServerStream: ss, + ctx: metadata.NewIncomingContext( + ss.Context(), mdCopy, + ), + } + + return handler(srv, wrapped) + }) +} + +// addSessionIDToUnaryCtx is a gRPC unary interceptor that adds the given +// session ID to the context of the unary call. This allows us to access the +// session ID later on for any gRPC calls made through this context. +func addSessionIDToUnaryCtx(id session.ID) grpc.ServerOption { + return grpc.UnaryInterceptor(func(ctx context.Context, req any, + info *grpc.UnaryServerInfo, + handler grpc.UnaryHandler) (resp any, err error) { + + md, _ := metadata.FromIncomingContext(ctx) + mdCopy := md.Copy() + session.AddToGRPCMetadata(mdCopy, id) + + return handler(metadata.NewIncomingContext(ctx, mdCopy), req) + }) +} + // start all the components necessary for the sessionRpcServer to start serving // requests. This includes resuming all non-revoked sessions. func (s *sessionRpcServer) start(ctx context.Context) error { From a89b3502e4903fe9d475cc7977a654cb803bde50 Mon Sep 17 00:00:00 2001 From: Elle Mouton Date: Tue, 13 May 2025 08:53:03 +0200 Subject: [PATCH 2/2] firewall: extract SessionID from gRPC metadata In this commit, we update our various firewall interceptors so that they rely on the session ID passed via gRPC metadata to extract a session ID. For the PrivacyMapper and RuleEnforcer, these _MUST_ always contain a session ID and so we error out if one was not found. For the request logger, the session ID is optional and so we pass it to the new SessionID field in the AddActionReq - our bbolt actions DB will not make use of this field on persistence (but our incoming SQL version will). --- firewall/privacy_mapper.go | 6 ++++-- firewall/privacy_mapper_test.go | 9 +++++++++ firewall/request_info.go | 18 ++++++++++++++++++ firewall/request_logger.go | 1 + firewall/rule_enforcer.go | 6 ++++-- 5 files changed, 36 insertions(+), 4 deletions(-) diff --git a/firewall/privacy_mapper.go b/firewall/privacy_mapper.go index fed4ba531..49aaf20f2 100644 --- a/firewall/privacy_mapper.go +++ b/firewall/privacy_mapper.go @@ -106,9 +106,11 @@ func (p *PrivacyMapper) Intercept(ctx context.Context, "interception request: %v", err) } - sessionID, err := session.IDFromMacaroon(ri.Macaroon) + sessionID, err := ri.SessionID.UnwrapOrErr( + fmt.Errorf("no session ID found in request info"), + ) if err != nil { - return nil, fmt.Errorf("could not extract ID from macaroon") + return nil, err } log.Tracef("PrivacyMapper: Intercepting %v", ri) diff --git a/firewall/privacy_mapper_test.go b/firewall/privacy_mapper_test.go index 9dcc814b2..61b24cd4c 100644 --- a/firewall/privacy_mapper_test.go +++ b/firewall/privacy_mapper_test.go @@ -13,6 +13,7 @@ import ( "github.com/lightningnetwork/lnd/lnrpc" "github.com/lightningnetwork/lnd/rpcperms" "github.com/stretchr/testify/require" + "google.golang.org/grpc/metadata" "google.golang.org/protobuf/proto" "gopkg.in/macaroon-bakery.v2/bakery" "gopkg.in/macaroon.v2" @@ -907,6 +908,9 @@ func TestPrivacyMapper(t *testing.T) { rawMsg, err := proto.Marshal(test.msg) require.NoError(t, err) + md := make(metadata.MD) + session.AddToGRPCMetadata(md, sessionID) + interceptReq := &rpcperms.InterceptionRequest{ Type: test.msgType, Macaroon: mac, @@ -916,6 +920,7 @@ func TestPrivacyMapper(t *testing.T) { ProtoTypeName: string( proto.MessageName(test.msg), ), + CtxMetadataPairs: md, } mwReq, err := interceptReq.ToRPC(1, 2) @@ -1006,6 +1011,9 @@ func TestPrivacyMapper(t *testing.T) { amounts := make([]uint64, numSamples) timestamps := make([]uint64, numSamples) + md := make(metadata.MD) + session.AddToGRPCMetadata(md, sessionID) + for i := 0; i < numSamples; i++ { interceptReq := &rpcperms.InterceptionRequest{ Type: rpcperms.TypeResponse, @@ -1016,6 +1024,7 @@ func TestPrivacyMapper(t *testing.T) { ProtoTypeName: string( proto.MessageName(msg), ), + CtxMetadataPairs: md, } mwReq, err := interceptReq.ToRPC(1, 2) diff --git a/firewall/request_info.go b/firewall/request_info.go index 10a524934..fd312b71f 100644 --- a/firewall/request_info.go +++ b/firewall/request_info.go @@ -4,7 +4,10 @@ import ( "fmt" "strings" + "github.com/lightninglabs/lightning-terminal/session" + "github.com/lightningnetwork/lnd/fn" "github.com/lightningnetwork/lnd/lnrpc" + "google.golang.org/grpc/metadata" "gopkg.in/macaroon.v2" ) @@ -25,6 +28,7 @@ const ( // RequestInfo stores the parsed representation of an incoming RPC middleware // request. type RequestInfo struct { + SessionID fn.Option[session.ID] MsgID uint64 RequestID uint64 MWRequestType string @@ -76,8 +80,22 @@ func NewInfoFromRequest(req *lnrpc.RPCMiddlewareRequest) (*RequestInfo, error) { return nil, fmt.Errorf("invalid request type: %T", t) } + md := make(metadata.MD) + for k, vs := range req.MetadataPairs { + for _, v := range vs.Values { + md.Append(k, v) + } + } + + sessionID, err := session.FromGRPCMetadata(md) + if err != nil { + return nil, fmt.Errorf("error extracting session ID "+ + "from request: %v", err) + } + ri.MsgID = req.MsgId ri.RequestID = req.RequestId + ri.SessionID = sessionID // If there is no macaroon in the request, then there is nothing left // to parse. diff --git a/firewall/request_logger.go b/firewall/request_logger.go index 3463dff2a..0a98e6458 100644 --- a/firewall/request_logger.go +++ b/firewall/request_logger.go @@ -194,6 +194,7 @@ func (r *RequestLogger) addNewAction(ctx context.Context, ri *RequestInfo, } actionReq := &firewalldb.AddActionReq{ + SessionID: ri.SessionID, MacaroonIdentifier: macaroonID, RPCMethod: ri.URI, } diff --git a/firewall/rule_enforcer.go b/firewall/rule_enforcer.go index 35f92c534..54d2b3a61 100644 --- a/firewall/rule_enforcer.go +++ b/firewall/rule_enforcer.go @@ -237,9 +237,11 @@ func (r *RuleEnforcer) Intercept(ctx context.Context, func (r *RuleEnforcer) handleRequest(ctx context.Context, ri *RequestInfo) (proto.Message, error) { - sessionID, err := session.IDFromMacaroon(ri.Macaroon) + sessionID, err := ri.SessionID.UnwrapOrErr( + fmt.Errorf("no session ID found in request info"), + ) if err != nil { - return nil, fmt.Errorf("could not extract ID from macaroon") + return nil, err } rules, err := r.collectEnforcers(ctx, ri, sessionID)