Skip to content

Commit b65e703

Browse files
authored
Merge pull request #989 from ellemouton/sql18Sessions10
[sql-18] sessions: tightly couple sessions & accounts
2 parents 64dba89 + bfb3c7b commit b65e703

File tree

10 files changed

+306
-25
lines changed

10 files changed

+306
-25
lines changed

accounts/interceptor.go

Lines changed: 63 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -5,11 +5,14 @@ import (
55
"encoding/hex"
66
"errors"
77
"fmt"
8+
"strings"
89

910
mid "github.com/lightninglabs/lightning-terminal/rpcmiddleware"
11+
"github.com/lightningnetwork/lnd/fn"
1012
"github.com/lightningnetwork/lnd/lnrpc"
1113
"github.com/lightningnetwork/lnd/macaroons"
1214
"google.golang.org/protobuf/proto"
15+
"gopkg.in/macaroon-bakery.v2/bakery/checkers"
1316
"gopkg.in/macaroon.v2"
1417
)
1518

@@ -23,6 +26,15 @@ const (
2326
accountMiddlewareName = "lit-account"
2427
)
2528

29+
var (
30+
// caveatPrefix is the prefix that is used for custom caveats that are
31+
// used by the account system. This prefix is used to identify the
32+
// custom caveat and extract the condition (the AccountID) from it.
33+
caveatPrefix = []byte(fmt.Sprintf(
34+
"%s %s ", macaroons.CondLndCustom, CondAccount,
35+
))
36+
)
37+
2638
// Name returns the name of the interceptor.
2739
func (s *InterceptorService) Name() string {
2840
return accountMiddlewareName
@@ -199,22 +211,64 @@ func parseRPCMessage(msg *lnrpc.RPCMessage) (proto.Message, error) {
199211
// accountFromMacaroon attempts to extract an account ID from the custom account
200212
// caveat in the macaroon.
201213
func accountFromMacaroon(mac *macaroon.Macaroon) (*AccountID, error) {
202-
// Extract the account caveat from the macaroon.
203-
macaroonAccount := macaroons.GetCustomCaveatCondition(mac, CondAccount)
204-
if macaroonAccount == "" {
205-
// There is no condition that locks the macaroon to an account,
206-
// so there is nothing to check.
214+
if mac == nil {
207215
return nil, nil
208216
}
209217

210-
// The macaroon is indeed locked to an account. Fetch the account and
211-
// validate its balance.
212-
accountIDBytes, err := hex.DecodeString(macaroonAccount)
218+
// Extract the account caveat from the macaroon.
219+
accountID, err := IDFromCaveats(mac.Caveats())
213220
if err != nil {
214221
return nil, err
215222
}
216223

224+
var id *AccountID
225+
accountID.WhenSome(func(aID AccountID) {
226+
id = &aID
227+
})
228+
229+
return id, nil
230+
}
231+
232+
// CaveatFromID creates a custom caveat that can be used to bind a macaroon to
233+
// a certain account.
234+
func CaveatFromID(id AccountID) macaroon.Caveat {
235+
condition := checkers.Condition(macaroons.CondLndCustom, fmt.Sprintf(
236+
"%s %x", CondAccount, id[:],
237+
))
238+
239+
return macaroon.Caveat{Id: []byte(condition)}
240+
}
241+
242+
// IDFromCaveats attempts to extract an AccountID from the given set of caveats
243+
// by looking for the custom caveat that binds a macaroon to a certain account.
244+
func IDFromCaveats(caveats []macaroon.Caveat) (fn.Option[AccountID], error) {
245+
var accountIDStr string
246+
for _, caveat := range caveats {
247+
// The caveat id has a format of
248+
// "lnd-custom [custom-caveat-name] [custom-caveat-condition]"
249+
// and we only want the condition part. If we match the prefix
250+
// part we return the condition that comes after the prefix.
251+
_, after, found := strings.Cut(
252+
string(caveat.Id), string(caveatPrefix),
253+
)
254+
if !found {
255+
continue
256+
}
257+
258+
accountIDStr = after
259+
}
260+
261+
if accountIDStr == "" {
262+
return fn.None[AccountID](), nil
263+
}
264+
217265
var accountID AccountID
266+
accountIDBytes, err := hex.DecodeString(accountIDStr)
267+
if err != nil {
268+
return fn.None[AccountID](), err
269+
}
270+
218271
copy(accountID[:], accountIDBytes)
219-
return &accountID, nil
272+
273+
return fn.Some(accountID), nil
220274
}

accounts/interceptor_test.go

Lines changed: 78 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,78 @@
1+
package accounts
2+
3+
import (
4+
"fmt"
5+
"testing"
6+
7+
"github.com/lightningnetwork/lnd/fn"
8+
"github.com/lightningnetwork/lnd/macaroons"
9+
"github.com/stretchr/testify/require"
10+
"gopkg.in/macaroon-bakery.v2/bakery/checkers"
11+
"gopkg.in/macaroon.v2"
12+
)
13+
14+
// TestAccountIDCaveatEmbedding tests that the account ID can be embedded in a
15+
// macaroon caveat and extracted from it.
16+
func TestAccountIDCaveatEmbedding(t *testing.T) {
17+
badCondition := checkers.Condition(macaroons.CondLndCustom, fmt.Sprintf(
18+
"%s %s", CondAccount, "invalid hex",
19+
))
20+
21+
tests := []struct {
22+
name string
23+
caveats []macaroon.Caveat
24+
expectedErr string
25+
expectedAcct fn.Option[AccountID]
26+
}{
27+
{
28+
name: "valid account ID, single caveat",
29+
caveats: []macaroon.Caveat{
30+
CaveatFromID(AccountID{1, 2, 3, 4, 5}),
31+
},
32+
expectedAcct: fn.Some(AccountID{1, 2, 3, 4, 5}),
33+
},
34+
{
35+
name: "valid account ID, single multiple caveats",
36+
caveats: []macaroon.Caveat{
37+
{Id: []byte("some other caveat")},
38+
CaveatFromID(AccountID{1, 2, 3, 4, 5}),
39+
{Id: []byte("another one")},
40+
},
41+
expectedAcct: fn.Some(AccountID{1, 2, 3, 4, 5}),
42+
},
43+
{
44+
name: "invalid account ID",
45+
caveats: []macaroon.Caveat{
46+
{Id: []byte(badCondition)},
47+
},
48+
expectedErr: "encoding/hex: invalid",
49+
},
50+
}
51+
52+
for _, test := range tests {
53+
t.Run(test.name, func(t *testing.T) {
54+
t.Parallel()
55+
56+
acct, err := IDFromCaveats(test.caveats)
57+
if test.expectedErr != "" {
58+
require.ErrorContains(t, err, test.expectedErr)
59+
60+
return
61+
}
62+
require.NoError(t, err)
63+
64+
if test.expectedAcct.IsNone() {
65+
require.True(t, acct.IsNone())
66+
67+
return
68+
}
69+
require.True(t, acct.IsSome())
70+
71+
test.expectedAcct.WhenSome(func(id AccountID) {
72+
acct.WhenSome(func(acct AccountID) {
73+
require.Equal(t, id, acct)
74+
})
75+
})
76+
})
77+
}
78+
}

itest/litd_accounts_test.go

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -199,6 +199,7 @@ func testAccountRestrictionsLNC(ctxm context.Context, t *harnessTest,
199199
AccountId: accountID,
200200
})
201201
require.NoError(t.t, err)
202+
require.Equal(t.t, accountID, sessResp.Session.AccountId)
202203

203204
// Try the LNC connection now.
204205
connectPhrase := strings.Split(

session/interface.go

Lines changed: 16 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -7,7 +7,9 @@ import (
77

88
"github.com/btcsuite/btcd/btcec/v2"
99
"github.com/lightninglabs/lightning-node-connect/mailbox"
10+
"github.com/lightninglabs/lightning-terminal/accounts"
1011
"github.com/lightninglabs/lightning-terminal/macaroons"
12+
"github.com/lightningnetwork/lnd/fn"
1113
"gopkg.in/macaroon-bakery.v2/bakery"
1214
"gopkg.in/macaroon.v2"
1315
)
@@ -117,6 +119,9 @@ type Session struct {
117119
// group of sessions. If this is the very first session in the group
118120
// then this will be the same as ID.
119121
GroupID ID
122+
123+
// AccountID is an optional account that the session has been linked to.
124+
AccountID fn.Option[accounts.AccountID]
120125
}
121126

122127
// buildSession creates a new session with the given user-defined parameters.
@@ -163,6 +168,7 @@ func buildSession(id ID, localPrivKey *btcec.PrivateKey, label string, typ Type,
163168
PrivacyFlags: opts.privacyFlags,
164169
GroupID: groupID,
165170
MacaroonRecipe: opts.macaroonRecipe,
171+
AccountID: opts.accountID,
166172
}
167173

168174
if len(opts.featureConfig) != 0 {
@@ -196,6 +202,9 @@ type sessionOptions struct {
196202
// macaroonRecipe holds the permissions and caveats that should be used
197203
// to bake the macaroon to be used with this session.
198204
macaroonRecipe *MacaroonRecipe
205+
206+
// accountID is an optional account that the session has been linked to.
207+
accountID fn.Option[accounts.AccountID]
199208
}
200209

201210
// defaultSessionOptions returns a new sessionOptions struct with default
@@ -258,6 +267,13 @@ func WithMacaroonRecipe(caveats []macaroon.Caveat, perms []bakery.Op) Option {
258267
}
259268
}
260269

270+
// WithAccount can be used to link the session to an account.
271+
func WithAccount(id accounts.AccountID) Option {
272+
return func(o *sessionOptions) {
273+
o.accountID = fn.Some(id)
274+
}
275+
}
276+
261277
// IDToGroupIndex defines an interface for the session ID to group ID index.
262278
type IDToGroupIndex interface {
263279
// GetGroupID will return the group ID for the given session ID.

session/kvdb_store.go

Lines changed: 18 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -12,6 +12,7 @@ import (
1212
"time"
1313

1414
"github.com/btcsuite/btcd/btcec/v2"
15+
"github.com/lightninglabs/lightning-terminal/accounts"
1516
"github.com/lightningnetwork/lnd/clock"
1617
"go.etcd.io/bbolt"
1718
)
@@ -82,13 +83,17 @@ type BoltStore struct {
8283
*bbolt.DB
8384

8485
clock clock.Clock
86+
87+
accounts accounts.Store
8588
}
8689

8790
// A compile-time check to ensure that BoltStore implements the Store interface.
8891
var _ Store = (*BoltStore)(nil)
8992

9093
// NewDB creates a new bolt database that can be found at the given directory.
91-
func NewDB(dir, fileName string, clock clock.Clock) (*BoltStore, error) {
94+
func NewDB(dir, fileName string, clock clock.Clock,
95+
store accounts.Store) (*BoltStore, error) {
96+
9297
firstInit := false
9398
path := filepath.Join(dir, fileName)
9499

@@ -112,8 +117,9 @@ func NewDB(dir, fileName string, clock clock.Clock) (*BoltStore, error) {
112117
}
113118

114119
return &BoltStore{
115-
DB: db,
116-
clock: clock,
120+
DB: db,
121+
clock: clock,
122+
accounts: store,
117123
}, nil
118124
}
119125

@@ -211,6 +217,15 @@ func (db *BoltStore) NewSession(ctx context.Context, label string, typ Type,
211217

212218
sessionKey := getSessionKey(session)
213219

220+
// If an account is being linked, we first need to check that
221+
// it exists.
222+
session.AccountID.WhenSome(func(account accounts.AccountID) {
223+
_, err = db.accounts.Account(ctx, account)
224+
})
225+
if err != nil {
226+
return err
227+
}
228+
214229
if len(sessionBucket.Get(sessionKey)) != 0 {
215230
return fmt.Errorf("session with local public key(%x) "+
216231
"already exists",

0 commit comments

Comments
 (0)