Skip to content

Commit dfcdc47

Browse files
committed
firewalldb: check that priv map session exists
Here, we adjust the bbolt impl of the privacy mapper such that it first checks that the referenced session group does in fact exist. We update our unit tests accordingly. We do this because once we plug in the SQL impl, the link will be explicit and the tests will error if the session group does not exist.
1 parent 0ea0b7c commit dfcdc47

File tree

2 files changed

+94
-16
lines changed

2 files changed

+94
-16
lines changed

firewalldb/privacy_mapper_kvdb.go

Lines changed: 40 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -35,23 +35,43 @@ func (db *BoltDB) PrivacyDB(groupID session.ID) PrivacyMapDB {
3535
db: db.DB,
3636
wrapTx: func(tx *bbolt.Tx) PrivacyMapTx {
3737
return &privacyMapTx{
38-
boltTx: tx,
39-
groupID: groupID,
38+
sessions: db.sessionIDIndex,
39+
boltTx: tx,
40+
groupID: groupID,
4041
}
4142
},
4243
}
4344
}
4445

4546
// privacyMapTx is an implementation of PrivacyMapTx.
4647
type privacyMapTx struct {
47-
groupID session.ID
48-
boltTx *bbolt.Tx
48+
sessions SessionDB
49+
groupID session.ID
50+
boltTx *bbolt.Tx
51+
}
52+
53+
// asserGroupExists checks that the session group that the privacy mapper is
54+
// pointing to exists.
55+
//
56+
// NOTE: this is technically a DB transaction within another DB transaction.
57+
// But this is ok because:
58+
// 1. We only do this for the bbolt backends in which case the transactions are
59+
// for _separate_ DB files.
60+
// 2. The aim is to completely remove this implementation in future.
61+
func (p *privacyMapTx) assertGroupExists(ctx context.Context) error {
62+
_, err := p.sessions.GetSessionIDs(ctx, p.groupID)
63+
64+
return err
4965
}
5066

5167
// NewPair inserts a new real-pseudo pair into the db.
5268
//
5369
// NOTE: this is part of the PrivacyMapTx interface.
54-
func (p *privacyMapTx) NewPair(_ context.Context, real, pseudo string) error {
70+
func (p *privacyMapTx) NewPair(ctx context.Context, real, pseudo string) error {
71+
if err := p.assertGroupExists(ctx); err != nil {
72+
return err
73+
}
74+
5575
privacyBucket, err := getBucket(p.boltTx, privacyBucketKey)
5676
if err != nil {
5777
return err
@@ -97,9 +117,13 @@ func (p *privacyMapTx) NewPair(_ context.Context, real, pseudo string) error {
97117
// it does then the real value is returned, else an error is returned.
98118
//
99119
// NOTE: this is part of the PrivacyMapTx interface.
100-
func (p *privacyMapTx) PseudoToReal(_ context.Context, pseudo string) (string,
120+
func (p *privacyMapTx) PseudoToReal(ctx context.Context, pseudo string) (string,
101121
error) {
102122

123+
if err := p.assertGroupExists(ctx); err != nil {
124+
return "", err
125+
}
126+
103127
privacyBucket, err := getBucket(p.boltTx, privacyBucketKey)
104128
if err != nil {
105129
return "", err
@@ -127,9 +151,13 @@ func (p *privacyMapTx) PseudoToReal(_ context.Context, pseudo string) (string,
127151
// it does then the pseudo value is returned, else an error is returned.
128152
//
129153
// NOTE: this is part of the PrivacyMapTx interface.
130-
func (p *privacyMapTx) RealToPseudo(_ context.Context, real string) (string,
154+
func (p *privacyMapTx) RealToPseudo(ctx context.Context, real string) (string,
131155
error) {
132156

157+
if err := p.assertGroupExists(ctx); err != nil {
158+
return "", err
159+
}
160+
133161
privacyBucket, err := getBucket(p.boltTx, privacyBucketKey)
134162
if err != nil {
135163
return "", err
@@ -156,9 +184,13 @@ func (p *privacyMapTx) RealToPseudo(_ context.Context, real string) (string,
156184
// FetchAllPairs loads and returns the real-to-pseudo pairs.
157185
//
158186
// NOTE: this is part of the PrivacyMapTx interface.
159-
func (p *privacyMapTx) FetchAllPairs(_ context.Context) (*PrivacyMapPairs,
187+
func (p *privacyMapTx) FetchAllPairs(ctx context.Context) (*PrivacyMapPairs,
160188
error) {
161189

190+
if err := p.assertGroupExists(ctx); err != nil {
191+
return nil, err
192+
}
193+
162194
privacyBucket, err := getBucket(p.boltTx, privacyBucketKey)
163195
if err != nil {
164196
return nil, err

firewalldb/privacy_mapper_test.go

Lines changed: 54 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -4,7 +4,10 @@ import (
44
"context"
55
"fmt"
66
"testing"
7+
"time"
78

9+
"github.com/lightninglabs/lightning-terminal/session"
10+
"github.com/lightningnetwork/lnd/clock"
811
"github.com/stretchr/testify/require"
912
)
1013

@@ -13,14 +16,42 @@ func TestPrivacyMapStorage(t *testing.T) {
1316
t.Parallel()
1417
ctx := context.Background()
1518

16-
tmpDir := t.TempDir()
17-
db, err := NewBoltDB(tmpDir, "test.db", nil)
19+
sessions := session.NewTestDB(t, clock.NewDefaultClock())
20+
db, err := NewBoltDB(t.TempDir(), "test.db", sessions)
1821
require.NoError(t, err)
1922
t.Cleanup(func() {
2023
_ = db.Close()
2124
})
2225

23-
pdb1 := db.PrivacyDB([4]byte{1, 1, 1, 1})
26+
// First up, let's test that the correct error is returned if an
27+
// attempt is made to write to a privacy map that is not linked to
28+
// an existing session group.
29+
pdb := db.PrivacyDB(session.ID{1, 2, 3, 4})
30+
err = pdb.Update(ctx,
31+
func(ctx context.Context, tx PrivacyMapTx) error {
32+
_, err := tx.RealToPseudo(ctx, "real")
33+
require.ErrorIs(t, err, session.ErrUnknownGroup)
34+
35+
_, err = tx.PseudoToReal(ctx, "pseudo")
36+
require.ErrorIs(t, err, session.ErrUnknownGroup)
37+
38+
err = tx.NewPair(ctx, "real", "pseudo")
39+
require.ErrorIs(t, err, session.ErrUnknownGroup)
40+
41+
_, err = tx.FetchAllPairs(ctx)
42+
require.ErrorIs(t, err, session.ErrUnknownGroup)
43+
44+
return nil
45+
},
46+
)
47+
require.NoError(t, err)
48+
49+
sess, err := sessions.NewSession(
50+
ctx, "test", session.TypeAutopilot, time.Unix(1000, 0), "",
51+
)
52+
require.NoError(t, err)
53+
54+
pdb1 := db.PrivacyDB(sess.GroupID)
2455

2556
_ = pdb1.Update(ctx, func(ctx context.Context, tx PrivacyMapTx) error {
2657
_, err = tx.RealToPseudo(ctx, "real")
@@ -50,7 +81,12 @@ func TestPrivacyMapStorage(t *testing.T) {
5081
return nil
5182
})
5283

53-
pdb2 := db.PrivacyDB([4]byte{2, 2, 2, 2})
84+
sess2, err := sessions.NewSession(
85+
ctx, "test", session.TypeAutopilot, time.Unix(1000, 0), "",
86+
)
87+
require.NoError(t, err)
88+
89+
pdb2 := db.PrivacyDB(sess2.GroupID)
5490

5591
_ = pdb2.Update(ctx, func(ctx context.Context, tx PrivacyMapTx) error {
5692
_, err = tx.RealToPseudo(ctx, "real")
@@ -80,7 +116,12 @@ func TestPrivacyMapStorage(t *testing.T) {
80116
return nil
81117
})
82118

83-
pdb3 := db.PrivacyDB([4]byte{3, 3, 3, 3})
119+
sess3, err := sessions.NewSession(
120+
ctx, "test", session.TypeAutopilot, time.Unix(1000, 0), "",
121+
)
122+
require.NoError(t, err)
123+
124+
pdb3 := db.PrivacyDB(sess3.GroupID)
84125

85126
_ = pdb3.Update(ctx, func(ctx context.Context, tx PrivacyMapTx) error {
86127
// Check that calling FetchAllPairs returns an empty map if
@@ -185,14 +226,19 @@ func TestPrivacyMapTxs(t *testing.T) {
185226
t.Parallel()
186227
ctx := context.Background()
187228

188-
tmpDir := t.TempDir()
189-
db, err := NewBoltDB(tmpDir, "test.db", nil)
229+
sessions := session.NewTestDB(t, clock.NewDefaultClock())
230+
db, err := NewBoltDB(t.TempDir(), "test.db", sessions)
190231
require.NoError(t, err)
191232
t.Cleanup(func() {
192233
_ = db.Close()
193234
})
194235

195-
pdb1 := db.PrivacyDB([4]byte{1, 1, 1, 1})
236+
sess, err := sessions.NewSession(
237+
ctx, "test", session.TypeAutopilot, time.Unix(1000, 0), "",
238+
)
239+
require.NoError(t, err)
240+
241+
pdb1 := db.PrivacyDB(sess.GroupID)
196242

197243
// Test that if an action fails midway through the transaction, then
198244
// it is rolled back.

0 commit comments

Comments
 (0)