diff --git a/db/migrations.go b/db/migrations.go index 1b4e7a6c0..f52ec942a 100644 --- a/db/migrations.go +++ b/db/migrations.go @@ -22,7 +22,7 @@ const ( // daemon. // // NOTE: This MUST be updated when a new migration is added. - LatestMigrationVersion = 3 + LatestMigrationVersion = 4 ) // MigrationTarget is a functional option that can be passed to applyMigrations diff --git a/db/sqlc/migrations/000004_privacy_pairs.down.sql b/db/sqlc/migrations/000004_privacy_pairs.down.sql new file mode 100644 index 000000000..cb1d73ae4 --- /dev/null +++ b/db/sqlc/migrations/000004_privacy_pairs.down.sql @@ -0,0 +1,4 @@ +DROP INDEX IF EXISTS privacy_pairs_group_id_idx; +DROP INDEX IF EXISTS privacy_pairs_unique_real; +DROP INDEX IF EXISTS privacy_pairs_unique_pseudo; +DROP TABLE IF EXISTS privacy_pairs; diff --git a/db/sqlc/migrations/000004_privacy_pairs.up.sql b/db/sqlc/migrations/000004_privacy_pairs.up.sql new file mode 100644 index 000000000..c7c2e44f5 --- /dev/null +++ b/db/sqlc/migrations/000004_privacy_pairs.up.sql @@ -0,0 +1,23 @@ +-- privacy_pairs stores the privacy map pairs for a given session group. +CREATE TABLE IF NOT EXISTS privacy_pairs ( + -- The group ID of the session that this privacy pair is associated + -- with. + group_id BIGINT NOT NULL REFERENCES sessions(id) ON DELETE CASCADE, + + -- The real value of the privacy pair. + real_val TEXT NOT NULL, + + -- The pseudo value of the privacy pair. + pseudo_val TEXT NOT NULL +); + +-- There should be no duplicate real values for a given group ID. +CREATE UNIQUE INDEX privacy_pairs_unique_real ON privacy_pairs ( + group_id, real_val +); + +-- There should be no duplicate pseudo values for a given group ID. +CREATE UNIQUE INDEX privacy_pairs_unique_pseudo ON privacy_pairs ( + group_id, pseudo_val +); + diff --git a/db/sqlc/models.go b/db/sqlc/models.go index f879ce998..ea9242bed 100644 --- a/db/sqlc/models.go +++ b/db/sqlc/models.go @@ -52,6 +52,12 @@ type Kvstore struct { Value []byte } +type PrivacyPair struct { + GroupID int64 + RealVal string + PseudoVal string +} + type Rule struct { ID int64 Name string diff --git a/db/sqlc/privacy_paris.sql.go b/db/sqlc/privacy_paris.sql.go new file mode 100644 index 000000000..5cc924a7f --- /dev/null +++ b/db/sqlc/privacy_paris.sql.go @@ -0,0 +1,96 @@ +// Code generated by sqlc. DO NOT EDIT. +// versions: +// sqlc v1.25.0 +// source: privacy_paris.sql + +package sqlc + +import ( + "context" +) + +const getAllPrivacyPairs = `-- name: GetAllPrivacyPairs :many +SELECT real_val, pseudo_val +FROM privacy_pairs +WHERE group_id = $1 +` + +type GetAllPrivacyPairsRow struct { + RealVal string + PseudoVal string +} + +func (q *Queries) GetAllPrivacyPairs(ctx context.Context, groupID int64) ([]GetAllPrivacyPairsRow, error) { + rows, err := q.db.QueryContext(ctx, getAllPrivacyPairs, groupID) + if err != nil { + return nil, err + } + defer rows.Close() + var items []GetAllPrivacyPairsRow + for rows.Next() { + var i GetAllPrivacyPairsRow + if err := rows.Scan(&i.RealVal, &i.PseudoVal); err != nil { + return nil, err + } + items = append(items, i) + } + if err := rows.Close(); err != nil { + return nil, err + } + if err := rows.Err(); err != nil { + return nil, err + } + return items, nil +} + +const getPseudoForReal = `-- name: GetPseudoForReal :one +SELECT pseudo_val +FROM privacy_pairs +WHERE group_id = $1 AND real_val = $2 +` + +type GetPseudoForRealParams struct { + GroupID int64 + RealVal string +} + +func (q *Queries) GetPseudoForReal(ctx context.Context, arg GetPseudoForRealParams) (string, error) { + row := q.db.QueryRowContext(ctx, getPseudoForReal, arg.GroupID, arg.RealVal) + var pseudo_val string + err := row.Scan(&pseudo_val) + return pseudo_val, err +} + +const getRealForPseudo = `-- name: GetRealForPseudo :one +SELECT real_val +FROM privacy_pairs +WHERE group_id = $1 AND pseudo_val = $2 +` + +type GetRealForPseudoParams struct { + GroupID int64 + PseudoVal string +} + +func (q *Queries) GetRealForPseudo(ctx context.Context, arg GetRealForPseudoParams) (string, error) { + row := q.db.QueryRowContext(ctx, getRealForPseudo, arg.GroupID, arg.PseudoVal) + var real_val string + err := row.Scan(&real_val) + return real_val, err +} + +const insertPrivacyPair = `-- name: InsertPrivacyPair :exec +INSERT INTO privacy_pairs (group_id, real_val, pseudo_val) +VALUES ($1, $2, $3) +` + +type InsertPrivacyPairParams struct { + GroupID int64 + RealVal string + PseudoVal string +} + +func (q *Queries) InsertPrivacyPair(ctx context.Context, arg InsertPrivacyPairParams) error { + _, err := q.db.ExecContext(ctx, insertPrivacyPair, arg.GroupID, arg.RealVal, arg.PseudoVal) + return err +} diff --git a/db/sqlc/querier.go b/db/sqlc/querier.go index 7564856c7..a0a9d122d 100644 --- a/db/sqlc/querier.go +++ b/db/sqlc/querier.go @@ -25,11 +25,14 @@ type Querier interface { GetAccountInvoice(ctx context.Context, arg GetAccountInvoiceParams) (AccountInvoice, error) GetAccountPayment(ctx context.Context, arg GetAccountPaymentParams) (AccountPayment, error) GetAliasBySessionID(ctx context.Context, id int64) ([]byte, error) + GetAllPrivacyPairs(ctx context.Context, groupID int64) ([]GetAllPrivacyPairsRow, error) GetFeatureID(ctx context.Context, name string) (int64, error) GetFeatureKVStoreRecord(ctx context.Context, arg GetFeatureKVStoreRecordParams) ([]byte, error) GetGlobalKVStoreRecord(ctx context.Context, arg GetGlobalKVStoreRecordParams) ([]byte, error) GetOrInsertFeatureID(ctx context.Context, name string) (int64, error) GetOrInsertRuleID(ctx context.Context, name string) (int64, error) + GetPseudoForReal(ctx context.Context, arg GetPseudoForRealParams) (string, error) + GetRealForPseudo(ctx context.Context, arg GetRealForPseudoParams) (string, error) GetRuleID(ctx context.Context, name string) (int64, error) GetSessionAliasesInGroup(ctx context.Context, groupID sql.NullInt64) ([][]byte, error) GetSessionByAlias(ctx context.Context, alias []byte) (Session, error) @@ -44,6 +47,7 @@ type Querier interface { GetSessionsInGroup(ctx context.Context, groupID sql.NullInt64) ([]Session, error) InsertAccount(ctx context.Context, arg InsertAccountParams) (int64, error) InsertKVStoreRecord(ctx context.Context, arg InsertKVStoreRecordParams) error + InsertPrivacyPair(ctx context.Context, arg InsertPrivacyPairParams) error InsertSession(ctx context.Context, arg InsertSessionParams) (int64, error) InsertSessionFeatureConfig(ctx context.Context, arg InsertSessionFeatureConfigParams) error InsertSessionMacaroonCaveat(ctx context.Context, arg InsertSessionMacaroonCaveatParams) error diff --git a/db/sqlc/queries/privacy_paris.sql b/db/sqlc/queries/privacy_paris.sql new file mode 100644 index 000000000..28e103372 --- /dev/null +++ b/db/sqlc/queries/privacy_paris.sql @@ -0,0 +1,18 @@ +-- name: InsertPrivacyPair :exec +INSERT INTO privacy_pairs (group_id, real_val, pseudo_val) +VALUES ($1, $2, $3); + +-- name: GetRealForPseudo :one +SELECT real_val +FROM privacy_pairs +WHERE group_id = $1 AND pseudo_val = $2; + +-- name: GetPseudoForReal :one +SELECT pseudo_val +FROM privacy_pairs +WHERE group_id = $1 AND real_val = $2; + +-- name: GetAllPrivacyPairs :many +SELECT real_val, pseudo_val +FROM privacy_pairs +WHERE group_id = $1; diff --git a/firewall/privacy_mapper.go b/firewall/privacy_mapper.go index cbd8a8da4..fed4ba531 100644 --- a/firewall/privacy_mapper.go +++ b/firewall/privacy_mapper.go @@ -60,19 +60,19 @@ var _ mid.RequestInterceptor = (*PrivacyMapper)(nil) // PrivacyMapper is a RequestInterceptor that maps any pseudo names in certain // requests to their real values and vice versa for responses. type PrivacyMapper struct { - newDB firewalldb.NewPrivacyMapDB + db firewalldb.PrivacyMapper randIntn func(int) (int, error) sessionDB firewalldb.SessionDB } // NewPrivacyMapper returns a new instance of PrivacyMapper. The randIntn // function is used to draw randomness for request field obfuscation. -func NewPrivacyMapper(newDB firewalldb.NewPrivacyMapDB, +func NewPrivacyMapper(newDB firewalldb.PrivacyMapper, randIntn func(int) (int, error), sessionDB firewalldb.SessionDB) *PrivacyMapper { return &PrivacyMapper{ - newDB: newDB, + db: newDB, randIntn: randIntn, sessionDB: sessionDB, } @@ -195,7 +195,7 @@ func (p *PrivacyMapper) checkAndReplaceIncomingRequest(ctx context.Context, return nil, err } - db := p.newDB(session.GroupID) + db := p.db.PrivacyDB(session.GroupID) // If we don't have a handler for the URI, we don't allow the request // to go through. @@ -225,7 +225,7 @@ func (p *PrivacyMapper) replaceOutgoingResponse(ctx context.Context, uri string, return nil, err } - db := p.newDB(session.GroupID) + db := p.db.PrivacyDB(session.GroupID) // If we don't have a handler for the URI, we don't allow the response // to go to avoid accidental leaks. diff --git a/firewall/privacy_mapper_test.go b/firewall/privacy_mapper_test.go index 1998d1280..9dcc814b2 100644 --- a/firewall/privacy_mapper_test.go +++ b/firewall/privacy_mapper_test.go @@ -902,7 +902,7 @@ func TestPrivacyMapper(t *testing.T) { // randIntn is used for deterministic testing. randIntn := func(n int) (int, error) { return 100, nil } - p := NewPrivacyMapper(db.NewSessionDB, randIntn, pd) + p := NewPrivacyMapper(db, randIntn, pd) rawMsg, err := proto.Marshal(test.msg) require.NoError(t, err) @@ -978,7 +978,7 @@ func TestPrivacyMapper(t *testing.T) { rawMsg, err := proto.Marshal(msg) require.NoError(t, err) - p := NewPrivacyMapper(db.NewSessionDB, CryptoRandIntn, pd) + p := NewPrivacyMapper(db, CryptoRandIntn, pd) require.NoError(t, err) // We test the independent outgoing amount (incoming amount @@ -1071,7 +1071,7 @@ func newMockDB(t *testing.T, preloadRealToPseudo map[string]string, sessID session.ID) mockDB { db := mockDB{privDB: make(map[string]*mockPrivacyMapDB)} - sessDB := db.NewSessionDB(sessID) + sessDB := db.PrivacyDB(sessID) _ = sessDB.Update(context.Background(), func(ctx context.Context, tx firewalldb.PrivacyMapTx) error { @@ -1085,14 +1085,14 @@ func newMockDB(t *testing.T, preloadRealToPseudo map[string]string, return db } -func (m mockDB) NewSessionDB(sessionID session.ID) firewalldb.PrivacyMapDB { - db, ok := m.privDB[string(sessionID[:])] +func (m mockDB) PrivacyDB(groupID session.ID) firewalldb.PrivacyMapDB { + db, ok := m.privDB[string(groupID[:])] if ok { return db } newDB := newMockPrivacyMapDB() - m.privDB[string(sessionID[:])] = newDB + m.privDB[string(groupID[:])] = newDB return newDB } diff --git a/firewall/rule_enforcer.go b/firewall/rule_enforcer.go index 7914965ed..472143f05 100644 --- a/firewall/rule_enforcer.go +++ b/firewall/rule_enforcer.go @@ -33,7 +33,7 @@ type RuleEnforcer struct { actionsDB firewalldb.ActionReadDBGetter sessionDB firewalldb.SessionDB markActionErrored func(reqID uint64, reason string) error - newPrivMap firewalldb.NewPrivacyMapDB + privMapDB firewalldb.PrivacyMapper permsMgr *perms.Manager getFeaturePerms featurePerms @@ -64,7 +64,7 @@ func NewRuleEnforcer(ruleDB firewalldb.RulesDB, lndClient lndclient.LightningClient, lndConnID string, ruleMgrs rules.ManagerSet, markActionErrored func(reqID uint64, reason string) error, - privMap firewalldb.NewPrivacyMapDB) *RuleEnforcer { + privMap firewalldb.PrivacyMapper) *RuleEnforcer { return &RuleEnforcer{ ruleDB: ruleDB, @@ -76,7 +76,7 @@ func NewRuleEnforcer(ruleDB firewalldb.RulesDB, lndClient: lndClient, ruleMgrs: ruleMgrs, markActionErrored: markActionErrored, - newPrivMap: privMap, + privMapDB: privMap, sessionDB: sessionIDIndex, lndConnID: lndConnID, } @@ -392,7 +392,7 @@ func (r *RuleEnforcer) initRule(ctx context.Context, reqID uint64, name string, } if privacy { - privMap := r.newPrivMap(session.GroupID) + privMap := r.privMapDB.PrivacyDB(session.GroupID) ruleValues, err = ruleValues.PseudoToReal( ctx, privMap, session.PrivacyFlags, diff --git a/firewalldb/db.go b/firewalldb/db.go index fe18cbb70..8b913b69d 100644 --- a/firewalldb/db.go +++ b/firewalldb/db.go @@ -14,21 +14,28 @@ var ( ErrNoSuchKeyFound = fmt.Errorf("no such key found") ) +// firewallDBs is an interface that groups the RulesDB and PrivacyMapper +// interfaces. +type firewallDBs interface { + RulesDB + PrivacyMapper +} + // DB manages the firewall rules database. type DB struct { started sync.Once stopped sync.Once - RulesDB + firewallDBs cancel fn.Option[context.CancelFunc] } // NewDB creates a new firewall database. For now, it only contains the -// underlying rules' database. -func NewDB(kvdb RulesDB) *DB { +// underlying rules' and privacy mapper databases. +func NewDB(dbs firewallDBs) *DB { return &DB{ - RulesDB: kvdb, + firewallDBs: dbs, } } diff --git a/firewalldb/interface.go b/firewalldb/interface.go index 3a0c4ddca..401b3b8d6 100644 --- a/firewalldb/interface.go +++ b/firewalldb/interface.go @@ -92,3 +92,11 @@ type RulesDB interface { // DeleteTempKVStores deletes all temporary kv stores. DeleteTempKVStores(ctx context.Context) error } + +// PrivacyMapper is an interface that abstracts access to the privacy mapper +// database. +type PrivacyMapper interface { + // PrivacyDB constructs a PrivacyMapDB that will be indexed under the + // given group ID key. + PrivacyDB(groupID session.ID) PrivacyMapDB +} diff --git a/firewalldb/privacy_mapper.go b/firewalldb/privacy_mapper.go index 8d3642c3a..cde91bfe3 100644 --- a/firewalldb/privacy_mapper.go +++ b/firewalldb/privacy_mapper.go @@ -11,8 +11,6 @@ import ( "strconv" "strings" "sync" - - "github.com/lightninglabs/lightning-terminal/session" ) var ( @@ -29,10 +27,6 @@ var ( "value already exists") ) -// NewPrivacyMapDB is a function type that takes a group ID and uses it to -// construct a new PrivacyMapDB. -type NewPrivacyMapDB func(groupID session.ID) PrivacyMapDB - // PrivacyMapDB provides an Update and View method that will allow the caller // to perform atomic read and write transactions defined by PrivacyMapTx on the // underlying DB. diff --git a/firewalldb/privacy_mapper_kvdb.go b/firewalldb/privacy_mapper_kvdb.go index 793a8342b..ec745de26 100644 --- a/firewalldb/privacy_mapper_kvdb.go +++ b/firewalldb/privacy_mapper_kvdb.go @@ -30,6 +30,8 @@ var ( // PrivacyDB constructs a PrivacyMapDB that will be indexed under the given // group ID key. +// +// NOTE: this is part of the PrivacyMapper interface. func (db *BoltDB) PrivacyDB(groupID session.ID) PrivacyMapDB { return &kvdbExecutor[PrivacyMapTx]{ db: db.DB, diff --git a/firewalldb/privacy_mapper_sql.go b/firewalldb/privacy_mapper_sql.go new file mode 100644 index 000000000..8a4863a6c --- /dev/null +++ b/firewalldb/privacy_mapper_sql.go @@ -0,0 +1,179 @@ +package firewalldb + +import ( + "context" + "database/sql" + "errors" + + "github.com/lightninglabs/lightning-terminal/db/sqlc" + "github.com/lightninglabs/lightning-terminal/session" +) + +// SQLPrivacyPairQueries is a subset of the sqlc.Queries interface that can be +// used to interact with the privacy map table. +// +//nolint:lll +type SQLPrivacyPairQueries interface { + SQLSessionQueries + + InsertPrivacyPair(ctx context.Context, arg sqlc.InsertPrivacyPairParams) error + GetAllPrivacyPairs(ctx context.Context, groupID int64) ([]sqlc.GetAllPrivacyPairsRow, error) + GetPseudoForReal(ctx context.Context, arg sqlc.GetPseudoForRealParams) (string, error) + GetRealForPseudo(ctx context.Context, arg sqlc.GetRealForPseudoParams) (string, error) +} + +// PrivacyDB constructs a PrivacyMapDB that will be indexed under the given +// group ID key. +// +// NOTE: this is part of the PrivacyMapper interface. +func (s *SQLDB) PrivacyDB(groupID session.ID) PrivacyMapDB { + return &sqlExecutor[PrivacyMapTx]{ + db: s.db, + wrapTx: func(queries SQLQueries) PrivacyMapTx { + return &privacyMapSQLTx{ + queries: queries, + groupID: groupID, + } + }, + } +} + +// privacyMapSQLTx is an implementation of PrivacyMapTx. +type privacyMapSQLTx struct { + queries SQLQueries + groupID session.ID +} + +// NewPair inserts a new real-pseudo pair into the db. +// +// NOTE: this is part of the PrivacyMapTx interface. +func (p *privacyMapSQLTx) NewPair(ctx context.Context, real, + pseudo string) error { + + groupID, err := p.getGroupID(ctx) + if err != nil { + return err + } + + _, err = p.queries.GetPseudoForReal(ctx, sqlc.GetPseudoForRealParams{ + GroupID: groupID, + RealVal: real, + }) + if err == nil { + return ErrDuplicateRealValue + } else if !errors.Is(err, sql.ErrNoRows) { + return err + } + + _, err = p.queries.GetRealForPseudo(ctx, sqlc.GetRealForPseudoParams{ + GroupID: groupID, + PseudoVal: pseudo, + }) + if err == nil { + return ErrDuplicatePseudoValue + } else if !errors.Is(err, sql.ErrNoRows) { + return err + } + + return p.queries.InsertPrivacyPair(ctx, sqlc.InsertPrivacyPairParams{ + GroupID: groupID, + RealVal: real, + PseudoVal: pseudo, + }) +} + +// PseudoToReal will check the db to see if the given pseudo key exists. If +// it does then the real value is returned, else an error is returned. +// +// NOTE: this is part of the PrivacyMapTx interface. +func (p *privacyMapSQLTx) PseudoToReal(ctx context.Context, + pseudo string) (string, error) { + + groupID, err := p.getGroupID(ctx) + if err != nil { + return "", err + } + + realVal, err := p.queries.GetRealForPseudo( + ctx, sqlc.GetRealForPseudoParams{ + GroupID: groupID, + PseudoVal: pseudo, + }, + ) + if errors.Is(err, sql.ErrNoRows) { + return "", ErrNoSuchKeyFound + } else if err != nil { + return "", err + } + + return realVal, nil +} + +// RealToPseudo will check the db to see if the given real key exists. If it +// does then the pseudo value is returned, else an error is returned. +// +// NOTE: this is part of the PrivacyMapTx interface. +func (p *privacyMapSQLTx) RealToPseudo(ctx context.Context, + real string) (string, error) { + + groupID, err := p.getGroupID(ctx) + if err != nil { + return "", err + } + + pseudo, err := p.queries.GetPseudoForReal( + ctx, sqlc.GetPseudoForRealParams{ + GroupID: groupID, + RealVal: real, + }, + ) + if errors.Is(err, sql.ErrNoRows) { + return "", ErrNoSuchKeyFound + } else if err != nil { + return "", err + } + + return pseudo, nil +} + +// FetchAllPairs loads and returns the real-to-pseudo pairs. +// +// NOTE: this is part of the PrivacyMapTx interface. +func (p *privacyMapSQLTx) FetchAllPairs(ctx context.Context) (*PrivacyMapPairs, + error) { + + groupID, err := p.getGroupID(ctx) + if err != nil { + return nil, err + } + + pairs, err := p.queries.GetAllPrivacyPairs(ctx, groupID) + if err != nil { + return nil, err + } + + privacyPairs := make(map[string]string, len(pairs)) + for _, pair := range pairs { + privacyPairs[pair.RealVal] = pair.PseudoVal + } + + return NewPrivacyMapPairs(privacyPairs), nil +} + +// getGroupID is a helper that can be used to get the DB ID for a session group +// given the group ID alias. If such a group is not found, then +// session.ErrUnknownGroup is returned. +func (p *privacyMapSQLTx) getGroupID(ctx context.Context) (int64, error) { + groupID, err := p.queries.GetSessionIDByAlias(ctx, p.groupID[:]) + if errors.Is(err, sql.ErrNoRows) { + return 0, session.ErrUnknownGroup + } else if err != nil { + return 0, err + } + + return groupID, nil +} + +// A compile-time constraint to ensure that the privacyMapSQLTx type implements +// the PrivacyMapTx interface. +var _ PrivacyMapTx = (*privacyMapSQLTx)(nil) diff --git a/firewalldb/privacy_mapper_test.go b/firewalldb/privacy_mapper_test.go index d4f235485..fbdf880ff 100644 --- a/firewalldb/privacy_mapper_test.go +++ b/firewalldb/privacy_mapper_test.go @@ -17,17 +17,13 @@ func TestPrivacyMapStorage(t *testing.T) { ctx := context.Background() sessions := session.NewTestDB(t, clock.NewDefaultClock()) - db, err := NewBoltDB(t.TempDir(), "test.db", sessions) - require.NoError(t, err) - t.Cleanup(func() { - _ = db.Close() - }) + db := NewTestDBWithSessions(t, sessions) // First up, let's test that the correct error is returned if an // attempt is made to write to a privacy map that is not linked to // an existing session group. pdb := db.PrivacyDB(session.ID{1, 2, 3, 4}) - err = pdb.Update(ctx, + err := pdb.Update(ctx, func(ctx context.Context, tx PrivacyMapTx) error { _, err := tx.RealToPseudo(ctx, "real") require.ErrorIs(t, err, session.ErrUnknownGroup) @@ -54,7 +50,7 @@ func TestPrivacyMapStorage(t *testing.T) { pdb1 := db.PrivacyDB(sess.GroupID) _ = pdb1.Update(ctx, func(ctx context.Context, tx PrivacyMapTx) error { - _, err = tx.RealToPseudo(ctx, "real") + _, err := tx.RealToPseudo(ctx, "real") require.ErrorIs(t, err, ErrNoSuchKeyFound) _, err = tx.PseudoToReal(ctx, "pseudo") @@ -89,7 +85,7 @@ func TestPrivacyMapStorage(t *testing.T) { pdb2 := db.PrivacyDB(sess2.GroupID) _ = pdb2.Update(ctx, func(ctx context.Context, tx PrivacyMapTx) error { - _, err = tx.RealToPseudo(ctx, "real") + _, err := tx.RealToPseudo(ctx, "real") require.ErrorIs(t, err, ErrNoSuchKeyFound) _, err = tx.PseudoToReal(ctx, "pseudo") @@ -227,11 +223,7 @@ func TestPrivacyMapTxs(t *testing.T) { ctx := context.Background() sessions := session.NewTestDB(t, clock.NewDefaultClock()) - db, err := NewBoltDB(t.TempDir(), "test.db", sessions) - require.NoError(t, err) - t.Cleanup(func() { - _ = db.Close() - }) + db := NewTestDBWithSessions(t, sessions) sess, err := sessions.NewSession( ctx, "test", session.TypeAutopilot, time.Unix(1000, 0), "", diff --git a/firewalldb/sql_store.go b/firewalldb/sql_store.go index 01f1d5b9c..acca60ce1 100644 --- a/firewalldb/sql_store.go +++ b/firewalldb/sql_store.go @@ -11,6 +11,7 @@ import ( // interact with various firewalldb tables. type SQLQueries interface { SQLKVStoreQueries + SQLPrivacyPairQueries } // BatchedSQLQueries is a version of the SQLQueries that's capable of batched diff --git a/session_rpcserver.go b/session_rpcserver.go index b20700948..ffb8e6b49 100644 --- a/session_rpcserver.go +++ b/session_rpcserver.go @@ -66,7 +66,7 @@ type sessionRpcServerConfig struct { actionsDB *firewalldb.BoltDB autopilot autopilotserver.Autopilot ruleMgrs rules.ManagerSet - privMap firewalldb.NewPrivacyMapDB + privMap firewalldb.PrivacyMapper } // newSessionRPCServer creates a new sessionRpcServer using the passed config. @@ -628,7 +628,7 @@ func (s *sessionRpcServer) PrivacyMapConversion(ctx context.Context, } var res string - privMap := s.cfg.privMap(groupID) + privMap := s.cfg.privMap.PrivacyDB(groupID) err = privMap.View(ctx, func(ctx context.Context, tx firewalldb.PrivacyMapTx) error { @@ -900,7 +900,7 @@ func (s *sessionRpcServer) AddAutopilotSession(ctx context.Context, linkedGroupID = &groupID linkedGroupSession = groupSess - privDB := s.cfg.privMap(groupID) + privDB := s.cfg.privMap.PrivacyDB(groupID) err = privDB.View(ctx, func(ctx context.Context, tx firewalldb.PrivacyMapTx) error { @@ -1225,7 +1225,7 @@ func (s *sessionRpcServer) AddAutopilotSession(ctx context.Context, } // Register all the privacy map pairs for this session ID. - privDB := s.cfg.privMap(sess.GroupID) + privDB := s.cfg.privMap.PrivacyDB(sess.GroupID) err = privDB.Update(ctx, func(ctx context.Context, tx firewalldb.PrivacyMapTx) error { @@ -1487,7 +1487,7 @@ func (s *sessionRpcServer) marshalRPCSession(ctx context.Context, } if sess.WithPrivacyMapper { - db := s.cfg.privMap( + db := s.cfg.privMap.PrivacyDB( sess.GroupID, ) val, err = val.PseudoToReal( diff --git a/terminal.go b/terminal.go index c48c5fcfc..9b35e0a66 100644 --- a/terminal.go +++ b/terminal.go @@ -534,7 +534,7 @@ func (g *LightningTerminal) start(ctx context.Context) error { actionsDB: g.stores.firewallBolt, autopilot: g.autopilotClient, ruleMgrs: g.ruleMgrs, - privMap: g.stores.firewallBolt.PrivacyDB, + privMap: g.stores.firewall, }) if err != nil { return fmt.Errorf("could not create new session rpc "+ @@ -1100,7 +1100,7 @@ func (g *LightningTerminal) startInternalSubServers(ctx context.Context, } privacyMapper := firewall.NewPrivacyMapper( - g.stores.firewallBolt.PrivacyDB, firewall.CryptoRandIntn, + g.stores.firewall, firewall.CryptoRandIntn, g.stores.sessions, ) @@ -1123,7 +1123,7 @@ func (g *LightningTerminal) startInternalSubServers(ctx context.Context, reqID, firewalldb.ActionStateError, reason, ) - }, g.stores.firewallBolt.PrivacyDB, + }, g.stores.firewall, ) mw = append(mw, ruleEnforcer)