diff --git a/firewall/privacy_mapper.go b/firewall/privacy_mapper.go index fed4ba531..49aaf20f2 100644 --- a/firewall/privacy_mapper.go +++ b/firewall/privacy_mapper.go @@ -106,9 +106,11 @@ func (p *PrivacyMapper) Intercept(ctx context.Context, "interception request: %v", err) } - sessionID, err := session.IDFromMacaroon(ri.Macaroon) + sessionID, err := ri.SessionID.UnwrapOrErr( + fmt.Errorf("no session ID found in request info"), + ) if err != nil { - return nil, fmt.Errorf("could not extract ID from macaroon") + return nil, err } log.Tracef("PrivacyMapper: Intercepting %v", ri) diff --git a/firewall/privacy_mapper_test.go b/firewall/privacy_mapper_test.go index 9dcc814b2..61b24cd4c 100644 --- a/firewall/privacy_mapper_test.go +++ b/firewall/privacy_mapper_test.go @@ -13,6 +13,7 @@ import ( "github.com/lightningnetwork/lnd/lnrpc" "github.com/lightningnetwork/lnd/rpcperms" "github.com/stretchr/testify/require" + "google.golang.org/grpc/metadata" "google.golang.org/protobuf/proto" "gopkg.in/macaroon-bakery.v2/bakery" "gopkg.in/macaroon.v2" @@ -907,6 +908,9 @@ func TestPrivacyMapper(t *testing.T) { rawMsg, err := proto.Marshal(test.msg) require.NoError(t, err) + md := make(metadata.MD) + session.AddToGRPCMetadata(md, sessionID) + interceptReq := &rpcperms.InterceptionRequest{ Type: test.msgType, Macaroon: mac, @@ -916,6 +920,7 @@ func TestPrivacyMapper(t *testing.T) { ProtoTypeName: string( proto.MessageName(test.msg), ), + CtxMetadataPairs: md, } mwReq, err := interceptReq.ToRPC(1, 2) @@ -1006,6 +1011,9 @@ func TestPrivacyMapper(t *testing.T) { amounts := make([]uint64, numSamples) timestamps := make([]uint64, numSamples) + md := make(metadata.MD) + session.AddToGRPCMetadata(md, sessionID) + for i := 0; i < numSamples; i++ { interceptReq := &rpcperms.InterceptionRequest{ Type: rpcperms.TypeResponse, @@ -1016,6 +1024,7 @@ func TestPrivacyMapper(t *testing.T) { ProtoTypeName: string( proto.MessageName(msg), ), + CtxMetadataPairs: md, } mwReq, err := interceptReq.ToRPC(1, 2) diff --git a/firewall/request_info.go b/firewall/request_info.go index 10a524934..fd312b71f 100644 --- a/firewall/request_info.go +++ b/firewall/request_info.go @@ -4,7 +4,10 @@ import ( "fmt" "strings" + "github.com/lightninglabs/lightning-terminal/session" + "github.com/lightningnetwork/lnd/fn" "github.com/lightningnetwork/lnd/lnrpc" + "google.golang.org/grpc/metadata" "gopkg.in/macaroon.v2" ) @@ -25,6 +28,7 @@ const ( // RequestInfo stores the parsed representation of an incoming RPC middleware // request. type RequestInfo struct { + SessionID fn.Option[session.ID] MsgID uint64 RequestID uint64 MWRequestType string @@ -76,8 +80,22 @@ func NewInfoFromRequest(req *lnrpc.RPCMiddlewareRequest) (*RequestInfo, error) { return nil, fmt.Errorf("invalid request type: %T", t) } + md := make(metadata.MD) + for k, vs := range req.MetadataPairs { + for _, v := range vs.Values { + md.Append(k, v) + } + } + + sessionID, err := session.FromGRPCMetadata(md) + if err != nil { + return nil, fmt.Errorf("error extracting session ID "+ + "from request: %v", err) + } + ri.MsgID = req.MsgId ri.RequestID = req.RequestId + ri.SessionID = sessionID // If there is no macaroon in the request, then there is nothing left // to parse. diff --git a/firewall/request_logger.go b/firewall/request_logger.go index 3463dff2a..0a98e6458 100644 --- a/firewall/request_logger.go +++ b/firewall/request_logger.go @@ -194,6 +194,7 @@ func (r *RequestLogger) addNewAction(ctx context.Context, ri *RequestInfo, } actionReq := &firewalldb.AddActionReq{ + SessionID: ri.SessionID, MacaroonIdentifier: macaroonID, RPCMethod: ri.URI, } diff --git a/firewall/rule_enforcer.go b/firewall/rule_enforcer.go index 35f92c534..54d2b3a61 100644 --- a/firewall/rule_enforcer.go +++ b/firewall/rule_enforcer.go @@ -237,9 +237,11 @@ func (r *RuleEnforcer) Intercept(ctx context.Context, func (r *RuleEnforcer) handleRequest(ctx context.Context, ri *RequestInfo) (proto.Message, error) { - sessionID, err := session.IDFromMacaroon(ri.Macaroon) + sessionID, err := ri.SessionID.UnwrapOrErr( + fmt.Errorf("no session ID found in request info"), + ) if err != nil { - return nil, fmt.Errorf("could not extract ID from macaroon") + return nil, err } rules, err := r.collectEnforcers(ctx, ri, sessionID) diff --git a/session/context.go b/session/context.go new file mode 100644 index 000000000..bb3b89526 --- /dev/null +++ b/session/context.go @@ -0,0 +1,57 @@ +package session + +import ( + "encoding/hex" + "fmt" + + "github.com/lightningnetwork/lnd/fn" + "google.golang.org/grpc/metadata" +) + +// contextKey is a struct that is used as a key for storing session IDs +// in a context. Using this unexported type prevents collisions with other +// context keys that may be used in the same context. However, this only +// applies if the context is passed around in the same binary and not if the +// value is converted to grpc metadata and sent over the wire. In that case, +// we need to use a string key to avoid collisions with other metadata keys. +type contextKey struct { + name string +} + +// sessionIDCtxKey is the context key used to store the session ID in +// a context. The key is a string to avoid collisions with other context values +// that may also be included in grpc metadata which is why we add the 'lit' +// prefix. +var sessionIDCtxKey = contextKey{"lit_session_id"} + +// FromGRPCMetadata extracts the session ID from the given gRPC metadata kv +// pairs if one is found. +func FromGRPCMetadata(md metadata.MD) (fn.Option[ID], error) { + val := md.Get(sessionIDCtxKey.name) + if len(val) == 0 { + return fn.None[ID](), nil + } + + if len(val) != 1 { + return fn.None[ID](), fmt.Errorf("more than one session ID "+ + "found in gRPC metadata: %v", val) + } + + b, err := hex.DecodeString(val[0]) + if err != nil { + return fn.None[ID](), err + } + + sessID, err := IDFromBytes(b) + if err != nil { + return fn.None[ID](), err + } + + return fn.Some(sessID), nil +} + +// AddToGRPCMetadata adds the session ID to the given gRPC metadata kv pairs. +// The session ID is encoded as a hex string. +func AddToGRPCMetadata(md metadata.MD, id ID) { + md.Set(sessionIDCtxKey.name, hex.EncodeToString(id[:])) +} diff --git a/session/server.go b/session/server.go index 75c0e3edf..8dd75b44f 100644 --- a/session/server.go +++ b/session/server.go @@ -18,7 +18,8 @@ import ( type sessionID [33]byte -type GRPCServerCreator func(opts ...grpc.ServerOption) *grpc.Server +type GRPCServerCreator func(sessionID ID, + opts ...grpc.ServerOption) *grpc.Server type mailboxSession struct { server *grpc.Server @@ -70,7 +71,7 @@ func (m *mailboxSession) start(session *Session, } noiseConn := mailbox.NewNoiseGrpcConn(keys) - m.server = serverCreator(grpc.Creds(noiseConn)) + m.server = serverCreator(session.ID, grpc.Creds(noiseConn)) m.wg.Add(1) go m.run(mailboxServer) diff --git a/session_rpcserver.go b/session_rpcserver.go index 59ebfdb29..9baf4aa5c 100644 --- a/session_rpcserver.go +++ b/session_rpcserver.go @@ -26,6 +26,7 @@ import ( "github.com/lightningnetwork/lnd/fn" "github.com/lightningnetwork/lnd/macaroons" "google.golang.org/grpc" + "google.golang.org/grpc/metadata" "gopkg.in/macaroon-bakery.v2/bakery" "gopkg.in/macaroon-bakery.v2/bakery/checkers" "gopkg.in/macaroon.v2" @@ -77,10 +78,23 @@ func newSessionRPCServer(cfg *sessionRpcServerConfig) (*sessionRpcServer, // actual mailbox server that spins up the Terminal Connect server // interface. server := session.NewServer( - func(opts ...grpc.ServerOption) *grpc.Server { - allOpts := append(cfg.grpcOptions, opts...) + func(id session.ID, opts ...grpc.ServerOption) *grpc.Server { + // Add the session ID injector interceptors first so + // that the session ID is available in the context of + // all interceptors that come after. + allOpts := []grpc.ServerOption{ + addSessionIDToStreamCtx(id), + addSessionIDToUnaryCtx(id), + } + + allOpts = append(allOpts, cfg.grpcOptions...) + allOpts = append(allOpts, opts...) + + // Construct the gRPC server with the options. grpcServer := grpc.NewServer(allOpts...) + // Register various grpc servers with the LNC session + // server. cfg.registerGrpcServers(grpcServer) return grpcServer @@ -94,6 +108,62 @@ func newSessionRPCServer(cfg *sessionRpcServerConfig) (*sessionRpcServer, }, nil } +// wrappedServerStream is a wrapper around the grpc.ServerStream that allows us +// to set a custom context. This is needed since the stream handler function +// doesn't take a context as an argument, but rather has a Context method on the +// handler itself. So we use this custom wrapper to override this method. +type wrappedServerStream struct { + grpc.ServerStream + ctx context.Context +} + +// Context returns the context of the stream. +// +// NOTE: This implements the grpc.ServerStream Context method. +func (w *wrappedServerStream) Context() context.Context { + return w.ctx +} + +// addSessionIDToStreamCtx is a gRPC stream interceptor that adds the given +// session ID to the context of the stream. This allows us to access the +// session ID later on for any gRPC calls made through this stream. +func addSessionIDToStreamCtx(id session.ID) grpc.ServerOption { + return grpc.StreamInterceptor(func(srv any, ss grpc.ServerStream, + info *grpc.StreamServerInfo, + handler grpc.StreamHandler) error { + + md, _ := metadata.FromIncomingContext(ss.Context()) + mdCopy := md.Copy() + session.AddToGRPCMetadata(mdCopy, id) + + // Wrap the original stream with our custom context. + wrapped := &wrappedServerStream{ + ServerStream: ss, + ctx: metadata.NewIncomingContext( + ss.Context(), mdCopy, + ), + } + + return handler(srv, wrapped) + }) +} + +// addSessionIDToUnaryCtx is a gRPC unary interceptor that adds the given +// session ID to the context of the unary call. This allows us to access the +// session ID later on for any gRPC calls made through this context. +func addSessionIDToUnaryCtx(id session.ID) grpc.ServerOption { + return grpc.UnaryInterceptor(func(ctx context.Context, req any, + info *grpc.UnaryServerInfo, + handler grpc.UnaryHandler) (resp any, err error) { + + md, _ := metadata.FromIncomingContext(ctx) + mdCopy := md.Copy() + session.AddToGRPCMetadata(mdCopy, id) + + return handler(metadata.NewIncomingContext(ctx, mdCopy), req) + }) +} + // start all the components necessary for the sessionRpcServer to start serving // requests. This includes resuming all non-revoked sessions. func (s *sessionRpcServer) start(ctx context.Context) error {